2022-11-19 18:43:57 +00:00
|
|
|
# Copyright 2022 Lunar Ring. All rights reserved.
|
2023-01-11 11:58:59 +00:00
|
|
|
# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
|
2022-11-19 18:43:57 +00:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2023-02-22 09:15:03 +00:00
|
|
|
import os
|
2022-11-19 18:43:57 +00:00
|
|
|
import torch
|
|
|
|
import numpy as np
|
|
|
|
import warnings
|
|
|
|
import time
|
|
|
|
from tqdm.auto import tqdm
|
|
|
|
from PIL import Image
|
2022-11-23 12:43:33 +00:00
|
|
|
from movie_util import MovieSaver
|
2023-02-22 09:15:03 +00:00
|
|
|
from typing import List, Optional
|
2023-02-15 17:21:00 +00:00
|
|
|
import lpips
|
2023-02-22 09:15:03 +00:00
|
|
|
from utils import interpolate_spherical, interpolate_linear, add_frames_linear_interp, yml_load, yml_save
|
2023-11-16 14:37:02 +00:00
|
|
|
warnings.filterwarnings('ignore')
|
|
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
torch.set_grad_enabled(False)
|
2023-02-22 09:15:03 +00:00
|
|
|
|
|
|
|
|
2022-11-19 18:43:57 +00:00
|
|
|
class LatentBlending():
|
|
|
|
def __init__(
|
2023-02-22 09:15:03 +00:00
|
|
|
self,
|
2023-07-20 11:49:19 +00:00
|
|
|
dh: None,
|
2022-11-28 14:34:18 +00:00
|
|
|
guidance_scale: float = 4,
|
|
|
|
guidance_scale_mid_damper: float = 0.5,
|
2023-02-22 09:15:03 +00:00
|
|
|
mid_compression_scaler: float = 1.2):
|
2022-11-19 18:43:57 +00:00
|
|
|
r"""
|
|
|
|
Initializes the latent blending class.
|
|
|
|
Args:
|
|
|
|
guidance_scale: float
|
|
|
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
|
|
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
|
|
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
|
|
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
|
|
|
usually at the expense of lower image quality.
|
2022-11-28 14:34:18 +00:00
|
|
|
guidance_scale_mid_damper: float = 0.5
|
|
|
|
Reduces the guidance scale towards the middle of the transition.
|
|
|
|
A value of 0.5 would decrease the guidance_scale towards the middle linearly by 0.5.
|
|
|
|
mid_compression_scaler: float = 2.0
|
|
|
|
Increases the sampling density in the middle (where most changes happen). Higher value
|
|
|
|
imply more values in the middle. However the inflection point can occur outside the middle,
|
|
|
|
thus high values can give rough transitions. Values around 2 should be fine.
|
2022-11-19 18:43:57 +00:00
|
|
|
"""
|
2023-02-22 09:15:03 +00:00
|
|
|
assert guidance_scale_mid_damper > 0 \
|
|
|
|
and guidance_scale_mid_damper <= 1.0, \
|
|
|
|
f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}"
|
2022-11-29 17:03:08 +00:00
|
|
|
|
2023-07-20 11:49:19 +00:00
|
|
|
self.dh = dh
|
|
|
|
self.device = self.dh.device
|
|
|
|
self.set_dimensions()
|
|
|
|
|
2022-11-28 14:34:18 +00:00
|
|
|
self.guidance_scale_mid_damper = guidance_scale_mid_damper
|
|
|
|
self.mid_compression_scaler = mid_compression_scaler
|
2023-02-22 09:15:03 +00:00
|
|
|
self.seed1 = 0
|
2023-01-08 09:32:58 +00:00
|
|
|
self.seed2 = 0
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2022-11-19 18:43:57 +00:00
|
|
|
# Initialize vars
|
|
|
|
self.prompt1 = ""
|
|
|
|
self.prompt2 = ""
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-02-15 17:21:00 +00:00
|
|
|
self.tree_latents = [None, None]
|
2022-11-29 17:03:08 +00:00
|
|
|
self.tree_fracts = None
|
2023-02-15 17:21:00 +00:00
|
|
|
self.idx_injection = []
|
2022-11-29 17:03:08 +00:00
|
|
|
self.tree_status = None
|
2022-11-19 18:43:57 +00:00
|
|
|
self.tree_final_imgs = []
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2022-11-19 18:43:57 +00:00
|
|
|
self.list_nmb_branches_prev = []
|
|
|
|
self.list_injection_idx_prev = []
|
|
|
|
self.text_embedding1 = None
|
|
|
|
self.text_embedding2 = None
|
2023-01-08 09:32:58 +00:00
|
|
|
self.image1_lowres = None
|
|
|
|
self.image2_lowres = None
|
2022-11-23 13:59:09 +00:00
|
|
|
self.negative_prompt = None
|
2023-07-20 11:49:19 +00:00
|
|
|
self.num_inference_steps = self.dh.num_inference_steps
|
2023-01-08 09:32:58 +00:00
|
|
|
self.noise_level_upscaling = 20
|
2022-11-28 11:41:15 +00:00
|
|
|
self.list_injection_idx = None
|
|
|
|
self.list_nmb_branches = None
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-02-16 10:48:45 +00:00
|
|
|
# Mixing parameters
|
2023-11-16 14:37:02 +00:00
|
|
|
self.branch1_crossfeed_power = 0.3
|
|
|
|
self.branch1_crossfeed_range = 0.3
|
|
|
|
self.branch1_crossfeed_decay = 0.99
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-11-16 14:37:02 +00:00
|
|
|
self.parental_crossfeed_power = 0.3
|
|
|
|
self.parental_crossfeed_range = 0.6
|
|
|
|
self.parental_crossfeed_power_decay = 0.9
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2022-11-28 11:43:33 +00:00
|
|
|
self.set_guidance_scale(guidance_scale)
|
2023-01-19 10:00:35 +00:00
|
|
|
self.multi_transition_img_first = None
|
|
|
|
self.multi_transition_img_last = None
|
2023-02-15 17:21:00 +00:00
|
|
|
self.dt_per_diff = 0
|
2023-02-19 14:32:37 +00:00
|
|
|
self.spatial_mask = None
|
2023-02-15 17:21:00 +00:00
|
|
|
self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
|
2023-10-13 09:50:53 +00:00
|
|
|
|
2023-07-20 11:49:19 +00:00
|
|
|
self.set_prompt1("")
|
|
|
|
self.set_prompt2("")
|
|
|
|
|
2023-10-13 09:50:53 +00:00
|
|
|
def set_dimensions(self, size_output=None):
|
|
|
|
r"""
|
|
|
|
sets the size of the output video.
|
|
|
|
Args:
|
|
|
|
size_output: tuple
|
|
|
|
width x height
|
|
|
|
Note: the size will get automatically adjusted to be divisable by 32.
|
|
|
|
"""
|
|
|
|
self.dh.set_dimensions(size_output)
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2022-11-28 11:43:33 +00:00
|
|
|
def set_guidance_scale(self, guidance_scale):
|
|
|
|
r"""
|
|
|
|
sets the guidance scale.
|
|
|
|
"""
|
2022-11-28 14:34:18 +00:00
|
|
|
self.guidance_scale_base = guidance_scale
|
2022-11-28 11:43:33 +00:00
|
|
|
self.guidance_scale = guidance_scale
|
2023-07-20 11:49:19 +00:00
|
|
|
self.dh.guidance_scale = guidance_scale
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-01-08 10:48:44 +00:00
|
|
|
def set_negative_prompt(self, negative_prompt):
|
|
|
|
r"""Set the negative prompt. Currenty only one negative prompt is supported
|
|
|
|
"""
|
|
|
|
self.negative_prompt = negative_prompt
|
2023-07-20 11:49:19 +00:00
|
|
|
self.dh.set_negative_prompt(negative_prompt)
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2022-11-28 14:34:18 +00:00
|
|
|
def set_guidance_mid_dampening(self, fract_mixing):
|
|
|
|
r"""
|
2023-02-22 09:15:03 +00:00
|
|
|
Tunes the guidance scale down as a linear function of fract_mixing,
|
2022-11-28 14:34:18 +00:00
|
|
|
towards 0.5 the minimum will be reached.
|
|
|
|
"""
|
2023-02-22 09:15:03 +00:00
|
|
|
mid_factor = 1 - np.abs(fract_mixing - 0.5) / 0.5
|
|
|
|
max_guidance_reduction = self.guidance_scale_base * (1 - self.guidance_scale_mid_damper) - 1
|
|
|
|
guidance_scale_effective = self.guidance_scale_base - max_guidance_reduction * mid_factor
|
2022-11-28 14:34:18 +00:00
|
|
|
self.guidance_scale = guidance_scale_effective
|
2023-07-20 11:49:19 +00:00
|
|
|
self.dh.guidance_scale = guidance_scale_effective
|
2022-11-19 18:43:57 +00:00
|
|
|
|
2023-02-20 10:44:50 +00:00
|
|
|
def set_branch1_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay):
|
|
|
|
r"""
|
|
|
|
Sets the crossfeed parameters for the first branch to the last branch.
|
|
|
|
Args:
|
|
|
|
crossfeed_power: float [0,1]
|
|
|
|
Controls the level of cross-feeding between the first and last image branch.
|
|
|
|
crossfeed_range: float [0,1]
|
|
|
|
Sets the duration of active crossfeed during development.
|
|
|
|
crossfeed_decay: float [0,1]
|
|
|
|
Sets decay for branch1_crossfeed_power. Lower values make the decay stronger across the range.
|
|
|
|
"""
|
|
|
|
self.branch1_crossfeed_power = np.clip(crossfeed_power, 0, 1)
|
|
|
|
self.branch1_crossfeed_range = np.clip(crossfeed_range, 0, 1)
|
|
|
|
self.branch1_crossfeed_decay = np.clip(crossfeed_decay, 0, 1)
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-02-20 10:44:50 +00:00
|
|
|
def set_parental_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay):
|
|
|
|
r"""
|
|
|
|
Sets the crossfeed parameters for all transition images (within the first and last branch).
|
|
|
|
Args:
|
|
|
|
crossfeed_power: float [0,1]
|
2023-02-22 09:15:03 +00:00
|
|
|
Controls the level of cross-feeding from the parental branches
|
2023-02-20 10:44:50 +00:00
|
|
|
crossfeed_range: float [0,1]
|
|
|
|
Sets the duration of active crossfeed during development.
|
|
|
|
crossfeed_decay: float [0,1]
|
|
|
|
Sets decay for branch1_crossfeed_power. Lower values make the decay stronger across the range.
|
|
|
|
"""
|
|
|
|
self.parental_crossfeed_power = np.clip(crossfeed_power, 0, 1)
|
|
|
|
self.parental_crossfeed_range = np.clip(crossfeed_range, 0, 1)
|
|
|
|
self.parental_crossfeed_power_decay = np.clip(crossfeed_decay, 0, 1)
|
|
|
|
|
2022-11-19 18:43:57 +00:00
|
|
|
def set_prompt1(self, prompt: str):
|
|
|
|
r"""
|
|
|
|
Sets the first prompt (for the first keyframe) including text embeddings.
|
|
|
|
Args:
|
|
|
|
prompt: str
|
|
|
|
ABC trending on artstation painted by Greg Rutkowski
|
|
|
|
"""
|
|
|
|
prompt = prompt.replace("_", " ")
|
|
|
|
self.prompt1 = prompt
|
|
|
|
self.text_embedding1 = self.get_text_embeddings(self.prompt1)
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2022-11-19 18:43:57 +00:00
|
|
|
def set_prompt2(self, prompt: str):
|
|
|
|
r"""
|
|
|
|
Sets the second prompt (for the second keyframe) including text embeddings.
|
|
|
|
Args:
|
|
|
|
prompt: str
|
|
|
|
XYZ trending on artstation painted by Greg Rutkowski
|
|
|
|
"""
|
|
|
|
prompt = prompt.replace("_", " ")
|
|
|
|
self.prompt2 = prompt
|
|
|
|
self.text_embedding2 = self.get_text_embeddings(self.prompt2)
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-01-08 09:32:58 +00:00
|
|
|
def set_image1(self, image: Image):
|
|
|
|
r"""
|
|
|
|
Sets the first image (keyframe), relevant for the upscaling model transitions.
|
|
|
|
Args:
|
|
|
|
image: Image
|
|
|
|
"""
|
|
|
|
self.image1_lowres = image
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-01-08 09:32:58 +00:00
|
|
|
def set_image2(self, image: Image):
|
|
|
|
r"""
|
|
|
|
Sets the second image (keyframe), relevant for the upscaling model transitions.
|
|
|
|
Args:
|
|
|
|
image: Image
|
|
|
|
"""
|
|
|
|
self.image2_lowres = image
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2022-11-28 11:41:15 +00:00
|
|
|
def run_transition(
|
2023-02-15 17:21:00 +00:00
|
|
|
self,
|
2023-02-22 09:15:03 +00:00
|
|
|
recycle_img1: Optional[bool] = False,
|
|
|
|
recycle_img2: Optional[bool] = False,
|
2023-02-15 17:21:00 +00:00
|
|
|
num_inference_steps: Optional[int] = 30,
|
|
|
|
depth_strength: Optional[float] = 0.3,
|
2023-02-18 06:56:30 +00:00
|
|
|
t_compute_max_allowed: Optional[float] = None,
|
|
|
|
nmb_max_branches: Optional[int] = None,
|
2023-02-22 09:15:03 +00:00
|
|
|
fixed_seeds: Optional[List[int]] = None):
|
2023-02-18 06:56:30 +00:00
|
|
|
r"""
|
|
|
|
Function for computing transitions.
|
|
|
|
Returns a list of transition images using spherical latent blending.
|
|
|
|
Args:
|
|
|
|
recycle_img1: Optional[bool]:
|
|
|
|
Don't recompute the latents for the first keyframe (purely prompt1). Saves compute.
|
|
|
|
recycle_img2: Optional[bool]:
|
|
|
|
Don't recompute the latents for the second keyframe (purely prompt2). Saves compute.
|
|
|
|
num_inference_steps:
|
|
|
|
Number of diffusion steps. Higher values will take more compute time.
|
|
|
|
depth_strength:
|
2023-02-22 09:15:03 +00:00
|
|
|
Determines how deep the first injection will happen.
|
2023-02-18 06:56:30 +00:00
|
|
|
Deeper injections will cause (unwanted) formation of new structures,
|
|
|
|
more shallow values will go into alpha-blendy land.
|
|
|
|
t_compute_max_allowed:
|
2023-02-22 09:15:03 +00:00
|
|
|
Either provide t_compute_max_allowed or nmb_max_branches.
|
|
|
|
The maximum time allowed for computation. Higher values give better results but take longer.
|
2023-02-18 06:56:30 +00:00
|
|
|
nmb_max_branches: int
|
|
|
|
Either provide t_compute_max_allowed or nmb_max_branches. The maximum number of branches to be computed. Higher values give better
|
2023-02-22 09:15:03 +00:00
|
|
|
results. Use this if you want to have controllable results independent
|
2023-02-18 06:56:30 +00:00
|
|
|
of your computer.
|
|
|
|
fixed_seeds: Optional[List[int)]:
|
|
|
|
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
|
|
|
|
Otherwise random seeds will be taken.
|
|
|
|
"""
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-02-15 17:21:00 +00:00
|
|
|
# Sanity checks first
|
|
|
|
assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before'
|
|
|
|
assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before'
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-02-15 17:21:00 +00:00
|
|
|
# Random seeds
|
|
|
|
if fixed_seeds is not None:
|
|
|
|
if fixed_seeds == 'randomize':
|
|
|
|
fixed_seeds = list(np.random.randint(0, 1000000, 2).astype(np.int32))
|
|
|
|
else:
|
2023-02-22 09:15:03 +00:00
|
|
|
assert len(fixed_seeds) == 2, "Supply a list with len = 2"
|
|
|
|
|
2023-02-15 17:21:00 +00:00
|
|
|
self.seed1 = fixed_seeds[0]
|
|
|
|
self.seed2 = fixed_seeds[1]
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-02-15 17:21:00 +00:00
|
|
|
# Ensure correct num_inference_steps in holder
|
2023-02-18 06:56:30 +00:00
|
|
|
self.num_inference_steps = num_inference_steps
|
2023-07-20 11:49:19 +00:00
|
|
|
self.dh.set_num_inference_steps(num_inference_steps)
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-02-15 17:21:00 +00:00
|
|
|
# Compute / Recycle first image
|
2023-02-16 10:48:45 +00:00
|
|
|
if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps:
|
2023-02-15 17:21:00 +00:00
|
|
|
list_latents1 = self.compute_latents1()
|
|
|
|
else:
|
|
|
|
list_latents1 = self.tree_latents[0]
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-02-15 17:21:00 +00:00
|
|
|
# Compute / Recycle first image
|
2023-02-16 10:48:45 +00:00
|
|
|
if not recycle_img2 or len(self.tree_latents[-1]) != self.num_inference_steps:
|
2023-02-15 17:21:00 +00:00
|
|
|
list_latents2 = self.compute_latents2()
|
|
|
|
else:
|
|
|
|
list_latents2 = self.tree_latents[-1]
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-02-15 17:21:00 +00:00
|
|
|
# Reset the tree, injecting the edge latents1/2 we just generated/recycled
|
2023-02-22 09:15:03 +00:00
|
|
|
self.tree_latents = [list_latents1, list_latents2]
|
2023-02-15 17:21:00 +00:00
|
|
|
self.tree_fracts = [0.0, 1.0]
|
2023-07-20 11:49:19 +00:00
|
|
|
self.tree_final_imgs = [self.dh.latent2image((self.tree_latents[0][-1])), self.dh.latent2image((self.tree_latents[-1][-1]))]
|
2023-02-15 17:21:00 +00:00
|
|
|
self.tree_idx_injection = [0, 0]
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-02-19 14:32:37 +00:00
|
|
|
# Hard-fix. Apply spatial mask only for list_latents2 but not for transition. WIP...
|
|
|
|
self.spatial_mask = None
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-02-15 17:21:00 +00:00
|
|
|
# Set up branching scheme (dependent on provided compute time)
|
2023-02-18 06:56:30 +00:00
|
|
|
list_idx_injection, list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches)
|
|
|
|
|
2023-02-22 09:15:03 +00:00
|
|
|
# Run iteratively, starting with the longest trajectory.
|
2023-02-18 06:56:30 +00:00
|
|
|
# Always inserting new branches where they are needed most according to image similarity
|
|
|
|
for s_idx in tqdm(range(len(list_idx_injection))):
|
|
|
|
nmb_stems = list_nmb_stems[s_idx]
|
|
|
|
idx_injection = list_idx_injection[s_idx]
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-02-18 06:56:30 +00:00
|
|
|
for i in range(nmb_stems):
|
|
|
|
fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection)
|
|
|
|
self.set_guidance_mid_dampening(fract_mixing)
|
|
|
|
list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection)
|
|
|
|
self.insert_into_tree(fract_mixing, idx_injection, list_latents)
|
|
|
|
# print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}")
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-02-18 06:56:30 +00:00
|
|
|
return self.tree_final_imgs
|
|
|
|
|
2023-02-19 14:32:37 +00:00
|
|
|
def compute_latents1(self, return_image=False):
|
|
|
|
r"""
|
|
|
|
Runs a diffusion trajectory for the first image
|
|
|
|
Args:
|
|
|
|
return_image: bool
|
|
|
|
whether to return an image or the list of latents
|
|
|
|
"""
|
|
|
|
print("starting compute_latents1")
|
|
|
|
list_conditionings = self.get_mixed_conditioning(0)
|
|
|
|
t0 = time.time()
|
|
|
|
latents_start = self.get_noise(self.seed1)
|
|
|
|
list_latents1 = self.run_diffusion(
|
2023-02-22 09:15:03 +00:00
|
|
|
list_conditionings,
|
|
|
|
latents_start=latents_start,
|
|
|
|
idx_start=0)
|
2023-02-19 14:32:37 +00:00
|
|
|
t1 = time.time()
|
2023-02-22 09:15:03 +00:00
|
|
|
self.dt_per_diff = (t1 - t0) / self.num_inference_steps
|
2023-02-19 14:32:37 +00:00
|
|
|
self.tree_latents[0] = list_latents1
|
|
|
|
if return_image:
|
2023-07-20 11:49:19 +00:00
|
|
|
return self.dh.latent2image(list_latents1[-1])
|
2023-02-19 14:32:37 +00:00
|
|
|
else:
|
|
|
|
return list_latents1
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-02-19 14:32:37 +00:00
|
|
|
def compute_latents2(self, return_image=False):
|
|
|
|
r"""
|
|
|
|
Runs a diffusion trajectory for the last image, which may be affected by the first image's trajectory.
|
|
|
|
Args:
|
|
|
|
return_image: bool
|
|
|
|
whether to return an image or the list of latents
|
|
|
|
"""
|
|
|
|
print("starting compute_latents2")
|
|
|
|
list_conditionings = self.get_mixed_conditioning(1)
|
|
|
|
latents_start = self.get_noise(self.seed2)
|
|
|
|
# Influence from branch1
|
2023-02-20 07:29:21 +00:00
|
|
|
if self.branch1_crossfeed_power > 0.0:
|
2023-02-19 14:32:37 +00:00
|
|
|
# Set up the mixing_coeffs
|
2023-02-22 09:15:03 +00:00
|
|
|
idx_mixing_stop = int(round(self.num_inference_steps * self.branch1_crossfeed_range))
|
|
|
|
mixing_coeffs = list(np.linspace(self.branch1_crossfeed_power, self.branch1_crossfeed_power * self.branch1_crossfeed_decay, idx_mixing_stop))
|
|
|
|
mixing_coeffs.extend((self.num_inference_steps - idx_mixing_stop) * [0])
|
2023-02-19 14:32:37 +00:00
|
|
|
list_latents_mixing = self.tree_latents[0]
|
|
|
|
list_latents2 = self.run_diffusion(
|
2023-02-22 09:15:03 +00:00
|
|
|
list_conditionings,
|
|
|
|
latents_start=latents_start,
|
|
|
|
idx_start=0,
|
|
|
|
list_latents_mixing=list_latents_mixing,
|
|
|
|
mixing_coeffs=mixing_coeffs)
|
2023-02-19 14:32:37 +00:00
|
|
|
else:
|
|
|
|
list_latents2 = self.run_diffusion(list_conditionings, latents_start)
|
|
|
|
self.tree_latents[-1] = list_latents2
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-02-19 14:32:37 +00:00
|
|
|
if return_image:
|
2023-07-20 11:49:19 +00:00
|
|
|
return self.dh.latent2image(list_latents2[-1])
|
2023-02-19 14:32:37 +00:00
|
|
|
else:
|
2023-02-22 09:15:03 +00:00
|
|
|
return list_latents2
|
2023-02-19 14:32:37 +00:00
|
|
|
|
2023-02-22 09:15:03 +00:00
|
|
|
def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection):
|
2023-02-19 14:32:37 +00:00
|
|
|
r"""
|
|
|
|
Runs a diffusion trajectory, using the latents from the respective parents
|
|
|
|
Args:
|
|
|
|
fract_mixing: float
|
|
|
|
the fraction along the transition axis [0, 1]
|
|
|
|
b_parent1: int
|
|
|
|
index of parent1 to be used
|
|
|
|
b_parent2: int
|
|
|
|
index of parent2 to be used
|
|
|
|
idx_injection: int
|
|
|
|
the index in terms of diffusion steps, where the next insertion will start.
|
|
|
|
"""
|
|
|
|
list_conditionings = self.get_mixed_conditioning(fract_mixing)
|
2023-02-22 09:15:03 +00:00
|
|
|
fract_mixing_parental = (fract_mixing - self.tree_fracts[b_parent1]) / (self.tree_fracts[b_parent2] - self.tree_fracts[b_parent1])
|
2023-02-19 14:32:37 +00:00
|
|
|
# idx_reversed = self.num_inference_steps - idx_injection
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-02-19 14:32:37 +00:00
|
|
|
list_latents_parental_mix = []
|
|
|
|
for i in range(self.num_inference_steps):
|
|
|
|
latents_p1 = self.tree_latents[b_parent1][i]
|
|
|
|
latents_p2 = self.tree_latents[b_parent2][i]
|
|
|
|
if latents_p1 is None or latents_p2 is None:
|
|
|
|
latents_parental = None
|
|
|
|
else:
|
|
|
|
latents_parental = interpolate_spherical(latents_p1, latents_p2, fract_mixing_parental)
|
|
|
|
list_latents_parental_mix.append(latents_parental)
|
|
|
|
|
2023-02-22 09:15:03 +00:00
|
|
|
idx_mixing_stop = int(round(self.num_inference_steps * self.parental_crossfeed_range))
|
|
|
|
mixing_coeffs = idx_injection * [self.parental_crossfeed_power]
|
2023-02-19 14:32:37 +00:00
|
|
|
nmb_mixing = idx_mixing_stop - idx_injection
|
|
|
|
if nmb_mixing > 0:
|
2023-02-22 09:15:03 +00:00
|
|
|
mixing_coeffs.extend(list(np.linspace(self.parental_crossfeed_power, self.parental_crossfeed_power * self.parental_crossfeed_power_decay, nmb_mixing)))
|
|
|
|
mixing_coeffs.extend((self.num_inference_steps - len(mixing_coeffs)) * [0])
|
|
|
|
latents_start = list_latents_parental_mix[idx_injection - 1]
|
2023-02-19 14:32:37 +00:00
|
|
|
list_latents = self.run_diffusion(
|
2023-02-22 09:15:03 +00:00
|
|
|
list_conditionings,
|
|
|
|
latents_start=latents_start,
|
|
|
|
idx_start=idx_injection,
|
|
|
|
list_latents_mixing=list_latents_parental_mix,
|
|
|
|
mixing_coeffs=mixing_coeffs)
|
2023-02-19 14:32:37 +00:00
|
|
|
return list_latents
|
|
|
|
|
2023-02-18 06:56:30 +00:00
|
|
|
def get_time_based_branching(self, depth_strength, t_compute_max_allowed=None, nmb_max_branches=None):
|
|
|
|
r"""
|
|
|
|
Sets up the branching scheme dependent on the time that is granted for compute.
|
|
|
|
The scheme uses an estimation derived from the first image's computation speed.
|
|
|
|
Either provide t_compute_max_allowed or nmb_max_branches
|
|
|
|
Args:
|
|
|
|
depth_strength:
|
2023-02-22 09:15:03 +00:00
|
|
|
Determines how deep the first injection will happen.
|
2023-02-18 06:56:30 +00:00
|
|
|
Deeper injections will cause (unwanted) formation of new structures,
|
|
|
|
more shallow values will go into alpha-blendy land.
|
|
|
|
t_compute_max_allowed: float
|
|
|
|
The maximum time allowed for computation. Higher values give better results
|
2023-02-22 09:15:03 +00:00
|
|
|
but take longer. Use this if you want to fix your waiting time for the results.
|
2023-02-18 06:56:30 +00:00
|
|
|
nmb_max_branches: int
|
|
|
|
The maximum number of branches to be computed. Higher values give better
|
2023-02-22 09:15:03 +00:00
|
|
|
results. Use this if you want to have controllable results independent
|
2023-02-18 06:56:30 +00:00
|
|
|
of your computer.
|
|
|
|
"""
|
2023-02-22 09:15:03 +00:00
|
|
|
idx_injection_base = int(round(self.num_inference_steps * depth_strength))
|
|
|
|
list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps - 1, 3)
|
2023-02-15 17:21:00 +00:00
|
|
|
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
|
|
|
|
t_compute = 0
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-02-18 06:56:30 +00:00
|
|
|
if nmb_max_branches is None:
|
|
|
|
assert t_compute_max_allowed is not None, "Either specify t_compute_max_allowed or nmb_max_branches"
|
|
|
|
stop_criterion = "t_compute_max_allowed"
|
|
|
|
elif t_compute_max_allowed is None:
|
|
|
|
assert nmb_max_branches is not None, "Either specify t_compute_max_allowed or nmb_max_branches"
|
|
|
|
stop_criterion = "nmb_max_branches"
|
2023-02-22 09:15:03 +00:00
|
|
|
nmb_max_branches -= 2 # Discounting the outer frames
|
2023-02-18 06:56:30 +00:00
|
|
|
else:
|
|
|
|
raise ValueError("Either specify t_compute_max_allowed or nmb_max_branches")
|
|
|
|
stop_criterion_reached = False
|
|
|
|
is_first_iteration = True
|
|
|
|
while not stop_criterion_reached:
|
2023-02-15 17:21:00 +00:00
|
|
|
list_compute_steps = self.num_inference_steps - list_idx_injection
|
|
|
|
list_compute_steps *= list_nmb_stems
|
2023-02-22 09:15:03 +00:00
|
|
|
t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15 * np.sum(list_nmb_stems)
|
2023-11-16 14:37:02 +00:00
|
|
|
t_compute += 2 * self.num_inference_steps * self.dt_per_diff # outer branches
|
2023-02-15 17:21:00 +00:00
|
|
|
increase_done = False
|
2023-02-22 09:15:03 +00:00
|
|
|
for s_idx in range(len(list_nmb_stems) - 1):
|
|
|
|
if list_nmb_stems[s_idx + 1] / list_nmb_stems[s_idx] >= 2:
|
2023-02-15 17:21:00 +00:00
|
|
|
list_nmb_stems[s_idx] += 1
|
|
|
|
increase_done = True
|
|
|
|
break
|
|
|
|
if not increase_done:
|
|
|
|
list_nmb_stems[-1] += 1
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-02-18 06:56:30 +00:00
|
|
|
if stop_criterion == "t_compute_max_allowed" and t_compute > t_compute_max_allowed:
|
|
|
|
stop_criterion_reached = True
|
|
|
|
elif stop_criterion == "nmb_max_branches" and np.sum(list_nmb_stems) >= nmb_max_branches:
|
|
|
|
stop_criterion_reached = True
|
|
|
|
if is_first_iteration:
|
|
|
|
# Need to undersample.
|
|
|
|
list_idx_injection = np.linspace(list_idx_injection[0], list_idx_injection[-1], nmb_max_branches).astype(np.int32)
|
|
|
|
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
|
|
|
|
else:
|
|
|
|
is_first_iteration = False
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-02-18 06:56:30 +00:00
|
|
|
# print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}")
|
|
|
|
return list_idx_injection, list_nmb_stems
|
2023-02-15 17:21:00 +00:00
|
|
|
|
|
|
|
def get_mixing_parameters(self, idx_injection):
|
2023-02-16 10:48:45 +00:00
|
|
|
r"""
|
|
|
|
Computes which parental latents should be mixed together to achieve a smooth blend.
|
|
|
|
As metric, we are using lpips image similarity. The insertion takes place
|
|
|
|
where the metric is maximal.
|
|
|
|
Args:
|
|
|
|
idx_injection: int
|
|
|
|
the index in terms of diffusion steps, where the next insertion will start.
|
|
|
|
"""
|
2023-02-15 17:21:00 +00:00
|
|
|
# get_lpips_similarity
|
|
|
|
similarities = []
|
2023-02-22 09:15:03 +00:00
|
|
|
for i in range(len(self.tree_final_imgs) - 1):
|
|
|
|
similarities.append(self.get_lpips_similarity(self.tree_final_imgs[i], self.tree_final_imgs[i + 1]))
|
2023-02-15 17:21:00 +00:00
|
|
|
b_closest1 = np.argmax(similarities)
|
2023-02-22 09:15:03 +00:00
|
|
|
b_closest2 = b_closest1 + 1
|
2023-02-15 17:21:00 +00:00
|
|
|
fract_closest1 = self.tree_fracts[b_closest1]
|
|
|
|
fract_closest2 = self.tree_fracts[b_closest2]
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-02-15 17:21:00 +00:00
|
|
|
# Ensure that the parents are indeed older!
|
|
|
|
b_parent1 = b_closest1
|
|
|
|
while True:
|
|
|
|
if self.tree_idx_injection[b_parent1] < idx_injection:
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
b_parent1 -= 1
|
|
|
|
b_parent2 = b_closest2
|
|
|
|
while True:
|
|
|
|
if self.tree_idx_injection[b_parent2] < idx_injection:
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
b_parent2 += 1
|
2023-02-22 09:15:03 +00:00
|
|
|
fract_mixing = (fract_closest1 + fract_closest2) / 2
|
2023-02-15 17:21:00 +00:00
|
|
|
return fract_mixing, b_parent1, b_parent2
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-02-15 17:21:00 +00:00
|
|
|
def insert_into_tree(self, fract_mixing, idx_injection, list_latents):
|
2023-02-16 10:48:45 +00:00
|
|
|
r"""
|
|
|
|
Inserts all necessary parameters into the trajectory tree.
|
|
|
|
Args:
|
|
|
|
fract_mixing: float
|
|
|
|
the fraction along the transition axis [0, 1]
|
|
|
|
idx_injection: int
|
|
|
|
the index in terms of diffusion steps, where the next insertion will start.
|
|
|
|
list_latents: list
|
|
|
|
list of the latents to be inserted
|
|
|
|
"""
|
2023-02-22 09:15:03 +00:00
|
|
|
b_parent1, b_parent2 = self.get_closest_idx(fract_mixing)
|
|
|
|
self.tree_latents.insert(b_parent1 + 1, list_latents)
|
2023-07-20 11:49:19 +00:00
|
|
|
self.tree_final_imgs.insert(b_parent1 + 1, self.dh.latent2image(list_latents[-1]))
|
2023-02-22 09:15:03 +00:00
|
|
|
self.tree_fracts.insert(b_parent1 + 1, fract_mixing)
|
|
|
|
self.tree_idx_injection.insert(b_parent1 + 1, idx_injection)
|
|
|
|
|
2023-02-16 10:48:45 +00:00
|
|
|
def get_noise(self, seed):
|
2022-11-28 11:41:15 +00:00
|
|
|
r"""
|
2023-02-18 06:56:30 +00:00
|
|
|
Helper function to get noise given seed.
|
2022-11-28 11:41:15 +00:00
|
|
|
Args:
|
2023-02-18 06:56:30 +00:00
|
|
|
seed: int
|
2022-11-19 18:43:57 +00:00
|
|
|
"""
|
2023-11-16 14:37:02 +00:00
|
|
|
return self.dh.get_noise(seed)
|
2022-11-19 18:43:57 +00:00
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def run_diffusion(
|
2023-02-22 09:15:03 +00:00
|
|
|
self,
|
|
|
|
list_conditionings,
|
|
|
|
latents_start: torch.FloatTensor = None,
|
|
|
|
idx_start: int = 0,
|
|
|
|
list_latents_mixing=None,
|
|
|
|
mixing_coeffs=0.0,
|
|
|
|
return_image: Optional[bool] = False):
|
2022-11-19 18:43:57 +00:00
|
|
|
r"""
|
2023-02-16 10:48:45 +00:00
|
|
|
Wrapper function for diffusion runners.
|
2022-11-19 18:43:57 +00:00
|
|
|
Depending on the mode, the correct one will be executed.
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2022-11-19 18:43:57 +00:00
|
|
|
Args:
|
2023-02-20 10:26:04 +00:00
|
|
|
list_conditionings: list
|
|
|
|
List of all conditionings for the diffusion model.
|
2023-02-22 09:15:03 +00:00
|
|
|
latents_start: torch.FloatTensor
|
2022-11-19 18:43:57 +00:00
|
|
|
Latents that are used for injection
|
|
|
|
idx_start: int
|
|
|
|
Index of the diffusion process start and where the latents_for_injection are injected
|
2023-02-22 09:15:03 +00:00
|
|
|
list_latents_mixing: torch.FloatTensor
|
2023-02-16 10:48:45 +00:00
|
|
|
List of latents (latent trajectories) that are used for mixing
|
|
|
|
mixing_coeffs: float or list
|
|
|
|
Coefficients, how strong each element of list_latents_mixing will be mixed in.
|
2022-11-19 18:43:57 +00:00
|
|
|
return_image: Optional[bool]
|
|
|
|
Optionally return image directly
|
|
|
|
"""
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2022-11-25 14:34:41 +00:00
|
|
|
# Ensure correct num_inference_steps in Holder
|
2023-07-20 11:49:19 +00:00
|
|
|
self.dh.set_num_inference_steps(self.num_inference_steps)
|
2023-01-08 09:32:58 +00:00
|
|
|
assert type(list_conditionings) is list, "list_conditionings need to be a list"
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-07-20 11:49:19 +00:00
|
|
|
if self.dh.use_sd_xl:
|
2023-01-08 09:32:58 +00:00
|
|
|
text_embeddings = list_conditionings[0]
|
2023-07-20 11:49:19 +00:00
|
|
|
return self.dh.run_diffusion_sd_xl(
|
2023-02-22 09:15:03 +00:00
|
|
|
text_embeddings=text_embeddings,
|
|
|
|
latents_start=latents_start,
|
|
|
|
idx_start=idx_start,
|
|
|
|
list_latents_mixing=list_latents_mixing,
|
|
|
|
mixing_coeffs=mixing_coeffs,
|
|
|
|
return_image=return_image)
|
|
|
|
|
2023-07-20 11:49:19 +00:00
|
|
|
else:
|
|
|
|
text_embeddings = list_conditionings[0]
|
|
|
|
return self.dh.run_diffusion_standard(
|
|
|
|
text_embeddings=text_embeddings,
|
2023-02-22 09:15:03 +00:00
|
|
|
latents_start=latents_start,
|
|
|
|
idx_start=idx_start,
|
|
|
|
list_latents_mixing=list_latents_mixing,
|
|
|
|
mixing_coeffs=mixing_coeffs,
|
2023-02-18 06:56:30 +00:00
|
|
|
return_image=return_image)
|
2023-01-08 09:32:58 +00:00
|
|
|
|
2023-02-18 07:44:28 +00:00
|
|
|
def run_upscaling(
|
2023-02-22 09:15:03 +00:00
|
|
|
self,
|
2023-01-08 09:32:58 +00:00
|
|
|
dp_img: str,
|
2023-01-09 09:59:00 +00:00
|
|
|
depth_strength: float = 0.65,
|
2023-02-18 06:56:30 +00:00
|
|
|
num_inference_steps: int = 100,
|
|
|
|
nmb_max_branches_highres: int = 5,
|
|
|
|
nmb_max_branches_lowres: int = 6,
|
2023-02-22 09:15:03 +00:00
|
|
|
duration_single_segment=3,
|
|
|
|
fps=24,
|
|
|
|
fixed_seeds: Optional[List[int]] = None):
|
2023-02-20 10:26:04 +00:00
|
|
|
r"""
|
|
|
|
Runs upscaling with the x4 model. Requires that you run a transition before with a low-res model and save the results using write_imgs_transition.
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-02-20 10:26:04 +00:00
|
|
|
Args:
|
|
|
|
dp_img: str
|
|
|
|
Path to the low-res transition path (as saved in write_imgs_transition)
|
|
|
|
depth_strength:
|
2023-02-22 09:15:03 +00:00
|
|
|
Determines how deep the first injection will happen.
|
2023-02-20 10:26:04 +00:00
|
|
|
Deeper injections will cause (unwanted) formation of new structures,
|
|
|
|
more shallow values will go into alpha-blendy land.
|
|
|
|
num_inference_steps:
|
|
|
|
Number of diffusion steps. Higher values will take more compute time.
|
|
|
|
nmb_max_branches_highres: int
|
|
|
|
Number of final branches of the upscaling transition pass. Note this is the number
|
|
|
|
of branches between each pair of low-res images.
|
|
|
|
nmb_max_branches_lowres: int
|
|
|
|
Number of input low-res images, subsampling all transition images written in the low-res pass.
|
|
|
|
Setting this number lower (e.g. 6) will decrease the compute time but not affect the results too much.
|
|
|
|
duration_single_segment: float
|
|
|
|
The duration of each high-res movie segment. You will have nmb_max_branches_lowres-1 segments in total.
|
2023-02-22 09:15:03 +00:00
|
|
|
fps: float
|
|
|
|
frames per second of movie
|
2023-02-20 10:26:04 +00:00
|
|
|
fixed_seeds: Optional[List[int)]:
|
|
|
|
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
|
|
|
|
Otherwise random seeds will be taken.
|
|
|
|
"""
|
2023-01-08 09:32:58 +00:00
|
|
|
fp_yml = os.path.join(dp_img, "lowres.yaml")
|
2023-02-18 06:56:30 +00:00
|
|
|
fp_movie = os.path.join(dp_img, "movie_highres.mp4")
|
2023-01-08 09:32:58 +00:00
|
|
|
ms = MovieSaver(fp_movie, fps=fps)
|
|
|
|
assert os.path.isfile(fp_yml), "lowres.yaml does not exist. did you forget run_upscaling_step1?"
|
|
|
|
dict_stuff = yml_load(fp_yml)
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-01-08 09:32:58 +00:00
|
|
|
# load lowres images
|
|
|
|
nmb_images_lowres = dict_stuff['nmb_images']
|
|
|
|
prompt1 = dict_stuff['prompt1']
|
|
|
|
prompt2 = dict_stuff['prompt2']
|
2023-02-22 09:15:03 +00:00
|
|
|
idx_img_lowres = np.round(np.linspace(0, nmb_images_lowres - 1, nmb_max_branches_lowres)).astype(np.int32)
|
2023-01-08 09:32:58 +00:00
|
|
|
imgs_lowres = []
|
2023-02-18 06:56:30 +00:00
|
|
|
for i in idx_img_lowres:
|
2023-01-08 09:32:58 +00:00
|
|
|
fp_img_lowres = os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg")
|
|
|
|
assert os.path.isfile(fp_img_lowres), f"{fp_img_lowres} does not exist. did you forget run_upscaling_step1?"
|
|
|
|
imgs_lowres.append(Image.open(fp_img_lowres))
|
|
|
|
|
|
|
|
# set up upscaling
|
2023-07-20 11:49:19 +00:00
|
|
|
text_embeddingA = self.dh.get_text_embedding(prompt1)
|
|
|
|
text_embeddingB = self.dh.get_text_embedding(prompt2)
|
2023-02-22 09:15:03 +00:00
|
|
|
list_fract_mixing = np.linspace(0, 1, nmb_max_branches_lowres - 1)
|
|
|
|
for i in range(nmb_max_branches_lowres - 1):
|
2023-02-18 06:56:30 +00:00
|
|
|
print(f"Starting movie segment {i+1}/{nmb_max_branches_lowres-1}")
|
2023-01-08 09:32:58 +00:00
|
|
|
self.text_embedding1 = interpolate_linear(text_embeddingA, text_embeddingB, list_fract_mixing[i])
|
2023-02-22 09:15:03 +00:00
|
|
|
self.text_embedding2 = interpolate_linear(text_embeddingA, text_embeddingB, 1 - list_fract_mixing[i])
|
|
|
|
if i == 0:
|
|
|
|
recycle_img1 = False
|
2023-01-08 09:32:58 +00:00
|
|
|
else:
|
|
|
|
self.swap_forward()
|
2023-02-22 09:15:03 +00:00
|
|
|
recycle_img1 = True
|
|
|
|
|
2023-01-08 09:32:58 +00:00
|
|
|
self.set_image1(imgs_lowres[i])
|
2023-02-22 09:15:03 +00:00
|
|
|
self.set_image2(imgs_lowres[i + 1])
|
|
|
|
|
2023-02-18 06:56:30 +00:00
|
|
|
list_imgs = self.run_transition(
|
2023-02-22 09:15:03 +00:00
|
|
|
recycle_img1=recycle_img1,
|
|
|
|
recycle_img2=False,
|
|
|
|
num_inference_steps=num_inference_steps,
|
|
|
|
depth_strength=depth_strength,
|
|
|
|
nmb_max_branches=nmb_max_branches_highres)
|
2023-02-18 07:44:28 +00:00
|
|
|
list_imgs_interp = add_frames_linear_interp(list_imgs, fps, duration_single_segment)
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-01-08 09:32:58 +00:00
|
|
|
# Save movie frame
|
|
|
|
for img in list_imgs_interp:
|
|
|
|
ms.write_frame(img)
|
|
|
|
ms.finalize()
|
2022-11-25 14:34:41 +00:00
|
|
|
|
2023-01-08 09:32:58 +00:00
|
|
|
@torch.no_grad()
|
|
|
|
def get_mixed_conditioning(self, fract_mixing):
|
2023-07-20 11:49:19 +00:00
|
|
|
if self.dh.use_sd_xl:
|
|
|
|
text_embeddings_mix = []
|
|
|
|
for i in range(len(self.text_embedding1)):
|
|
|
|
text_embeddings_mix.append(interpolate_linear(self.text_embedding1[i], self.text_embedding2[i], fract_mixing))
|
2023-01-08 09:32:58 +00:00
|
|
|
list_conditionings = [text_embeddings_mix]
|
2023-07-20 11:49:19 +00:00
|
|
|
else:
|
2023-01-08 09:32:58 +00:00
|
|
|
text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing)
|
|
|
|
list_conditionings = [text_embeddings_mix]
|
|
|
|
return list_conditionings
|
|
|
|
|
2022-11-19 18:43:57 +00:00
|
|
|
@torch.no_grad()
|
|
|
|
def get_text_embeddings(
|
2023-02-22 09:15:03 +00:00
|
|
|
self,
|
|
|
|
prompt: str):
|
2022-11-19 18:43:57 +00:00
|
|
|
r"""
|
|
|
|
Computes the text embeddings provided a string with a prompts.
|
2022-11-25 14:34:41 +00:00
|
|
|
Adapted from stable diffusion repo
|
2022-11-19 18:43:57 +00:00
|
|
|
Args:
|
|
|
|
prompt: str
|
|
|
|
ABC trending on artstation painted by Old Greg.
|
2022-11-23 13:36:33 +00:00
|
|
|
"""
|
2023-07-20 11:49:19 +00:00
|
|
|
return self.dh.get_text_embedding(prompt)
|
2022-11-19 18:43:57 +00:00
|
|
|
|
2023-02-18 06:56:30 +00:00
|
|
|
def write_imgs_transition(self, dp_img):
|
2023-01-08 09:32:58 +00:00
|
|
|
r"""
|
|
|
|
Writes the transition images into the folder dp_img.
|
2023-02-18 07:19:40 +00:00
|
|
|
Requires run_transition to be completed.
|
|
|
|
Args:
|
|
|
|
dp_img: str
|
|
|
|
Directory, into which the transition images, yaml file and latents are written.
|
2023-01-08 09:32:58 +00:00
|
|
|
"""
|
2023-02-18 06:56:30 +00:00
|
|
|
imgs_transition = self.tree_final_imgs
|
|
|
|
os.makedirs(dp_img, exist_ok=True)
|
2023-01-08 09:32:58 +00:00
|
|
|
for i, img in enumerate(imgs_transition):
|
|
|
|
img_leaf = Image.fromarray(img)
|
|
|
|
img_leaf.save(os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg"))
|
2023-02-22 09:15:03 +00:00
|
|
|
fp_yml = os.path.join(dp_img, "lowres.yaml")
|
2023-01-15 15:54:28 +00:00
|
|
|
self.save_statedict(fp_yml)
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-02-18 07:19:40 +00:00
|
|
|
def write_movie_transition(self, fp_movie, duration_transition, fps=30):
|
|
|
|
r"""
|
|
|
|
Writes the transition movie to fp_movie, using the given duration and fps..
|
|
|
|
The missing frames are linearly interpolated.
|
|
|
|
Args:
|
|
|
|
fp_movie: str
|
|
|
|
file pointer to the final movie.
|
|
|
|
duration_transition: float
|
|
|
|
duration of the movie in seonds
|
|
|
|
fps: int
|
|
|
|
fps of the movie
|
|
|
|
"""
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-02-18 07:19:40 +00:00
|
|
|
# Let's get more cheap frames via linear interpolation (duration_transition*fps frames)
|
|
|
|
imgs_transition_ext = add_frames_linear_interp(self.tree_final_imgs, duration_transition, fps)
|
|
|
|
|
|
|
|
# Save as MP4
|
|
|
|
if os.path.isfile(fp_movie):
|
|
|
|
os.remove(fp_movie)
|
2023-07-20 11:49:19 +00:00
|
|
|
ms = MovieSaver(fp_movie, fps=fps, shape_hw=[self.dh.height_img, self.dh.width_img])
|
2023-02-18 07:19:40 +00:00
|
|
|
for img in tqdm(imgs_transition_ext):
|
|
|
|
ms.write_frame(img)
|
|
|
|
ms.finalize()
|
|
|
|
|
2023-01-15 15:54:28 +00:00
|
|
|
def save_statedict(self, fp_yml):
|
2023-01-08 09:32:58 +00:00
|
|
|
# Dump everything relevant into yaml
|
2023-02-18 06:56:30 +00:00
|
|
|
imgs_transition = self.tree_final_imgs
|
2023-01-11 10:36:44 +00:00
|
|
|
state_dict = self.get_state_dict()
|
|
|
|
state_dict['nmb_images'] = len(imgs_transition)
|
2023-01-15 15:54:28 +00:00
|
|
|
yml_save(fp_yml, state_dict)
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-01-11 10:36:44 +00:00
|
|
|
def get_state_dict(self):
|
|
|
|
state_dict = {}
|
|
|
|
grab_vars = ['prompt1', 'prompt2', 'seed1', 'seed2', 'height', 'width',
|
|
|
|
'num_inference_steps', 'depth_strength', 'guidance_scale',
|
2023-02-18 06:56:30 +00:00
|
|
|
'guidance_scale_mid_damper', 'mid_compression_scaler', 'negative_prompt',
|
2023-02-20 07:29:21 +00:00
|
|
|
'branch1_crossfeed_power', 'branch1_crossfeed_range', 'branch1_crossfeed_decay'
|
|
|
|
'parental_crossfeed_power', 'parental_crossfeed_range', 'parental_crossfeed_power_decay']
|
2023-01-11 10:36:44 +00:00
|
|
|
for v in grab_vars:
|
|
|
|
if hasattr(self, v):
|
|
|
|
if v == 'seed1' or v == 'seed2':
|
|
|
|
state_dict[v] = int(getattr(self, v))
|
|
|
|
elif v == 'guidance_scale':
|
|
|
|
state_dict[v] = float(getattr(self, v))
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-01-11 10:36:44 +00:00
|
|
|
else:
|
2023-02-18 06:56:30 +00:00
|
|
|
try:
|
|
|
|
state_dict[v] = getattr(self, v)
|
2023-02-22 09:15:03 +00:00
|
|
|
except Exception:
|
2023-02-18 06:56:30 +00:00
|
|
|
pass
|
2023-01-11 10:36:44 +00:00
|
|
|
return state_dict
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2022-11-19 18:43:57 +00:00
|
|
|
def randomize_seed(self):
|
|
|
|
r"""
|
|
|
|
Set a random seed for a fresh start.
|
2023-02-22 09:15:03 +00:00
|
|
|
"""
|
2022-11-19 18:43:57 +00:00
|
|
|
seed = np.random.randint(999999999)
|
|
|
|
self.set_seed(seed)
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2022-11-19 18:43:57 +00:00
|
|
|
def set_seed(self, seed: int):
|
|
|
|
r"""
|
|
|
|
Set a the seed for a fresh start.
|
2023-02-22 09:15:03 +00:00
|
|
|
"""
|
2022-11-19 18:43:57 +00:00
|
|
|
self.seed = seed
|
2023-07-20 11:49:19 +00:00
|
|
|
self.dh.seed = seed
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-01-09 08:59:14 +00:00
|
|
|
def set_width(self, width):
|
|
|
|
r"""
|
|
|
|
Set the width of the resulting image.
|
2023-02-22 09:15:03 +00:00
|
|
|
"""
|
2023-01-09 08:59:14 +00:00
|
|
|
assert np.mod(width, 64) == 0, "set_width: value needs to be divisible by 64"
|
|
|
|
self.width = width
|
2023-07-20 11:49:19 +00:00
|
|
|
self.dh.width = width
|
2023-02-22 09:15:03 +00:00
|
|
|
|
2023-01-09 08:59:14 +00:00
|
|
|
def set_height(self, height):
|
|
|
|
r"""
|
|
|
|
Set the height of the resulting image.
|
2023-02-22 09:15:03 +00:00
|
|
|
"""
|
2023-01-09 08:59:14 +00:00
|
|
|
assert np.mod(height, 64) == 0, "set_height: value needs to be divisible by 64"
|
|
|
|
self.height = height
|
2023-07-20 11:49:19 +00:00
|
|
|
self.dh.height = height
|
2022-11-19 18:43:57 +00:00
|
|
|
|
|
|
|
def swap_forward(self):
|
|
|
|
r"""
|
2022-11-28 04:07:01 +00:00
|
|
|
Moves over keyframe two -> keyframe one. Useful for making a sequence of transitions
|
|
|
|
as in run_multi_transition()
|
2023-02-22 09:15:03 +00:00
|
|
|
"""
|
2022-11-19 18:43:57 +00:00
|
|
|
# Move over all latents
|
2023-02-18 06:56:30 +00:00
|
|
|
self.tree_latents[0] = self.tree_latents[-1]
|
2022-11-19 18:43:57 +00:00
|
|
|
# Move over prompts and text embeddings
|
|
|
|
self.prompt1 = self.prompt2
|
|
|
|
self.text_embedding1 = self.text_embedding2
|
|
|
|
# Final cleanup for extra sanity
|
2023-02-22 09:15:03 +00:00
|
|
|
self.tree_final_imgs = []
|
|
|
|
|
2023-02-15 17:21:00 +00:00
|
|
|
def get_lpips_similarity(self, imgA, imgB):
|
2023-02-16 10:48:45 +00:00
|
|
|
r"""
|
2023-02-22 09:15:03 +00:00
|
|
|
Computes the image similarity between two images imgA and imgB.
|
2023-02-16 10:48:45 +00:00
|
|
|
Used to determine the optimal point of insertion to create smooth transitions.
|
|
|
|
High values indicate low similarity.
|
2023-02-22 09:15:03 +00:00
|
|
|
"""
|
2024-01-06 17:16:36 +00:00
|
|
|
tensorA = torch.from_numpy(np.asarray(imgA)).float().cuda(self.device)
|
2023-02-22 09:15:03 +00:00
|
|
|
tensorA = 2 * tensorA / 255.0 - 1
|
|
|
|
tensorA = tensorA.permute([2, 0, 1]).unsqueeze(0)
|
2024-01-06 17:16:36 +00:00
|
|
|
tensorB = torch.from_numpy(np.asarray(imgB)).float().cuda(self.device)
|
2023-02-22 09:15:03 +00:00
|
|
|
tensorB = 2 * tensorB / 255.0 - 1
|
|
|
|
tensorB = tensorB.permute([2, 0, 1]).unsqueeze(0)
|
2023-02-15 17:21:00 +00:00
|
|
|
lploss = self.lpips(tensorA, tensorB)
|
|
|
|
lploss = float(lploss[0][0][0][0])
|
|
|
|
return lploss
|
2023-02-22 09:15:03 +00:00
|
|
|
|
|
|
|
# Auxiliary functions
|
|
|
|
def get_closest_idx(
|
|
|
|
self,
|
|
|
|
fract_mixing: float):
|
|
|
|
r"""
|
|
|
|
Helper function to retrieve the parents for any given mixing.
|
|
|
|
Example: fract_mixing = 0.4 and self.tree_fracts = [0, 0.3, 0.6, 1.0]
|
|
|
|
Will return the two closest values here, i.e. [1, 2]
|
|
|
|
"""
|
|
|
|
|
|
|
|
pdist = fract_mixing - np.asarray(self.tree_fracts)
|
|
|
|
pdist_pos = pdist.copy()
|
|
|
|
pdist_pos[pdist_pos < 0] = np.inf
|
|
|
|
b_parent1 = np.argmin(pdist_pos)
|
|
|
|
pdist_neg = -pdist.copy()
|
|
|
|
pdist_neg[pdist_neg <= 0] = np.inf
|
|
|
|
b_parent2 = np.argmin(pdist_neg)
|
|
|
|
|
|
|
|
if b_parent1 > b_parent2:
|
|
|
|
tmp = b_parent2
|
|
|
|
b_parent2 = b_parent1
|
|
|
|
b_parent1 = tmp
|
|
|
|
|
|
|
|
return b_parent1, b_parent2
|
2024-01-06 17:16:36 +00:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
|
|
# %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
|
|
|
|
from diffusers_holder import DiffusersHolder
|
|
|
|
from diffusers import DiffusionPipeline
|
|
|
|
from diffusers import AutoencoderTiny
|
|
|
|
pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
|
|
|
|
pipe.to("cuda")
|
|
|
|
# pipe.vae = AutoencoderTiny.from_pretrained('madebyollin/taesdxl', torch_device='cuda', torch_dtype=torch.float16)
|
|
|
|
# pipe.vae = pipe.vae.cuda()
|
|
|
|
|
|
|
|
dh = DiffusersHolder(pipe)
|
|
|
|
# %% Next let's set up all parameters
|
|
|
|
depth_strength = 0.5 # Specifies how deep (in terms of diffusion iterations the first branching happens)
|
|
|
|
t_compute_max_allowed = 3 # Determines the quality of the transition in terms of compute time you grant it
|
|
|
|
num_inference_steps = 4
|
|
|
|
size_output = (512, 512)
|
|
|
|
|
|
|
|
|
|
|
|
prompt1 = "underwater landscape, fish, und the sea, incredible detail, high resolution"
|
|
|
|
prompt2 = "rendering of an alien planet, strange plants, strange creatures, surreal"
|
|
|
|
negative_prompt = "blurry, ugly, pale" # Optional
|
|
|
|
|
|
|
|
fp_movie = 'movie_example1.mp4'
|
|
|
|
duration_transition = 12 # In seconds
|
|
|
|
|
|
|
|
# Spawn latent blending
|
|
|
|
lb = LatentBlending(dh)
|
|
|
|
lb.set_prompt1(prompt1)
|
|
|
|
lb.set_prompt2(prompt2)
|
|
|
|
lb.set_dimensions(size_output)
|
|
|
|
lb.set_negative_prompt(negative_prompt)
|
|
|
|
lb.set_guidance_scale(0)
|
|
|
|
|
|
|
|
lb.branch1_crossfeed_power = 0.3
|
|
|
|
lb.branch1_crossfeed_range = 0.6
|
|
|
|
lb.branch1_crossfeed_decay = 0.99
|
|
|
|
|
|
|
|
lb.parental_crossfeed_power = 0.8
|
|
|
|
lb.parental_crossfeed_power_decay = 1.0
|
|
|
|
lb.parental_crossfeed_range = 1.0
|
|
|
|
|
|
|
|
# Run latent blending
|
|
|
|
lb.run_transition(
|
|
|
|
depth_strength=depth_strength,
|
|
|
|
num_inference_steps=num_inference_steps,
|
|
|
|
t_compute_max_allowed=t_compute_max_allowed)
|
|
|
|
|
|
|
|
|
|
|
|
# Save movie
|
|
|
|
lb.write_movie_transition(fp_movie, duration_transition)
|
|
|
|
|
|
|
|
#%%
|
|
|
|
|
|
|
|
"""
|
|
|
|
checkout sizes
|
|
|
|
checkout good tree for num inference steps
|
|
|
|
checkout that good nmb inference step given
|
|
|
|
|
|
|
|
"""
|