parental mixing mode

This commit is contained in:
Johannes Stelzer 2023-02-16 11:48:45 +01:00
parent 0371868603
commit 07a5c4ffd7
3 changed files with 259 additions and 201 deletions

View File

@ -31,15 +31,9 @@ from stable_diffusion_holder import StableDiffusionHolder
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
import gradio as gr import gradio as gr
import copy import copy
from dotenv import find_dotenv, load_dotenv
"""
TODOS:
- clean parameter handling
- three buttons: diffuse A, diffuse B, make transition
- collapse for easy mode
- transition quality in terms of render time
"""
#%% #%%
@ -65,13 +59,17 @@ class BlendingFrontend():
self.list_settings = [] self.list_settings = []
self.state_current = {} self.state_current = {}
self.showing_current = True self.showing_current = True
self.branch1_influence = 0.1 self.branch1_influence = 0.3
self.branch1_mixing_depth = 0.3 self.branch1_max_depth_influence = 0.6
self.branch1_influence_decay = 0.3
self.parental_influence = 0.1
self.parental_max_depth_influence = 1.0
self.parental_influence_decay = 1.0
self.nmb_branches_final = 9 self.nmb_branches_final = 9
self.nmb_imgs_show = 5 # don't change self.nmb_imgs_show = 5 # don't change
self.fps = 30 self.fps = 30
self.duration_video = 15 self.duration_video = 10
self.t_compute_max_allowed = 15 self.t_compute_max_allowed = 10
self.dict_multi_trans = {} self.dict_multi_trans = {}
self.dict_multi_trans_include = {} self.dict_multi_trans_include = {}
self.multi_trans_currently_shown = [] self.multi_trans_currently_shown = []
@ -86,10 +84,21 @@ class BlendingFrontend():
else: else:
self.height = 768 self.height = 768
self.width = 768 self.width = 768
self.init_save_dir()
def init_save_dir(self):
load_dotenv(find_dotenv(), verbose=False)
try:
self.dp_out = os.getenv("dp_out")
except Exception as e:
self.dp_out = ""
# make dummy image # make dummy image
def save_empty_image(self): def save_empty_image(self):
self.fp_img_empty = 'empty.jpg' self.fp_img_empty = os.path.join(self.dp_out, 'empty.jpg')
Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5) Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5)
@ -116,8 +125,6 @@ class BlendingFrontend():
self.lb.set_negative_prompt(list_ui_elem[list_ui_keys.index('negative_prompt')]) self.lb.set_negative_prompt(list_ui_elem[list_ui_keys.index('negative_prompt')])
self.lb.guidance_scale = list_ui_elem[list_ui_keys.index('guidance_scale')] self.lb.guidance_scale = list_ui_elem[list_ui_keys.index('guidance_scale')]
self.lb.guidance_scale_mid_damper = list_ui_elem[list_ui_keys.index('guidance_scale_mid_damper')] self.lb.guidance_scale_mid_damper = list_ui_elem[list_ui_keys.index('guidance_scale_mid_damper')]
self.lb.branch1_influence = list_ui_elem[list_ui_keys.index('branch1_influence')]
self.lb.branch1_mixing_depth = list_ui_elem[list_ui_keys.index('branch1_mixing_depth')]
self.lb.t_compute_max_allowed = list_ui_elem[list_ui_keys.index('duration_compute')] self.lb.t_compute_max_allowed = list_ui_elem[list_ui_keys.index('duration_compute')]
self.lb.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')] self.lb.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
self.lb.sdh.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')] self.lb.sdh.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
@ -125,11 +132,19 @@ class BlendingFrontend():
self.lb.seed1 = list_ui_elem[list_ui_keys.index('seed1')] self.lb.seed1 = list_ui_elem[list_ui_keys.index('seed1')]
self.lb.seed2 = list_ui_elem[list_ui_keys.index('seed2')] self.lb.seed2 = list_ui_elem[list_ui_keys.index('seed2')]
self.lb.branch1_influence = list_ui_elem[list_ui_keys.index('branch1_influence')]
self.lb.branch1_max_depth_influence = list_ui_elem[list_ui_keys.index('branch1_max_depth_influence')]
self.lb.branch1_influence_decay = list_ui_elem[list_ui_keys.index('branch1_influence_decay')]
self.lb.parental_influence = list_ui_elem[list_ui_keys.index('parental_influence')]
self.lb.parental_max_depth_influence = list_ui_elem[list_ui_keys.index('parental_max_depth_influence')]
self.lb.parental_influence_decay = list_ui_elem[list_ui_keys.index('parental_influence_decay')]
def compute_img1(self, *args): def compute_img1(self, *args):
list_ui_elem = args list_ui_elem = args
self.setup_lb(list_ui_elem) self.setup_lb(list_ui_elem)
fp_img1 = f"img1_{get_time('second')}.jpg" fp_img1 = os.path.join(self.dp_out, f"img1_{get_time('second')}.jpg")
img1 = Image.fromarray(self.lb.compute_latents1(return_image=True)) img1 = Image.fromarray(self.lb.compute_latents1(return_image=True))
img1.save(fp_img1) img1.save(fp_img1)
self.save_empty_image() self.save_empty_image()
@ -138,7 +153,7 @@ class BlendingFrontend():
def compute_img2(self, *args): def compute_img2(self, *args):
list_ui_elem = args list_ui_elem = args
self.setup_lb(list_ui_elem) self.setup_lb(list_ui_elem)
fp_img2 = f"img2_{get_time('second')}.jpg" fp_img2 = os.path.join(self.dp_out, f"img2_{get_time('second')}.jpg")
img2 = Image.fromarray(self.lb.compute_latents2(return_image=True)) img2 = Image.fromarray(self.lb.compute_latents2(return_image=True))
img2.save(fp_img2) img2.save(fp_img2)
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2] return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2]
@ -198,9 +213,11 @@ class BlendingFrontend():
def get_fp_movie(self, timestamp, is_stacked=False): def get_fp_movie(self, timestamp, is_stacked=False):
if not is_stacked: if not is_stacked:
return f"movie_{timestamp}.mp4" fn = f"movie_{timestamp}.mp4"
else: else:
return f"movie_stacked_{timestamp}.mp4" fn = f"movie_stacked_{timestamp}.mp4"
fp = os.path.join(self.dp_out, fn)
return fp
def stack_forward(self, prompt2, seed2): def stack_forward(self, prompt2, seed2):
@ -319,32 +336,43 @@ if __name__ == "__main__":
with gr.Blocks() as demo: with gr.Blocks() as demo:
with gr.Row(): with gr.Row():
prompt1 = gr.Textbox(label="prompt 1") prompt1 = gr.Textbox(label="prompt 1")
negative_prompt = gr.Textbox(label="negative prompt")
prompt2 = gr.Textbox(label="prompt 2") prompt2 = gr.Textbox(label="prompt 2")
with gr.Row(): with gr.Row():
duration_compute = gr.Slider(10, 40, self.duration_video, step=1, label='compute budget for transition (seconds)', interactive=True) duration_compute = gr.Slider(5, 45, self.t_compute_max_allowed, step=1, label='compute budget for transition (seconds)', interactive=True)
duration_video = gr.Slider(0.1, 30, self.duration_video, step=0.1, label='result video duration (seconds)', interactive=True) duration_video = gr.Slider(0.1, 30, self.duration_video, step=0.1, label='result video duration (seconds)', interactive=True)
height = gr.Slider(256, 2048, self.height, step=128, label='height', interactive=True) height = gr.Slider(256, 2048, self.height, step=128, label='height', interactive=True)
width = gr.Slider(256, 2048, self.width, step=128, label='width', interactive=True) width = gr.Slider(256, 2048, self.width, step=128, label='width', interactive=True)
with gr.Accordion("Advanced Settings (click to expand)", open=False): with gr.Accordion("Advanced Settings (click to expand)", open=False):
with gr.Row(): with gr.Accordion("Diffusion settings", open=True):
depth_strength = gr.Slider(0.01, 0.99, self.depth_strength, step=0.01, label='depth_strength', interactive=True) with gr.Row():
branch1_influence = gr.Slider(0.0, 1.0, self.branch1_influence, step=0.01, label='branch1_influence', interactive=True) num_inference_steps = gr.Slider(5, 100, self.num_inference_steps, step=1, label='num_inference_steps', interactive=True)
branch1_mixing_depth = gr.Slider(0.0, 1.0, self.branch1_mixing_depth, step=0.01, label='branch1_mixing_depth', interactive=True) guidance_scale = gr.Slider(1, 25, self.guidance_scale, step=0.1, label='guidance_scale', interactive=True)
negative_prompt = gr.Textbox(label="negative prompt")
with gr.Accordion("Seeds control", open=True):
with gr.Row():
seed1 = gr.Number(420, label="seed 1", interactive=True)
b_newseed1 = gr.Button("randomize seed 1", variant='secondary')
seed2 = gr.Number(420, label="seed 2", interactive=True)
b_newseed2 = gr.Button("randomize seed 2", variant='secondary')
with gr.Accordion("Crossfeeding for last image", open=True):
with gr.Row():
branch1_influence = gr.Slider(0.0, 1.0, self.branch1_influence, step=0.01, label='crossfeed power', interactive=True)
branch1_max_depth_influence = gr.Slider(0.0, 1.0, self.branch1_max_depth_influence, step=0.01, label='crossfeed range', interactive=True)
branch1_influence_decay = gr.Slider(0.0, 1.0, self.branch1_influence_decay, step=0.01, label='crossfeed decay', interactive=True)
with gr.Row(): with gr.Accordion("Transition settings", open=True):
num_inference_steps = gr.Slider(5, 100, self.num_inference_steps, step=1, label='num_inference_steps', interactive=True) with gr.Row():
guidance_scale = gr.Slider(1, 25, self.guidance_scale, step=0.1, label='guidance_scale', interactive=True) depth_strength = gr.Slider(0.01, 0.99, self.depth_strength, step=0.01, label='depth_strength', interactive=True)
guidance_scale_mid_damper = gr.Slider(0.01, 2.0, self.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True) guidance_scale_mid_damper = gr.Slider(0.01, 2.0, self.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
parental_influence = gr.Slider(0.0, 1.0, self.parental_influence, step=0.01, label='parental power', interactive=True)
parental_max_depth_influence = gr.Slider(0.0, 1.0, self.parental_max_depth_influence, step=0.01, label='parental range', interactive=True)
parental_influence_decay = gr.Slider(0.0, 1.0, self.parental_influence_decay, step=0.01, label='parental decay', interactive=True)
with gr.Row():
seed1 = gr.Number(420, label="seed 1", interactive=True)
b_newseed1 = gr.Button("randomize seed 1", variant='secondary')
seed2 = gr.Number(420, label="seed 2", interactive=True)
b_newseed2 = gr.Button("randomize seed 2", variant='secondary')
with gr.Row(): with gr.Row():
b_compute1 = gr.Button('compute first image', variant='primary') b_compute1 = gr.Button('compute first image', variant='primary')
@ -373,7 +401,8 @@ if __name__ == "__main__":
dict_ui_elem["depth_strength"] = depth_strength dict_ui_elem["depth_strength"] = depth_strength
dict_ui_elem["branch1_influence"] = branch1_influence dict_ui_elem["branch1_influence"] = branch1_influence
dict_ui_elem["branch1_mixing_depth"] = branch1_mixing_depth dict_ui_elem["branch1_max_depth_influence"] = branch1_max_depth_influence
dict_ui_elem["branch1_influence_decay"] = branch1_influence_decay
dict_ui_elem["num_inference_steps"] = num_inference_steps dict_ui_elem["num_inference_steps"] = num_inference_steps
dict_ui_elem["guidance_scale"] = guidance_scale dict_ui_elem["guidance_scale"] = guidance_scale
@ -381,6 +410,10 @@ if __name__ == "__main__":
dict_ui_elem["seed1"] = seed1 dict_ui_elem["seed1"] = seed1
dict_ui_elem["seed2"] = seed2 dict_ui_elem["seed2"] = seed2
dict_ui_elem["parental_max_depth_influence"] = parental_max_depth_influence
dict_ui_elem["parental_influence"] = parental_influence
dict_ui_elem["parental_influence_decay"] = parental_influence_decay
# Convert to list, as gradio doesn't seem to accept dicts # Convert to list, as gradio doesn't seem to accept dicts
list_ui_elem = [] list_ui_elem = []
list_ui_keys = [] list_ui_keys = []

View File

@ -107,8 +107,16 @@ class LatentBlending():
self.noise_level_upscaling = 20 self.noise_level_upscaling = 20
self.list_injection_idx = None self.list_injection_idx = None
self.list_nmb_branches = None self.list_nmb_branches = None
# Mixing parameters
self.branch1_influence = 0.0 self.branch1_influence = 0.0
self.branch1_mixing_depth = 0.65 self.branch1_max_depth_influence = 0.65
self.branch1_influence_decay = 0.8
self.parental_influence = 0.0
self.parental_max_depth_influence = 1.0
self.parental_influence_decay = 1.0
self.branch1_insertion_completed = False self.branch1_insertion_completed = False
self.set_guidance_scale(guidance_scale) self.set_guidance_scale(guidance_scale)
self.init_mode() self.init_mode()
@ -389,10 +397,6 @@ class LatentBlending():
fixed_seeds: Optional[List[int]] = None, fixed_seeds: Optional[List[int]] = None,
): ):
# # FIXME: deal with these tree args later
# self.num_inference_steps = 30
# self.t_compute_max_allowed = 60
# Sanity checks first # Sanity checks first
assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before' assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before'
assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before' assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before'
@ -411,17 +415,15 @@ class LatentBlending():
self.sdh.num_inference_steps = self.num_inference_steps self.sdh.num_inference_steps = self.num_inference_steps
# Compute / Recycle first image # Compute / Recycle first image
if not recycle_img1: if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps:
list_latents1 = self.compute_latents1() list_latents1 = self.compute_latents1()
else: else:
# FIXME: check if latents there...
list_latents1 = self.tree_latents[0] list_latents1 = self.tree_latents[0]
# Compute / Recycle first image # Compute / Recycle first image
if not recycle_img2: if not recycle_img2 or len(self.tree_latents[-1]) != self.num_inference_steps:
list_latents2 = self.compute_latents2() list_latents2 = self.compute_latents2()
else: else:
# FIXME: check if latents there...
list_latents2 = self.tree_latents[-1] list_latents2 = self.tree_latents[-1]
# Reset the tree, injecting the edge latents1/2 we just generated/recycled # Reset the tree, injecting the edge latents1/2 we just generated/recycled
@ -438,7 +440,7 @@ class LatentBlending():
while t_compute < self.t_compute_max_allowed: while t_compute < self.t_compute_max_allowed:
list_compute_steps = self.num_inference_steps - list_idx_injection list_compute_steps = self.num_inference_steps - list_idx_injection
list_compute_steps *= list_nmb_stems list_compute_steps *= list_nmb_stems
t_compute = np.sum(list_compute_steps) * self.dt_per_diff t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15*np.sum(list_nmb_stems)
increase_done = False increase_done = False
for s_idx in range(len(list_nmb_stems)-1): for s_idx in range(len(list_nmb_stems)-1):
if list_nmb_stems[s_idx+1] / list_nmb_stems[s_idx] >= 2: if list_nmb_stems[s_idx+1] / list_nmb_stems[s_idx] >= 2:
@ -449,13 +451,14 @@ class LatentBlending():
list_nmb_stems[-1] += 1 list_nmb_stems[-1] += 1
# print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}") # print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}")
# Run iteratively # Run iteratively, always inserting new branches where they are needed most
for s_idx in tqdm(range(len(list_idx_injection))): for s_idx in tqdm(range(len(list_idx_injection))):
nmb_stems = list_nmb_stems[s_idx] nmb_stems = list_nmb_stems[s_idx]
idx_injection = list_idx_injection[s_idx] idx_injection = list_idx_injection[s_idx]
for i in range(nmb_stems): for i in range(nmb_stems):
fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection) fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection)
self.set_guidance_mid_dampening(fract_mixing)
# print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}") # print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}")
list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection) list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection)
self.insert_into_tree(fract_mixing, idx_injection, list_latents) self.insert_into_tree(fract_mixing, idx_injection, list_latents)
@ -465,6 +468,14 @@ class LatentBlending():
def get_mixing_parameters(self, idx_injection): def get_mixing_parameters(self, idx_injection):
r"""
Computes which parental latents should be mixed together to achieve a smooth blend.
As metric, we are using lpips image similarity. The insertion takes place
where the metric is maximal.
Args:
idx_injection: int
the index in terms of diffusion steps, where the next insertion will start.
"""
# get_lpips_similarity # get_lpips_similarity
similarities = [] similarities = []
for i in range(len(self.tree_final_imgs)-1): for i in range(len(self.tree_final_imgs)-1):
@ -499,7 +510,16 @@ class LatentBlending():
def insert_into_tree(self, fract_mixing, idx_injection, list_latents): def insert_into_tree(self, fract_mixing, idx_injection, list_latents):
# FIXME r"""
Inserts all necessary parameters into the trajectory tree.
Args:
fract_mixing: float
the fraction along the transition axis [0, 1]
idx_injection: int
the index in terms of diffusion steps, where the next insertion will start.
list_latents: list
list of the latents to be inserted
"""
b_parent1, b_parent2 = get_closest_idx(fract_mixing, self.tree_fracts) b_parent1, b_parent2 = get_closest_idx(fract_mixing, self.tree_fracts)
self.tree_latents.insert(b_parent1+1, list_latents) self.tree_latents.insert(b_parent1+1, list_latents)
self.tree_final_imgs.insert(b_parent1+1, self.sdh.latent2image(list_latents[-1])) self.tree_final_imgs.insert(b_parent1+1, self.sdh.latent2image(list_latents[-1]))
@ -508,23 +528,67 @@ class LatentBlending():
def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection): def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection):
# FIXME r"""
Runs a diffusion trajectory, using the latents from the respective parents
Args:
fract_mixing: float
the fraction along the transition axis [0, 1]
b_parent1: int
index of parent1 to be used
b_parent2: int
index of parent2 to be used
idx_injection: int
the index in terms of diffusion steps, where the next insertion will start.
"""
list_conditionings = self.get_mixed_conditioning(fract_mixing) list_conditionings = self.get_mixed_conditioning(fract_mixing)
fract_mixing_parental = (fract_mixing - self.tree_fracts[b_parent1]) / (self.tree_fracts[b_parent2] - self.tree_fracts[b_parent1]) fract_mixing_parental = (fract_mixing - self.tree_fracts[b_parent1]) / (self.tree_fracts[b_parent2] - self.tree_fracts[b_parent1])
idx_reversed = self.num_inference_steps - idx_injection # idx_reversed = self.num_inference_steps - idx_injection
latents_for_injection = interpolate_spherical(
self.tree_latents[b_parent1][-idx_reversed-1], list_latents_parental_mix = []
self.tree_latents[b_parent2][-idx_reversed-1], for i in range(self.num_inference_steps):
fract_mixing_parental) latents_p1 = self.tree_latents[b_parent1][i]
list_latents = self.run_diffusion(list_conditionings, latents_for_injection=latents_for_injection, idx_start=idx_injection) latents_p2 = self.tree_latents[b_parent2][i]
if latents_p1 is None or latents_p2 is None:
latents_parental = None
else:
latents_parental = interpolate_spherical(latents_p1, latents_p2, fract_mixing_parental)
list_latents_parental_mix.append(latents_parental)
idx_mixing_stop = int(round(self.num_inference_steps*self.parental_max_depth_influence))
mixing_coeffs = idx_injection*[self.parental_influence]
nmb_mixing = idx_mixing_stop - idx_injection
if nmb_mixing > 0:
mixing_coeffs.extend(list(np.linspace(self.parental_influence, self.parental_influence*self.parental_influence_decay, nmb_mixing)))
mixing_coeffs.extend((self.num_inference_steps-len(mixing_coeffs))*[0])
latents_start = list_latents_parental_mix[idx_injection-1]
list_latents = self.run_diffusion(
list_conditionings,
latents_start = latents_start,
idx_start = idx_injection,
list_latents_mixing = list_latents_parental_mix,
mixing_coeffs = mixing_coeffs
)
return list_latents return list_latents
def compute_latents1(self, return_image=False): def compute_latents1(self, return_image=False):
r"""
Runs a diffusion trajectory for the first image
Args:
return_image: bool
whether to return an image or the list of latents
"""
print("starting compute_latents1") print("starting compute_latents1")
list_conditionings = [self.text_embedding1] list_conditionings = [self.text_embedding1]
t0 = time.time() t0 = time.time()
list_latents1 = self.run_diffusion(list_conditionings, seed_source=self.seed1) latents_start = self.get_noise(self.seed1)
list_latents1 = self.run_diffusion(
list_conditionings,
latents_start = latents_start,
idx_start = 0
)
t1 = time.time() t1 = time.time()
self.dt_per_diff = (t1-t0) / self.num_inference_steps self.dt_per_diff = (t1-t0) / self.num_inference_steps
self.tree_latents[0] = list_latents1 self.tree_latents[0] = list_latents1
@ -533,25 +597,31 @@ class LatentBlending():
else: else:
return list_latents1 return list_latents1
def compute_latents2(self, return_image=False): def compute_latents2(self, return_image=False):
print("starting compute_latents2") r"""
list_conditionings = [self.text_embedding2 Runs a diffusion trajectory for the last image, which may be affected by the first image's trajectory.
] Args:
return_image: bool
whether to return an image or the list of latents
"""
list_conditionings = [self.text_embedding2]
latents_start = self.get_noise(self.seed2)
# Influence from branch1 # Influence from branch1
if self.branch1_influence > 0.0: if self.branch1_influence > 0.0:
self.branch1_influence = np.clip(self.branch1_influence, 0, 1) # Set up the mixing_coeffs
self.branch1_mixing_depth = np.clip(self.branch1_mixing_depth, 0, 1) idx_mixing_stop = int(round(self.num_inference_steps*self.branch1_max_depth_influence))
idx_crossfeed = int(round(self.num_inference_steps*self.branch1_mixing_depth)) mixing_coeffs = list(np.linspace(self.branch1_influence, self.branch1_influence*self.branch1_influence_decay, idx_mixing_stop))
mixing_coeffs.extend((self.num_inference_steps-idx_mixing_stop)*[0])
list_latents_mixing = self.tree_latents[0]
list_latents2 = self.run_diffusion( list_latents2 = self.run_diffusion(
list_conditionings, list_conditionings,
idx_start=idx_crossfeed, latents_start = latents_start,
latents_for_injection=self.tree_latents[0], idx_start = 0,
seed_source=self.seed2, list_latents_mixing = list_latents_mixing,
seed_mixing_target=self.seed1, mixing_coeffs = mixing_coeffs
mixing_coeff=self.branch1_influence) )
else: else:
list_latents2 = self.run_diffusion(list_conditionings) list_latents2 = self.run_diffusion(list_conditionings, latents_start)
self.tree_latents[-1] = list_latents2 self.tree_latents[-1] = list_latents2
if return_image: if return_image:
@ -559,9 +629,14 @@ class LatentBlending():
else: else:
return list_latents2 return list_latents2
def get_noise(self, seed):
generator = torch.Generator(device=self.sdh.device).manual_seed(int(seed))
shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
C, H, W = shape_latents
return torch.randn((1, C, H, W), generator=generator, device=self.sdh.device)
def run_transition_legacy(
def run_transition_OLD(
self, self,
recycle_img1: Optional[bool] = False, recycle_img1: Optional[bool] = False,
recycle_img2: Optional[bool] = False, recycle_img2: Optional[bool] = False,
@ -569,6 +644,7 @@ class LatentBlending():
premature_stop: Optional[int] = np.inf, premature_stop: Optional[int] = np.inf,
): ):
r""" r"""
Old legacy function for computing transitions.
Returns a list of transition images using spherical latent blending. Returns a list of transition images using spherical latent blending.
Args: Args:
recycle_img1: Optional[bool]: recycle_img1: Optional[bool]:
@ -610,9 +686,9 @@ class LatentBlending():
if self.branch1_influence > 0.0 and not self.branch1_insertion_completed: if self.branch1_influence > 0.0 and not self.branch1_insertion_completed:
assert self.list_nmb_branches[0]==2, 'branch1 influnce currently requires the self.list_nmb_branches[0] = 0' assert self.list_nmb_branches[0]==2, 'branch1 influnce currently requires the self.list_nmb_branches[0] = 0'
self.branch1_influence = np.clip(self.branch1_influence, 0, 1) self.branch1_influence = np.clip(self.branch1_influence, 0, 1)
self.branch1_mixing_depth = np.clip(self.branch1_mixing_depth, 0, 1) self.branch1_max_depth_influence = np.clip(self.branch1_max_depth_influence, 0, 1)
self.list_nmb_branches.insert(1, 2) self.list_nmb_branches.insert(1, 2)
idx_crossfeed = int(round(self.list_injection_idx[1]*self.branch1_mixing_depth)) idx_crossfeed = int(round(self.list_injection_idx[1]*self.branch1_max_depth_influence))
self.list_injection_idx_ext.insert(1, idx_crossfeed) self.list_injection_idx_ext.insert(1, idx_crossfeed)
self.tree_fracts.insert(1, self.tree_fracts[0]) self.tree_fracts.insert(1, self.tree_fracts[0])
self.tree_status.insert(1, self.tree_status[0]) self.tree_status.insert(1, self.tree_status[0])
@ -790,27 +866,27 @@ class LatentBlending():
def run_diffusion( def run_diffusion(
self, self,
list_conditionings, list_conditionings,
latents_for_injection: torch.FloatTensor = None, latents_start: torch.FloatTensor = None,
idx_start: int = -1, idx_start: int = 0,
idx_stop: int = -1, list_latents_mixing = None,
seed_source: int = -1, mixing_coeffs = 0.0,
seed_mixing_target: int = -1,
mixing_coeff: float = 0.0,
return_image: Optional[bool] = False return_image: Optional[bool] = False
): ):
r""" r"""
Wrapper function for run_diffusion_standard and run_diffusion_inpaint. Wrapper function for diffusion runners.
Depending on the mode, the correct one will be executed. Depending on the mode, the correct one will be executed.
Args: Args:
list_conditionings: List of all conditionings for the diffusion model. list_conditionings: List of all conditionings for the diffusion model.
latents_for_injection: torch.FloatTensor latents_start: torch.FloatTensor
Latents that are used for injection Latents that are used for injection
idx_start: int idx_start: int
Index of the diffusion process start and where the latents_for_injection are injected Index of the diffusion process start and where the latents_for_injection are injected
idx_stop: int list_latents_mixing: torch.FloatTensor
Index of the diffusion process end. List of latents (latent trajectories) that are used for mixing
FIXME ARGS mixing_coeffs: float or list
Coefficients, how strong each element of list_latents_mixing will be mixed in.
return_image: Optional[bool] return_image: Optional[bool]
Optionally return image directly Optionally return image directly
""" """
@ -822,26 +898,25 @@ class LatentBlending():
if self.mode == 'standard': if self.mode == 'standard':
text_embeddings = list_conditionings[0] text_embeddings = list_conditionings[0]
return self.sdh.run_diffusion_standard( return self.sdh.run_diffusion_standard(
text_embeddings, text_embeddings = text_embeddings,
latents_for_injection=latents_for_injection, latents_start = latents_start,
idx_start=idx_start, idx_start = idx_start,
idx_stop=idx_stop, list_latents_mixing = list_latents_mixing,
seed_source=seed_source, mixing_coeffs = mixing_coeffs,
seed_mixing_target=seed_mixing_target, return_image = return_image,
mixing_coeff=mixing_coeff,
return_image=return_image,
) )
elif self.mode == 'inpaint': # elif self.mode == 'inpaint':
text_embeddings = list_conditionings[0] # text_embeddings = list_conditionings[0]
assert self.sdh.image_source is not None, "image_source is None. Please run init_inpainting first." # assert self.sdh.image_source is not None, "image_source is None. Please run init_inpainting first."
assert self.sdh.mask_image is not None, "image_source is None. Please run init_inpainting first." # assert self.sdh.mask_image is not None, "image_source is None. Please run init_inpainting first."
return self.sdh.run_diffusion_inpaint(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image) # return self.sdh.run_diffusion_inpaint(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
# FIXME LONG LINE
elif self.mode == 'upscale': # # FIXME LONG LINE and bad args
cond = list_conditionings[0] # elif self.mode == 'upscale':
uc_full = list_conditionings[1] # cond = list_conditionings[0]
return self.sdh.run_diffusion_upscaling(cond, uc_full, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image) # uc_full = list_conditionings[1]
# return self.sdh.run_diffusion_upscaling(cond, uc_full, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
def run_upscaling_step1( def run_upscaling_step1(
self, self,
@ -1100,7 +1175,11 @@ class LatentBlending():
def get_lpips_similarity(self, imgA, imgB): def get_lpips_similarity(self, imgA, imgB):
# FIXME r"""
Computes the image similarity between two images imgA and imgB.
Used to determine the optimal point of insertion to create smooth transitions.
High values indicate low similarity.
"""
tensorA = torch.from_numpy(imgA).float().cuda(self.device) tensorA = torch.from_numpy(imgA).float().cuda(self.device)
tensorA = 2*tensorA/255.0 - 1 tensorA = 2*tensorA/255.0 - 1
tensorA = tensorA.permute([2,0,1]).unsqueeze(0) tensorA = tensorA.permute([2,0,1]).unsqueeze(0)
@ -1406,55 +1485,24 @@ if __name__ == "__main__":
# Run latent blending # Run latent blending
self.branch1_influence = 0.3 self.branch1_influence = 0.3
self.branch1_mixing_depth = 0.4 self.branch1_max_depth_influence = 0.4
self.run_transition(depth_strength=depth_strength, fixed_seeds=fixed_seeds) # self.run_transition(depth_strength=depth_strength, fixed_seeds=fixed_seeds)
#%% self.seed1=21312
self.branch1_influence = 0.3 img1 =self.compute_latents1(True)
self.branch1_mixing_depth = 0.5 #%
img2 = self.compute_latents2(return_image=True) self.seed2=1234121
Image.fromarray(img2) self.branch1_influence = 0.7
self.branch1_max_depth_influence = 0.3
self.branch1_influence_decay = 0.3
img2 =self.compute_latents2(True)
# Image.fromarray(np.concatenate((img1, img2), axis=1))
#%% #%%
idx_injection = 15 t0 = time.time()
fract_mixing = 0.5 self.t_compute_max_allowed = 30
list_conditionings = self.get_mixed_conditioning(fract_mixing) self.parental_max_depth_influence = 1.0
latents_for_injection = interpolate_spherical(self.tree_latents[0][idx_injection], self.tree_latents[-1][idx_injection], fract_mixing) self.parental_influence = 0.0
list_latents = self.run_diffusion(list_conditionings, latents_for_injection=latents_for_injection, idx_start=idx_injection) self.parental_influence_decay = 1.0
img_mix = self.sdh.latent2image((list_latents[-1])) imgs_transition = self.run_transition(recycle_img1=True, recycle_img2=True)
t1 = time.time()
Image.fromarray(np.concatenate((img1,img_mix,img2), axis=1)).resize((800,800//3)) print(f"took: {t1-t0}s")
#%% scheme
# init scheme
list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps-1, 2)
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
#%%
list_compute_steps = self.num_inference_steps - list_idx_injection
list_compute_steps *= list_nmb_stems
t_compute = np.sum(list_compute_steps) * self.dt_per_diff
increase_done = False
for s_idx in range(len(list_nmb_stems)-1):
if list_nmb_stems[s_idx+1] / list_nmb_stems[s_idx] >= 3:
list_nmb_stems[s_idx] += 1
increase_done = True
break
if not increase_done:
list_nmb_stems[-1] += 1
print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}")
#%%
imgs_transition = self.tree_final_imgs
# Let's get more cheap frames via linear interpolation (duration_transition*fps frames)
imgs_transition_ext = add_frames_linear_interp(imgs_transition, 15, fps)
# Save as MP4
fp_movie = "test.mp4"
if os.path.isfile(fp_movie):
os.remove(fp_movie)
ms = MovieSaver(fp_movie, fps=fps, shape_hw=[sdh.height, sdh.width])
for img in tqdm(imgs_transition_ext):
ms.write_frame(img)
ms.finalize()

View File

@ -278,12 +278,10 @@ class StableDiffusionHolder:
def run_diffusion_standard( def run_diffusion_standard(
self, self,
text_embeddings: torch.FloatTensor, text_embeddings: torch.FloatTensor,
latents_for_injection = None, latents_start: torch.FloatTensor,
idx_start: int = -1, idx_start: int = 0,
idx_stop: int = -1, list_latents_mixing = None,
seed_source: int = -1, mixing_coeffs = 0.0,
seed_mixing_target: int = -1,
mixing_coeff: float = 0.0,
return_image: Optional[bool] = False, return_image: Optional[bool] = False,
): ):
r""" r"""
@ -297,34 +295,26 @@ class StableDiffusionHolder:
Latents that are used for injection Latents that are used for injection
idx_start: int idx_start: int
Index of the diffusion process start and where the latents_for_injection are injected Index of the diffusion process start and where the latents_for_injection are injected
idx_stop: int
Index of the diffusion process end.
mixing_coeff: mixing_coeff:
# FIXME # FIXME
seed_source:
# FIXME
seed_mixing:
# FIXME
return_image: Optional[bool] return_image: Optional[bool]
Optionally return image directly Optionally return image directly
""" """
# Asserts
if latents_for_injection is None: if type(mixing_coeffs) == float:
do_inject_latents = False list_mixing_coeffs = self.num_inference_steps*[mixing_coeffs]
do_mix_latents = False elif type(mixing_coeffs) == list:
assert len(mixing_coeffs) == self.num_inference_steps
list_mixing_coeffs = mixing_coeffs
else: else:
if mixing_coeff > 0.0: raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps")
do_inject_latents = False
do_mix_latents = True
assert seed_mixing_target != -1, "Set to correct seed for mixing"
else:
do_inject_latents = True
do_mix_latents = False
if np.sum(list_mixing_coeffs) > 0:
assert len(list_latents_mixing) == self.num_inference_steps
precision_scope = autocast if self.precision == "autocast" else nullcontext precision_scope = autocast if self.precision == "autocast" else nullcontext
generator = torch.Generator(device=self.device).manual_seed(int(seed_source))
with precision_scope("cuda"): with precision_scope("cuda"):
with self.model.ema_scope(): with self.model.ema_scope():
@ -332,14 +322,10 @@ class StableDiffusionHolder:
uc = self.model.get_learned_conditioning(self.negative_prompt) uc = self.model.get_learned_conditioning(self.negative_prompt)
else: else:
uc = None uc = None
shape_latents = [self.C, self.height // self.f, self.width // self.f]
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False) self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
C, H, W = shape_latents
size = (1, C, H, W)
b = size[0]
latents = torch.randn(size, generator=generator, device=self.device) latents = latents_start.clone()
timesteps = self.sampler.ddim_timesteps timesteps = self.sampler.ddim_timesteps
@ -349,29 +335,20 @@ class StableDiffusionHolder:
# collect latents # collect latents
list_latents_out = [] list_latents_out = []
for i, step in enumerate(time_range): for i, step in enumerate(time_range):
if do_inject_latents: # Set the right starting latents
# Inject latent at right place if i < idx_start:
if i < idx_start: list_latents_out.append(None)
continue continue
elif i == idx_start: elif i == idx_start:
latents = latents_for_injection.clone() latents = latents_start.clone()
if do_mix_latents:
if i == 0: # Mix the latents.
generator = torch.Generator(device=self.device).manual_seed(int(seed_mixing_target)) if i > 0 and list_mixing_coeffs[i]>0:
latents_mixtarget = torch.randn(size, generator=generator, device=self.device) latents_mixtarget = list_latents_mixing[i-1].clone()
if i < idx_start: latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
latents_mixtarget = latents_for_injection[i-1].clone()
latents = interpolate_spherical(latents, latents_mixtarget, mixing_coeff)
if i == idx_start:
do_mix_latents = False
if i == idx_stop:
return list_latents_out
# print(f"diffusion iter {i}")
index = total_steps - i - 1 index = total_steps - i - 1
ts = torch.full((b,), step, device=self.device, dtype=torch.long) ts = torch.full((1,), step, device=self.device, dtype=torch.long)
outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False, outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False,
quantize_denoised=False, temperature=1.0, quantize_denoised=False, temperature=1.0,
noise_dropout=0.0, score_corrector=None, noise_dropout=0.0, score_corrector=None,