new transition engine and crossfeeding

This commit is contained in:
Johannes Stelzer 2023-02-15 18:21:00 +01:00
parent 5e979818b2
commit 0371868603
3 changed files with 480 additions and 150 deletions

View File

@ -33,6 +33,13 @@ import gradio as gr
import copy
"""
TODOS:
- clean parameter handling
- three buttons: diffuse A, diffuse B, make transition
- collapse for easy mode
- transition quality in terms of render time
"""
#%%
@ -45,7 +52,7 @@ class BlendingFrontend():
self.lb = LatentBlending(sdh)
self.share = True
self.num_inference_steps = 20
self.num_inference_steps = 30
self.depth_strength = 0.25
self.seed1 = 42
self.seed2 = 420
@ -58,11 +65,13 @@ class BlendingFrontend():
self.list_settings = []
self.state_current = {}
self.showing_current = True
self.branch1_influence = 0.02
self.branch1_influence = 0.1
self.branch1_mixing_depth = 0.3
self.nmb_branches_final = 9
self.nmb_imgs_show = 5 # don't change
self.fps = 30
self.duration = 10
self.duration_video = 15
self.t_compute_max_allowed = 15
self.dict_multi_trans = {}
self.dict_multi_trans_include = {}
self.multi_trans_currently_shown = []
@ -79,114 +88,87 @@ class BlendingFrontend():
self.width = 768
# make dummy image
def save_empty_image(self):
self.fp_img_empty = 'empty.jpg'
Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5)
def change_depth_strength(self, value):
self.depth_strength = value
print(f"changed depth_strength to {value}")
def change_num_inference_steps(self, value):
self.num_inference_steps = value
print(f"changed num_inference_steps to {value}")
def change_guidance_scale(self, value):
self.guidance_scale = value
self.lb.set_guidance_scale(value)
print(f"changed guidance_scale to {value}")
def change_guidance_scale_mid_damper(self, value):
self.guidance_scale_mid_damper = value
print(f"changed guidance_scale_mid_damper to {value}")
def change_mid_compression_scaler(self, value):
self.mid_compression_scaler = value
print(f"changed mid_compression_scaler to {value}")
def change_branch1_influence(self, value):
self.branch1_influence = value
print(f"changed branch1_influence to {value}")
def change_height(self, value):
self.height = value
print(f"changed height to {value}")
def change_width(self, value):
self.width = value
print(f"changed width to {value}")
def change_nmb_branches_final(self, value):
self.nmb_branches_final = value
print(f"changed nmb_branches_final to {value}")
def change_duration(self, value):
self.duration = value
print(f"changed duration to {value}")
def change_fps(self, value):
self.fps = value
print(f"changed fps to {value}")
def change_negative_prompt(self, value):
self.negative_prompt = value
def change_seed1(self, value):
self.seed1 = int(value)
def change_seed2(self, value):
self.seed2 = int(value)
def randomize_seed1(self):
seed = np.random.randint(0, 10000000)
self.change_seed1(seed)
self.seed1 = int(seed)
print(f"randomize_seed1: new seed = {self.seed1}")
return seed
def randomize_seed2(self):
seed = np.random.randint(0, 10000000)
self.change_seed2(seed)
self.seed2 = int(seed)
print(f"randomize_seed2: new seed = {self.seed2}")
return seed
def compute_transition(self, prompt1, prompt2):
self.prompt1 = prompt1
self.prompt2 = prompt2
print("STARTING DIFFUSION!")
def setup_lb(self, list_ui_elem):
# Collect latent blending variables
self.state_current = self.get_state_dict()
self.lb.set_width(list_ui_elem[list_ui_keys.index('width')])
self.lb.set_height(list_ui_elem[list_ui_keys.index('height')])
self.lb.set_prompt1(list_ui_elem[list_ui_keys.index('prompt1')])
self.lb.set_prompt2(list_ui_elem[list_ui_keys.index('prompt2')])
self.lb.set_negative_prompt(list_ui_elem[list_ui_keys.index('negative_prompt')])
self.lb.guidance_scale = list_ui_elem[list_ui_keys.index('guidance_scale')]
self.lb.guidance_scale_mid_damper = list_ui_elem[list_ui_keys.index('guidance_scale_mid_damper')]
self.lb.branch1_influence = list_ui_elem[list_ui_keys.index('branch1_influence')]
self.lb.branch1_mixing_depth = list_ui_elem[list_ui_keys.index('branch1_mixing_depth')]
self.lb.t_compute_max_allowed = list_ui_elem[list_ui_keys.index('duration_compute')]
self.lb.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
self.lb.sdh.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
self.duration_video = list_ui_elem[list_ui_keys.index('duration_video')]
self.lb.seed1 = list_ui_elem[list_ui_keys.index('seed1')]
self.lb.seed2 = list_ui_elem[list_ui_keys.index('seed2')]
def compute_img1(self, *args):
list_ui_elem = args
self.setup_lb(list_ui_elem)
fp_img1 = f"img1_{get_time('second')}.jpg"
img1 = Image.fromarray(self.lb.compute_latents1(return_image=True))
img1.save(fp_img1)
self.save_empty_image()
return [fp_img1, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty]
def compute_img2(self, *args):
list_ui_elem = args
self.setup_lb(list_ui_elem)
fp_img2 = f"img2_{get_time('second')}.jpg"
img2 = Image.fromarray(self.lb.compute_latents2(return_image=True))
img2.save(fp_img2)
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2]
def compute_transition(self, *args):
list_ui_elem = args
self.setup_lb(list_ui_elem)
print("STARTING DIFFUSION!")
if self.use_debug:
list_imgs = [(255*np.random.rand(self.height,self.width,3)).astype(np.uint8) for l in range(5)]
list_imgs = [Image.fromarray(l) for l in list_imgs]
print("DONE! SENDING BACK RESULTS")
return list_imgs
# Collect latent blending variables
self.lb.set_width(self.width)
self.lb.set_height(self.height)
self.lb.autosetup_branching(
depth_strength = self.depth_strength,
num_inference_steps = self.num_inference_steps,
nmb_branches_final = self.nmb_branches_final,
nmb_mindist = 3)
self.lb.set_prompt1(self.prompt1)
self.lb.set_prompt2(self.prompt2)
self.lb.set_negative_prompt(self.negative_prompt)
self.lb.guidance_scale = self.guidance_scale
self.lb.guidance_scale_mid_damper = self.guidance_scale_mid_damper
self.lb.mid_compression_scaler = self.mid_compression_scaler
self.lb.branch1_influence = self.branch1_influence
fixed_seeds = [self.seed1, self.seed2]
# Run Latent Blending
imgs_transition = self.lb.run_transition(fixed_seeds=fixed_seeds)
imgs_transition = self.lb.run_transition(
recycle_img1=True,
recycle_img2=True,
num_inference_steps=self.num_inference_steps,
depth_strength=self.depth_strength,
fixed_seeds=fixed_seeds
)
print(f"Latent Blending pass finished. Resulted in {len(imgs_transition)} images")
# Subselect the preview images (hard fixed to self.nmb_imgs_show=5)
assert np.mod((self.nmb_branches_final-self.nmb_imgs_show)/4, 1)==0, 'self.nmb_branches_final illegal value!'
idx_list = np.linspace(0, self.nmb_branches_final-1, self.nmb_imgs_show).astype(np.int32)
# Subselect three preview images
idx_img_prev = np.round(np.linspace(0, len(imgs_transition)-1, 5)[1:-1]).astype(np.int32)
list_imgs_preview = []
for j in idx_list:
for j in idx_img_prev:
list_imgs_preview.append(Image.fromarray(imgs_transition[j]))
# Save the preview imgs as jpgs on disk so we are not sending umcompressed data around
@ -198,7 +180,7 @@ class BlendingFrontend():
self.list_fp_imgs_current.append(fp_img)
# Insert cheap frames for the movie
imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration, self.fps)
imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration_video, self.fps)
# Save as movie
fp_movie = self.get_fp_movie(self.current_timestamp)
@ -330,35 +312,44 @@ if __name__ == "__main__":
sdh = StableDiffusionHolder(fp_ckpt)
self = BlendingFrontend(sdh) # Yes this is possible in python and yes it is an awesome trick
# self = BlendingFrontend(None) # Yes this is possible in python and yes it is an awesome trick
dict_ui_elem = {}
with gr.Blocks() as demo:
with gr.Row():
prompt1 = gr.Textbox(label="prompt 1")
prompt2 = gr.Textbox(label="prompt 2")
negative_prompt = gr.Textbox(label="negative prompt")
prompt2 = gr.Textbox(label="prompt 2")
with gr.Row():
nmb_branches_final = gr.Slider(5, 125, self.nmb_branches_final, step=4, label='nmb trans images', interactive=True)
duration_compute = gr.Slider(10, 40, self.duration_video, step=1, label='compute budget for transition (seconds)', interactive=True)
duration_video = gr.Slider(0.1, 30, self.duration_video, step=0.1, label='result video duration (seconds)', interactive=True)
height = gr.Slider(256, 2048, self.height, step=128, label='height', interactive=True)
width = gr.Slider(256, 2048, self.width, step=128, label='width', interactive=True)
with gr.Accordion("Advanced Settings (click to expand)", open=False):
with gr.Row():
depth_strength = gr.Slider(0.01, 0.99, self.depth_strength, step=0.01, label='depth_strength', interactive=True)
branch1_influence = gr.Slider(0.0, 1.0, self.branch1_influence, step=0.01, label='branch1_influence', interactive=True)
branch1_mixing_depth = gr.Slider(0.0, 1.0, self.branch1_mixing_depth, step=0.01, label='branch1_mixing_depth', interactive=True)
with gr.Row():
num_inference_steps = gr.Slider(5, 100, self.num_inference_steps, step=1, label='num_inference_steps', interactive=True)
guidance_scale = gr.Slider(1, 25, self.guidance_scale, step=0.1, label='guidance_scale', interactive=True)
guidance_scale_mid_damper = gr.Slider(0.01, 2.0, self.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
with gr.Row():
seed1 = gr.Number(420, label="seed 1", interactive=True)
b_newseed1 = gr.Button("randomize seed 1", variant='secondary')
seed2 = gr.Number(420, label="seed 2", interactive=True)
b_newseed2 = gr.Button("randomize seed 2", variant='secondary')
with gr.Row():
num_inference_steps = gr.Slider(5, 100, self.num_inference_steps, step=1, label='num_inference_steps', interactive=True)
branch1_influence = gr.Slider(0.0, 1.0, self.branch1_influence, step=0.01, label='branch1_influence', interactive=True)
guidance_scale = gr.Slider(1, 25, self.guidance_scale, step=0.1, label='guidance_scale', interactive=True)
with gr.Row():
depth_strength = gr.Slider(0.01, 0.99, self.depth_strength, step=0.01, label='depth_strength', interactive=True)
duration = gr.Slider(0.1, 30, self.duration, step=0.1, label='video duration', interactive=True)
guidance_scale_mid_damper = gr.Slider(0.01, 2.0, self.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
with gr.Row():
seed1 = gr.Number(42, label="seed 1", interactive=True)
b_newseed1 = gr.Button("randomize seed 1", variant='secondary')
seed2 = gr.Number(420, label="seed 2", interactive=True)
b_newseed2 = gr.Button("randomize seed 2", variant='secondary')
with gr.Row():
b_compute1 = gr.Button('compute first image', variant='primary')
b_compute_transition = gr.Button('compute transition', variant='primary')
b_compute2 = gr.Button('compute last image', variant='primary')
with gr.Row():
img1 = gr.Image(label="1/5")
@ -370,31 +361,40 @@ if __name__ == "__main__":
with gr.Row():
vid_transition = gr.Video()
# Bind the on-change methods
depth_strength.change(fn=self.change_depth_strength, inputs=depth_strength)
num_inference_steps.change(fn=self.change_num_inference_steps, inputs=num_inference_steps)
nmb_branches_final.change(fn=self.change_nmb_branches_final, inputs=nmb_branches_final)
# Collect all UI elemts in list to easily pass as inputs
dict_ui_elem["prompt1"] = prompt1
dict_ui_elem["negative_prompt"] = negative_prompt
dict_ui_elem["prompt2"] = prompt2
dict_ui_elem["duration_compute"] = duration_compute
dict_ui_elem["duration_video"] = duration_video
dict_ui_elem["height"] = height
dict_ui_elem["width"] = width
dict_ui_elem["depth_strength"] = depth_strength
dict_ui_elem["branch1_influence"] = branch1_influence
dict_ui_elem["branch1_mixing_depth"] = branch1_mixing_depth
guidance_scale.change(fn=self.change_guidance_scale, inputs=guidance_scale)
guidance_scale_mid_damper.change(fn=self.change_guidance_scale_mid_damper, inputs=guidance_scale_mid_damper)
dict_ui_elem["num_inference_steps"] = num_inference_steps
dict_ui_elem["guidance_scale"] = guidance_scale
dict_ui_elem["guidance_scale_mid_damper"] = guidance_scale_mid_damper
dict_ui_elem["seed1"] = seed1
dict_ui_elem["seed2"] = seed2
# Convert to list, as gradio doesn't seem to accept dicts
list_ui_elem = []
list_ui_keys = []
for k in dict_ui_elem.keys():
list_ui_elem.append(dict_ui_elem[k])
list_ui_keys.append(k)
self.list_ui_keys = list_ui_keys
height.change(fn=self.change_height, inputs=height)
width.change(fn=self.change_width, inputs=width)
negative_prompt.change(fn=self.change_negative_prompt, inputs=negative_prompt)
seed1.change(fn=self.change_seed1, inputs=seed1)
seed2.change(fn=self.change_seed2, inputs=seed2)
duration.change(fn=self.change_duration, inputs=duration)
branch1_influence.change(fn=self.change_branch1_influence, inputs=branch1_influence)
b_newseed1.click(self.randomize_seed1, outputs=seed1)
b_newseed2.click(self.randomize_seed2, outputs=seed2)
# b_stackforward.click(self.stack_forward,
# inputs=[prompt2, seed2],
# outputs=[img1, img2, img3, img4, img5, prompt1, seed1, prompt2])
b_compute1.click(self.compute_img1, inputs=list_ui_elem, outputs=[img1, img2, img3, img4, img5])
b_compute2.click(self.compute_img2, inputs=list_ui_elem, outputs=[img2, img3, img4, img5])
b_compute_transition.click(self.compute_transition,
inputs=[prompt1, prompt2],
outputs=[img1, img2, img3, img4, img5, vid_transition])
inputs=list_ui_elem,
outputs=[img2, img3, img4, vid_transition])
demo.launch(share=self.share, inbrowser=True, inline=False)

View File

@ -45,7 +45,7 @@ from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentInpaintDiffusion
from stable_diffusion_holder import StableDiffusionHolder
import yaml
import lpips
#%%
class LatentBlending():
def __init__(
@ -88,10 +88,13 @@ class LatentBlending():
self.prompt1 = ""
self.prompt2 = ""
self.negative_prompt = ""
self.tree_latents = None
self.tree_latents = [None, None]
self.tree_fracts = None
self.idx_injection = []
self.tree_status = None
self.tree_final_imgs = []
self.list_nmb_branches_prev = []
self.list_injection_idx_prev = []
self.text_embedding1 = None
@ -105,12 +108,15 @@ class LatentBlending():
self.list_injection_idx = None
self.list_nmb_branches = None
self.branch1_influence = 0.0
self.branch1_fract_crossfeed = 0.65
self.branch1_mixing_depth = 0.65
self.branch1_insertion_completed = False
self.set_guidance_scale(guidance_scale)
self.init_mode()
self.multi_transition_img_first = None
self.multi_transition_img_last = None
self.dt_per_diff = 0
self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
def init_mode(self):
@ -375,6 +381,187 @@ class LatentBlending():
self.tree_status.append(['untouched']*nmb_branches)
def run_transition(
self,
recycle_img1: Optional[bool] = False,
recycle_img2: Optional[bool] = False,
num_inference_steps: Optional[int] = 30,
depth_strength: Optional[float] = 0.3,
fixed_seeds: Optional[List[int]] = None,
):
# # FIXME: deal with these tree args later
# self.num_inference_steps = 30
# self.t_compute_max_allowed = 60
# Sanity checks first
assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before'
assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before'
# Random seeds
if fixed_seeds is not None:
if fixed_seeds == 'randomize':
fixed_seeds = list(np.random.randint(0, 1000000, 2).astype(np.int32))
else:
assert len(fixed_seeds)==2, "Supply a list with len = 2"
self.seed1 = fixed_seeds[0]
self.seed2 = fixed_seeds[1]
# Ensure correct num_inference_steps in holder
self.sdh.num_inference_steps = self.num_inference_steps
# Compute / Recycle first image
if not recycle_img1:
list_latents1 = self.compute_latents1()
else:
# FIXME: check if latents there...
list_latents1 = self.tree_latents[0]
# Compute / Recycle first image
if not recycle_img2:
list_latents2 = self.compute_latents2()
else:
# FIXME: check if latents there...
list_latents2 = self.tree_latents[-1]
# Reset the tree, injecting the edge latents1/2 we just generated/recycled
self.tree_latents = [list_latents1, list_latents2]
self.tree_fracts = [0.0, 1.0]
self.tree_final_imgs = [self.sdh.latent2image((self.tree_latents[0][-1])), self.sdh.latent2image((self.tree_latents[-1][-1]))]
self.tree_idx_injection = [0, 0]
# Set up branching scheme (dependent on provided compute time)
idx_injection_base = int(round(self.num_inference_steps*depth_strength))
list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps-1, 3)
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
t_compute = 0
while t_compute < self.t_compute_max_allowed:
list_compute_steps = self.num_inference_steps - list_idx_injection
list_compute_steps *= list_nmb_stems
t_compute = np.sum(list_compute_steps) * self.dt_per_diff
increase_done = False
for s_idx in range(len(list_nmb_stems)-1):
if list_nmb_stems[s_idx+1] / list_nmb_stems[s_idx] >= 2:
list_nmb_stems[s_idx] += 1
increase_done = True
break
if not increase_done:
list_nmb_stems[-1] += 1
# print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}")
# Run iteratively
for s_idx in tqdm(range(len(list_idx_injection))):
nmb_stems = list_nmb_stems[s_idx]
idx_injection = list_idx_injection[s_idx]
for i in range(nmb_stems):
fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection)
# print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}")
list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection)
self.insert_into_tree(fract_mixing, idx_injection, list_latents)
return self.tree_final_imgs
def get_mixing_parameters(self, idx_injection):
# get_lpips_similarity
similarities = []
for i in range(len(self.tree_final_imgs)-1):
similarities.append(self.get_lpips_similarity(self.tree_final_imgs[i], self.tree_final_imgs[i+1]))
b_closest1 = np.argmax(similarities)
b_closest2 = b_closest1+1
fract_closest1 = self.tree_fracts[b_closest1]
fract_closest2 = self.tree_fracts[b_closest2]
# Ensure that the parents are indeed older!
b_parent1 = b_closest1
while True:
if self.tree_idx_injection[b_parent1] < idx_injection:
break
else:
b_parent1 -= 1
b_parent2 = b_closest2
while True:
if self.tree_idx_injection[b_parent2] < idx_injection:
break
else:
b_parent2 += 1
# print(f"\n\nb_closest: {b_closest1} {b_closest2} fract_closest1 {fract_closest1} fract_closest2 {fract_closest2}")
# print(f"b_parent: {b_parent1} {b_parent2}")
# print(f"similarities {similarities}")
# print(f"idx_injection {idx_injection} tree_idx_injection {self.tree_idx_injection}")
fract_mixing = (fract_closest1 + fract_closest2) /2
return fract_mixing, b_parent1, b_parent2
def insert_into_tree(self, fract_mixing, idx_injection, list_latents):
# FIXME
b_parent1, b_parent2 = get_closest_idx(fract_mixing, self.tree_fracts)
self.tree_latents.insert(b_parent1+1, list_latents)
self.tree_final_imgs.insert(b_parent1+1, self.sdh.latent2image(list_latents[-1]))
self.tree_fracts.insert(b_parent1+1, fract_mixing)
self.tree_idx_injection.insert(b_parent1+1, idx_injection)
def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection):
# FIXME
list_conditionings = self.get_mixed_conditioning(fract_mixing)
fract_mixing_parental = (fract_mixing - self.tree_fracts[b_parent1]) / (self.tree_fracts[b_parent2] - self.tree_fracts[b_parent1])
idx_reversed = self.num_inference_steps - idx_injection
latents_for_injection = interpolate_spherical(
self.tree_latents[b_parent1][-idx_reversed-1],
self.tree_latents[b_parent2][-idx_reversed-1],
fract_mixing_parental)
list_latents = self.run_diffusion(list_conditionings, latents_for_injection=latents_for_injection, idx_start=idx_injection)
return list_latents
def compute_latents1(self, return_image=False):
print("starting compute_latents1")
list_conditionings = [self.text_embedding1]
t0 = time.time()
list_latents1 = self.run_diffusion(list_conditionings, seed_source=self.seed1)
t1 = time.time()
self.dt_per_diff = (t1-t0) / self.num_inference_steps
self.tree_latents[0] = list_latents1
if return_image:
return self.sdh.latent2image(list_latents1[-1])
else:
return list_latents1
def compute_latents2(self, return_image=False):
print("starting compute_latents2")
list_conditionings = [self.text_embedding2
]
# Influence from branch1
if self.branch1_influence > 0.0:
self.branch1_influence = np.clip(self.branch1_influence, 0, 1)
self.branch1_mixing_depth = np.clip(self.branch1_mixing_depth, 0, 1)
idx_crossfeed = int(round(self.num_inference_steps*self.branch1_mixing_depth))
list_latents2 = self.run_diffusion(
list_conditionings,
idx_start=idx_crossfeed,
latents_for_injection=self.tree_latents[0],
seed_source=self.seed2,
seed_mixing_target=self.seed1,
mixing_coeff=self.branch1_influence)
else:
list_latents2 = self.run_diffusion(list_conditionings)
self.tree_latents[-1] = list_latents2
if return_image:
return self.sdh.latent2image(list_latents2[-1])
else:
return list_latents2
def run_transition_OLD(
self,
recycle_img1: Optional[bool] = False,
recycle_img2: Optional[bool] = False,
@ -423,9 +610,9 @@ class LatentBlending():
if self.branch1_influence > 0.0 and not self.branch1_insertion_completed:
assert self.list_nmb_branches[0]==2, 'branch1 influnce currently requires the self.list_nmb_branches[0] = 0'
self.branch1_influence = np.clip(self.branch1_influence, 0, 1)
self.branch1_fract_crossfeed = np.clip(self.branch1_fract_crossfeed, 0, 1)
self.branch1_mixing_depth = np.clip(self.branch1_mixing_depth, 0, 1)
self.list_nmb_branches.insert(1, 2)
idx_crossfeed = int(round(self.list_injection_idx[1]*self.branch1_fract_crossfeed))
idx_crossfeed = int(round(self.list_injection_idx[1]*self.branch1_mixing_depth))
self.list_injection_idx_ext.insert(1, idx_crossfeed)
self.tree_fracts.insert(1, self.tree_fracts[0])
self.tree_status.insert(1, self.tree_status[0])
@ -606,6 +793,9 @@ class LatentBlending():
latents_for_injection: torch.FloatTensor = None,
idx_start: int = -1,
idx_stop: int = -1,
seed_source: int = -1,
seed_mixing_target: int = -1,
mixing_coeff: float = 0.0,
return_image: Optional[bool] = False
):
r"""
@ -620,6 +810,7 @@ class LatentBlending():
Index of the diffusion process start and where the latents_for_injection are injected
idx_stop: int
Index of the diffusion process end.
FIXME ARGS
return_image: Optional[bool]
Optionally return image directly
"""
@ -630,14 +821,23 @@ class LatentBlending():
if self.mode == 'standard':
text_embeddings = list_conditionings[0]
return self.sdh.run_diffusion_standard(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
return self.sdh.run_diffusion_standard(
text_embeddings,
latents_for_injection=latents_for_injection,
idx_start=idx_start,
idx_stop=idx_stop,
seed_source=seed_source,
seed_mixing_target=seed_mixing_target,
mixing_coeff=mixing_coeff,
return_image=return_image,
)
elif self.mode == 'inpaint':
text_embeddings = list_conditionings[0]
assert self.sdh.image_source is not None, "image_source is None. Please run init_inpainting first."
assert self.sdh.mask_image is not None, "image_source is None. Please run init_inpainting first."
return self.sdh.run_diffusion_inpaint(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
# FIXME LONG LINE
elif self.mode == 'upscale':
cond = list_conditionings[0]
uc_full = list_conditionings[1]
@ -881,8 +1081,6 @@ class LatentBlending():
if inject_img2:
self.tree_latents[t_block][-1] = list_latents[self.list_injection_idx_ext[t_block]:self.list_injection_idx_ext[t_block+1]]
def swap_forward(self):
r"""
@ -901,6 +1099,21 @@ class LatentBlending():
self.tree_final_imgs = []
def get_lpips_similarity(self, imgA, imgB):
# FIXME
tensorA = torch.from_numpy(imgA).float().cuda(self.device)
tensorA = 2*tensorA/255.0 - 1
tensorA = tensorA.permute([2,0,1]).unsqueeze(0)
tensorB = torch.from_numpy(imgB).float().cuda(self.device)
tensorB = 2*tensorB/255.0 - 1
tensorB = tensorB.permute([2,0,1]).unsqueeze(0)
lploss = self.lpips(tensorA, tensorB)
lploss = float(lploss[0][0][0][0])
return lploss
# Auxiliary functions
def get_closest_idx(
fract_mixing: float,
@ -1169,31 +1382,79 @@ if __name__ == "__main__":
#%% First let us spawn a stable diffusion holder
device = "cuda"
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
fp_config = 'configs/v2-inference.yaml'
sdh = StableDiffusionHolder(fp_ckpt, fp_config, device)
sdh = StableDiffusionHolder(fp_ckpt)
xxx
#%% Next let's set up all parameters
quality = 'medium'
depth_strength = 0.65 # Specifies how deep (in terms of diffusion iterations the first branching happens)
fixed_seeds = [69731932, 504430820]
depth_strength = 0.3 # Specifies how deep (in terms of diffusion iterations the first branching happens)
fixed_seeds = [697164, 430214]
prompt1 = "photo of a beautiful cherry forest covered in white flowers, ambient light, very detailed, magic"
prompt2 = "photo of an golden statue with a funny hat, surrounded by ferns and vines, grainy analog photograph, mystical ambience, incredible detail"
prompt1 = "photo of a desert and a sky"
prompt2 = "photo of a tree with a lake"
duration_transition = 12 # In seconds
fps = 30
# Spawn latent blending
self = LatentBlending(sdh)
self.branch1_influence = 0.8
self.load_branching_profile(quality=quality, depth_strength=0.3)
self.set_prompt1(prompt1)
self.set_prompt2(prompt2)
# Run latent blending
imgs_transition = self.run_transition(fixed_seeds=fixed_seeds)
self.branch1_influence = 0.3
self.branch1_mixing_depth = 0.4
self.run_transition(depth_strength=depth_strength, fixed_seeds=fixed_seeds)
#%%
self.branch1_influence = 0.3
self.branch1_mixing_depth = 0.5
img2 = self.compute_latents2(return_image=True)
Image.fromarray(img2)
#%%
idx_injection = 15
fract_mixing = 0.5
list_conditionings = self.get_mixed_conditioning(fract_mixing)
latents_for_injection = interpolate_spherical(self.tree_latents[0][idx_injection], self.tree_latents[-1][idx_injection], fract_mixing)
list_latents = self.run_diffusion(list_conditionings, latents_for_injection=latents_for_injection, idx_start=idx_injection)
img_mix = self.sdh.latent2image((list_latents[-1]))
Image.fromarray(np.concatenate((img1,img_mix,img2), axis=1)).resize((800,800//3))
#%% scheme
# init scheme
list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps-1, 2)
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
#%%
list_compute_steps = self.num_inference_steps - list_idx_injection
list_compute_steps *= list_nmb_stems
t_compute = np.sum(list_compute_steps) * self.dt_per_diff
increase_done = False
for s_idx in range(len(list_nmb_stems)-1):
if list_nmb_stems[s_idx+1] / list_nmb_stems[s_idx] >= 3:
list_nmb_stems[s_idx] += 1
increase_done = True
break
if not increase_done:
list_nmb_stems[-1] += 1
print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}")
#%%
imgs_transition = self.tree_final_imgs
# Let's get more cheap frames via linear interpolation (duration_transition*fps frames)
imgs_transition_ext = add_frames_linear_interp(imgs_transition, 15, fps)
# Save as MP4
fp_movie = "test.mp4"
if os.path.isfile(fp_movie):
os.remove(fp_movie)
ms = MovieSaver(fp_movie, fps=fps, shape_hw=[sdh.height, sdh.width])
for img in tqdm(imgs_transition_ext):
ms.write_frame(img)
ms.finalize()

View File

@ -42,7 +42,6 @@ from contextlib import nullcontext
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from einops import repeat, rearrange
#%%
@ -279,24 +278,33 @@ class StableDiffusionHolder:
def run_diffusion_standard(
self,
text_embeddings: torch.FloatTensor,
latents_for_injection: torch.FloatTensor = None,
latents_for_injection = None,
idx_start: int = -1,
idx_stop: int = -1,
return_image: Optional[bool] = False
seed_source: int = -1,
seed_mixing_target: int = -1,
mixing_coeff: float = 0.0,
return_image: Optional[bool] = False,
):
r"""
Wrapper function for run_diffusion_standard and run_diffusion_inpaint.
Depending on the mode, the correct one will be executed.
Args:
text_embeddings: torch.FloatTensor
text_embeddings: torch.FloatTensor
Text embeddings used for diffusion
latents_for_injection: torch.FloatTensor
latents_for_injection: torch.FloatTensor or list
Latents that are used for injection
idx_start: int
Index of the diffusion process start and where the latents_for_injection are injected
idx_stop: int
Index of the diffusion process end.
mixing_coeff:
# FIXME
seed_source:
# FIXME
seed_mixing:
# FIXME
return_image: Optional[bool]
Optionally return image directly
"""
@ -304,12 +312,19 @@ class StableDiffusionHolder:
if latents_for_injection is None:
do_inject_latents = False
do_mix_latents = False
else:
do_inject_latents = True
if mixing_coeff > 0.0:
do_inject_latents = False
do_mix_latents = True
assert seed_mixing_target != -1, "Set to correct seed for mixing"
else:
do_inject_latents = True
do_mix_latents = False
precision_scope = autocast if self.precision == "autocast" else nullcontext
generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
generator = torch.Generator(device=self.device).manual_seed(int(seed_source))
with precision_scope("cuda"):
with self.model.ema_scope():
@ -340,6 +355,16 @@ class StableDiffusionHolder:
continue
elif i == idx_start:
latents = latents_for_injection.clone()
if do_mix_latents:
if i == 0:
generator = torch.Generator(device=self.device).manual_seed(int(seed_mixing_target))
latents_mixtarget = torch.randn(size, generator=generator, device=self.device)
if i < idx_start:
latents_mixtarget = latents_for_injection[i-1].clone()
latents = interpolate_spherical(latents, latents_mixtarget, mixing_coeff)
if i == idx_start:
do_mix_latents = False
if i == idx_stop:
return list_latents_out
@ -576,6 +601,50 @@ class StableDiffusionHolder:
image = x_sample.astype(np.uint8)
return image
@torch.no_grad()
def interpolate_spherical(p0, p1, fract_mixing: float):
r"""
Helper function to correctly mix two random variables using spherical interpolation.
See https://en.wikipedia.org/wiki/Slerp
The function will always cast up to float64 for sake of extra 4.
Args:
p0:
First tensor for interpolation
p1:
Second tensor for interpolation
fract_mixing: float
Mixing coefficient of interval [0, 1].
0 will return in p0
1 will return in p1
0.x will return a mix between both preserving angular velocity.
"""
if p0.dtype == torch.float16:
recast_to = 'fp16'
else:
recast_to = 'fp32'
p0 = p0.double()
p1 = p1.double()
norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
epsilon = 1e-7
dot = torch.sum(p0 * p1) / norm
dot = dot.clamp(-1+epsilon, 1-epsilon)
theta_0 = torch.arccos(dot)
sin_theta_0 = torch.sin(theta_0)
theta_t = theta_0 * fract_mixing
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
s1 = torch.sin(theta_t) / sin_theta_0
interp = p0*s0 + p1*s1
if recast_to == 'fp16':
interp = interp.half()
elif recast_to == 'fp32':
interp = interp.float()
return interp
if __name__ == "__main__":