2022-11-19 18:43:57 +00:00
# Copyright 2022 Lunar Ring. All rights reserved.
2023-01-11 11:58:59 +00:00
# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
2022-11-19 18:43:57 +00:00
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os , sys
dp_git = " /home/lugo/git/ "
sys . path . append ( ' util ' )
2022-12-02 12:17:13 +00:00
# sys.path.append('../stablediffusion/ldm')
2022-11-19 18:43:57 +00:00
import torch
torch . backends . cudnn . benchmark = False
import numpy as np
import warnings
warnings . filterwarnings ( ' ignore ' )
import time
import subprocess
import warnings
import torch
from tqdm . auto import tqdm
from PIL import Image
2023-01-08 09:32:58 +00:00
# import matplotlib.pyplot as plt
2022-11-19 18:43:57 +00:00
import torch
2022-11-23 12:43:33 +00:00
from movie_util import MovieSaver
2022-11-19 18:43:57 +00:00
import datetime
from typing import Callable , List , Optional , Union
import inspect
2022-11-23 12:43:33 +00:00
from threading import Thread
2022-11-19 18:43:57 +00:00
torch . set_grad_enabled ( False )
2022-11-25 14:34:41 +00:00
from omegaconf import OmegaConf
from torch import autocast
from contextlib import nullcontext
2022-12-02 12:17:13 +00:00
2022-11-28 04:07:01 +00:00
from ldm . models . diffusion . ddim import DDIMSampler
2022-12-02 12:17:13 +00:00
from ldm . util import instantiate_from_config
2023-01-08 09:32:58 +00:00
from ldm . models . diffusion . ddpm import LatentUpscaleDiffusion , LatentInpaintDiffusion
2022-11-28 04:07:01 +00:00
from stable_diffusion_holder import StableDiffusionHolder
2023-01-08 09:32:58 +00:00
import yaml
2023-02-15 17:21:00 +00:00
import lpips
2022-11-19 18:43:57 +00:00
#%%
class LatentBlending ( ) :
def __init__ (
self ,
2022-11-25 14:34:41 +00:00
sdh : None ,
2022-11-28 14:34:18 +00:00
guidance_scale : float = 4 ,
guidance_scale_mid_damper : float = 0.5 ,
2023-01-08 09:32:58 +00:00
mid_compression_scaler : float = 1.2 ,
2022-11-19 18:43:57 +00:00
) :
r """
Initializes the latent blending class .
Args :
guidance_scale : float
Guidance scale as defined in [ Classifier - Free Diffusion Guidance ] ( https : / / arxiv . org / abs / 2207.12598 ) .
` guidance_scale ` is defined as ` w ` of equation 2. of [ Imagen
Paper ] ( https : / / arxiv . org / pdf / 2205.11487 . pdf ) . Guidance scale is enabled by setting ` guidance_scale >
1 ` . Higher guidance scale encourages to generate images that are closely linked to the text ` prompt ` ,
usually at the expense of lower image quality .
2022-11-28 14:34:18 +00:00
guidance_scale_mid_damper : float = 0.5
Reduces the guidance scale towards the middle of the transition .
A value of 0.5 would decrease the guidance_scale towards the middle linearly by 0.5 .
mid_compression_scaler : float = 2.0
Increases the sampling density in the middle ( where most changes happen ) . Higher value
imply more values in the middle . However the inflection point can occur outside the middle ,
thus high values can give rough transitions . Values around 2 should be fine .
2022-11-19 18:43:57 +00:00
"""
2022-11-29 17:03:08 +00:00
assert guidance_scale_mid_damper > 0 and guidance_scale_mid_damper < = 1.0 , f " guidance_scale_mid_damper neees to be in interval (0,1], you provided { guidance_scale_mid_damper } "
2022-11-25 14:34:41 +00:00
self . sdh = sdh
self . device = self . sdh . device
self . width = self . sdh . width
self . height = self . sdh . height
2022-11-28 14:34:18 +00:00
self . guidance_scale_mid_damper = guidance_scale_mid_damper
self . mid_compression_scaler = mid_compression_scaler
2023-01-08 09:32:58 +00:00
self . seed1 = 0
self . seed2 = 0
2022-11-19 18:43:57 +00:00
# Initialize vars
self . prompt1 = " "
self . prompt2 = " "
2023-01-08 10:48:44 +00:00
self . negative_prompt = " "
2023-02-15 17:21:00 +00:00
self . tree_latents = [ None , None ]
2022-11-29 17:03:08 +00:00
self . tree_fracts = None
2023-02-15 17:21:00 +00:00
self . idx_injection = [ ]
2022-11-29 17:03:08 +00:00
self . tree_status = None
2022-11-19 18:43:57 +00:00
self . tree_final_imgs = [ ]
2023-02-15 17:21:00 +00:00
2022-11-19 18:43:57 +00:00
self . list_nmb_branches_prev = [ ]
self . list_injection_idx_prev = [ ]
self . text_embedding1 = None
self . text_embedding2 = None
2023-01-08 09:32:58 +00:00
self . image1_lowres = None
self . image2_lowres = None
2022-11-23 12:43:33 +00:00
self . stop_diffusion = False
2022-11-23 13:59:09 +00:00
self . negative_prompt = None
2023-01-08 09:32:58 +00:00
self . num_inference_steps = self . sdh . num_inference_steps
self . noise_level_upscaling = 20
2022-11-28 11:41:15 +00:00
self . list_injection_idx = None
self . list_nmb_branches = None
2023-02-16 10:48:45 +00:00
# Mixing parameters
2023-01-10 12:53:29 +00:00
self . branch1_influence = 0.0
2023-02-16 10:48:45 +00:00
self . branch1_max_depth_influence = 0.65
self . branch1_influence_decay = 0.8
self . parental_influence = 0.0
self . parental_max_depth_influence = 1.0
self . parental_influence_decay = 1.0
2023-01-11 10:36:44 +00:00
self . branch1_insertion_completed = False
2022-11-28 11:43:33 +00:00
self . set_guidance_scale ( guidance_scale )
2022-11-25 14:34:41 +00:00
self . init_mode ( )
2023-01-19 10:00:35 +00:00
self . multi_transition_img_first = None
self . multi_transition_img_last = None
2023-02-15 17:21:00 +00:00
self . dt_per_diff = 0
self . lpips = lpips . LPIPS ( net = ' alex ' ) . cuda ( self . device )
2022-11-19 18:43:57 +00:00
2022-11-25 14:34:41 +00:00
2023-01-08 09:32:58 +00:00
def init_mode ( self ) :
2022-11-19 18:43:57 +00:00
r """
2023-01-08 09:32:58 +00:00
Sets the operational mode . Currently supported are standard , inpainting and x4 upscaling .
2022-11-19 18:43:57 +00:00
"""
2023-01-08 09:32:58 +00:00
if isinstance ( self . sdh . model , LatentUpscaleDiffusion ) :
self . mode = ' upscale '
elif isinstance ( self . sdh . model , LatentInpaintDiffusion ) :
2022-11-25 14:34:41 +00:00
self . sdh . image_source = None
self . sdh . mask_image = None
2022-11-19 18:43:57 +00:00
self . mode = ' inpaint '
else :
2022-11-21 09:49:33 +00:00
self . mode = ' standard '
2022-11-19 18:43:57 +00:00
2022-11-28 11:43:33 +00:00
def set_guidance_scale ( self , guidance_scale ) :
r """
sets the guidance scale .
"""
2022-11-28 14:34:18 +00:00
self . guidance_scale_base = guidance_scale
2022-11-28 11:43:33 +00:00
self . guidance_scale = guidance_scale
self . sdh . guidance_scale = guidance_scale
2022-11-28 14:34:18 +00:00
2023-01-08 10:48:44 +00:00
def set_negative_prompt ( self , negative_prompt ) :
r """ Set the negative prompt. Currenty only one negative prompt is supported
"""
self . negative_prompt = negative_prompt
self . sdh . set_negative_prompt ( negative_prompt )
2022-11-28 14:34:18 +00:00
def set_guidance_mid_dampening ( self , fract_mixing ) :
r """
Tunes the guidance scale down as a linear function of fract_mixing ,
towards 0.5 the minimum will be reached .
"""
mid_factor = 1 - np . abs ( fract_mixing - 0.5 ) / 0.5
2023-01-08 10:51:44 +00:00
max_guidance_reduction = self . guidance_scale_base * ( 1 - self . guidance_scale_mid_damper ) - 1
2022-11-28 14:34:18 +00:00
guidance_scale_effective = self . guidance_scale_base - max_guidance_reduction * mid_factor
self . guidance_scale = guidance_scale_effective
self . sdh . guidance_scale = guidance_scale_effective
2022-11-19 18:43:57 +00:00
def set_prompt1 ( self , prompt : str ) :
r """
Sets the first prompt ( for the first keyframe ) including text embeddings .
Args :
prompt : str
ABC trending on artstation painted by Greg Rutkowski
"""
prompt = prompt . replace ( " _ " , " " )
self . prompt1 = prompt
self . text_embedding1 = self . get_text_embeddings ( self . prompt1 )
def set_prompt2 ( self , prompt : str ) :
r """
Sets the second prompt ( for the second keyframe ) including text embeddings .
Args :
prompt : str
XYZ trending on artstation painted by Greg Rutkowski
"""
prompt = prompt . replace ( " _ " , " " )
self . prompt2 = prompt
self . text_embedding2 = self . get_text_embeddings ( self . prompt2 )
2023-01-08 09:32:58 +00:00
def set_image1 ( self , image : Image ) :
r """
Sets the first image ( keyframe ) , relevant for the upscaling model transitions .
Args :
image : Image
"""
self . image1_lowres = image
def set_image2 ( self , image : Image ) :
r """
Sets the second image ( keyframe ) , relevant for the upscaling model transitions .
Args :
image : Image
"""
self . image2_lowres = image
def load_branching_profile (
2022-11-19 18:43:57 +00:00
self ,
2022-11-28 11:41:15 +00:00
quality : str = ' medium ' ,
2023-01-08 09:32:58 +00:00
depth_strength : float = 0.65 ,
2023-01-09 08:59:14 +00:00
nmb_frames : int = 100 ,
2022-11-28 11:41:15 +00:00
nmb_mindist : int = 3 ,
2022-11-19 18:43:57 +00:00
) :
r """
2022-11-28 11:41:15 +00:00
Helper function to set up the branching structure automatically .
2022-11-19 18:43:57 +00:00
Args :
2022-11-28 11:41:15 +00:00
quality : str
Determines how many diffusion steps are being made + how many branches in total .
Tradeoff between quality and speed of computation .
Choose : lowest , low , medium , high , ultra
2023-01-08 09:32:58 +00:00
depth_strength : float = 0.65 ,
2022-11-28 11:41:15 +00:00
Determines how deep the first injection will happen .
Deeper injections will cause ( unwanted ) formation of new structures ,
more shallow values will go into alpha - blendy land .
nmb_frames : int = 360 ,
total number of frames
nmb_mindist : int = 3
minimum distance in terms of diffusion iteratinos between subsequent injections
"""
if quality == ' lowest ' :
num_inference_steps = 12
2023-02-18 06:56:30 +00:00
nmb_max_branches = 5
2022-11-28 11:41:15 +00:00
elif quality == ' low ' :
num_inference_steps = 15
2023-02-18 06:56:30 +00:00
nmb_max_branches = nmb_frames / / 16
2022-11-28 11:41:15 +00:00
elif quality == ' medium ' :
num_inference_steps = 30
2023-02-18 06:56:30 +00:00
nmb_max_branches = nmb_frames / / 8
2022-11-28 11:41:15 +00:00
elif quality == ' high ' :
num_inference_steps = 60
2023-02-18 06:56:30 +00:00
nmb_max_branches = nmb_frames / / 4
2022-11-28 11:41:15 +00:00
elif quality == ' ultra ' :
num_inference_steps = 100
2023-02-18 06:56:30 +00:00
nmb_max_branches = nmb_frames / / 2
2023-01-08 09:32:58 +00:00
elif quality == ' upscaling_step1 ' :
num_inference_steps = 40
2023-02-18 06:56:30 +00:00
nmb_max_branches = 12
2023-01-08 09:32:58 +00:00
elif quality == ' upscaling_step2 ' :
num_inference_steps = 100
2023-02-18 06:56:30 +00:00
nmb_max_branches = 6
2022-11-28 11:41:15 +00:00
else :
2023-01-08 09:32:58 +00:00
raise ValueError ( f " quality = ' { quality } ' not supported " )
2023-02-18 06:56:30 +00:00
self . autosetup_branching ( depth_strength , num_inference_steps , nmb_max_branches )
2023-01-08 09:32:58 +00:00
def autosetup_branching (
self ,
depth_strength : float = 0.65 ,
num_inference_steps : int = 30 ,
2023-02-18 06:56:30 +00:00
nmb_max_branches : int = 20 ,
2023-01-08 09:32:58 +00:00
nmb_mindist : int = 3 ,
) :
r """
Automatically sets up the branching schedule .
Args :
depth_strength : float = 0.65 ,
Determines how deep the first injection will happen .
Deeper injections will cause ( unwanted ) formation of new structures ,
more shallow values will go into alpha - blendy land .
num_inference_steps : int
2023-01-09 12:42:02 +00:00
Number of diffusion steps . Higher values will take more compute time .
2023-02-18 06:56:30 +00:00
nmb_max_branches ( int ) : The number of diffusion - generated images
2023-01-08 09:32:58 +00:00
at the end of the inference .
nmb_mindist ( int ) : The minimum number of diffusion steps
between two injections .
"""
2022-11-28 11:41:15 +00:00
2023-01-08 09:32:58 +00:00
idx_injection_first = int ( np . round ( num_inference_steps * depth_strength ) )
2023-01-08 09:48:45 +00:00
idx_injection_last = num_inference_steps - nmb_mindist
2022-11-28 11:41:15 +00:00
nmb_injections = int ( np . floor ( num_inference_steps / 5 ) ) - 1
list_injection_idx = [ 0 ]
list_injection_idx . extend ( np . linspace ( idx_injection_first , idx_injection_last , nmb_injections ) . astype ( int ) )
2023-02-18 06:56:30 +00:00
list_nmb_branches = np . round ( np . logspace ( np . log10 ( 2 ) , np . log10 ( nmb_max_branches ) , nmb_injections + 1 ) ) . astype ( int )
2022-11-28 11:41:15 +00:00
2023-01-08 09:48:45 +00:00
# Cleanup. There should be at least nmb_mindist diffusion steps between each injection and list_nmb_branches increases
2022-11-28 11:41:15 +00:00
list_nmb_branches_clean = [ list_nmb_branches [ 0 ] ]
2023-01-08 11:18:54 +00:00
list_injection_idx_clean = [ list_injection_idx [ 0 ] ]
for idx_injection , nmb_branches in zip ( list_injection_idx [ 1 : ] , list_nmb_branches [ 1 : ] ) :
if idx_injection - list_injection_idx_clean [ - 1 ] > = nmb_mindist and nmb_branches > list_nmb_branches_clean [ - 1 ] :
list_nmb_branches_clean . append ( nmb_branches )
list_injection_idx_clean . append ( idx_injection )
2023-02-18 06:56:30 +00:00
list_nmb_branches_clean [ - 1 ] = nmb_max_branches
2023-01-08 09:48:45 +00:00
2022-11-28 11:41:15 +00:00
list_injection_idx_clean = [ int ( l ) for l in list_injection_idx_clean ]
list_nmb_branches_clean = [ int ( l ) for l in list_nmb_branches_clean ]
list_injection_idx = list_injection_idx_clean
list_nmb_branches = list_nmb_branches_clean
2022-11-28 14:34:18 +00:00
list_nmb_branches = list_nmb_branches
list_injection_idx = list_injection_idx
2023-01-08 09:48:45 +00:00
print ( f " autosetup_branching: num_inference_steps: { num_inference_steps } list_nmb_branches: { list_nmb_branches } list_injection_idx: { list_injection_idx } " )
2022-11-28 14:34:18 +00:00
self . setup_branching ( num_inference_steps , list_nmb_branches = list_nmb_branches , list_injection_idx = list_injection_idx )
2022-11-28 11:41:15 +00:00
def setup_branching ( self ,
num_inference_steps : int = 30 ,
list_nmb_branches : List [ int ] = None ,
list_injection_strength : List [ float ] = None ,
list_injection_idx : List [ int ] = None ,
) :
r """
Sets the branching structure for making transitions .
num_inference_steps : int
Number of diffusion steps . Larger values will take more compute time .
2022-11-19 18:43:57 +00:00
list_nmb_branches : List [ int ] :
list of the number of branches for each injection .
list_injection_strength : List [ float ] :
list of injection strengths within interval [ 0 , 1 ) , values need to be increasing .
Alternatively you can direclty specify the list_injection_idx .
list_injection_idx : List [ int ] :
list of injection strengths within interval [ 0 , 1 ) , values need to be increasing .
Alternatively you can specify the list_injection_strength .
2022-11-28 11:41:15 +00:00
"""
# Assert
assert not ( ( list_injection_strength is not None ) and ( list_injection_idx is not None ) ) , " suppyl either list_injection_strength or list_injection_idx "
if list_injection_strength is None :
assert list_injection_idx is not None , " Supply either list_injection_idx or list_injection_strength "
assert isinstance ( list_injection_idx [ 0 ] , int ) or isinstance ( list_injection_idx [ 0 ] , np . int ) , " Need to supply integers for list_injection_idx "
if list_injection_idx is None :
assert list_injection_strength is not None , " Supply either list_injection_idx or list_injection_strength "
# Create the injection indexes
list_injection_idx = [ int ( round ( x * num_inference_steps ) ) for x in list_injection_strength ]
assert min ( np . diff ( list_injection_idx ) ) > 0 , ' Injection idx needs to be increasing '
if min ( np . diff ( list_injection_idx ) ) < 2 :
print ( " Warning: your injection spacing is very tight. consider increasing the distances " )
assert isinstance ( list_injection_strength [ 1 ] , np . floating ) or isinstance ( list_injection_strength [ 1 ] , float ) , " Need to supply floats for list_injection_strength "
# we are checking element 1 in list_injection_strength because "0" is an int... [0, 0.5]
assert max ( list_injection_idx ) < num_inference_steps , " Decrease the injection index or strength "
assert len ( list_injection_idx ) == len ( list_nmb_branches ) , " Need to have same length "
assert max ( list_injection_idx ) < num_inference_steps , " Injection index cannot happen after last diffusion step! Decrease list_injection_idx or list_injection_strength[-1] "
2022-11-29 17:03:08 +00:00
# Auto inits
list_injection_idx_ext = list_injection_idx [ : ]
list_injection_idx_ext . append ( num_inference_steps )
# If injection at depth 0 not specified, we will start out with 2 branches
if list_injection_idx_ext [ 0 ] != 0 :
list_injection_idx_ext . insert ( 0 , 0 )
list_nmb_branches . insert ( 0 , 2 )
assert list_nmb_branches [ 0 ] == 2 , " Need to start with 2 branches. set list_nmb_branches[0]=2 "
2022-11-28 11:41:15 +00:00
# Set attributes
self . num_inference_steps = num_inference_steps
self . sdh . num_inference_steps = num_inference_steps
self . list_nmb_branches = list_nmb_branches
self . list_injection_idx = list_injection_idx
2022-11-29 17:03:08 +00:00
self . list_injection_idx_ext = list_injection_idx_ext
2022-11-28 11:41:15 +00:00
2022-11-29 17:03:08 +00:00
self . init_tree_struct ( )
def init_tree_struct ( self ) :
r """
Initializes tree variables for holding latents etc .
"""
self . tree_latents = [ ]
self . tree_fracts = [ ]
self . tree_status = [ ]
self . tree_final_imgs_timing = [ 0 ] * self . list_nmb_branches [ - 1 ]
nmb_blocks_time = len ( self . list_injection_idx_ext ) - 1
for t_block in range ( nmb_blocks_time ) :
nmb_branches = self . list_nmb_branches [ t_block ]
list_fract_mixing_current = get_spacing ( nmb_branches , self . mid_compression_scaler )
self . tree_fracts . append ( list_fract_mixing_current )
self . tree_latents . append ( [ None ] * nmb_branches )
self . tree_status . append ( [ ' untouched ' ] * nmb_branches )
2022-11-28 11:41:15 +00:00
def run_transition (
2023-02-15 17:21:00 +00:00
self ,
recycle_img1 : Optional [ bool ] = False ,
recycle_img2 : Optional [ bool ] = False ,
num_inference_steps : Optional [ int ] = 30 ,
depth_strength : Optional [ float ] = 0.3 ,
2023-02-18 06:56:30 +00:00
t_compute_max_allowed : Optional [ float ] = None ,
nmb_max_branches : Optional [ int ] = None ,
2023-02-15 17:21:00 +00:00
fixed_seeds : Optional [ List [ int ] ] = None ,
) :
2023-02-18 06:56:30 +00:00
r """
Function for computing transitions .
Returns a list of transition images using spherical latent blending .
Args :
recycle_img1 : Optional [ bool ] :
Don ' t recompute the latents for the first keyframe (purely prompt1). Saves compute.
recycle_img2 : Optional [ bool ] :
Don ' t recompute the latents for the second keyframe (purely prompt2). Saves compute.
num_inference_steps :
Number of diffusion steps . Higher values will take more compute time .
depth_strength :
Determines how deep the first injection will happen .
Deeper injections will cause ( unwanted ) formation of new structures ,
more shallow values will go into alpha - blendy land .
t_compute_max_allowed :
Either provide t_compute_max_allowed or nmb_max_branches .
The maximum time allowed for computation . Higher values give better results but take longer .
nmb_max_branches : int
Either provide t_compute_max_allowed or nmb_max_branches . The maximum number of branches to be computed . Higher values give better
results . Use this if you want to have controllable results independent
of your computer .
fixed_seeds : Optional [ List [ int ) ] :
You can supply two seeds that are used for the first and second keyframe ( prompt1 and prompt2 ) .
Otherwise random seeds will be taken .
"""
2023-02-15 17:21:00 +00:00
# Sanity checks first
assert self . text_embedding1 is not None , ' Set the first text embedding with .set_prompt1(...) before '
assert self . text_embedding2 is not None , ' Set the second text embedding with .set_prompt2(...) before '
# Random seeds
if fixed_seeds is not None :
if fixed_seeds == ' randomize ' :
fixed_seeds = list ( np . random . randint ( 0 , 1000000 , 2 ) . astype ( np . int32 ) )
else :
assert len ( fixed_seeds ) == 2 , " Supply a list with len = 2 "
self . seed1 = fixed_seeds [ 0 ]
self . seed2 = fixed_seeds [ 1 ]
# Ensure correct num_inference_steps in holder
2023-02-18 06:56:30 +00:00
self . num_inference_steps = num_inference_steps
self . sdh . num_inference_steps = num_inference_steps
2023-02-15 17:21:00 +00:00
# Compute / Recycle first image
2023-02-16 10:48:45 +00:00
if not recycle_img1 or len ( self . tree_latents [ 0 ] ) != self . num_inference_steps :
2023-02-15 17:21:00 +00:00
list_latents1 = self . compute_latents1 ( )
else :
list_latents1 = self . tree_latents [ 0 ]
# Compute / Recycle first image
2023-02-16 10:48:45 +00:00
if not recycle_img2 or len ( self . tree_latents [ - 1 ] ) != self . num_inference_steps :
2023-02-15 17:21:00 +00:00
list_latents2 = self . compute_latents2 ( )
else :
list_latents2 = self . tree_latents [ - 1 ]
# Reset the tree, injecting the edge latents1/2 we just generated/recycled
self . tree_latents = [ list_latents1 , list_latents2 ]
self . tree_fracts = [ 0.0 , 1.0 ]
self . tree_final_imgs = [ self . sdh . latent2image ( ( self . tree_latents [ 0 ] [ - 1 ] ) ) , self . sdh . latent2image ( ( self . tree_latents [ - 1 ] [ - 1 ] ) ) ]
self . tree_idx_injection = [ 0 , 0 ]
# Set up branching scheme (dependent on provided compute time)
2023-02-18 06:56:30 +00:00
list_idx_injection , list_nmb_stems = self . get_time_based_branching ( depth_strength , t_compute_max_allowed , nmb_max_branches )
# Run iteratively, starting with the longest trajectory.
# Always inserting new branches where they are needed most according to image similarity
for s_idx in tqdm ( range ( len ( list_idx_injection ) ) ) :
nmb_stems = list_nmb_stems [ s_idx ]
idx_injection = list_idx_injection [ s_idx ]
for i in range ( nmb_stems ) :
fract_mixing , b_parent1 , b_parent2 = self . get_mixing_parameters ( idx_injection )
self . set_guidance_mid_dampening ( fract_mixing )
list_latents = self . compute_latents_mix ( fract_mixing , b_parent1 , b_parent2 , idx_injection )
self . insert_into_tree ( fract_mixing , idx_injection , list_latents )
# print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}")
return self . tree_final_imgs
def get_time_based_branching ( self , depth_strength , t_compute_max_allowed = None , nmb_max_branches = None ) :
r """
Sets up the branching scheme dependent on the time that is granted for compute .
The scheme uses an estimation derived from the first image ' s computation speed.
Either provide t_compute_max_allowed or nmb_max_branches
Args :
depth_strength :
Determines how deep the first injection will happen .
Deeper injections will cause ( unwanted ) formation of new structures ,
more shallow values will go into alpha - blendy land .
t_compute_max_allowed : float
The maximum time allowed for computation . Higher values give better results
but take longer . Use this if you want to fix your waiting time for the results .
nmb_max_branches : int
The maximum number of branches to be computed . Higher values give better
results . Use this if you want to have controllable results independent
of your computer .
"""
2023-02-15 17:21:00 +00:00
idx_injection_base = int ( round ( self . num_inference_steps * depth_strength ) )
list_idx_injection = np . arange ( idx_injection_base , self . num_inference_steps - 1 , 3 )
list_nmb_stems = np . ones ( len ( list_idx_injection ) , dtype = np . int32 )
t_compute = 0
2023-02-18 06:56:30 +00:00
if nmb_max_branches is None :
assert t_compute_max_allowed is not None , " Either specify t_compute_max_allowed or nmb_max_branches "
stop_criterion = " t_compute_max_allowed "
elif t_compute_max_allowed is None :
assert nmb_max_branches is not None , " Either specify t_compute_max_allowed or nmb_max_branches "
stop_criterion = " nmb_max_branches "
nmb_max_branches - = 2 # discounting the outer frames
else :
raise ValueError ( " Either specify t_compute_max_allowed or nmb_max_branches " )
stop_criterion_reached = False
is_first_iteration = True
while not stop_criterion_reached :
2023-02-15 17:21:00 +00:00
list_compute_steps = self . num_inference_steps - list_idx_injection
list_compute_steps * = list_nmb_stems
2023-02-16 10:48:45 +00:00
t_compute = np . sum ( list_compute_steps ) * self . dt_per_diff + 0.15 * np . sum ( list_nmb_stems )
2023-02-15 17:21:00 +00:00
increase_done = False
for s_idx in range ( len ( list_nmb_stems ) - 1 ) :
if list_nmb_stems [ s_idx + 1 ] / list_nmb_stems [ s_idx ] > = 2 :
list_nmb_stems [ s_idx ] + = 1
increase_done = True
break
if not increase_done :
list_nmb_stems [ - 1 ] + = 1
2023-02-18 06:56:30 +00:00
if stop_criterion == " t_compute_max_allowed " and t_compute > t_compute_max_allowed :
stop_criterion_reached = True
# FIXME: also undersample here... but how... maybe drop them iteratively?
elif stop_criterion == " nmb_max_branches " and np . sum ( list_nmb_stems ) > = nmb_max_branches :
stop_criterion_reached = True
if is_first_iteration :
# Need to undersample.
list_idx_injection = np . linspace ( list_idx_injection [ 0 ] , list_idx_injection [ - 1 ] , nmb_max_branches ) . astype ( np . int32 )
list_nmb_stems = np . ones ( len ( list_idx_injection ) , dtype = np . int32 )
else :
is_first_iteration = False
2023-02-15 17:21:00 +00:00
2023-02-18 06:56:30 +00:00
# print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}")
return list_idx_injection , list_nmb_stems
2023-02-15 17:21:00 +00:00
def get_mixing_parameters ( self , idx_injection ) :
2023-02-16 10:48:45 +00:00
r """
Computes which parental latents should be mixed together to achieve a smooth blend .
As metric , we are using lpips image similarity . The insertion takes place
where the metric is maximal .
Args :
idx_injection : int
the index in terms of diffusion steps , where the next insertion will start .
"""
2023-02-15 17:21:00 +00:00
# get_lpips_similarity
similarities = [ ]
for i in range ( len ( self . tree_final_imgs ) - 1 ) :
similarities . append ( self . get_lpips_similarity ( self . tree_final_imgs [ i ] , self . tree_final_imgs [ i + 1 ] ) )
b_closest1 = np . argmax ( similarities )
b_closest2 = b_closest1 + 1
fract_closest1 = self . tree_fracts [ b_closest1 ]
fract_closest2 = self . tree_fracts [ b_closest2 ]
# Ensure that the parents are indeed older!
b_parent1 = b_closest1
while True :
if self . tree_idx_injection [ b_parent1 ] < idx_injection :
break
else :
b_parent1 - = 1
b_parent2 = b_closest2
while True :
if self . tree_idx_injection [ b_parent2 ] < idx_injection :
break
else :
b_parent2 + = 1
# print(f"\n\nb_closest: {b_closest1} {b_closest2} fract_closest1 {fract_closest1} fract_closest2 {fract_closest2}")
# print(f"b_parent: {b_parent1} {b_parent2}")
# print(f"similarities {similarities}")
# print(f"idx_injection {idx_injection} tree_idx_injection {self.tree_idx_injection}")
fract_mixing = ( fract_closest1 + fract_closest2 ) / 2
return fract_mixing , b_parent1 , b_parent2
def insert_into_tree ( self , fract_mixing , idx_injection , list_latents ) :
2023-02-16 10:48:45 +00:00
r """
Inserts all necessary parameters into the trajectory tree .
Args :
fract_mixing : float
the fraction along the transition axis [ 0 , 1 ]
idx_injection : int
the index in terms of diffusion steps , where the next insertion will start .
list_latents : list
list of the latents to be inserted
"""
2023-02-15 17:21:00 +00:00
b_parent1 , b_parent2 = get_closest_idx ( fract_mixing , self . tree_fracts )
self . tree_latents . insert ( b_parent1 + 1 , list_latents )
self . tree_final_imgs . insert ( b_parent1 + 1 , self . sdh . latent2image ( list_latents [ - 1 ] ) )
self . tree_fracts . insert ( b_parent1 + 1 , fract_mixing )
self . tree_idx_injection . insert ( b_parent1 + 1 , idx_injection )
def compute_latents_mix ( self , fract_mixing , b_parent1 , b_parent2 , idx_injection ) :
2023-02-16 10:48:45 +00:00
r """
Runs a diffusion trajectory , using the latents from the respective parents
Args :
fract_mixing : float
the fraction along the transition axis [ 0 , 1 ]
b_parent1 : int
index of parent1 to be used
b_parent2 : int
index of parent2 to be used
idx_injection : int
the index in terms of diffusion steps , where the next insertion will start .
"""
2023-02-15 17:21:00 +00:00
list_conditionings = self . get_mixed_conditioning ( fract_mixing )
fract_mixing_parental = ( fract_mixing - self . tree_fracts [ b_parent1 ] ) / ( self . tree_fracts [ b_parent2 ] - self . tree_fracts [ b_parent1 ] )
2023-02-16 10:48:45 +00:00
# idx_reversed = self.num_inference_steps - idx_injection
list_latents_parental_mix = [ ]
for i in range ( self . num_inference_steps ) :
latents_p1 = self . tree_latents [ b_parent1 ] [ i ]
latents_p2 = self . tree_latents [ b_parent2 ] [ i ]
if latents_p1 is None or latents_p2 is None :
latents_parental = None
else :
latents_parental = interpolate_spherical ( latents_p1 , latents_p2 , fract_mixing_parental )
list_latents_parental_mix . append ( latents_parental )
idx_mixing_stop = int ( round ( self . num_inference_steps * self . parental_max_depth_influence ) )
mixing_coeffs = idx_injection * [ self . parental_influence ]
nmb_mixing = idx_mixing_stop - idx_injection
if nmb_mixing > 0 :
mixing_coeffs . extend ( list ( np . linspace ( self . parental_influence , self . parental_influence * self . parental_influence_decay , nmb_mixing ) ) )
mixing_coeffs . extend ( ( self . num_inference_steps - len ( mixing_coeffs ) ) * [ 0 ] )
latents_start = list_latents_parental_mix [ idx_injection - 1 ]
list_latents = self . run_diffusion (
list_conditionings ,
latents_start = latents_start ,
idx_start = idx_injection ,
list_latents_mixing = list_latents_parental_mix ,
mixing_coeffs = mixing_coeffs
)
2023-02-15 17:21:00 +00:00
return list_latents
def compute_latents1 ( self , return_image = False ) :
2023-02-16 10:48:45 +00:00
r """
Runs a diffusion trajectory for the first image
Args :
return_image : bool
whether to return an image or the list of latents
"""
2023-02-15 17:21:00 +00:00
print ( " starting compute_latents1 " )
2023-02-18 06:56:30 +00:00
list_conditionings = self . get_mixed_conditioning ( 0 )
2023-02-15 17:21:00 +00:00
t0 = time . time ( )
2023-02-16 10:48:45 +00:00
latents_start = self . get_noise ( self . seed1 )
list_latents1 = self . run_diffusion (
list_conditionings ,
latents_start = latents_start ,
idx_start = 0
)
2023-02-15 17:21:00 +00:00
t1 = time . time ( )
self . dt_per_diff = ( t1 - t0 ) / self . num_inference_steps
self . tree_latents [ 0 ] = list_latents1
if return_image :
return self . sdh . latent2image ( list_latents1 [ - 1 ] )
else :
return list_latents1
def compute_latents2 ( self , return_image = False ) :
2023-02-16 10:48:45 +00:00
r """
Runs a diffusion trajectory for the last image , which may be affected by the first image ' s trajectory.
Args :
return_image : bool
whether to return an image or the list of latents
"""
2023-02-18 06:56:30 +00:00
print ( " starting compute_latents2 " )
list_conditionings = self . get_mixed_conditioning ( 1 )
2023-02-16 10:48:45 +00:00
latents_start = self . get_noise ( self . seed2 )
2023-02-15 17:21:00 +00:00
# Influence from branch1
if self . branch1_influence > 0.0 :
2023-02-16 10:48:45 +00:00
# Set up the mixing_coeffs
idx_mixing_stop = int ( round ( self . num_inference_steps * self . branch1_max_depth_influence ) )
mixing_coeffs = list ( np . linspace ( self . branch1_influence , self . branch1_influence * self . branch1_influence_decay , idx_mixing_stop ) )
mixing_coeffs . extend ( ( self . num_inference_steps - idx_mixing_stop ) * [ 0 ] )
list_latents_mixing = self . tree_latents [ 0 ]
2023-02-15 17:21:00 +00:00
list_latents2 = self . run_diffusion (
list_conditionings ,
2023-02-16 10:48:45 +00:00
latents_start = latents_start ,
idx_start = 0 ,
list_latents_mixing = list_latents_mixing ,
mixing_coeffs = mixing_coeffs
)
2023-02-15 17:21:00 +00:00
else :
2023-02-16 10:48:45 +00:00
list_latents2 = self . run_diffusion ( list_conditionings , latents_start )
2023-02-15 17:21:00 +00:00
self . tree_latents [ - 1 ] = list_latents2
if return_image :
return self . sdh . latent2image ( list_latents2 [ - 1 ] )
else :
return list_latents2
2023-02-16 10:48:45 +00:00
def get_noise ( self , seed ) :
2022-11-28 11:41:15 +00:00
r """
2023-02-18 06:56:30 +00:00
Helper function to get noise given seed .
2022-11-28 11:41:15 +00:00
Args :
2023-02-18 06:56:30 +00:00
seed : int
2022-11-19 18:43:57 +00:00
"""
2023-02-18 06:56:30 +00:00
generator = torch . Generator ( device = self . sdh . device ) . manual_seed ( int ( seed ) )
if self . mode == ' standard ' :
shape_latents = [ self . sdh . C , self . sdh . height / / self . sdh . f , self . sdh . width / / self . sdh . f ]
C , H , W = shape_latents
elif self . mode == ' upscale ' :
w = self . image1_lowres . size [ 0 ]
h = self . image1_lowres . size [ 1 ]
shape_latents = [ self . sdh . model . channels , h , w ]
C , H , W = shape_latents
2023-01-11 10:36:44 +00:00
2023-02-18 06:56:30 +00:00
return torch . randn ( ( 1 , C , H , W ) , generator = generator , device = self . sdh . device )
2022-11-19 18:43:57 +00:00
2022-11-23 13:59:09 +00:00
2022-11-19 18:43:57 +00:00
@torch.no_grad ( )
def run_diffusion (
self ,
2023-01-08 09:32:58 +00:00
list_conditionings ,
2023-02-16 10:48:45 +00:00
latents_start : torch . FloatTensor = None ,
idx_start : int = 0 ,
list_latents_mixing = None ,
mixing_coeffs = 0.0 ,
2022-11-19 18:43:57 +00:00
return_image : Optional [ bool ] = False
) :
2023-02-16 10:48:45 +00:00
2022-11-19 18:43:57 +00:00
r """
2023-02-16 10:48:45 +00:00
Wrapper function for diffusion runners .
2022-11-19 18:43:57 +00:00
Depending on the mode , the correct one will be executed .
Args :
2023-01-08 09:32:58 +00:00
list_conditionings : List of all conditionings for the diffusion model .
2023-02-16 10:48:45 +00:00
latents_start : torch . FloatTensor
2022-11-19 18:43:57 +00:00
Latents that are used for injection
idx_start : int
Index of the diffusion process start and where the latents_for_injection are injected
2023-02-16 10:48:45 +00:00
list_latents_mixing : torch . FloatTensor
List of latents ( latent trajectories ) that are used for mixing
mixing_coeffs : float or list
Coefficients , how strong each element of list_latents_mixing will be mixed in .
2022-11-19 18:43:57 +00:00
return_image : Optional [ bool ]
Optionally return image directly
"""
2022-11-25 14:34:41 +00:00
# Ensure correct num_inference_steps in Holder
self . sdh . num_inference_steps = self . num_inference_steps
2023-01-08 09:32:58 +00:00
assert type ( list_conditionings ) is list , " list_conditionings need to be a list "
2022-11-19 18:43:57 +00:00
2022-11-21 09:49:33 +00:00
if self . mode == ' standard ' :
2023-01-08 09:32:58 +00:00
text_embeddings = list_conditionings [ 0 ]
2023-02-15 17:21:00 +00:00
return self . sdh . run_diffusion_standard (
2023-02-16 10:48:45 +00:00
text_embeddings = text_embeddings ,
latents_start = latents_start ,
idx_start = idx_start ,
list_latents_mixing = list_latents_mixing ,
mixing_coeffs = mixing_coeffs ,
return_image = return_image ,
2023-02-15 17:21:00 +00:00
)
2022-11-19 18:43:57 +00:00
2023-02-18 06:56:30 +00:00
elif self . mode == ' upscale ' :
cond = list_conditionings [ 0 ]
uc_full = list_conditionings [ 1 ]
return self . sdh . run_diffusion_upscaling (
cond ,
uc_full ,
latents_start = latents_start ,
idx_start = idx_start ,
list_latents_mixing = list_latents_mixing ,
mixing_coeffs = mixing_coeffs ,
return_image = return_image )
2023-02-16 10:48:45 +00:00
# elif self.mode == 'inpaint':
# text_embeddings = list_conditionings[0]
# assert self.sdh.image_source is not None, "image_source is None. Please run init_inpainting first."
# assert self.sdh.mask_image is not None, "image_source is None. Please run init_inpainting first."
# return self.sdh.run_diffusion_inpaint(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
2023-01-08 09:32:58 +00:00
2023-02-18 06:56:30 +00:00
# FIXME. new transition engine
2023-01-08 09:32:58 +00:00
def run_upscaling_step1 (
self ,
dp_img : str ,
depth_strength : float = 0.65 ,
2023-01-09 09:59:00 +00:00
num_inference_steps : int = 30 ,
2023-02-18 06:56:30 +00:00
nmb_max_branches : int = 10 ,
2023-01-08 09:32:58 +00:00
fixed_seeds : Optional [ List [ int ] ] = None ,
) :
r """
Initializes inpainting with a source and maks image .
Args :
dp_img :
Path to directory where the low - res images and yaml will be saved to .
This directory cannot exist and will be created here .
2023-02-18 06:56:30 +00:00
FIXME
2023-01-08 09:32:58 +00:00
quality : str
Determines how many diffusion steps are being made + how many branches in total .
We suggest to leave it with upscaling_step1 which has 10 final branches .
depth_strength : float = 0.65 ,
Determines how deep the first injection will happen .
Deeper injections will cause ( unwanted ) formation of new structures ,
more shallow values will go into alpha - blendy land .
fixed_seeds : Optional [ List [ int ) ] :
You can supply two seeds that are used for the first and second keyframe ( prompt1 and prompt2 ) .
Otherwise random seeds will be taken .
"""
assert self . text_embedding1 is not None , ' run set_prompt1(yourprompt1) first '
assert self . text_embedding2 is not None , ' run set_prompt2(yourprompt2) first '
assert not os . path . isdir ( dp_img ) , f " directory already exists: { dp_img } "
if fixed_seeds is None :
fixed_seeds = list ( np . random . randint ( 0 , 1000000 , 2 ) . astype ( np . int32 ) )
# Run latent blending
imgs_transition = self . run_transition ( fixed_seeds = fixed_seeds )
self . write_imgs_transition ( dp_img , imgs_transition )
print ( f " run_upscaling_step1: completed! { dp_img } " )
def run_upscaling_step2 (
self ,
dp_img : str ,
2023-01-09 09:59:00 +00:00
depth_strength : float = 0.65 ,
2023-02-18 06:56:30 +00:00
num_inference_steps : int = 100 ,
nmb_max_branches_highres : int = 5 ,
nmb_max_branches_lowres : int = 6 ,
2023-01-08 09:32:58 +00:00
fixed_seeds : Optional [ List [ int ] ] = None ,
) :
fp_yml = os . path . join ( dp_img , " lowres.yaml " )
2023-02-18 06:56:30 +00:00
fp_movie = os . path . join ( dp_img , " movie_highres.mp4 " )
2023-01-08 09:32:58 +00:00
fps = 24
ms = MovieSaver ( fp_movie , fps = fps )
assert os . path . isfile ( fp_yml ) , " lowres.yaml does not exist. did you forget run_upscaling_step1? "
dict_stuff = yml_load ( fp_yml )
# load lowres images
nmb_images_lowres = dict_stuff [ ' nmb_images ' ]
prompt1 = dict_stuff [ ' prompt1 ' ]
prompt2 = dict_stuff [ ' prompt2 ' ]
2023-02-18 06:56:30 +00:00
idx_img_lowres = np . round ( np . linspace ( 0 , nmb_images_lowres - 1 , nmb_max_branches_lowres ) ) . astype ( np . int32 )
2023-01-08 09:32:58 +00:00
imgs_lowres = [ ]
2023-02-18 06:56:30 +00:00
for i in idx_img_lowres :
2023-01-08 09:32:58 +00:00
fp_img_lowres = os . path . join ( dp_img , f " lowres_img_ { str ( i ) . zfill ( 4 ) } .jpg " )
assert os . path . isfile ( fp_img_lowres ) , f " { fp_img_lowres } does not exist. did you forget run_upscaling_step1? "
imgs_lowres . append ( Image . open ( fp_img_lowres ) )
# set up upscaling
text_embeddingA = self . sdh . get_text_embedding ( prompt1 )
text_embeddingB = self . sdh . get_text_embedding ( prompt2 )
2023-02-18 06:56:30 +00:00
#FIXME: have a total length for the whole video section
2023-01-08 09:32:58 +00:00
duration_single_trans = 3
2023-02-18 06:56:30 +00:00
list_fract_mixing = np . linspace ( 0 , 1 , nmb_max_branches_lowres - 1 )
2023-01-08 09:32:58 +00:00
2023-02-18 06:56:30 +00:00
for i in range ( nmb_max_branches_lowres - 1 ) :
print ( f " Starting movie segment { i + 1 } / { nmb_max_branches_lowres - 1 } " )
2023-01-08 09:32:58 +00:00
self . text_embedding1 = interpolate_linear ( text_embeddingA , text_embeddingB , list_fract_mixing [ i ] )
self . text_embedding2 = interpolate_linear ( text_embeddingA , text_embeddingB , 1 - list_fract_mixing [ i ] )
if i == 0 :
recycle_img1 = False
else :
self . swap_forward ( )
recycle_img1 = True
self . set_image1 ( imgs_lowres [ i ] )
self . set_image2 ( imgs_lowres [ i + 1 ] )
2023-02-18 06:56:30 +00:00
list_imgs = self . run_transition (
recycle_img1 = recycle_img1 ,
recycle_img2 = False ,
num_inference_steps = num_inference_steps ,
depth_strength = depth_strength ,
nmb_max_branches = nmb_max_branches_highres ,
)
2023-01-08 09:32:58 +00:00
list_imgs_interp = add_frames_linear_interp ( list_imgs , fps , duration_single_trans )
# Save movie frame
for img in list_imgs_interp :
ms . write_frame ( img )
ms . finalize ( )
2022-11-25 14:34:41 +00:00
def init_inpainting (
2022-11-19 18:43:57 +00:00
self ,
2022-11-25 14:34:41 +00:00
image_source : Union [ Image . Image , np . ndarray ] = None ,
mask_image : Union [ Image . Image , np . ndarray ] = None ,
init_empty : Optional [ bool ] = False ,
2022-11-19 18:43:57 +00:00
) :
r """
2022-11-25 14:34:41 +00:00
Initializes inpainting with a source and maks image .
2022-11-19 18:43:57 +00:00
Args :
2022-11-25 14:34:41 +00:00
image_source : Union [ Image . Image , np . ndarray ]
Source image onto which the mask will be applied .
mask_image : Union [ Image . Image , np . ndarray ]
Mask image , value = 0 will stay untouched , value = 255 subjet to diffusion
init_empty : Optional [ bool ] :
Initialize inpainting with an empty image and mask , effectively disabling inpainting ,
useful for generating a first image for transitions using diffusion .
2022-11-19 18:43:57 +00:00
"""
2023-01-08 09:32:58 +00:00
self . init_mode ( )
2022-11-25 14:34:41 +00:00
self . sdh . init_inpainting ( image_source , mask_image , init_empty )
2022-11-19 18:43:57 +00:00
2023-01-08 09:32:58 +00:00
@torch.no_grad ( )
def get_mixed_conditioning ( self , fract_mixing ) :
if self . mode == ' standard ' :
text_embeddings_mix = interpolate_linear ( self . text_embedding1 , self . text_embedding2 , fract_mixing )
list_conditionings = [ text_embeddings_mix ]
elif self . mode == ' inpaint ' :
text_embeddings_mix = interpolate_linear ( self . text_embedding1 , self . text_embedding2 , fract_mixing )
list_conditionings = [ text_embeddings_mix ]
elif self . mode == ' upscale ' :
text_embeddings_mix = interpolate_linear ( self . text_embedding1 , self . text_embedding2 , fract_mixing )
cond , uc_full = self . sdh . get_cond_upscaling ( self . image1_lowres , text_embeddings_mix , self . noise_level_upscaling )
condB , uc_fullB = self . sdh . get_cond_upscaling ( self . image2_lowres , text_embeddings_mix , self . noise_level_upscaling )
cond [ ' c_concat ' ] [ 0 ] = interpolate_spherical ( cond [ ' c_concat ' ] [ 0 ] , condB [ ' c_concat ' ] [ 0 ] , fract_mixing )
uc_full [ ' c_concat ' ] [ 0 ] = interpolate_spherical ( uc_full [ ' c_concat ' ] [ 0 ] , uc_fullB [ ' c_concat ' ] [ 0 ] , fract_mixing )
list_conditionings = [ cond , uc_full ]
else :
raise ValueError ( f " mix_conditioning: unknown mode { self . mode } " )
return list_conditionings
2022-11-19 18:43:57 +00:00
@torch.no_grad ( )
def get_text_embeddings (
self ,
prompt : str
) :
r """
Computes the text embeddings provided a string with a prompts .
2022-11-25 14:34:41 +00:00
Adapted from stable diffusion repo
2022-11-19 18:43:57 +00:00
Args :
prompt : str
ABC trending on artstation painted by Old Greg .
2022-11-23 13:36:33 +00:00
"""
2022-11-19 18:43:57 +00:00
2022-11-25 14:34:41 +00:00
return self . sdh . get_text_embedding ( prompt )
2022-11-19 18:43:57 +00:00
2023-02-18 06:56:30 +00:00
def write_imgs_transition ( self , dp_img ) :
2023-01-08 09:32:58 +00:00
r """
Writes the transition images into the folder dp_img .
2023-02-18 07:19:40 +00:00
Requires run_transition to be completed .
Args :
dp_img : str
Directory , into which the transition images , yaml file and latents are written .
2023-01-08 09:32:58 +00:00
"""
2023-02-18 06:56:30 +00:00
imgs_transition = self . tree_final_imgs
os . makedirs ( dp_img , exist_ok = True )
2023-01-08 09:32:58 +00:00
for i , img in enumerate ( imgs_transition ) :
img_leaf = Image . fromarray ( img )
img_leaf . save ( os . path . join ( dp_img , f " lowres_img_ { str ( i ) . zfill ( 4 ) } .jpg " ) )
2023-01-15 15:54:28 +00:00
fp_yml = os . path . join ( dp_img , " lowres.yaml " )
self . save_statedict ( fp_yml )
2023-02-18 07:19:40 +00:00
def write_movie_transition ( self , fp_movie , duration_transition , fps = 30 ) :
r """
Writes the transition movie to fp_movie , using the given duration and fps . .
The missing frames are linearly interpolated .
Args :
fp_movie : str
file pointer to the final movie .
duration_transition : float
duration of the movie in seonds
fps : int
fps of the movie
"""
# Let's get more cheap frames via linear interpolation (duration_transition*fps frames)
imgs_transition_ext = add_frames_linear_interp ( self . tree_final_imgs , duration_transition , fps )
# Save as MP4
if os . path . isfile ( fp_movie ) :
os . remove ( fp_movie )
ms = MovieSaver ( fp_movie , fps = fps , shape_hw = [ self . sdh . height , self . sdh . width ] )
for img in tqdm ( imgs_transition_ext ) :
ms . write_frame ( img )
ms . finalize ( )
2023-01-15 15:54:28 +00:00
def save_statedict ( self , fp_yml ) :
2023-01-08 09:32:58 +00:00
# Dump everything relevant into yaml
2023-02-18 06:56:30 +00:00
imgs_transition = self . tree_final_imgs
2023-01-11 10:36:44 +00:00
state_dict = self . get_state_dict ( )
state_dict [ ' nmb_images ' ] = len ( imgs_transition )
2023-01-15 15:54:28 +00:00
yml_save ( fp_yml , state_dict )
2023-01-11 10:36:44 +00:00
def get_state_dict ( self ) :
state_dict = { }
grab_vars = [ ' prompt1 ' , ' prompt2 ' , ' seed1 ' , ' seed2 ' , ' height ' , ' width ' ,
' num_inference_steps ' , ' depth_strength ' , ' guidance_scale ' ,
2023-02-18 06:56:30 +00:00
' guidance_scale_mid_damper ' , ' mid_compression_scaler ' , ' negative_prompt ' ,
' branch1_influence ' , ' branch1_max_depth_influence ' , ' branch1_influence_decay '
' parental_influence ' , ' parental_max_depth_influence ' , ' parental_influence_decay ' ]
2023-01-11 10:36:44 +00:00
for v in grab_vars :
if hasattr ( self , v ) :
if v == ' seed1 ' or v == ' seed2 ' :
state_dict [ v ] = int ( getattr ( self , v ) )
elif v == ' guidance_scale ' :
state_dict [ v ] = float ( getattr ( self , v ) )
else :
2023-02-18 06:56:30 +00:00
try :
state_dict [ v ] = getattr ( self , v )
except Exception as e :
pass
2023-01-11 10:36:44 +00:00
return state_dict
2023-01-08 09:32:58 +00:00
2022-11-19 18:43:57 +00:00
def randomize_seed ( self ) :
r """
Set a random seed for a fresh start .
"""
seed = np . random . randint ( 999999999 )
self . set_seed ( seed )
def set_seed ( self , seed : int ) :
r """
Set a the seed for a fresh start .
"""
self . seed = seed
2022-11-28 11:41:15 +00:00
self . sdh . seed = seed
2022-11-19 18:43:57 +00:00
2023-01-09 08:59:14 +00:00
def set_width ( self , width ) :
r """
Set the width of the resulting image .
"""
assert np . mod ( width , 64 ) == 0 , " set_width: value needs to be divisible by 64 "
self . width = width
self . sdh . width = width
def set_height ( self , height ) :
r """
Set the height of the resulting image .
"""
assert np . mod ( height , 64 ) == 0 , " set_height: value needs to be divisible by 64 "
self . height = height
self . sdh . height = height
2022-11-29 17:03:08 +00:00
def inject_latents ( self , list_latents , inject_img1 = True , inject_img2 = False ) :
r """
Injects list of latents into tree structure .
"""
assert inject_img1 != inject_img2 , " Either inject into img1 or img2 "
assert self . tree_latents is not None , " You need to setup the branching beforehand, run autosetup_branching() or setup_branching() before "
for t_block in range ( len ( self . list_injection_idx ) ) :
if inject_img1 :
self . tree_latents [ t_block ] [ 0 ] = list_latents [ self . list_injection_idx_ext [ t_block ] : self . list_injection_idx_ext [ t_block + 1 ] ]
if inject_img2 :
self . tree_latents [ t_block ] [ - 1 ] = list_latents [ self . list_injection_idx_ext [ t_block ] : self . list_injection_idx_ext [ t_block + 1 ] ]
2022-11-19 18:43:57 +00:00
def swap_forward ( self ) :
r """
2022-11-28 04:07:01 +00:00
Moves over keyframe two - > keyframe one . Useful for making a sequence of transitions
as in run_multi_transition ( )
2022-11-19 18:43:57 +00:00
"""
# Move over all latents
2023-02-18 06:56:30 +00:00
self . tree_latents [ 0 ] = self . tree_latents [ - 1 ]
2022-11-19 18:43:57 +00:00
# Move over prompts and text embeddings
self . prompt1 = self . prompt2
self . text_embedding1 = self . text_embedding2
# Final cleanup for extra sanity
self . tree_final_imgs = [ ]
2022-11-28 04:07:01 +00:00
2023-02-15 17:21:00 +00:00
def get_lpips_similarity ( self , imgA , imgB ) :
2023-02-16 10:48:45 +00:00
r """
Computes the image similarity between two images imgA and imgB .
Used to determine the optimal point of insertion to create smooth transitions .
High values indicate low similarity .
"""
2023-02-15 17:21:00 +00:00
tensorA = torch . from_numpy ( imgA ) . float ( ) . cuda ( self . device )
tensorA = 2 * tensorA / 255.0 - 1
tensorA = tensorA . permute ( [ 2 , 0 , 1 ] ) . unsqueeze ( 0 )
tensorB = torch . from_numpy ( imgB ) . float ( ) . cuda ( self . device )
tensorB = 2 * tensorB / 255.0 - 1
tensorB = tensorB . permute ( [ 2 , 0 , 1 ] ) . unsqueeze ( 0 )
lploss = self . lpips ( tensorA , tensorB )
lploss = float ( lploss [ 0 ] [ 0 ] [ 0 ] [ 0 ] )
return lploss
2022-11-19 18:43:57 +00:00
# Auxiliary functions
def get_closest_idx (
fract_mixing : float ,
list_fract_mixing_prev : List [ float ] ,
) :
r """
Helper function to retrieve the parents for any given mixing .
Example : fract_mixing = 0.4 and list_fract_mixing_prev = [ 0 , 0.3 , 0.6 , 1.0 ]
Will return the two closest values from list_fract_mixing_prev , i . e . [ 1 , 2 ]
"""
pdist = fract_mixing - np . asarray ( list_fract_mixing_prev )
pdist_pos = pdist . copy ( )
pdist_pos [ pdist_pos < 0 ] = np . inf
b_parent1 = np . argmin ( pdist_pos )
pdist_neg = - pdist . copy ( )
pdist_neg [ pdist_neg < = 0 ] = np . inf
b_parent2 = np . argmin ( pdist_neg )
if b_parent1 > b_parent2 :
tmp = b_parent2
b_parent2 = b_parent1
b_parent1 = tmp
return b_parent1 , b_parent2
@torch.no_grad ( )
def interpolate_spherical ( p0 , p1 , fract_mixing : float ) :
r """
Helper function to correctly mix two random variables using spherical interpolation .
See https : / / en . wikipedia . org / wiki / Slerp
2022-11-25 14:34:41 +00:00
The function will always cast up to float64 for sake of extra 4.
2022-11-19 18:43:57 +00:00
Args :
p0 :
First tensor for interpolation
p1 :
Second tensor for interpolation
fract_mixing : float
Mixing coefficient of interval [ 0 , 1 ] .
0 will return in p0
1 will return in p1
0. x will return a mix between both preserving angular velocity .
"""
if p0 . dtype == torch . float16 :
recast_to = ' fp16 '
else :
recast_to = ' fp32 '
p0 = p0 . double ( )
p1 = p1 . double ( )
norm = torch . linalg . norm ( p0 ) * torch . linalg . norm ( p1 )
epsilon = 1e-7
dot = torch . sum ( p0 * p1 ) / norm
dot = dot . clamp ( - 1 + epsilon , 1 - epsilon )
theta_0 = torch . arccos ( dot )
sin_theta_0 = torch . sin ( theta_0 )
theta_t = theta_0 * fract_mixing
s0 = torch . sin ( theta_0 - theta_t ) / sin_theta_0
s1 = torch . sin ( theta_t ) / sin_theta_0
interp = p0 * s0 + p1 * s1
if recast_to == ' fp16 ' :
interp = interp . half ( )
elif recast_to == ' fp32 ' :
interp = interp . float ( )
return interp
def interpolate_linear ( p0 , p1 , fract_mixing ) :
r """
Helper function to mix two variables using standard linear interpolation .
Args :
p0 :
2022-11-23 13:04:20 +00:00
First tensor / np . ndarray for interpolation
2022-11-19 18:43:57 +00:00
p1 :
2022-11-23 13:04:20 +00:00
Second tensor / np . ndarray for interpolation
2022-11-19 18:43:57 +00:00
fract_mixing : float
Mixing coefficient of interval [ 0 , 1 ] .
0 will return in p0
1 will return in p1
0. x will return a linear mix between both .
"""
2022-11-23 13:04:20 +00:00
reconvert_uint8 = False
if type ( p0 ) is np . ndarray and p0 . dtype == ' uint8 ' :
reconvert_uint8 = True
p0 = p0 . astype ( np . float64 )
if type ( p1 ) is np . ndarray and p1 . dtype == ' uint8 ' :
reconvert_uint8 = True
p1 = p1 . astype ( np . float64 )
interp = ( 1 - fract_mixing ) * p0 + fract_mixing * p1
if reconvert_uint8 :
interp = np . clip ( interp , 0 , 255 ) . astype ( np . uint8 )
return interp
2022-11-19 18:43:57 +00:00
def add_frames_linear_interp (
list_imgs : List [ np . ndarray ] ,
fps_target : Union [ float , int ] = None ,
duration_target : Union [ float , int ] = None ,
nmb_frames_target : int = None ,
) :
r """
Helper function to cheaply increase the number of frames given a list of images ,
by virtue of standard linear interpolation .
The number of inserted frames will be automatically adjusted so that the total of number
of frames can be fixed precisely , using a random shuffling technique .
The function allows 1 : 1 comparisons between transitions as videos .
Args :
list_imgs : List [ np . ndarray )
List of images , between each image new frames will be inserted via linear interpolation .
fps_target :
OptionA : specify here the desired frames per second .
duration_target :
OptionA : specify here the desired duration of the transition in seconds .
nmb_frames_target :
OptionB : directly fix the total number of frames of the output .
"""
# Sanity
if nmb_frames_target is not None and fps_target is not None :
raise ValueError ( " You cannot specify both fps_target and nmb_frames_target " )
if fps_target is None :
assert nmb_frames_target is not None , " Either specify nmb_frames_target or nmb_frames_target "
if nmb_frames_target is None :
assert fps_target is not None , " Either specify duration_target and fps_target OR nmb_frames_target "
assert duration_target is not None , " Either specify duration_target and fps_target OR nmb_frames_target "
nmb_frames_target = fps_target * duration_target
# Get number of frames that are missing
nmb_frames_diff = len ( list_imgs ) - 1
nmb_frames_missing = nmb_frames_target - nmb_frames_diff - 1
if nmb_frames_missing < 1 :
return list_imgs
list_imgs_float = [ img . astype ( np . float32 ) for img in list_imgs ]
# Distribute missing frames, append nmb_frames_to_insert(i) frames for each frame
mean_nmb_frames_insert = nmb_frames_missing / nmb_frames_diff
constfact = np . floor ( mean_nmb_frames_insert )
remainder_x = 1 - ( mean_nmb_frames_insert - constfact )
nmb_iter = 0
while True :
nmb_frames_to_insert = np . random . rand ( nmb_frames_diff )
nmb_frames_to_insert [ nmb_frames_to_insert < = remainder_x ] = 0
nmb_frames_to_insert [ nmb_frames_to_insert > remainder_x ] = 1
nmb_frames_to_insert + = constfact
if np . sum ( nmb_frames_to_insert ) == nmb_frames_missing :
break
nmb_iter + = 1
if nmb_iter > 100000 :
print ( " add_frames_linear_interp: issue with inserting the right number of frames " )
break
nmb_frames_to_insert = nmb_frames_to_insert . astype ( np . int32 )
list_imgs_interp = [ ]
2022-11-25 14:34:41 +00:00
for i in range ( len ( list_imgs_float ) - 1 ) : #, desc="STAGE linear interp"):
2022-11-19 18:43:57 +00:00
img0 = list_imgs_float [ i ]
img1 = list_imgs_float [ i + 1 ]
list_imgs_interp . append ( img0 . astype ( np . uint8 ) )
list_fracts_linblend = np . linspace ( 0 , 1 , nmb_frames_to_insert [ i ] + 2 ) [ 1 : - 1 ]
for fract_linblend in list_fracts_linblend :
img_blend = interpolate_linear ( img0 , img1 , fract_linblend ) . astype ( np . uint8 )
list_imgs_interp . append ( img_blend . astype ( np . uint8 ) )
if i == len ( list_imgs_float ) - 2 :
list_imgs_interp . append ( img1 . astype ( np . uint8 ) )
return list_imgs_interp
2023-01-08 09:32:58 +00:00
def get_spacing ( nmb_points : int , scaling : float ) :
2022-11-28 14:34:18 +00:00
"""
Helper function for getting nonlinear spacing between 0 and 1 , symmetric around 0.5
Args :
nmb_points : int
Number of points between [ 0 , 1 ]
scaling : float
Higher values will return higher sampling density around 0.5
"""
if scaling < 1.7 :
return np . linspace ( 0 , 1 , nmb_points )
nmb_points_per_side = nmb_points / / 2 + 1
if np . mod ( nmb_points , 2 ) != 0 : # uneven case
left_side = np . abs ( np . linspace ( 1 , 0 , nmb_points_per_side ) * * scaling / 2 - 0.5 )
right_side = 1 - left_side [ : : - 1 ] [ 1 : ]
else :
left_side = np . abs ( np . linspace ( 1 , 0 , nmb_points_per_side ) * * scaling / 2 - 0.5 ) [ 0 : - 1 ]
right_side = 1 - left_side [ : : - 1 ]
all_fracts = np . hstack ( [ left_side , right_side ] )
return all_fracts
2022-11-19 18:43:57 +00:00
def get_time ( resolution = None ) :
"""
Helper function returning an nicely formatted time string , e . g . 221117_1620
"""
if resolution == None :
resolution = " second "
if resolution == " day " :
t = time . strftime ( ' % y % m %d ' , time . localtime ( ) )
elif resolution == " minute " :
t = time . strftime ( ' % y % m %d _ % H % M ' , time . localtime ( ) )
elif resolution == " second " :
t = time . strftime ( ' % y % m %d _ % H % M % S ' , time . localtime ( ) )
elif resolution == " millisecond " :
t = time . strftime ( ' % y % m %d _ % H % M % S ' , time . localtime ( ) )
t + = " _ "
t + = str ( " {:03d} " . format ( int ( int ( datetime . utcnow ( ) . strftime ( ' %f ' ) ) / 1000 ) ) )
else :
raise ValueError ( " bad resolution provided: %s " % resolution )
return t
2023-01-12 09:16:31 +00:00
def compare_dicts ( a , b ) :
"""
Compares two dictionaries a and b and returns a dictionary c , with all
keys , values that have shared keys in a and b but same values in a and b .
The values of a and b are stacked together in the output .
Example :
a = { } ; a [ ' bobo ' ] = 4
b = { } ; b [ ' bobo ' ] = 5
c = dict_compare ( a , b )
c = { " bobo " , [ 4 , 5 ] }
"""
c = { }
for key in a . keys ( ) :
if key in b . keys ( ) :
val_a = a [ key ]
val_b = b [ key ]
if val_a != val_b :
c [ key ] = [ val_a , val_b ]
return c
2022-11-25 14:34:41 +00:00
2023-01-08 09:32:58 +00:00
def yml_load ( fp_yml , print_fields = False ) :
"""
Helper function for loading yaml files
"""
with open ( fp_yml ) as f :
data = yaml . load ( f , Loader = yaml . loader . SafeLoader )
dict_data = dict ( data )
print ( " load: loaded {} " . format ( fp_yml ) )
return dict_data
2022-11-25 14:34:41 +00:00
2023-01-08 09:32:58 +00:00
def yml_save ( fp_yml , dict_stuff ) :
"""
Helper function for saving yaml files
"""
with open ( fp_yml , ' w ' ) as f :
data = yaml . dump ( dict_stuff , f , sort_keys = False , default_flow_style = False )
print ( " yml_save: saved {} " . format ( fp_yml ) )
2022-11-23 16:46:25 +00:00
2022-11-24 10:24:23 +00:00
2022-11-21 16:23:16 +00:00
#%% le main
2022-11-19 18:43:57 +00:00
if __name__ == " __main__ " :
2023-01-08 09:32:58 +00:00
# xxxx
2023-01-10 10:00:14 +00:00
#%% First let us spawn a stable diffusion holder
device = " cuda "
2023-01-11 10:36:44 +00:00
fp_ckpt = " ../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt "
2023-01-10 10:00:14 +00:00
2023-02-15 17:21:00 +00:00
sdh = StableDiffusionHolder ( fp_ckpt )
2023-01-10 10:00:14 +00:00
2023-01-12 09:16:31 +00:00
xxx
2023-01-10 10:00:14 +00:00
#%% Next let's set up all parameters
2023-02-15 17:21:00 +00:00
depth_strength = 0.3 # Specifies how deep (in terms of diffusion iterations the first branching happens)
fixed_seeds = [ 697164 , 430214 ]
2023-01-10 10:00:14 +00:00
2023-02-15 17:21:00 +00:00
prompt1 = " photo of a desert and a sky "
prompt2 = " photo of a tree with a lake "
2023-01-10 10:00:14 +00:00
duration_transition = 12 # In seconds
fps = 30
# Spawn latent blending
self = LatentBlending ( sdh )
2023-02-15 17:21:00 +00:00
2023-01-10 10:00:14 +00:00
self . set_prompt1 ( prompt1 )
self . set_prompt2 ( prompt2 )
# Run latent blending
2023-02-15 17:21:00 +00:00
self . branch1_influence = 0.3
2023-02-16 10:48:45 +00:00
self . branch1_max_depth_influence = 0.4
# self.run_transition(depth_strength=depth_strength, fixed_seeds=fixed_seeds)
self . seed1 = 21312
img1 = self . compute_latents1 ( True )
#%
self . seed2 = 1234121
self . branch1_influence = 0.7
self . branch1_max_depth_influence = 0.3
self . branch1_influence_decay = 0.3
img2 = self . compute_latents2 ( True )
# Image.fromarray(np.concatenate((img1, img2), axis=1))
2023-02-15 17:21:00 +00:00
#%%
2023-02-16 10:48:45 +00:00
t0 = time . time ( )
self . t_compute_max_allowed = 30
self . parental_max_depth_influence = 1.0
self . parental_influence = 0.0
self . parental_influence_decay = 1.0
imgs_transition = self . run_transition ( recycle_img1 = True , recycle_img2 = True )
t1 = time . time ( )
print ( f " took: { t1 - t0 } s " )