reorganization
This commit is contained in:
4
latentblending/__init__.py
Normal file
4
latentblending/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .blending_engine import BlendingEngine
|
||||
from .diffusers_holder import DiffusersHolder
|
||||
from .movie_utils import MovieSaver
|
||||
from .utils import interpolate_spherical, add_frames_linear_interp, interpolate_linear, get_spacing, get_time, yml_load, yml_save
|
848
latentblending/blending_engine.py
Normal file
848
latentblending/blending_engine.py
Normal file
@@ -0,0 +1,848 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import warnings
|
||||
import time
|
||||
from tqdm.auto import tqdm
|
||||
from PIL import Image
|
||||
from latentblending.movie_util import MovieSaver
|
||||
from typing import List, Optional
|
||||
import lpips
|
||||
from latentblending.utils import interpolate_spherical, interpolate_linear, add_frames_linear_interp, yml_load, yml_save
|
||||
warnings.filterwarnings('ignore')
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
class BlendingEngine():
|
||||
def __init__(
|
||||
self,
|
||||
dh: None,
|
||||
guidance_scale_mid_damper: float = 0.5,
|
||||
mid_compression_scaler: float = 1.2):
|
||||
r"""
|
||||
Initializes the latent blending class.
|
||||
Args:
|
||||
guidance_scale_mid_damper: float = 0.5
|
||||
Reduces the guidance scale towards the middle of the transition.
|
||||
A value of 0.5 would decrease the guidance_scale towards the middle linearly by 0.5.
|
||||
mid_compression_scaler: float = 2.0
|
||||
Increases the sampling density in the middle (where most changes happen). Higher value
|
||||
imply more values in the middle. However the inflection point can occur outside the middle,
|
||||
thus high values can give rough transitions. Values around 2 should be fine.
|
||||
"""
|
||||
assert guidance_scale_mid_damper > 0 \
|
||||
and guidance_scale_mid_damper <= 1.0, \
|
||||
f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}"
|
||||
|
||||
self.dh = dh
|
||||
self.device = self.dh.device
|
||||
self.set_dimensions()
|
||||
|
||||
self.guidance_scale_mid_damper = guidance_scale_mid_damper
|
||||
self.mid_compression_scaler = mid_compression_scaler
|
||||
self.seed1 = 0
|
||||
self.seed2 = 0
|
||||
|
||||
# Initialize vars
|
||||
self.prompt1 = ""
|
||||
self.prompt2 = ""
|
||||
|
||||
self.tree_latents = [None, None]
|
||||
self.tree_fracts = None
|
||||
self.idx_injection = []
|
||||
self.tree_status = None
|
||||
self.tree_final_imgs = []
|
||||
|
||||
self.text_embedding1 = None
|
||||
self.text_embedding2 = None
|
||||
self.image1_lowres = None
|
||||
self.image2_lowres = None
|
||||
self.negative_prompt = None
|
||||
|
||||
self.set_guidance_scale()
|
||||
self.multi_transition_img_first = None
|
||||
self.multi_transition_img_last = None
|
||||
self.dt_unet_step = 0
|
||||
self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
|
||||
|
||||
self.set_prompt1("")
|
||||
self.set_prompt2("")
|
||||
|
||||
self.set_branch1_crossfeed()
|
||||
self.set_parental_crossfeed()
|
||||
|
||||
self.set_num_inference_steps()
|
||||
self.benchmark_speed()
|
||||
self.set_branching()
|
||||
|
||||
|
||||
|
||||
def benchmark_speed(self):
|
||||
"""
|
||||
Measures the time per diffusion step and for the vae decoding
|
||||
"""
|
||||
|
||||
text_embeddings = self.dh.get_text_embedding("test")
|
||||
latents_start = self.dh.get_noise(np.random.randint(111111))
|
||||
# warmup
|
||||
list_latents = self.dh.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=False, idx_start=self.num_inference_steps-1)
|
||||
# bench unet
|
||||
t0 = time.time()
|
||||
list_latents = self.dh.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=False, idx_start=self.num_inference_steps-1)
|
||||
self.dt_unet_step = time.time() - t0
|
||||
|
||||
# bench vae
|
||||
t0 = time.time()
|
||||
img = self.dh.latent2image(list_latents[-1])
|
||||
self.dt_vae = time.time() - t0
|
||||
|
||||
def set_dimensions(self, size_output=None):
|
||||
r"""
|
||||
sets the size of the output video.
|
||||
Args:
|
||||
size_output: tuple
|
||||
width x height
|
||||
Note: the size will get automatically adjusted to be divisable by 32.
|
||||
"""
|
||||
if size_output is None:
|
||||
if self.dh.is_sdxl_turbo:
|
||||
size_output = (512, 512)
|
||||
else:
|
||||
size_output = (1024, 1024)
|
||||
self.dh.set_dimensions(size_output)
|
||||
|
||||
def set_guidance_scale(self, guidance_scale=None):
|
||||
r"""
|
||||
sets the guidance scale.
|
||||
"""
|
||||
if guidance_scale is None:
|
||||
if self.dh.is_sdxl_turbo:
|
||||
guidance_scale = 0.0
|
||||
else:
|
||||
guidance_scale = 4.0
|
||||
|
||||
self.guidance_scale_base = guidance_scale
|
||||
self.guidance_scale = guidance_scale
|
||||
self.dh.guidance_scale = guidance_scale
|
||||
|
||||
def set_negative_prompt(self, negative_prompt):
|
||||
r"""Set the negative prompt. Currenty only one negative prompt is supported
|
||||
"""
|
||||
self.negative_prompt = negative_prompt
|
||||
self.dh.set_negative_prompt(negative_prompt)
|
||||
|
||||
def set_guidance_mid_dampening(self, fract_mixing):
|
||||
r"""
|
||||
Tunes the guidance scale down as a linear function of fract_mixing,
|
||||
towards 0.5 the minimum will be reached.
|
||||
"""
|
||||
mid_factor = 1 - np.abs(fract_mixing - 0.5) / 0.5
|
||||
max_guidance_reduction = self.guidance_scale_base * (1 - self.guidance_scale_mid_damper) - 1
|
||||
guidance_scale_effective = self.guidance_scale_base - max_guidance_reduction * mid_factor
|
||||
self.guidance_scale = guidance_scale_effective
|
||||
self.dh.guidance_scale = guidance_scale_effective
|
||||
|
||||
def set_branch1_crossfeed(self, crossfeed_power=0, crossfeed_range=0, crossfeed_decay=0):
|
||||
r"""
|
||||
Sets the crossfeed parameters for the first branch to the last branch.
|
||||
Args:
|
||||
crossfeed_power: float [0,1]
|
||||
Controls the level of cross-feeding between the first and last image branch.
|
||||
crossfeed_range: float [0,1]
|
||||
Sets the duration of active crossfeed during development.
|
||||
crossfeed_decay: float [0,1]
|
||||
Sets decay for branch1_crossfeed_power. Lower values make the decay stronger across the range.
|
||||
"""
|
||||
self.branch1_crossfeed_power = np.clip(crossfeed_power, 0, 1)
|
||||
self.branch1_crossfeed_range = np.clip(crossfeed_range, 0, 1)
|
||||
self.branch1_crossfeed_decay = np.clip(crossfeed_decay, 0, 1)
|
||||
|
||||
def set_parental_crossfeed(self, crossfeed_power=None, crossfeed_range=None, crossfeed_decay=None):
|
||||
r"""
|
||||
Sets the crossfeed parameters for all transition images (within the first and last branch).
|
||||
Args:
|
||||
crossfeed_power: float [0,1]
|
||||
Controls the level of cross-feeding from the parental branches
|
||||
crossfeed_range: float [0,1]
|
||||
Sets the duration of active crossfeed during development.
|
||||
crossfeed_decay: float [0,1]
|
||||
Sets decay for branch1_crossfeed_power. Lower values make the decay stronger across the range.
|
||||
"""
|
||||
|
||||
if self.dh.is_sdxl_turbo:
|
||||
if crossfeed_power is None:
|
||||
crossfeed_power = 1.0
|
||||
if crossfeed_range is None:
|
||||
crossfeed_range = 1.0
|
||||
if crossfeed_decay is None:
|
||||
crossfeed_decay = 1.0
|
||||
else:
|
||||
crossfeed_power = 0.3
|
||||
crossfeed_range = 0.6
|
||||
crossfeed_decay = 0.9
|
||||
|
||||
self.parental_crossfeed_power = np.clip(crossfeed_power, 0, 1)
|
||||
self.parental_crossfeed_range = np.clip(crossfeed_range, 0, 1)
|
||||
self.parental_crossfeed_decay = np.clip(crossfeed_decay, 0, 1)
|
||||
|
||||
def set_prompt1(self, prompt: str):
|
||||
r"""
|
||||
Sets the first prompt (for the first keyframe) including text embeddings.
|
||||
Args:
|
||||
prompt: str
|
||||
ABC trending on artstation painted by Greg Rutkowski
|
||||
"""
|
||||
prompt = prompt.replace("_", " ")
|
||||
self.prompt1 = prompt
|
||||
self.text_embedding1 = self.get_text_embeddings(self.prompt1)
|
||||
|
||||
def set_prompt2(self, prompt: str):
|
||||
r"""
|
||||
Sets the second prompt (for the second keyframe) including text embeddings.
|
||||
Args:
|
||||
prompt: str
|
||||
XYZ trending on artstation painted by Greg Rutkowski
|
||||
"""
|
||||
prompt = prompt.replace("_", " ")
|
||||
self.prompt2 = prompt
|
||||
self.text_embedding2 = self.get_text_embeddings(self.prompt2)
|
||||
|
||||
def set_image1(self, image: Image):
|
||||
r"""
|
||||
Sets the first image (keyframe), relevant for the upscaling model transitions.
|
||||
Args:
|
||||
image: Image
|
||||
"""
|
||||
self.image1_lowres = image
|
||||
|
||||
def set_image2(self, image: Image):
|
||||
r"""
|
||||
Sets the second image (keyframe), relevant for the upscaling model transitions.
|
||||
Args:
|
||||
image: Image
|
||||
"""
|
||||
self.image2_lowres = image
|
||||
|
||||
def set_num_inference_steps(self, num_inference_steps=None):
|
||||
if self.dh.is_sdxl_turbo:
|
||||
if num_inference_steps is None:
|
||||
num_inference_steps = 4
|
||||
else:
|
||||
if num_inference_steps is None:
|
||||
num_inference_steps = 30
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.dh.set_num_inference_steps(num_inference_steps)
|
||||
|
||||
def set_branching(self, depth_strength=None, t_compute_max_allowed=None, nmb_max_branches=None):
|
||||
"""
|
||||
Sets the branching structure of the blending tree. Default arguments depend on pipe!
|
||||
depth_strength:
|
||||
Determines how deep the first injection will happen.
|
||||
Deeper injections will cause (unwanted) formation of new structures,
|
||||
more shallow values will go into alpha-blendy land.
|
||||
t_compute_max_allowed:
|
||||
Either provide t_compute_max_allowed or nmb_max_branches.
|
||||
The maximum time allowed for computation. Higher values give better results but take longer.
|
||||
nmb_max_branches: int
|
||||
Either provide t_compute_max_allowed or nmb_max_branches. The maximum number of branches to be computed. Higher values give better
|
||||
results. Use this if you want to have controllable results independent
|
||||
of your computer.
|
||||
"""
|
||||
if self.dh.is_sdxl_turbo:
|
||||
assert t_compute_max_allowed is None, "time-based branching not supported for SDXL Turbo"
|
||||
if depth_strength is not None:
|
||||
idx_inject = int(round(self.num_inference_steps*depth_strength))
|
||||
else:
|
||||
idx_inject = 2
|
||||
if nmb_max_branches is None:
|
||||
nmb_max_branches = 10
|
||||
|
||||
self.list_idx_injection = [idx_inject]
|
||||
self.list_nmb_stems = [nmb_max_branches]
|
||||
|
||||
else:
|
||||
if depth_strength is None:
|
||||
depth_strength = 0.5
|
||||
if t_compute_max_allowed is None and nmb_max_branches is None:
|
||||
t_compute_max_allowed = 20
|
||||
elif t_compute_max_allowed is not None and nmb_max_branches is not None:
|
||||
raise ValueErorr("Either specify t_compute_max_allowed or nmb_max_branches")
|
||||
|
||||
self.list_idx_injection, self.list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches)
|
||||
|
||||
def run_transition(
|
||||
self,
|
||||
recycle_img1: Optional[bool] = False,
|
||||
recycle_img2: Optional[bool] = False,
|
||||
fixed_seeds: Optional[List[int]] = None):
|
||||
r"""
|
||||
Function for computing transitions.
|
||||
Returns a list of transition images using spherical latent blending.
|
||||
Args:
|
||||
recycle_img1: Optional[bool]:
|
||||
Don't recompute the latents for the first keyframe (purely prompt1). Saves compute.
|
||||
recycle_img2: Optional[bool]:
|
||||
Don't recompute the latents for the second keyframe (purely prompt2). Saves compute.
|
||||
num_inference_steps:
|
||||
Number of diffusion steps. Higher values will take more compute time.
|
||||
|
||||
fixed_seeds: Optional[List[int)]:
|
||||
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
|
||||
Otherwise random seeds will be taken.
|
||||
"""
|
||||
|
||||
# Sanity checks first
|
||||
assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before'
|
||||
assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before'
|
||||
|
||||
|
||||
# Random seeds
|
||||
if fixed_seeds is not None:
|
||||
if fixed_seeds == 'randomize':
|
||||
fixed_seeds = list(np.random.randint(0, 1000000, 2).astype(np.int32))
|
||||
else:
|
||||
assert len(fixed_seeds) == 2, "Supply a list with len = 2"
|
||||
|
||||
self.seed1 = fixed_seeds[0]
|
||||
self.seed2 = fixed_seeds[1]
|
||||
|
||||
|
||||
# Compute / Recycle first image
|
||||
if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps:
|
||||
list_latents1 = self.compute_latents1()
|
||||
else:
|
||||
list_latents1 = self.tree_latents[0]
|
||||
|
||||
# Compute / Recycle first image
|
||||
if not recycle_img2 or len(self.tree_latents[-1]) != self.num_inference_steps:
|
||||
list_latents2 = self.compute_latents2()
|
||||
else:
|
||||
list_latents2 = self.tree_latents[-1]
|
||||
|
||||
# Reset the tree, injecting the edge latents1/2 we just generated/recycled
|
||||
self.tree_latents = [list_latents1, list_latents2]
|
||||
self.tree_fracts = [0.0, 1.0]
|
||||
self.tree_final_imgs = [self.dh.latent2image((self.tree_latents[0][-1])), self.dh.latent2image((self.tree_latents[-1][-1]))]
|
||||
self.tree_idx_injection = [0, 0]
|
||||
self.tree_similarities = [self.get_tree_similarities]
|
||||
|
||||
|
||||
# Run iteratively, starting with the longest trajectory.
|
||||
# Always inserting new branches where they are needed most according to image similarity
|
||||
for s_idx in tqdm(range(len(self.list_idx_injection))):
|
||||
nmb_stems = self.list_nmb_stems[s_idx]
|
||||
idx_injection = self.list_idx_injection[s_idx]
|
||||
|
||||
for i in range(nmb_stems):
|
||||
fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection)
|
||||
self.set_guidance_mid_dampening(fract_mixing)
|
||||
list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection)
|
||||
self.insert_into_tree(fract_mixing, idx_injection, list_latents)
|
||||
# print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection} bp1 {b_parent1} bp2 {b_parent2}")
|
||||
|
||||
return self.tree_final_imgs
|
||||
|
||||
|
||||
|
||||
|
||||
def compute_latents1(self, return_image=False):
|
||||
r"""
|
||||
Runs a diffusion trajectory for the first image
|
||||
Args:
|
||||
return_image: bool
|
||||
whether to return an image or the list of latents
|
||||
"""
|
||||
print("starting compute_latents1")
|
||||
list_conditionings = self.get_mixed_conditioning(0)
|
||||
t0 = time.time()
|
||||
latents_start = self.get_noise(self.seed1)
|
||||
list_latents1 = self.run_diffusion(
|
||||
list_conditionings,
|
||||
latents_start=latents_start,
|
||||
idx_start=0)
|
||||
t1 = time.time()
|
||||
self.dt_unet_step = (t1 - t0) / self.num_inference_steps
|
||||
self.tree_latents[0] = list_latents1
|
||||
if return_image:
|
||||
return self.dh.latent2image(list_latents1[-1])
|
||||
else:
|
||||
return list_latents1
|
||||
|
||||
def compute_latents2(self, return_image=False):
|
||||
r"""
|
||||
Runs a diffusion trajectory for the last image, which may be affected by the first image's trajectory.
|
||||
Args:
|
||||
return_image: bool
|
||||
whether to return an image or the list of latents
|
||||
"""
|
||||
print("starting compute_latents2")
|
||||
list_conditionings = self.get_mixed_conditioning(1)
|
||||
latents_start = self.get_noise(self.seed2)
|
||||
# Influence from branch1
|
||||
if self.branch1_crossfeed_power > 0.0:
|
||||
# Set up the mixing_coeffs
|
||||
idx_mixing_stop = int(round(self.num_inference_steps * self.branch1_crossfeed_range))
|
||||
mixing_coeffs = list(np.linspace(self.branch1_crossfeed_power, self.branch1_crossfeed_power * self.branch1_crossfeed_decay, idx_mixing_stop))
|
||||
mixing_coeffs.extend((self.num_inference_steps - idx_mixing_stop) * [0])
|
||||
list_latents_mixing = self.tree_latents[0]
|
||||
list_latents2 = self.run_diffusion(
|
||||
list_conditionings,
|
||||
latents_start=latents_start,
|
||||
idx_start=0,
|
||||
list_latents_mixing=list_latents_mixing,
|
||||
mixing_coeffs=mixing_coeffs)
|
||||
else:
|
||||
list_latents2 = self.run_diffusion(list_conditionings, latents_start)
|
||||
self.tree_latents[-1] = list_latents2
|
||||
|
||||
if return_image:
|
||||
return self.dh.latent2image(list_latents2[-1])
|
||||
else:
|
||||
return list_latents2
|
||||
|
||||
def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection):
|
||||
r"""
|
||||
Runs a diffusion trajectory, using the latents from the respective parents
|
||||
Args:
|
||||
fract_mixing: float
|
||||
the fraction along the transition axis [0, 1]
|
||||
b_parent1: int
|
||||
index of parent1 to be used
|
||||
b_parent2: int
|
||||
index of parent2 to be used
|
||||
idx_injection: int
|
||||
the index in terms of diffusion steps, where the next insertion will start.
|
||||
"""
|
||||
list_conditionings = self.get_mixed_conditioning(fract_mixing)
|
||||
fract_mixing_parental = (fract_mixing - self.tree_fracts[b_parent1]) / (self.tree_fracts[b_parent2] - self.tree_fracts[b_parent1])
|
||||
# idx_reversed = self.num_inference_steps - idx_injection
|
||||
|
||||
list_latents_parental_mix = []
|
||||
for i in range(self.num_inference_steps):
|
||||
latents_p1 = self.tree_latents[b_parent1][i]
|
||||
latents_p2 = self.tree_latents[b_parent2][i]
|
||||
if latents_p1 is None or latents_p2 is None:
|
||||
latents_parental = None
|
||||
else:
|
||||
latents_parental = interpolate_spherical(latents_p1, latents_p2, fract_mixing_parental)
|
||||
list_latents_parental_mix.append(latents_parental)
|
||||
|
||||
idx_mixing_stop = int(round(self.num_inference_steps * self.parental_crossfeed_range))
|
||||
mixing_coeffs = idx_injection * [self.parental_crossfeed_power]
|
||||
nmb_mixing = idx_mixing_stop - idx_injection
|
||||
if nmb_mixing > 0:
|
||||
mixing_coeffs.extend(list(np.linspace(self.parental_crossfeed_power, self.parental_crossfeed_power * self.parental_crossfeed_decay, nmb_mixing)))
|
||||
mixing_coeffs.extend((self.num_inference_steps - len(mixing_coeffs)) * [0])
|
||||
latents_start = list_latents_parental_mix[idx_injection - 1]
|
||||
list_latents = self.run_diffusion(
|
||||
list_conditionings,
|
||||
latents_start=latents_start,
|
||||
idx_start=idx_injection,
|
||||
list_latents_mixing=list_latents_parental_mix,
|
||||
mixing_coeffs=mixing_coeffs)
|
||||
return list_latents
|
||||
|
||||
def get_time_based_branching(self, depth_strength, t_compute_max_allowed=None, nmb_max_branches=None):
|
||||
r"""
|
||||
Sets up the branching scheme dependent on the time that is granted for compute.
|
||||
The scheme uses an estimation derived from the first image's computation speed.
|
||||
Either provide t_compute_max_allowed or nmb_max_branches
|
||||
Args:
|
||||
depth_strength:
|
||||
Determines how deep the first injection will happen.
|
||||
Deeper injections will cause (unwanted) formation of new structures,
|
||||
more shallow values will go into alpha-blendy land.
|
||||
t_compute_max_allowed: float
|
||||
The maximum time allowed for computation. Higher values give better results
|
||||
but take longer. Use this if you want to fix your waiting time for the results.
|
||||
nmb_max_branches: int
|
||||
The maximum number of branches to be computed. Higher values give better
|
||||
results. Use this if you want to have controllable results independent
|
||||
of your computer.
|
||||
"""
|
||||
idx_injection_base = int(np.floor(self.num_inference_steps * depth_strength))
|
||||
|
||||
steps = int(np.ceil(self.num_inference_steps/10))
|
||||
list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps, steps)
|
||||
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
|
||||
t_compute = 0
|
||||
|
||||
if nmb_max_branches is None:
|
||||
assert t_compute_max_allowed is not None, "Either specify t_compute_max_allowed or nmb_max_branches"
|
||||
stop_criterion = "t_compute_max_allowed"
|
||||
elif t_compute_max_allowed is None:
|
||||
assert nmb_max_branches is not None, "Either specify t_compute_max_allowed or nmb_max_branches"
|
||||
stop_criterion = "nmb_max_branches"
|
||||
nmb_max_branches -= 2 # Discounting the outer frames
|
||||
else:
|
||||
raise ValueError("Either specify t_compute_max_allowed or nmb_max_branches")
|
||||
stop_criterion_reached = False
|
||||
is_first_iteration = True
|
||||
while not stop_criterion_reached:
|
||||
list_compute_steps = self.num_inference_steps - list_idx_injection
|
||||
list_compute_steps *= list_nmb_stems
|
||||
t_compute = np.sum(list_compute_steps) * self.dt_unet_step + self.dt_vae * np.sum(list_nmb_stems)
|
||||
t_compute += 2 * (self.num_inference_steps * self.dt_unet_step + self.dt_vae) # outer branches
|
||||
increase_done = False
|
||||
for s_idx in range(len(list_nmb_stems) - 1):
|
||||
if list_nmb_stems[s_idx + 1] / list_nmb_stems[s_idx] >= 1:
|
||||
list_nmb_stems[s_idx] += 1
|
||||
increase_done = True
|
||||
break
|
||||
if not increase_done:
|
||||
list_nmb_stems[-1] += 1
|
||||
|
||||
if stop_criterion == "t_compute_max_allowed" and t_compute > t_compute_max_allowed:
|
||||
stop_criterion_reached = True
|
||||
elif stop_criterion == "nmb_max_branches" and np.sum(list_nmb_stems) >= nmb_max_branches:
|
||||
stop_criterion_reached = True
|
||||
if is_first_iteration:
|
||||
# Need to undersample.
|
||||
list_idx_injection = np.linspace(list_idx_injection[0], list_idx_injection[-1], nmb_max_branches).astype(np.int32)
|
||||
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
|
||||
else:
|
||||
is_first_iteration = False
|
||||
|
||||
# print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}")
|
||||
return list_idx_injection, list_nmb_stems
|
||||
|
||||
def get_mixing_parameters(self, idx_injection):
|
||||
r"""
|
||||
Computes which parental latents should be mixed together to achieve a smooth blend.
|
||||
As metric, we are using lpips image similarity. The insertion takes place
|
||||
where the metric is maximal.
|
||||
Args:
|
||||
idx_injection: int
|
||||
the index in terms of diffusion steps, where the next insertion will start.
|
||||
"""
|
||||
# get_lpips_similarity
|
||||
similarities = self.tree_similarities
|
||||
# similarities = self.get_tree_similarities()
|
||||
b_closest1 = np.argmax(similarities)
|
||||
b_closest2 = b_closest1 + 1
|
||||
fract_closest1 = self.tree_fracts[b_closest1]
|
||||
fract_closest2 = self.tree_fracts[b_closest2]
|
||||
fract_mixing = (fract_closest1 + fract_closest2) / 2
|
||||
|
||||
# Ensure that the parents are indeed older
|
||||
b_parent1 = b_closest1
|
||||
while True:
|
||||
if self.tree_idx_injection[b_parent1] < idx_injection:
|
||||
break
|
||||
else:
|
||||
b_parent1 -= 1
|
||||
b_parent2 = b_closest2
|
||||
while True:
|
||||
if self.tree_idx_injection[b_parent2] < idx_injection:
|
||||
break
|
||||
else:
|
||||
b_parent2 += 1
|
||||
return fract_mixing, b_parent1, b_parent2
|
||||
|
||||
def insert_into_tree(self, fract_mixing, idx_injection, list_latents):
|
||||
r"""
|
||||
Inserts all necessary parameters into the trajectory tree.
|
||||
Args:
|
||||
fract_mixing: float
|
||||
the fraction along the transition axis [0, 1]
|
||||
idx_injection: int
|
||||
the index in terms of diffusion steps, where the next insertion will start.
|
||||
list_latents: list
|
||||
list of the latents to be inserted
|
||||
"""
|
||||
img_insert = self.dh.latent2image(list_latents[-1])
|
||||
|
||||
b_parent1, b_parent2 = self.get_closest_idx(fract_mixing)
|
||||
left_sim = self.get_lpips_similarity(img_insert, self.tree_final_imgs[b_parent1])
|
||||
right_sim = self.get_lpips_similarity(img_insert, self.tree_final_imgs[b_parent2])
|
||||
idx_insert = b_parent1 + 1
|
||||
self.tree_latents.insert(idx_insert, list_latents)
|
||||
self.tree_final_imgs.insert(idx_insert, img_insert)
|
||||
self.tree_fracts.insert(idx_insert, fract_mixing)
|
||||
self.tree_idx_injection.insert(idx_insert, idx_injection)
|
||||
|
||||
# update similarities
|
||||
self.tree_similarities[b_parent1] = left_sim
|
||||
self.tree_similarities.insert(idx_insert, right_sim)
|
||||
|
||||
|
||||
def get_noise(self, seed):
|
||||
r"""
|
||||
Helper function to get noise given seed.
|
||||
Args:
|
||||
seed: int
|
||||
"""
|
||||
return self.dh.get_noise(seed)
|
||||
|
||||
@torch.no_grad()
|
||||
def run_diffusion(
|
||||
self,
|
||||
list_conditionings,
|
||||
latents_start: torch.FloatTensor = None,
|
||||
idx_start: int = 0,
|
||||
list_latents_mixing=None,
|
||||
mixing_coeffs=0.0,
|
||||
return_image: Optional[bool] = False):
|
||||
r"""
|
||||
Wrapper function for diffusion runners.
|
||||
Depending on the mode, the correct one will be executed.
|
||||
|
||||
Args:
|
||||
list_conditionings: list
|
||||
List of all conditionings for the diffusion model.
|
||||
latents_start: torch.FloatTensor
|
||||
Latents that are used for injection
|
||||
idx_start: int
|
||||
Index of the diffusion process start and where the latents_for_injection are injected
|
||||
list_latents_mixing: torch.FloatTensor
|
||||
List of latents (latent trajectories) that are used for mixing
|
||||
mixing_coeffs: float or list
|
||||
Coefficients, how strong each element of list_latents_mixing will be mixed in.
|
||||
return_image: Optional[bool]
|
||||
Optionally return image directly
|
||||
"""
|
||||
|
||||
# Ensure correct num_inference_steps in Holder
|
||||
self.dh.set_num_inference_steps(self.num_inference_steps)
|
||||
assert type(list_conditionings) is list, "list_conditionings need to be a list"
|
||||
|
||||
text_embeddings = list_conditionings[0]
|
||||
return self.dh.run_diffusion_sd_xl(
|
||||
text_embeddings=text_embeddings,
|
||||
latents_start=latents_start,
|
||||
idx_start=idx_start,
|
||||
list_latents_mixing=list_latents_mixing,
|
||||
mixing_coeffs=mixing_coeffs,
|
||||
return_image=return_image)
|
||||
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_mixed_conditioning(self, fract_mixing):
|
||||
text_embeddings_mix = []
|
||||
for i in range(len(self.text_embedding1)):
|
||||
if self.text_embedding1[i] is None:
|
||||
mix = None
|
||||
else:
|
||||
mix = interpolate_linear(self.text_embedding1[i], self.text_embedding2[i], fract_mixing)
|
||||
text_embeddings_mix.append(mix)
|
||||
list_conditionings = [text_embeddings_mix]
|
||||
|
||||
return list_conditionings
|
||||
|
||||
@torch.no_grad()
|
||||
def get_text_embeddings(
|
||||
self,
|
||||
prompt: str):
|
||||
r"""
|
||||
Computes the text embeddings provided a string with a prompts.
|
||||
Adapted from stable diffusion repo
|
||||
Args:
|
||||
prompt: str
|
||||
ABC trending on artstation painted by Old Greg.
|
||||
"""
|
||||
return self.dh.get_text_embedding(prompt)
|
||||
|
||||
def write_imgs_transition(self, dp_img):
|
||||
r"""
|
||||
Writes the transition images into the folder dp_img.
|
||||
Requires run_transition to be completed.
|
||||
Args:
|
||||
dp_img: str
|
||||
Directory, into which the transition images, yaml file and latents are written.
|
||||
"""
|
||||
imgs_transition = self.tree_final_imgs
|
||||
os.makedirs(dp_img, exist_ok=True)
|
||||
for i, img in enumerate(imgs_transition):
|
||||
img_leaf = Image.fromarray(img)
|
||||
img_leaf.save(os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg"))
|
||||
fp_yml = os.path.join(dp_img, "lowres.yaml")
|
||||
self.save_statedict(fp_yml)
|
||||
|
||||
def write_movie_transition(self, fp_movie, duration_transition, fps=30):
|
||||
r"""
|
||||
Writes the transition movie to fp_movie, using the given duration and fps..
|
||||
The missing frames are linearly interpolated.
|
||||
Args:
|
||||
fp_movie: str
|
||||
file pointer to the final movie.
|
||||
duration_transition: float
|
||||
duration of the movie in seonds
|
||||
fps: int
|
||||
fps of the movie
|
||||
"""
|
||||
|
||||
# Let's get more cheap frames via linear interpolation (duration_transition*fps frames)
|
||||
imgs_transition_ext = add_frames_linear_interp(self.tree_final_imgs, duration_transition, fps)
|
||||
|
||||
# Save as MP4
|
||||
if os.path.isfile(fp_movie):
|
||||
os.remove(fp_movie)
|
||||
ms = MovieSaver(fp_movie, fps=fps, shape_hw=[self.dh.height_img, self.dh.width_img])
|
||||
for img in tqdm(imgs_transition_ext):
|
||||
ms.write_frame(img)
|
||||
ms.finalize()
|
||||
|
||||
def save_statedict(self, fp_yml):
|
||||
# Dump everything relevant into yaml
|
||||
imgs_transition = self.tree_final_imgs
|
||||
state_dict = self.get_state_dict()
|
||||
state_dict['nmb_images'] = len(imgs_transition)
|
||||
yml_save(fp_yml, state_dict)
|
||||
|
||||
def get_state_dict(self):
|
||||
state_dict = {}
|
||||
grab_vars = ['prompt1', 'prompt2', 'seed1', 'seed2', 'height', 'width',
|
||||
'num_inference_steps', 'depth_strength', 'guidance_scale',
|
||||
'guidance_scale_mid_damper', 'mid_compression_scaler', 'negative_prompt',
|
||||
'branch1_crossfeed_power', 'branch1_crossfeed_range', 'branch1_crossfeed_decay'
|
||||
'parental_crossfeed_power', 'parental_crossfeed_range', 'parental_crossfeed_decay']
|
||||
for v in grab_vars:
|
||||
if hasattr(self, v):
|
||||
if v == 'seed1' or v == 'seed2':
|
||||
state_dict[v] = int(getattr(self, v))
|
||||
elif v == 'guidance_scale':
|
||||
state_dict[v] = float(getattr(self, v))
|
||||
|
||||
else:
|
||||
try:
|
||||
state_dict[v] = getattr(self, v)
|
||||
except Exception:
|
||||
pass
|
||||
return state_dict
|
||||
|
||||
def randomize_seed(self):
|
||||
r"""
|
||||
Set a random seed for a fresh start.
|
||||
"""
|
||||
seed = np.random.randint(999999999)
|
||||
self.set_seed(seed)
|
||||
|
||||
def set_seed(self, seed: int):
|
||||
r"""
|
||||
Set a the seed for a fresh start.
|
||||
"""
|
||||
self.seed = seed
|
||||
self.dh.seed = seed
|
||||
|
||||
def set_width(self, width):
|
||||
r"""
|
||||
Set the width of the resulting image.
|
||||
"""
|
||||
assert np.mod(width, 64) == 0, "set_width: value needs to be divisible by 64"
|
||||
self.width = width
|
||||
self.dh.width = width
|
||||
|
||||
def set_height(self, height):
|
||||
r"""
|
||||
Set the height of the resulting image.
|
||||
"""
|
||||
assert np.mod(height, 64) == 0, "set_height: value needs to be divisible by 64"
|
||||
self.height = height
|
||||
self.dh.height = height
|
||||
|
||||
def swap_forward(self):
|
||||
r"""
|
||||
Moves over keyframe two -> keyframe one. Useful for making a sequence of transitions
|
||||
as in run_multi_transition()
|
||||
"""
|
||||
# Move over all latents
|
||||
self.tree_latents[0] = self.tree_latents[-1]
|
||||
# Move over prompts and text embeddings
|
||||
self.prompt1 = self.prompt2
|
||||
self.text_embedding1 = self.text_embedding2
|
||||
# Final cleanup for extra sanity
|
||||
self.tree_final_imgs = []
|
||||
|
||||
def get_lpips_similarity(self, imgA, imgB):
|
||||
r"""
|
||||
Computes the image similarity between two images imgA and imgB.
|
||||
Used to determine the optimal point of insertion to create smooth transitions.
|
||||
High values indicate low similarity.
|
||||
"""
|
||||
tensorA = torch.from_numpy(np.asarray(imgA)).float().cuda(self.device)
|
||||
tensorA = 2 * tensorA / 255.0 - 1
|
||||
tensorA = tensorA.permute([2, 0, 1]).unsqueeze(0)
|
||||
tensorB = torch.from_numpy(np.asarray(imgB)).float().cuda(self.device)
|
||||
tensorB = 2 * tensorB / 255.0 - 1
|
||||
tensorB = tensorB.permute([2, 0, 1]).unsqueeze(0)
|
||||
lploss = self.lpips(tensorA, tensorB)
|
||||
lploss = float(lploss[0][0][0][0])
|
||||
return lploss
|
||||
|
||||
def get_tree_similarities(self):
|
||||
similarities = []
|
||||
for i in range(len(self.tree_final_imgs) - 1):
|
||||
similarities.append(self.get_lpips_similarity(self.tree_final_imgs[i], self.tree_final_imgs[i + 1]))
|
||||
return similarities
|
||||
|
||||
# Auxiliary functions
|
||||
def get_closest_idx(
|
||||
self,
|
||||
fract_mixing: float):
|
||||
r"""
|
||||
Helper function to retrieve the parents for any given mixing.
|
||||
Example: fract_mixing = 0.4 and self.tree_fracts = [0, 0.3, 0.6, 1.0]
|
||||
Will return the two closest values here, i.e. [1, 2]
|
||||
"""
|
||||
|
||||
pdist = fract_mixing - np.asarray(self.tree_fracts)
|
||||
pdist_pos = pdist.copy()
|
||||
pdist_pos[pdist_pos < 0] = np.inf
|
||||
b_parent1 = np.argmin(pdist_pos)
|
||||
pdist_neg = -pdist.copy()
|
||||
pdist_neg[pdist_neg <= 0] = np.inf
|
||||
b_parent2 = np.argmin(pdist_neg)
|
||||
|
||||
if b_parent1 > b_parent2:
|
||||
tmp = b_parent2
|
||||
b_parent2 = b_parent1
|
||||
b_parent1 = tmp
|
||||
|
||||
return b_parent1, b_parent2
|
||||
|
||||
#%%
|
||||
if __name__ == "__main__":
|
||||
|
||||
# %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
|
||||
from diffusers_holder import DiffusersHolder
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import AutoencoderTiny
|
||||
# pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
pretrained_model_name_or_path = "stabilityai/sdxl-turbo"
|
||||
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16, variant="fp16")
|
||||
pipe.to("cuda")
|
||||
pipe.vae = AutoencoderTiny.from_pretrained('madebyollin/taesdxl', torch_device='cuda', torch_dtype=torch.float16)
|
||||
pipe.vae = pipe.vae.cuda()
|
||||
|
||||
dh = DiffusersHolder(pipe)
|
||||
# %% Next let's set up all parameters
|
||||
prompt1 = "photo of underwater landscape, fish, und the sea, incredible detail, high resolution"
|
||||
prompt2 = "rendering of an alien planet, strange plants, strange creatures, surreal"
|
||||
negative_prompt = "blurry, ugly, pale" # Optional
|
||||
|
||||
duration_transition = 12 # In seconds
|
||||
|
||||
# Spawn latent blending
|
||||
lb = LatentBlending(dh)
|
||||
lb.set_prompt1(prompt1)
|
||||
lb.set_prompt2(prompt2)
|
||||
lb.set_negative_prompt(negative_prompt)
|
||||
|
||||
# Run latent blending
|
||||
t0 = time.time()
|
||||
lb.run_transition(fixed_seeds=[420, 421])
|
||||
dt = time.time() - t0
|
||||
|
||||
# Save movie
|
||||
fp_movie = f'test.mp4'
|
||||
lb.write_movie_transition(fp_movie, duration_transition)
|
||||
|
||||
|
||||
|
||||
|
474
latentblending/diffusers_holder.py
Normal file
474
latentblending/diffusers_holder.py
Normal file
@@ -0,0 +1,474 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import warnings
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from latentblending.utils import interpolate_spherical
|
||||
from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import retrieve_timesteps
|
||||
warnings.filterwarnings('ignore')
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
class DiffusersHolder():
|
||||
def __init__(self, pipe):
|
||||
# Base settings
|
||||
self.negative_prompt = ""
|
||||
self.guidance_scale = 5.0
|
||||
self.num_inference_steps = 30
|
||||
|
||||
# Check if valid pipe
|
||||
self.pipe = pipe
|
||||
self.device = str(pipe._execution_device)
|
||||
self.init_types()
|
||||
|
||||
self.width_latent = self.pipe.unet.config.sample_size
|
||||
self.height_latent = self.pipe.unet.config.sample_size
|
||||
self.width_img = self.width_latent * self.pipe.vae_scale_factor
|
||||
self.height_img = self.height_latent * self.pipe.vae_scale_factor
|
||||
|
||||
|
||||
def init_types(self):
|
||||
assert hasattr(self.pipe, "__class__"), "No valid diffusers pipeline found."
|
||||
assert hasattr(self.pipe.__class__, "__name__"), "No valid diffusers pipeline found."
|
||||
if self.pipe.__class__.__name__ == 'StableDiffusionXLPipeline':
|
||||
self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device)
|
||||
prompt_embeds, _, _, _ = self.pipe.encode_prompt("test")
|
||||
else:
|
||||
prompt_embeds = self.pipe._encode_prompt("test", self.device, 1, True)
|
||||
self.dtype = prompt_embeds.dtype
|
||||
|
||||
self.is_sdxl_turbo = 'turbo' in self.pipe._name_or_path
|
||||
|
||||
|
||||
def set_num_inference_steps(self, num_inference_steps):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device)
|
||||
|
||||
def set_dimensions(self, size_output):
|
||||
s = self.pipe.vae_scale_factor
|
||||
if size_output is None:
|
||||
width = self.pipe.unet.config.sample_size
|
||||
height = self.pipe.unet.config.sample_size
|
||||
else:
|
||||
width, height = size_output
|
||||
self.width_img = int(round(width / s) * s)
|
||||
self.width_latent = int(self.width_img / s)
|
||||
self.height_img = int(round(height / s) * s)
|
||||
self.height_latent = int(self.height_img / s)
|
||||
print(f"set_dimensions to width={width} and height={height}")
|
||||
|
||||
def set_negative_prompt(self, negative_prompt):
|
||||
r"""Set the negative prompt. Currenty only one negative prompt is supported
|
||||
"""
|
||||
if isinstance(negative_prompt, str):
|
||||
self.negative_prompt = [negative_prompt]
|
||||
else:
|
||||
self.negative_prompt = negative_prompt
|
||||
|
||||
if len(self.negative_prompt) > 1:
|
||||
self.negative_prompt = [self.negative_prompt[0]]
|
||||
|
||||
def get_text_embedding(self, prompt):
|
||||
do_classifier_free_guidance = self.guidance_scale > 1 and self.pipe.unet.config.time_cond_proj_dim is None
|
||||
text_embeddings = self.pipe.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt,
|
||||
device=self.pipe._execution_device,
|
||||
num_images_per_prompt=1,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=self.negative_prompt,
|
||||
negative_prompt_2=self.negative_prompt,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
negative_pooled_prompt_embeds=None,
|
||||
lora_scale=None,
|
||||
clip_skip=None,#self.pipe._clip_skip,
|
||||
)
|
||||
return text_embeddings
|
||||
|
||||
def get_noise(self, seed=420):
|
||||
|
||||
latents = self.pipe.prepare_latents(
|
||||
1,
|
||||
self.pipe.unet.config.in_channels,
|
||||
self.height_img,
|
||||
self.width_img,
|
||||
torch.float16,
|
||||
self.pipe._execution_device,
|
||||
torch.Generator(device=self.device).manual_seed(int(seed)),
|
||||
None,
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def latent2image(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
output_type="pil"):
|
||||
r"""
|
||||
Returns an image provided a latent representation from diffusion.
|
||||
Args:
|
||||
latents: torch.FloatTensor
|
||||
Result of the diffusion process.
|
||||
output_type: "pil" or "np"
|
||||
"""
|
||||
assert output_type in ["pil", "np"]
|
||||
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.pipe.vae.dtype == torch.float16 and self.pipe.vae.config.force_upcast
|
||||
|
||||
if needs_upcasting:
|
||||
self.pipe.upcast_vae()
|
||||
latents = latents.to(next(iter(self.pipe.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
image = self.pipe.vae.decode(latents / self.pipe.vae.config.scaling_factor, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.pipe.vae.to(dtype=torch.float16)
|
||||
|
||||
image = self.pipe.image_processor.postprocess(image, output_type=output_type)[0]
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def prepare_mixing(self, mixing_coeffs, list_latents_mixing):
|
||||
if type(mixing_coeffs) == float:
|
||||
list_mixing_coeffs = (1 + self.num_inference_steps) * [mixing_coeffs]
|
||||
elif type(mixing_coeffs) == list:
|
||||
assert len(mixing_coeffs) == self.num_inference_steps, f"len(mixing_coeffs) {len(mixing_coeffs)} != self.num_inference_steps {self.num_inference_steps}"
|
||||
list_mixing_coeffs = mixing_coeffs
|
||||
else:
|
||||
raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps")
|
||||
if np.sum(list_mixing_coeffs) > 0:
|
||||
assert len(list_latents_mixing) == self.num_inference_steps, f"len(list_latents_mixing) {len(list_latents_mixing)} != self.num_inference_steps {self.num_inference_steps}"
|
||||
return list_mixing_coeffs
|
||||
|
||||
@torch.no_grad()
|
||||
def run_diffusion(
|
||||
self,
|
||||
text_embeddings: torch.FloatTensor,
|
||||
latents_start: torch.FloatTensor,
|
||||
idx_start: int = 0,
|
||||
list_latents_mixing=None,
|
||||
mixing_coeffs=0.0,
|
||||
return_image: Optional[bool] = False):
|
||||
|
||||
return self.run_diffusion_sd_xl(text_embeddings, latents_start, idx_start, list_latents_mixing, mixing_coeffs, return_image)
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def run_diffusion_sd_xl(
|
||||
self,
|
||||
text_embeddings: tuple,
|
||||
latents_start: torch.FloatTensor,
|
||||
idx_start: int = 0,
|
||||
list_latents_mixing=None,
|
||||
mixing_coeffs=0.0,
|
||||
return_image: Optional[bool] = False,
|
||||
):
|
||||
|
||||
|
||||
prompt_2 = None
|
||||
height = None
|
||||
width = None
|
||||
timesteps = None
|
||||
denoising_end = None
|
||||
negative_prompt_2 = None
|
||||
num_images_per_prompt = 1
|
||||
eta = 0.0
|
||||
generator = None
|
||||
latents = None
|
||||
prompt_embeds = None
|
||||
negative_prompt_embeds = None
|
||||
pooled_prompt_embeds = None
|
||||
negative_pooled_prompt_embeds = None
|
||||
ip_adapter_image = None
|
||||
output_type = "pil"
|
||||
return_dict = True
|
||||
cross_attention_kwargs = None
|
||||
guidance_rescale = 0.0
|
||||
original_size = None
|
||||
crops_coords_top_left = (0, 0)
|
||||
target_size = None
|
||||
negative_original_size = None
|
||||
negative_crops_coords_top_left = (0, 0)
|
||||
negative_target_size = None
|
||||
clip_skip = None
|
||||
callback = None
|
||||
callback_on_step_end = None
|
||||
callback_on_step_end_tensor_inputs = ["latents"]
|
||||
# kwargs are additional keyword arguments and don't need a default value set here.
|
||||
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.pipe.default_sample_size * self.pipe.vae_scale_factor
|
||||
width = width or self.pipe.default_sample_size * self.pipe.vae_scale_factor
|
||||
|
||||
original_size = original_size or (height, width)
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
# 1. Check inputs. skipped.
|
||||
|
||||
self.pipe._guidance_scale = self.guidance_scale
|
||||
self.pipe._guidance_rescale = guidance_rescale
|
||||
self.pipe._clip_skip = clip_skip
|
||||
self.pipe._cross_attention_kwargs = cross_attention_kwargs
|
||||
self.pipe._denoising_end = denoising_end
|
||||
self.pipe._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
list_mixing_coeffs = self.prepare_mixing(mixing_coeffs, list_latents_mixing)
|
||||
batch_size = 1
|
||||
|
||||
device = self.pipe._execution_device
|
||||
|
||||
# 3. Encode input prompt
|
||||
lora_scale = None
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = text_embeddings
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.pipe.scheduler, self.num_inference_steps, device, timesteps)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.pipe.unet.config.in_channels
|
||||
latents = latents_start.clone()
|
||||
list_latents_out = []
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.pipe.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
if self.pipe.text_encoder_2 is None:
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
else:
|
||||
text_encoder_projection_dim = self.pipe.text_encoder_2.config.projection_dim
|
||||
|
||||
add_time_ids = self.pipe._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
if negative_original_size is not None and negative_target_size is not None:
|
||||
negative_add_time_ids = self.pipe._get_add_time_ids(
|
||||
negative_original_size,
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
else:
|
||||
negative_add_time_ids = add_time_ids
|
||||
|
||||
if self.pipe.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.pipe.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.pipe.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
)
|
||||
if self.pipe.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
image_embeds = image_embeds.to(device)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.pipe.scheduler.order, 0)
|
||||
|
||||
# 9. Optionally get Guidance Scale Embedding
|
||||
timestep_cond = None
|
||||
if self.pipe.unet.config.time_cond_proj_dim is not None:
|
||||
guidance_scale_tensor = torch.tensor(self.pipe.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
||||
timestep_cond = self.pipe.get_guidance_scale_embedding(
|
||||
guidance_scale_tensor, embedding_dim=self.pipe.unet.config.time_cond_proj_dim
|
||||
).to(device=device, dtype=latents.dtype)
|
||||
|
||||
self.pipe._num_timesteps = len(timesteps)
|
||||
for i, t in enumerate(timesteps):
|
||||
# Set the right starting latents
|
||||
# Write latents out and skip
|
||||
if i < idx_start:
|
||||
list_latents_out.append(None)
|
||||
continue
|
||||
elif i == idx_start:
|
||||
latents = latents_start.clone()
|
||||
|
||||
# Mix latents for crossfeeding
|
||||
if i > 0 and list_mixing_coeffs[i] > 0:
|
||||
latents_mixtarget = list_latents_mixing[i - 1].clone()
|
||||
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
|
||||
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.pipe.do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
if ip_adapter_image is not None:
|
||||
added_cond_kwargs["image_embeds"] = image_embeds
|
||||
noise_pred = self.pipe.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep_cond=timestep_cond,
|
||||
cross_attention_kwargs=self.pipe.cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.pipe.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.pipe.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if self.pipe.do_classifier_free_guidance and self.pipe.guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.pipe.guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# Append latents
|
||||
list_latents_out.append(latents.clone())
|
||||
|
||||
|
||||
|
||||
if return_image:
|
||||
return self.latent2image(latents)
|
||||
else:
|
||||
return list_latents_out
|
||||
|
||||
|
||||
|
||||
#%%
|
||||
if __name__ == "__main__":
|
||||
from PIL import Image
|
||||
from diffusers import AutoencoderTiny
|
||||
# pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
pretrained_model_name_or_path = "stabilityai/sdxl-turbo"
|
||||
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16, variant="fp16")
|
||||
pipe.to("cuda")
|
||||
#%
|
||||
# pipe.vae = AutoencoderTiny.from_pretrained('madebyollin/taesdxl', torch_device='cuda', torch_dtype=torch.float16)
|
||||
# pipe.vae = pipe.vae.cuda()
|
||||
#%% resanity
|
||||
import time
|
||||
self = DiffusersHolder(pipe)
|
||||
prompt1 = "photo of underwater landscape, fish, und the sea, incredible detail, high resolution"
|
||||
negative_prompt = "blurry, ugly, pale"
|
||||
num_inference_steps = 4
|
||||
guidance_scale = 0
|
||||
|
||||
self.set_num_inference_steps(num_inference_steps)
|
||||
self.guidance_scale = guidance_scale
|
||||
|
||||
prefix='turbo'
|
||||
for i in range(10):
|
||||
self.set_negative_prompt(negative_prompt)
|
||||
|
||||
text_embeddings = self.get_text_embedding(prompt1)
|
||||
latents_start = self.get_noise(np.random.randint(111111))
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
# img_refx = self.pipe(prompt=prompt1, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale)[0]
|
||||
|
||||
img_refx = self.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=False)
|
||||
|
||||
dt_ref = time.time() - t0
|
||||
img_refx.save(f"x_{prefix}_{i}.jpg")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# xxx
|
||||
|
||||
# self.set_negative_prompt(negative_prompt)
|
||||
# self.set_num_inference_steps(num_inference_steps)
|
||||
# text_embeddings1 = self.get_text_embedding(prompt1)
|
||||
# prompt_embeds1, negative_prompt_embeds1, pooled_prompt_embeds1, negative_pooled_prompt_embeds1 = text_embeddings1
|
||||
# latents_start = self.get_noise(420)
|
||||
# t0 = time.time()
|
||||
# img_dh = self.run_diffusion_sd_xl_resanity(text_embeddings1, latents_start, idx_start=0, return_image=True)
|
||||
# dt_dh = time.time() - t0
|
||||
|
||||
|
||||
|
||||
|
||||
# xxxx
|
||||
# #%%
|
||||
|
||||
# self = DiffusersHolder(pipe)
|
||||
# num_inference_steps = 4
|
||||
# self.set_num_inference_steps(num_inference_steps)
|
||||
# latents_start = self.get_noise(420)
|
||||
# guidance_scale = 0
|
||||
# self.guidance_scale = 0
|
||||
|
||||
# #% get embeddings1
|
||||
# prompt1 = "Photo of a colorful landscape with a blue sky with clouds"
|
||||
# text_embeddings1 = self.get_text_embedding(prompt1)
|
||||
# prompt_embeds1, negative_prompt_embeds1, pooled_prompt_embeds1, negative_pooled_prompt_embeds1 = text_embeddings1
|
||||
|
||||
# #% get embeddings2
|
||||
# prompt2 = "Photo of a tree"
|
||||
# text_embeddings2 = self.get_text_embedding(prompt2)
|
||||
# prompt_embeds2, negative_prompt_embeds2, pooled_prompt_embeds2, negative_pooled_prompt_embeds2 = text_embeddings2
|
||||
|
||||
# latents1 = self.run_diffusion_sd_xl(text_embeddings1, latents_start, idx_start=0, return_image=False)
|
||||
|
||||
# img1 = self.run_diffusion_sd_xl(text_embeddings1, latents_start, idx_start=0, return_image=True)
|
||||
# img1B = self.run_diffusion_sd_xl(text_embeddings1, latents_start, idx_start=0, return_image=True)
|
||||
|
||||
|
||||
|
||||
# # latents2 = self.run_diffusion_sd_xl(text_embeddings2, latents_start, idx_start=0, return_image=False)
|
||||
|
||||
|
||||
# # # check if brings same image if restarted
|
||||
# # img1_return = self.run_diffusion_sd_xl(text_embeddings1, latents1[idx_mix-1], idx_start=idx_start, return_image=True)
|
||||
|
||||
# # mix latents
|
||||
# #%%
|
||||
# idx_mix = 2
|
||||
# fract=0.8
|
||||
# latents_start_mixed = interpolate_spherical(latents1[idx_mix-1], latents2[idx_mix-1], fract)
|
||||
# prompt_embeds = interpolate_spherical(prompt_embeds1, prompt_embeds2, fract)
|
||||
# pooled_prompt_embeds = interpolate_spherical(pooled_prompt_embeds1, pooled_prompt_embeds2, fract)
|
||||
# negative_prompt_embeds = negative_prompt_embeds1
|
||||
# negative_pooled_prompt_embeds = negative_pooled_prompt_embeds1
|
||||
# text_embeddings_mix = [prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds]
|
||||
|
||||
# self.run_diffusion_sd_xl(text_embeddings_mix, latents_start_mixed, idx_start=idx_start, return_image=True)
|
||||
|
||||
|
||||
|
||||
|
26
latentblending/example1_standard.py
Normal file
26
latentblending/example1_standard.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import torch
|
||||
import warnings
|
||||
from blending_engine import BlendingEngine
|
||||
from diffusers_holder import DiffusersHolder
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
torch.set_grad_enabled(False)
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
# %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
|
||||
pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
|
||||
pipe.to("cuda")
|
||||
|
||||
dh = DiffusersHolder(pipe)
|
||||
|
||||
lb = LatentBlending(dh)
|
||||
lb.set_prompt1("photo of underwater landscape, fish, und the sea, incredible detail, high resolution")
|
||||
lb.set_prompt2("rendering of an alien planet, strange plants, strange creatures, surreal")
|
||||
lb.set_negative_prompt("blurry, ugly, pale")
|
||||
|
||||
# Run latent blending
|
||||
lb.run_transition()
|
||||
|
||||
# Save movie
|
||||
lb.write_movie_transition('movie_example1.mp4', duration_transition=12)
|
56
latentblending/example2_multitrans.py
Normal file
56
latentblending/example2_multitrans.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import torch
|
||||
import warnings
|
||||
from blending_engine import BlendingEngine
|
||||
from diffusers_holder import DiffusersHolder
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
from movie_util import concatenate_movies
|
||||
torch.set_grad_enabled(False)
|
||||
torch.backends.cudnn.benchmark = False
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
# %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
|
||||
pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
|
||||
pipe.to('cuda')
|
||||
dh = DiffusersHolder(pipe)
|
||||
|
||||
# %% Let's setup the multi transition
|
||||
fps = 30
|
||||
duration_single_trans = 10
|
||||
|
||||
# Specify a list of prompts below
|
||||
list_prompts = []
|
||||
list_prompts.append("Photo of a house, high detail")
|
||||
list_prompts.append("Photo of an elephant in african savannah")
|
||||
list_prompts.append("photo of a house, high detail")
|
||||
|
||||
|
||||
# You can optionally specify the seeds
|
||||
list_seeds = [95437579, 33259350, 956051013]
|
||||
fp_movie = 'movie_example2.mp4'
|
||||
lb = BlendingEngine(dh)
|
||||
|
||||
list_movie_parts = []
|
||||
for i in range(len(list_prompts) - 1):
|
||||
# For a multi transition we can save some computation time and recycle the latents
|
||||
if i == 0:
|
||||
lb.set_prompt1(list_prompts[i])
|
||||
lb.set_prompt2(list_prompts[i + 1])
|
||||
recycle_img1 = False
|
||||
else:
|
||||
lb.swap_forward()
|
||||
lb.set_prompt2(list_prompts[i + 1])
|
||||
recycle_img1 = True
|
||||
|
||||
fp_movie_part = f"tmp_part_{str(i).zfill(3)}.mp4"
|
||||
fixed_seeds = list_seeds[i:i + 2]
|
||||
# Run latent blending
|
||||
lb.run_transition(
|
||||
recycle_img1=recycle_img1,
|
||||
fixed_seeds=fixed_seeds)
|
||||
|
||||
# Save movie
|
||||
lb.write_movie_transition(fp_movie_part, duration_single_trans)
|
||||
list_movie_parts.append(fp_movie_part)
|
||||
|
||||
# Finally, concatente the result
|
||||
concatenate_movies(fp_movie, list_movie_parts)
|
500
latentblending/gradio_ui.py
Normal file
500
latentblending/gradio_ui.py
Normal file
@@ -0,0 +1,500 @@
|
||||
# Copyright 2022 Lunar Ring. All rights reserved.
|
||||
# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import torch
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.set_grad_enabled(False)
|
||||
import numpy as np
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
import warnings
|
||||
from tqdm.auto import tqdm
|
||||
from PIL import Image
|
||||
from movie_util import MovieSaver, concatenate_movies
|
||||
from latent_blending import LatentBlending
|
||||
from stable_diffusion_holder import StableDiffusionHolder
|
||||
import gradio as gr
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
import shutil
|
||||
import uuid
|
||||
from utils import get_time, add_frames_linear_interp
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
||||
class BlendingFrontend():
|
||||
def __init__(
|
||||
self,
|
||||
sdh,
|
||||
share=False):
|
||||
r"""
|
||||
Gradio Helper Class to collect UI data and start latent blending.
|
||||
Args:
|
||||
sdh:
|
||||
StableDiffusionHolder
|
||||
share: bool
|
||||
Set true to get a shareable gradio link (e.g. for running a remote server)
|
||||
"""
|
||||
self.share = share
|
||||
|
||||
# UI Defaults
|
||||
self.num_inference_steps = 30
|
||||
self.depth_strength = 0.25
|
||||
self.seed1 = 420
|
||||
self.seed2 = 420
|
||||
self.prompt1 = ""
|
||||
self.prompt2 = ""
|
||||
self.negative_prompt = ""
|
||||
self.fps = 30
|
||||
self.duration_video = 8
|
||||
self.t_compute_max_allowed = 10
|
||||
|
||||
self.lb = LatentBlending(sdh)
|
||||
self.lb.sdh.num_inference_steps = self.num_inference_steps
|
||||
self.init_parameters_from_lb()
|
||||
self.init_save_dir()
|
||||
|
||||
# Vars
|
||||
self.list_fp_imgs_current = []
|
||||
self.recycle_img1 = False
|
||||
self.recycle_img2 = False
|
||||
self.list_all_segments = []
|
||||
self.dp_session = ""
|
||||
self.user_id = None
|
||||
|
||||
def init_parameters_from_lb(self):
|
||||
r"""
|
||||
Automatically init parameters from latentblending instance
|
||||
"""
|
||||
self.height = self.lb.sdh.height
|
||||
self.width = self.lb.sdh.width
|
||||
self.guidance_scale = self.lb.guidance_scale
|
||||
self.guidance_scale_mid_damper = self.lb.guidance_scale_mid_damper
|
||||
self.mid_compression_scaler = self.lb.mid_compression_scaler
|
||||
self.branch1_crossfeed_power = self.lb.branch1_crossfeed_power
|
||||
self.branch1_crossfeed_range = self.lb.branch1_crossfeed_range
|
||||
self.branch1_crossfeed_decay = self.lb.branch1_crossfeed_decay
|
||||
self.parental_crossfeed_power = self.lb.parental_crossfeed_power
|
||||
self.parental_crossfeed_range = self.lb.parental_crossfeed_range
|
||||
self.parental_crossfeed_power_decay = self.lb.parental_crossfeed_power_decay
|
||||
|
||||
def init_save_dir(self):
|
||||
r"""
|
||||
Initializes the directory where stuff is being saved.
|
||||
You can specify this directory in a ".env" file in your latentblending root, setting
|
||||
DIR_OUT='/path/to/saving'
|
||||
"""
|
||||
load_dotenv(find_dotenv(), verbose=False)
|
||||
self.dp_out = os.getenv("DIR_OUT")
|
||||
if self.dp_out is None:
|
||||
self.dp_out = ""
|
||||
self.dp_imgs = os.path.join(self.dp_out, "imgs")
|
||||
os.makedirs(self.dp_imgs, exist_ok=True)
|
||||
self.dp_movies = os.path.join(self.dp_out, "movies")
|
||||
os.makedirs(self.dp_movies, exist_ok=True)
|
||||
self.save_empty_image()
|
||||
|
||||
def save_empty_image(self):
|
||||
r"""
|
||||
Saves an empty/black dummy image.
|
||||
"""
|
||||
self.fp_img_empty = os.path.join(self.dp_imgs, 'empty.jpg')
|
||||
Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5)
|
||||
|
||||
def randomize_seed1(self):
|
||||
r"""
|
||||
Randomizes the first seed
|
||||
"""
|
||||
seed = np.random.randint(0, 10000000)
|
||||
self.seed1 = int(seed)
|
||||
print(f"randomize_seed1: new seed = {self.seed1}")
|
||||
return seed
|
||||
|
||||
def randomize_seed2(self):
|
||||
r"""
|
||||
Randomizes the second seed
|
||||
"""
|
||||
seed = np.random.randint(0, 10000000)
|
||||
self.seed2 = int(seed)
|
||||
print(f"randomize_seed2: new seed = {self.seed2}")
|
||||
return seed
|
||||
|
||||
def setup_lb(self, list_ui_vals):
|
||||
r"""
|
||||
Sets all parameters from the UI. Since gradio does not support to pass dictionaries,
|
||||
we have to instead pass keys (list_ui_keys, global) and values (list_ui_vals)
|
||||
"""
|
||||
# Collect latent blending variables
|
||||
self.lb.set_width(list_ui_vals[list_ui_keys.index('width')])
|
||||
self.lb.set_height(list_ui_vals[list_ui_keys.index('height')])
|
||||
self.lb.set_prompt1(list_ui_vals[list_ui_keys.index('prompt1')])
|
||||
self.lb.set_prompt2(list_ui_vals[list_ui_keys.index('prompt2')])
|
||||
self.lb.set_negative_prompt(list_ui_vals[list_ui_keys.index('negative_prompt')])
|
||||
self.lb.guidance_scale = list_ui_vals[list_ui_keys.index('guidance_scale')]
|
||||
self.lb.guidance_scale_mid_damper = list_ui_vals[list_ui_keys.index('guidance_scale_mid_damper')]
|
||||
self.t_compute_max_allowed = list_ui_vals[list_ui_keys.index('duration_compute')]
|
||||
self.lb.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
|
||||
self.lb.sdh.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
|
||||
self.duration_video = list_ui_vals[list_ui_keys.index('duration_video')]
|
||||
self.lb.seed1 = list_ui_vals[list_ui_keys.index('seed1')]
|
||||
self.lb.seed2 = list_ui_vals[list_ui_keys.index('seed2')]
|
||||
self.lb.branch1_crossfeed_power = list_ui_vals[list_ui_keys.index('branch1_crossfeed_power')]
|
||||
self.lb.branch1_crossfeed_range = list_ui_vals[list_ui_keys.index('branch1_crossfeed_range')]
|
||||
self.lb.branch1_crossfeed_decay = list_ui_vals[list_ui_keys.index('branch1_crossfeed_decay')]
|
||||
self.lb.parental_crossfeed_power = list_ui_vals[list_ui_keys.index('parental_crossfeed_power')]
|
||||
self.lb.parental_crossfeed_range = list_ui_vals[list_ui_keys.index('parental_crossfeed_range')]
|
||||
self.lb.parental_crossfeed_power_decay = list_ui_vals[list_ui_keys.index('parental_crossfeed_power_decay')]
|
||||
self.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
|
||||
self.depth_strength = list_ui_vals[list_ui_keys.index('depth_strength')]
|
||||
|
||||
if len(list_ui_vals[list_ui_keys.index('user_id')]) > 1:
|
||||
self.user_id = list_ui_vals[list_ui_keys.index('user_id')]
|
||||
else:
|
||||
# generate new user id
|
||||
self.user_id = uuid.uuid4().hex
|
||||
print(f"made new user_id: {self.user_id} at {get_time('second')}")
|
||||
|
||||
def save_latents(self, fp_latents, list_latents):
|
||||
r"""
|
||||
Saves a latent trajectory on disk, in npy format.
|
||||
"""
|
||||
list_latents_cpu = [l.cpu().numpy() for l in list_latents]
|
||||
np.save(fp_latents, list_latents_cpu)
|
||||
|
||||
def load_latents(self, fp_latents):
|
||||
r"""
|
||||
Loads a latent trajectory from disk, converts to torch tensor.
|
||||
"""
|
||||
list_latents_cpu = np.load(fp_latents)
|
||||
list_latents = [torch.from_numpy(l).to(self.lb.device) for l in list_latents_cpu]
|
||||
return list_latents
|
||||
|
||||
def compute_img1(self, *args):
|
||||
r"""
|
||||
Computes the first transition image and returns it for display.
|
||||
Sets all other transition images and last image to empty (as they are obsolete with this operation)
|
||||
"""
|
||||
list_ui_vals = args
|
||||
self.setup_lb(list_ui_vals)
|
||||
fp_img1 = os.path.join(self.dp_imgs, f"img1_{self.user_id}")
|
||||
img1 = Image.fromarray(self.lb.compute_latents1(return_image=True))
|
||||
img1.save(fp_img1 + ".jpg")
|
||||
self.save_latents(fp_img1 + ".npy", self.lb.tree_latents[0])
|
||||
self.recycle_img1 = True
|
||||
self.recycle_img2 = False
|
||||
return [fp_img1 + ".jpg", self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
|
||||
|
||||
def compute_img2(self, *args):
|
||||
r"""
|
||||
Computes the last transition image and returns it for display.
|
||||
Sets all other transition images to empty (as they are obsolete with this operation)
|
||||
"""
|
||||
if not os.path.isfile(os.path.join(self.dp_imgs, f"img1_{self.user_id}.jpg")): # don't do anything
|
||||
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
|
||||
list_ui_vals = args
|
||||
self.setup_lb(list_ui_vals)
|
||||
|
||||
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
|
||||
fp_img2 = os.path.join(self.dp_imgs, f"img2_{self.user_id}")
|
||||
img2 = Image.fromarray(self.lb.compute_latents2(return_image=True))
|
||||
img2.save(fp_img2 + '.jpg')
|
||||
self.save_latents(fp_img2 + ".npy", self.lb.tree_latents[-1])
|
||||
self.recycle_img2 = True
|
||||
# fixme save seeds. change filenames?
|
||||
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2 + ".jpg", self.user_id]
|
||||
|
||||
def compute_transition(self, *args):
|
||||
r"""
|
||||
Computes transition images and movie.
|
||||
"""
|
||||
list_ui_vals = args
|
||||
self.setup_lb(list_ui_vals)
|
||||
print("STARTING TRANSITION...")
|
||||
fixed_seeds = [self.seed1, self.seed2]
|
||||
# Inject loaded latents (other user interference)
|
||||
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
|
||||
self.lb.tree_latents[-1] = self.load_latents(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"))
|
||||
imgs_transition = self.lb.run_transition(
|
||||
recycle_img1=self.recycle_img1,
|
||||
recycle_img2=self.recycle_img2,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
depth_strength=self.depth_strength,
|
||||
t_compute_max_allowed=self.t_compute_max_allowed,
|
||||
fixed_seeds=fixed_seeds)
|
||||
print(f"Latent Blending pass finished ({get_time('second')}). Resulted in {len(imgs_transition)} images")
|
||||
|
||||
# Subselect three preview images
|
||||
idx_img_prev = np.round(np.linspace(0, len(imgs_transition) - 1, 5)[1:-1]).astype(np.int32)
|
||||
|
||||
list_imgs_preview = []
|
||||
for j in idx_img_prev:
|
||||
list_imgs_preview.append(Image.fromarray(imgs_transition[j]))
|
||||
|
||||
# Save the preview imgs as jpgs on disk so we are not sending umcompressed data around
|
||||
current_timestamp = get_time('second')
|
||||
self.list_fp_imgs_current = []
|
||||
for i in range(len(list_imgs_preview)):
|
||||
fp_img = os.path.join(self.dp_imgs, f"img_preview_{i}_{current_timestamp}.jpg")
|
||||
list_imgs_preview[i].save(fp_img)
|
||||
self.list_fp_imgs_current.append(fp_img)
|
||||
# Insert cheap frames for the movie
|
||||
imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration_video, self.fps)
|
||||
|
||||
# Save as movie
|
||||
self.fp_movie = self.get_fp_video_last()
|
||||
if os.path.isfile(self.fp_movie):
|
||||
os.remove(self.fp_movie)
|
||||
ms = MovieSaver(self.fp_movie, fps=self.fps)
|
||||
for img in tqdm(imgs_transition_ext):
|
||||
ms.write_frame(img)
|
||||
ms.finalize()
|
||||
print("DONE SAVING MOVIE! SENDING BACK...")
|
||||
|
||||
# Assemble Output, updating the preview images and le movie
|
||||
list_return = self.list_fp_imgs_current + [self.fp_movie]
|
||||
return list_return
|
||||
|
||||
def stack_forward(self, prompt2, seed2):
|
||||
r"""
|
||||
Allows to generate multi-segment movies. Sets last image -> first image with all
|
||||
relevant parameters.
|
||||
"""
|
||||
# Save preview images, prompts and seeds into dictionary for stacking
|
||||
if len(self.list_all_segments) == 0:
|
||||
timestamp_session = get_time('second')
|
||||
self.dp_session = os.path.join(self.dp_out, f"session_{timestamp_session}")
|
||||
os.makedirs(self.dp_session)
|
||||
|
||||
idx_segment = len(self.list_all_segments)
|
||||
dp_segment = os.path.join(self.dp_session, f"segment_{str(idx_segment).zfill(3)}")
|
||||
|
||||
self.list_all_segments.append(dp_segment)
|
||||
self.lb.write_imgs_transition(dp_segment)
|
||||
|
||||
fp_movie_last = self.get_fp_video_last()
|
||||
fp_movie_next = self.get_fp_video_next()
|
||||
|
||||
shutil.copyfile(fp_movie_last, fp_movie_next)
|
||||
|
||||
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
|
||||
self.lb.tree_latents[-1] = self.load_latents(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"))
|
||||
self.lb.swap_forward()
|
||||
|
||||
shutil.copyfile(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"), os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
|
||||
fp_multi = self.multi_concat()
|
||||
list_out = [fp_multi]
|
||||
|
||||
list_out.extend([os.path.join(self.dp_imgs, f"img2_{self.user_id}.jpg")])
|
||||
list_out.extend([self.fp_img_empty] * 4)
|
||||
list_out.append(gr.update(interactive=False, value=prompt2))
|
||||
list_out.append(gr.update(interactive=False, value=seed2))
|
||||
list_out.append("")
|
||||
list_out.append(np.random.randint(0, 10000000))
|
||||
print(f"stack_forward: fp_multi {fp_multi}")
|
||||
return list_out
|
||||
|
||||
def multi_concat(self):
|
||||
r"""
|
||||
Concatentates all stacked segments into one long movie.
|
||||
"""
|
||||
list_fp_movies = self.get_fp_video_all()
|
||||
# Concatenate movies and save
|
||||
fp_final = os.path.join(self.dp_session, f"concat_{self.user_id}.mp4")
|
||||
concatenate_movies(fp_final, list_fp_movies)
|
||||
return fp_final
|
||||
|
||||
def get_fp_video_all(self):
|
||||
r"""
|
||||
Collects all stacked movie segments.
|
||||
"""
|
||||
list_all = os.listdir(self.dp_movies)
|
||||
str_beg = f"movie_{self.user_id}_"
|
||||
list_user = [l for l in list_all if str_beg in l]
|
||||
list_user.sort()
|
||||
list_user = [os.path.join(self.dp_movies, l) for l in list_user]
|
||||
return list_user
|
||||
|
||||
def get_fp_video_next(self):
|
||||
r"""
|
||||
Gets the filepath of the next movie segment.
|
||||
"""
|
||||
list_videos = self.get_fp_video_all()
|
||||
if len(list_videos) == 0:
|
||||
idx_next = 0
|
||||
else:
|
||||
idx_next = len(list_videos)
|
||||
fp_video_next = os.path.join(self.dp_movies, f"movie_{self.user_id}_{str(idx_next).zfill(3)}.mp4")
|
||||
return fp_video_next
|
||||
|
||||
def get_fp_video_last(self):
|
||||
r"""
|
||||
Gets the current video that was saved.
|
||||
"""
|
||||
fp_video_last = os.path.join(self.dp_movies, f"last_{self.user_id}.mp4")
|
||||
return fp_video_last
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
|
||||
fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.ckpt")
|
||||
bf = BlendingFrontend(StableDiffusionHolder(fp_ckpt))
|
||||
# self = BlendingFrontend(None)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.HTML("""<h1>Latent Blending</h1>
|
||||
<p>Create butter-smooth transitions between prompts, powered by stable diffusion</p>
|
||||
<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
|
||||
<br/>
|
||||
<a href="https://huggingface.co/spaces/lunarring/latentblending?duplicate=true">
|
||||
<img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
|
||||
</p>""")
|
||||
|
||||
with gr.Row():
|
||||
prompt1 = gr.Textbox(label="prompt 1")
|
||||
prompt2 = gr.Textbox(label="prompt 2")
|
||||
|
||||
with gr.Row():
|
||||
duration_compute = gr.Slider(10, 25, bf.t_compute_max_allowed, step=1, label='waiting time', interactive=True)
|
||||
duration_video = gr.Slider(1, 100, bf.duration_video, step=0.1, label='video duration', interactive=True)
|
||||
height = gr.Slider(256, 1024, bf.height, step=128, label='height', interactive=True)
|
||||
width = gr.Slider(256, 1024, bf.width, step=128, label='width', interactive=True)
|
||||
|
||||
with gr.Accordion("Advanced Settings (click to expand)", open=False):
|
||||
|
||||
with gr.Accordion("Diffusion settings", open=True):
|
||||
with gr.Row():
|
||||
num_inference_steps = gr.Slider(5, 100, bf.num_inference_steps, step=1, label='num_inference_steps', interactive=True)
|
||||
guidance_scale = gr.Slider(1, 25, bf.guidance_scale, step=0.1, label='guidance_scale', interactive=True)
|
||||
negative_prompt = gr.Textbox(label="negative prompt")
|
||||
|
||||
with gr.Accordion("Seed control: adjust seeds for first and last images", open=True):
|
||||
with gr.Row():
|
||||
b_newseed1 = gr.Button("randomize seed 1", variant='secondary')
|
||||
seed1 = gr.Number(bf.seed1, label="seed 1", interactive=True)
|
||||
seed2 = gr.Number(bf.seed2, label="seed 2", interactive=True)
|
||||
b_newseed2 = gr.Button("randomize seed 2", variant='secondary')
|
||||
|
||||
with gr.Accordion("Last image crossfeeding.", open=True):
|
||||
with gr.Row():
|
||||
branch1_crossfeed_power = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_power, step=0.01, label='branch1 crossfeed power', interactive=True)
|
||||
branch1_crossfeed_range = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_range, step=0.01, label='branch1 crossfeed range', interactive=True)
|
||||
branch1_crossfeed_decay = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_decay, step=0.01, label='branch1 crossfeed decay', interactive=True)
|
||||
|
||||
with gr.Accordion("Transition settings", open=True):
|
||||
with gr.Row():
|
||||
parental_crossfeed_power = gr.Slider(0.0, 1.0, bf.parental_crossfeed_power, step=0.01, label='parental crossfeed power', interactive=True)
|
||||
parental_crossfeed_range = gr.Slider(0.0, 1.0, bf.parental_crossfeed_range, step=0.01, label='parental crossfeed range', interactive=True)
|
||||
parental_crossfeed_power_decay = gr.Slider(0.0, 1.0, bf.parental_crossfeed_power_decay, step=0.01, label='parental crossfeed decay', interactive=True)
|
||||
with gr.Row():
|
||||
depth_strength = gr.Slider(0.01, 0.99, bf.depth_strength, step=0.01, label='depth_strength', interactive=True)
|
||||
guidance_scale_mid_damper = gr.Slider(0.01, 2.0, bf.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
|
||||
|
||||
with gr.Row():
|
||||
b_compute1 = gr.Button('step1: compute first image', variant='primary')
|
||||
b_compute2 = gr.Button('step2: compute last image', variant='primary')
|
||||
b_compute_transition = gr.Button('step3: compute transition', variant='primary')
|
||||
|
||||
with gr.Row():
|
||||
img1 = gr.Image(label="1/5")
|
||||
img2 = gr.Image(label="2/5", show_progress=False)
|
||||
img3 = gr.Image(label="3/5", show_progress=False)
|
||||
img4 = gr.Image(label="4/5", show_progress=False)
|
||||
img5 = gr.Image(label="5/5")
|
||||
|
||||
with gr.Row():
|
||||
vid_single = gr.Video(label="current single trans")
|
||||
vid_multi = gr.Video(label="concatented multi trans")
|
||||
|
||||
with gr.Row():
|
||||
b_stackforward = gr.Button('append last movie segment (left) to multi movie (right)', variant='primary')
|
||||
|
||||
with gr.Row():
|
||||
gr.Markdown(
|
||||
"""
|
||||
# Parameters
|
||||
## Main
|
||||
- waiting time: set your waiting time for the transition. high values = better quality
|
||||
- video duration: seconds per segment
|
||||
- height/width: in pixels
|
||||
|
||||
## Diffusion settings
|
||||
- num_inference_steps: number of diffusion steps
|
||||
- guidance_scale: latent blending seems to prefer lower values here
|
||||
- negative prompt: enter negative prompt here, applied for all images
|
||||
|
||||
## Last image crossfeeding
|
||||
- branch1_crossfeed_power: Controls the level of cross-feeding between the first and last image branch. For preserving structures.
|
||||
- branch1_crossfeed_range: Sets the duration of active crossfeed during development. High values enforce strong structural similarity.
|
||||
- branch1_crossfeed_decay: Sets decay for branch1_crossfeed_power. Lower values make the decay stronger across the range.
|
||||
|
||||
## Transition settings
|
||||
- parental_crossfeed_power: Similar to branch1_crossfeed_power, however applied for the images withinin the transition.
|
||||
- parental_crossfeed_range: Similar to branch1_crossfeed_range, however applied for the images withinin the transition.
|
||||
- parental_crossfeed_power_decay: Similar to branch1_crossfeed_decay, however applied for the images withinin the transition.
|
||||
- depth_strength: Determines when the blending process will begin in terms of diffusion steps. Low values more inventive but can cause motion.
|
||||
- guidance_scale_mid_damper: Decreases the guidance scale in the middle of a transition.
|
||||
""")
|
||||
|
||||
with gr.Row():
|
||||
user_id = gr.Textbox(label="user id", interactive=False)
|
||||
|
||||
# Collect all UI elemts in list to easily pass as inputs in gradio
|
||||
dict_ui_elem = {}
|
||||
dict_ui_elem["prompt1"] = prompt1
|
||||
dict_ui_elem["negative_prompt"] = negative_prompt
|
||||
dict_ui_elem["prompt2"] = prompt2
|
||||
|
||||
dict_ui_elem["duration_compute"] = duration_compute
|
||||
dict_ui_elem["duration_video"] = duration_video
|
||||
dict_ui_elem["height"] = height
|
||||
dict_ui_elem["width"] = width
|
||||
|
||||
dict_ui_elem["depth_strength"] = depth_strength
|
||||
dict_ui_elem["branch1_crossfeed_power"] = branch1_crossfeed_power
|
||||
dict_ui_elem["branch1_crossfeed_range"] = branch1_crossfeed_range
|
||||
dict_ui_elem["branch1_crossfeed_decay"] = branch1_crossfeed_decay
|
||||
|
||||
dict_ui_elem["num_inference_steps"] = num_inference_steps
|
||||
dict_ui_elem["guidance_scale"] = guidance_scale
|
||||
dict_ui_elem["guidance_scale_mid_damper"] = guidance_scale_mid_damper
|
||||
dict_ui_elem["seed1"] = seed1
|
||||
dict_ui_elem["seed2"] = seed2
|
||||
|
||||
dict_ui_elem["parental_crossfeed_range"] = parental_crossfeed_range
|
||||
dict_ui_elem["parental_crossfeed_power"] = parental_crossfeed_power
|
||||
dict_ui_elem["parental_crossfeed_power_decay"] = parental_crossfeed_power_decay
|
||||
dict_ui_elem["user_id"] = user_id
|
||||
|
||||
# Convert to list, as gradio doesn't seem to accept dicts
|
||||
list_ui_vals = []
|
||||
list_ui_keys = []
|
||||
for k in dict_ui_elem.keys():
|
||||
list_ui_vals.append(dict_ui_elem[k])
|
||||
list_ui_keys.append(k)
|
||||
bf.list_ui_keys = list_ui_keys
|
||||
|
||||
b_newseed1.click(bf.randomize_seed1, outputs=seed1)
|
||||
b_newseed2.click(bf.randomize_seed2, outputs=seed2)
|
||||
b_compute1.click(bf.compute_img1, inputs=list_ui_vals, outputs=[img1, img2, img3, img4, img5, user_id])
|
||||
b_compute2.click(bf.compute_img2, inputs=list_ui_vals, outputs=[img2, img3, img4, img5, user_id])
|
||||
b_compute_transition.click(bf.compute_transition,
|
||||
inputs=list_ui_vals,
|
||||
outputs=[img2, img3, img4, vid_single])
|
||||
|
||||
b_stackforward.click(bf.stack_forward,
|
||||
inputs=[prompt2, seed2],
|
||||
outputs=[vid_multi, img1, img2, img3, img4, img5, prompt1, seed1, prompt2])
|
||||
|
||||
demo.launch(share=bf.share, inbrowser=True, inline=False)
|
301
latentblending/movie_util.py
Normal file
301
latentblending/movie_util.py
Normal file
@@ -0,0 +1,301 @@
|
||||
# Copyright 2022 Lunar Ring. All rights reserved.
|
||||
# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import subprocess
|
||||
import os
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import cv2
|
||||
from typing import List
|
||||
import ffmpeg # pip install ffmpeg-python. if error with broken pipe: conda update ffmpeg
|
||||
|
||||
|
||||
class MovieSaver():
|
||||
def __init__(
|
||||
self,
|
||||
fp_out: str,
|
||||
fps: int = 24,
|
||||
shape_hw: List[int] = None,
|
||||
crf: int = 21,
|
||||
codec: str = 'libx264',
|
||||
preset: str = 'fast',
|
||||
pix_fmt: str = 'yuv420p',
|
||||
silent_ffmpeg: bool = True):
|
||||
r"""
|
||||
Initializes movie saver class - a human friendly ffmpeg wrapper.
|
||||
After you init the class, you can dump numpy arrays x into moviesaver.write_frame(x).
|
||||
Don't forget toi finalize movie file with moviesaver.finalize().
|
||||
Args:
|
||||
fp_out: str
|
||||
Output file name. If it already exists, it will be deleted.
|
||||
fps: int
|
||||
Frames per second.
|
||||
shape_hw: List[int, int]
|
||||
Output shape, optional argument. Can be initialized automatically when first frame is written.
|
||||
crf: int
|
||||
ffmpeg doc: the range of the CRF scale is 0–51, where 0 is lossless
|
||||
(for 8 bit only, for 10 bit use -qp 0), 23 is the default, and 51 is worst quality possible.
|
||||
A lower value generally leads to higher quality, and a subjectively sane range is 17–28.
|
||||
Consider 17 or 18 to be visually lossless or nearly so;
|
||||
it should look the same or nearly the same as the input but it isn't technically lossless.
|
||||
The range is exponential, so increasing the CRF value +6 results in
|
||||
roughly half the bitrate / file size, while -6 leads to roughly twice the bitrate.
|
||||
codec: int
|
||||
Number of diffusion steps. Larger values will take more compute time.
|
||||
preset: str
|
||||
Choose between ultrafast, superfast, veryfast, faster, fast, medium, slow, slower, veryslow.
|
||||
ffmpeg doc: A preset is a collection of options that will provide a certain encoding speed
|
||||
to compression ratio. A slower preset will provide better compression
|
||||
(compression is quality per filesize).
|
||||
This means that, for example, if you target a certain file size or constant bit rate,
|
||||
you will achieve better quality with a slower preset. Similarly, for constant quality encoding,
|
||||
you will simply save bitrate by choosing a slower preset.
|
||||
pix_fmt: str
|
||||
Pixel format. Run 'ffmpeg -pix_fmts' in your shell to see all options.
|
||||
silent_ffmpeg: bool
|
||||
Surpress the output from ffmpeg.
|
||||
"""
|
||||
if len(os.path.split(fp_out)[0]) > 0:
|
||||
assert os.path.isdir(os.path.split(fp_out)[0]), "Directory does not exist!"
|
||||
|
||||
self.fp_out = fp_out
|
||||
self.fps = fps
|
||||
self.crf = crf
|
||||
self.pix_fmt = pix_fmt
|
||||
self.codec = codec
|
||||
self.preset = preset
|
||||
self.silent_ffmpeg = silent_ffmpeg
|
||||
|
||||
if os.path.isfile(fp_out):
|
||||
os.remove(fp_out)
|
||||
|
||||
self.init_done = False
|
||||
self.nmb_frames = 0
|
||||
if shape_hw is None:
|
||||
self.shape_hw = [-1, 1]
|
||||
else:
|
||||
if len(shape_hw) == 2:
|
||||
shape_hw.append(3)
|
||||
self.shape_hw = shape_hw
|
||||
self.initialize()
|
||||
|
||||
print(f"MovieSaver initialized. fps={fps} crf={crf} pix_fmt={pix_fmt} codec={codec} preset={preset}")
|
||||
|
||||
def initialize(self):
|
||||
args = (
|
||||
ffmpeg
|
||||
.input('pipe:', format='rawvideo', pix_fmt='rgb24', s='{}x{}'.format(self.shape_hw[1], self.shape_hw[0]), framerate=self.fps)
|
||||
.output(self.fp_out, crf=self.crf, pix_fmt=self.pix_fmt, c=self.codec, preset=self.preset)
|
||||
.overwrite_output()
|
||||
.compile()
|
||||
)
|
||||
if self.silent_ffmpeg:
|
||||
self.ffmpg_process = subprocess.Popen(args, stdin=subprocess.PIPE, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL)
|
||||
else:
|
||||
self.ffmpg_process = subprocess.Popen(args, stdin=subprocess.PIPE)
|
||||
self.init_done = True
|
||||
self.shape_hw = tuple(self.shape_hw)
|
||||
print(f"Initialization done. Movie shape: {self.shape_hw}")
|
||||
|
||||
def write_frame(self, out_frame: np.ndarray):
|
||||
r"""
|
||||
Function to dump a numpy array as frame of a movie.
|
||||
Args:
|
||||
out_frame: np.ndarray
|
||||
Numpy array, in np.uint8 format. Convert with np.astype(x, np.uint8).
|
||||
Dim 0: y
|
||||
Dim 1: x
|
||||
Dim 2: RGB
|
||||
"""
|
||||
assert out_frame.dtype == np.uint8, "Convert to np.uint8 before"
|
||||
assert len(out_frame.shape) == 3, "out_frame needs to be three dimensional, Y X C"
|
||||
assert out_frame.shape[2] == 3, f"need three color channels, but you provided {out_frame.shape[2]}."
|
||||
|
||||
if not self.init_done:
|
||||
self.shape_hw = out_frame.shape
|
||||
self.initialize()
|
||||
|
||||
assert self.shape_hw == out_frame.shape, f"You cannot change the image size after init. Initialized with {self.shape_hw}, out_frame {out_frame.shape}"
|
||||
|
||||
# write frame
|
||||
self.ffmpg_process.stdin.write(
|
||||
out_frame
|
||||
.astype(np.uint8)
|
||||
.tobytes()
|
||||
)
|
||||
|
||||
self.nmb_frames += 1
|
||||
|
||||
def finalize(self):
|
||||
r"""
|
||||
Call this function to finalize the movie. If you forget to call it your movie will be garbage.
|
||||
"""
|
||||
if self.nmb_frames == 0:
|
||||
print("You did not write any frames yet! nmb_frames = 0. Cannot save.")
|
||||
return
|
||||
self.ffmpg_process.stdin.close()
|
||||
self.ffmpg_process.wait()
|
||||
duration = int(self.nmb_frames / self.fps)
|
||||
print(f"Movie saved, {duration}s playtime, watch here: \n{self.fp_out}")
|
||||
|
||||
|
||||
def concatenate_movies(fp_final: str, list_fp_movies: List[str]):
|
||||
r"""
|
||||
Concatenate multiple movie segments into one long movie, using ffmpeg.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fp_final : str
|
||||
Full path of the final movie file. Should end with .mp4
|
||||
list_fp_movies : list[str]
|
||||
List of full paths of movie segments.
|
||||
"""
|
||||
assert fp_final[-4] == ".", "fp_final seems to miss file extension: {fp_final}"
|
||||
for fp in list_fp_movies:
|
||||
assert os.path.isfile(fp), f"Input movie does not exist: {fp}"
|
||||
assert os.path.getsize(fp) > 100, f"Input movie seems empty: {fp}"
|
||||
|
||||
if os.path.isfile(fp_final):
|
||||
os.remove(fp_final)
|
||||
|
||||
# make a list for ffmpeg
|
||||
list_concat = []
|
||||
for fp_part in list_fp_movies:
|
||||
list_concat.append(f"""file '{fp_part}'""")
|
||||
|
||||
# save this list
|
||||
fp_list = "tmp_move.txt"
|
||||
with open(fp_list, "w") as fa:
|
||||
for item in list_concat:
|
||||
fa.write("%s\n" % item)
|
||||
|
||||
cmd = f'ffmpeg -f concat -safe 0 -i {fp_list} -c copy {fp_final}'
|
||||
subprocess.call(cmd, shell=True)
|
||||
os.remove(fp_list)
|
||||
if os.path.isfile(fp_final):
|
||||
print(f"concatenate_movies: success! Watch here: {fp_final}")
|
||||
|
||||
|
||||
def add_sound(fp_final, fp_silentmovie, fp_sound):
|
||||
cmd = f'ffmpeg -i {fp_silentmovie} -i {fp_sound} -c copy -map 0:v:0 -map 1:a:0 {fp_final}'
|
||||
subprocess.call(cmd, shell=True)
|
||||
if os.path.isfile(fp_final):
|
||||
print(f"add_sound: success! Watch here: {fp_final}")
|
||||
|
||||
|
||||
def add_subtitles_to_video(
|
||||
fp_input: str,
|
||||
fp_output: str,
|
||||
subtitles: list,
|
||||
fontsize: int = 50,
|
||||
font_name: str = "Arial",
|
||||
color: str = 'yellow'
|
||||
):
|
||||
from moviepy.editor import VideoFileClip, TextClip, CompositeVideoClip
|
||||
r"""
|
||||
Function to add subtitles to a video.
|
||||
|
||||
Args:
|
||||
fp_input (str): File path of the input video.
|
||||
fp_output (str): File path of the output video with subtitles.
|
||||
subtitles (list): List of dictionaries containing subtitle information
|
||||
(start, duration, text). Example:
|
||||
subtitles = [
|
||||
{"start": 1, "duration": 3, "text": "hello test"},
|
||||
{"start": 4, "duration": 2, "text": "this works"},
|
||||
]
|
||||
fontsize (int): Font size of the subtitles.
|
||||
font_name (str): Font name of the subtitles.
|
||||
color (str): Color of the subtitles.
|
||||
"""
|
||||
|
||||
# Check if the input file exists
|
||||
if not os.path.isfile(fp_input):
|
||||
raise FileNotFoundError(f"Input file not found: {fp_input}")
|
||||
|
||||
# Check the subtitles format and sort them by the start time
|
||||
time_points = []
|
||||
for subtitle in subtitles:
|
||||
if not isinstance(subtitle, dict):
|
||||
raise ValueError("Each subtitle must be a dictionary containing 'start', 'duration' and 'text'.")
|
||||
if not all(key in subtitle for key in ["start", "duration", "text"]):
|
||||
raise ValueError("Each subtitle dictionary must contain 'start', 'duration' and 'text'.")
|
||||
if subtitle['start'] < 0 or subtitle['duration'] <= 0:
|
||||
raise ValueError("'start' should be non-negative and 'duration' should be positive.")
|
||||
time_points.append((subtitle['start'], subtitle['start'] + subtitle['duration']))
|
||||
|
||||
# Check for overlaps
|
||||
time_points.sort()
|
||||
for i in range(1, len(time_points)):
|
||||
if time_points[i][0] < time_points[i - 1][1]:
|
||||
raise ValueError("Subtitle time intervals should not overlap.")
|
||||
|
||||
# Load the video clip
|
||||
video = VideoFileClip(fp_input)
|
||||
|
||||
# Create a list to store subtitle clips
|
||||
subtitle_clips = []
|
||||
|
||||
# Loop through the subtitle information and create TextClip for each
|
||||
for subtitle in subtitles:
|
||||
text_clip = TextClip(subtitle["text"], fontsize=fontsize, color=color, font=font_name)
|
||||
text_clip = text_clip.set_position(('center', 'bottom')).set_start(subtitle["start"]).set_duration(subtitle["duration"])
|
||||
subtitle_clips.append(text_clip)
|
||||
|
||||
# Overlay the subtitles on the video
|
||||
video = CompositeVideoClip([video] + subtitle_clips)
|
||||
|
||||
# Write the final clip to a new file
|
||||
video.write_videofile(fp_output)
|
||||
|
||||
|
||||
|
||||
class MovieReader():
|
||||
r"""
|
||||
Class to read in a movie.
|
||||
"""
|
||||
|
||||
def __init__(self, fp_movie):
|
||||
self.video_player_object = cv2.VideoCapture(fp_movie)
|
||||
self.nmb_frames = int(self.video_player_object.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
self.fps_movie = int(self.video_player_object.get(cv2.CAP_PROP_FPS))
|
||||
self.shape = [100, 100, 3]
|
||||
self.shape_is_set = False
|
||||
|
||||
def get_next_frame(self):
|
||||
success, image = self.video_player_object.read()
|
||||
if success:
|
||||
if not self.shape_is_set:
|
||||
self.shape_is_set = True
|
||||
self.shape = image.shape
|
||||
return image
|
||||
else:
|
||||
return np.zeros(self.shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fps = 2
|
||||
list_fp_movies = []
|
||||
for k in range(4):
|
||||
fp_movie = f"/tmp/my_random_movie_{k}.mp4"
|
||||
list_fp_movies.append(fp_movie)
|
||||
ms = MovieSaver(fp_movie, fps=fps)
|
||||
for fn in tqdm(range(30)):
|
||||
img = (np.random.rand(512, 1024, 3) * 255).astype(np.uint8)
|
||||
ms.write_frame(img)
|
||||
ms.finalize()
|
||||
|
||||
fp_final = "/tmp/my_concatenated_movie.mp4"
|
||||
concatenate_movies(fp_final, list_fp_movies)
|
262
latentblending/utils.py
Normal file
262
latentblending/utils.py
Normal file
@@ -0,0 +1,262 @@
|
||||
# Copyright 2022 Lunar Ring. All rights reserved.
|
||||
# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
torch.backends.cudnn.benchmark = False
|
||||
import numpy as np
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
import time
|
||||
import warnings
|
||||
import datetime
|
||||
from typing import List, Union
|
||||
torch.set_grad_enabled(False)
|
||||
import yaml
|
||||
import PIL
|
||||
|
||||
@torch.no_grad()
|
||||
def interpolate_spherical(p0, p1, fract_mixing: float):
|
||||
r"""
|
||||
Helper function to correctly mix two random variables using spherical interpolation.
|
||||
See https://en.wikipedia.org/wiki/Slerp
|
||||
The function will always cast up to float64 for sake of extra 4.
|
||||
Args:
|
||||
p0:
|
||||
First tensor for interpolation
|
||||
p1:
|
||||
Second tensor for interpolation
|
||||
fract_mixing: float
|
||||
Mixing coefficient of interval [0, 1].
|
||||
0 will return in p0
|
||||
1 will return in p1
|
||||
0.x will return a mix between both preserving angular velocity.
|
||||
"""
|
||||
|
||||
if p0.dtype == torch.float16:
|
||||
recast_to = 'fp16'
|
||||
else:
|
||||
recast_to = 'fp32'
|
||||
|
||||
p0 = p0.double()
|
||||
p1 = p1.double()
|
||||
norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
|
||||
epsilon = 1e-7
|
||||
dot = torch.sum(p0 * p1) / norm
|
||||
dot = dot.clamp(-1 + epsilon, 1 - epsilon)
|
||||
|
||||
theta_0 = torch.arccos(dot)
|
||||
sin_theta_0 = torch.sin(theta_0)
|
||||
theta_t = theta_0 * fract_mixing
|
||||
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = torch.sin(theta_t) / sin_theta_0
|
||||
interp = p0 * s0 + p1 * s1
|
||||
|
||||
if recast_to == 'fp16':
|
||||
interp = interp.half()
|
||||
elif recast_to == 'fp32':
|
||||
interp = interp.float()
|
||||
|
||||
return interp
|
||||
|
||||
|
||||
def interpolate_linear(p0, p1, fract_mixing):
|
||||
r"""
|
||||
Helper function to mix two variables using standard linear interpolation.
|
||||
Args:
|
||||
p0:
|
||||
First tensor / np.ndarray for interpolation
|
||||
p1:
|
||||
Second tensor / np.ndarray for interpolation
|
||||
fract_mixing: float
|
||||
Mixing coefficient of interval [0, 1].
|
||||
0 will return in p0
|
||||
1 will return in p1
|
||||
0.x will return a linear mix between both.
|
||||
"""
|
||||
reconvert_uint8 = False
|
||||
if type(p0) is np.ndarray and p0.dtype == 'uint8':
|
||||
reconvert_uint8 = True
|
||||
p0 = p0.astype(np.float64)
|
||||
|
||||
if type(p1) is np.ndarray and p1.dtype == 'uint8':
|
||||
reconvert_uint8 = True
|
||||
p1 = p1.astype(np.float64)
|
||||
|
||||
interp = (1 - fract_mixing) * p0 + fract_mixing * p1
|
||||
|
||||
if reconvert_uint8:
|
||||
interp = np.clip(interp, 0, 255).astype(np.uint8)
|
||||
|
||||
return interp
|
||||
|
||||
|
||||
def add_frames_linear_interp(
|
||||
list_imgs: List[np.ndarray],
|
||||
fps_target: Union[float, int] = None,
|
||||
duration_target: Union[float, int] = None,
|
||||
nmb_frames_target: int = None):
|
||||
r"""
|
||||
Helper function to cheaply increase the number of frames given a list of images,
|
||||
by virtue of standard linear interpolation.
|
||||
The number of inserted frames will be automatically adjusted so that the total of number
|
||||
of frames can be fixed precisely, using a random shuffling technique.
|
||||
The function allows 1:1 comparisons between transitions as videos.
|
||||
|
||||
Args:
|
||||
list_imgs: List[np.ndarray)
|
||||
List of images, between each image new frames will be inserted via linear interpolation.
|
||||
fps_target:
|
||||
OptionA: specify here the desired frames per second.
|
||||
duration_target:
|
||||
OptionA: specify here the desired duration of the transition in seconds.
|
||||
nmb_frames_target:
|
||||
OptionB: directly fix the total number of frames of the output.
|
||||
"""
|
||||
|
||||
# Sanity
|
||||
if nmb_frames_target is not None and fps_target is not None:
|
||||
raise ValueError("You cannot specify both fps_target and nmb_frames_target")
|
||||
if fps_target is None:
|
||||
assert nmb_frames_target is not None, "Either specify nmb_frames_target or nmb_frames_target"
|
||||
if nmb_frames_target is None:
|
||||
assert fps_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
|
||||
assert duration_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
|
||||
nmb_frames_target = fps_target * duration_target
|
||||
|
||||
# Get number of frames that are missing
|
||||
nmb_frames_diff = len(list_imgs) - 1
|
||||
nmb_frames_missing = nmb_frames_target - nmb_frames_diff - 1
|
||||
|
||||
if nmb_frames_missing < 1:
|
||||
return list_imgs
|
||||
|
||||
if type(list_imgs[0]) == PIL.Image.Image:
|
||||
list_imgs = [np.asarray(l) for l in list_imgs]
|
||||
list_imgs_float = [img.astype(np.float32) for img in list_imgs]
|
||||
# Distribute missing frames, append nmb_frames_to_insert(i) frames for each frame
|
||||
mean_nmb_frames_insert = nmb_frames_missing / nmb_frames_diff
|
||||
constfact = np.floor(mean_nmb_frames_insert)
|
||||
remainder_x = 1 - (mean_nmb_frames_insert - constfact)
|
||||
nmb_iter = 0
|
||||
while True:
|
||||
nmb_frames_to_insert = np.random.rand(nmb_frames_diff)
|
||||
nmb_frames_to_insert[nmb_frames_to_insert <= remainder_x] = 0
|
||||
nmb_frames_to_insert[nmb_frames_to_insert > remainder_x] = 1
|
||||
nmb_frames_to_insert += constfact
|
||||
if np.sum(nmb_frames_to_insert) == nmb_frames_missing:
|
||||
break
|
||||
nmb_iter += 1
|
||||
if nmb_iter > 100000:
|
||||
print("add_frames_linear_interp: issue with inserting the right number of frames")
|
||||
break
|
||||
|
||||
nmb_frames_to_insert = nmb_frames_to_insert.astype(np.int32)
|
||||
list_imgs_interp = []
|
||||
for i in range(len(list_imgs_float) - 1):
|
||||
img0 = list_imgs_float[i]
|
||||
img1 = list_imgs_float[i + 1]
|
||||
list_imgs_interp.append(img0.astype(np.uint8))
|
||||
list_fracts_linblend = np.linspace(0, 1, nmb_frames_to_insert[i] + 2)[1:-1]
|
||||
for fract_linblend in list_fracts_linblend:
|
||||
img_blend = interpolate_linear(img0, img1, fract_linblend).astype(np.uint8)
|
||||
list_imgs_interp.append(img_blend.astype(np.uint8))
|
||||
if i == len(list_imgs_float) - 2:
|
||||
list_imgs_interp.append(img1.astype(np.uint8))
|
||||
|
||||
return list_imgs_interp
|
||||
|
||||
|
||||
def get_spacing(nmb_points: int, scaling: float):
|
||||
"""
|
||||
Helper function for getting nonlinear spacing between 0 and 1, symmetric around 0.5
|
||||
Args:
|
||||
nmb_points: int
|
||||
Number of points between [0, 1]
|
||||
scaling: float
|
||||
Higher values will return higher sampling density around 0.5
|
||||
"""
|
||||
if scaling < 1.7:
|
||||
return np.linspace(0, 1, nmb_points)
|
||||
nmb_points_per_side = nmb_points // 2 + 1
|
||||
if np.mod(nmb_points, 2) != 0: # Uneven case
|
||||
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)
|
||||
right_side = 1 - left_side[::-1][1:]
|
||||
else:
|
||||
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)[0:-1]
|
||||
right_side = 1 - left_side[::-1]
|
||||
all_fracts = np.hstack([left_side, right_side])
|
||||
return all_fracts
|
||||
|
||||
|
||||
def get_time(resolution=None):
|
||||
"""
|
||||
Helper function returning an nicely formatted time string, e.g. 221117_1620
|
||||
"""
|
||||
if resolution is None:
|
||||
resolution = "second"
|
||||
if resolution == "day":
|
||||
t = time.strftime('%y%m%d', time.localtime())
|
||||
elif resolution == "minute":
|
||||
t = time.strftime('%y%m%d_%H%M', time.localtime())
|
||||
elif resolution == "second":
|
||||
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
|
||||
elif resolution == "millisecond":
|
||||
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
|
||||
t += "_"
|
||||
t += str("{:03d}".format(int(int(datetime.utcnow().strftime('%f')) / 1000)))
|
||||
else:
|
||||
raise ValueError("bad resolution provided: %s" % resolution)
|
||||
return t
|
||||
|
||||
|
||||
def compare_dicts(a, b):
|
||||
"""
|
||||
Compares two dictionaries a and b and returns a dictionary c, with all
|
||||
keys,values that have shared keys in a and b but same values in a and b.
|
||||
The values of a and b are stacked together in the output.
|
||||
Example:
|
||||
a = {}; a['bobo'] = 4
|
||||
b = {}; b['bobo'] = 5
|
||||
c = dict_compare(a,b)
|
||||
c = {"bobo",[4,5]}
|
||||
"""
|
||||
c = {}
|
||||
for key in a.keys():
|
||||
if key in b.keys():
|
||||
val_a = a[key]
|
||||
val_b = b[key]
|
||||
if val_a != val_b:
|
||||
c[key] = [val_a, val_b]
|
||||
return c
|
||||
|
||||
|
||||
def yml_load(fp_yml, print_fields=False):
|
||||
"""
|
||||
Helper function for loading yaml files
|
||||
"""
|
||||
with open(fp_yml) as f:
|
||||
data = yaml.load(f, Loader=yaml.loader.SafeLoader)
|
||||
dict_data = dict(data)
|
||||
print("load: loaded {}".format(fp_yml))
|
||||
return dict_data
|
||||
|
||||
|
||||
def yml_save(fp_yml, dict_stuff):
|
||||
"""
|
||||
Helper function for saving yaml files
|
||||
"""
|
||||
with open(fp_yml, 'w') as f:
|
||||
yaml.dump(dict_stuff, f, sort_keys=False, default_flow_style=False)
|
||||
print("yml_save: saved {}".format(fp_yml))
|
Reference in New Issue
Block a user