diff --git a/gradio_ui.py b/gradio_ui.py index a56c418..ad6e03f 100644 --- a/gradio_ui.py +++ b/gradio_ui.py @@ -33,6 +33,13 @@ import gradio as gr import copy +""" +TODOS: + - clean parameter handling + - three buttons: diffuse A, diffuse B, make transition + - collapse for easy mode + - transition quality in terms of render time +""" #%% @@ -45,7 +52,7 @@ class BlendingFrontend(): self.lb = LatentBlending(sdh) self.share = True - self.num_inference_steps = 20 + self.num_inference_steps = 30 self.depth_strength = 0.25 self.seed1 = 42 self.seed2 = 420 @@ -58,11 +65,13 @@ class BlendingFrontend(): self.list_settings = [] self.state_current = {} self.showing_current = True - self.branch1_influence = 0.02 + self.branch1_influence = 0.1 + self.branch1_mixing_depth = 0.3 self.nmb_branches_final = 9 self.nmb_imgs_show = 5 # don't change self.fps = 30 - self.duration = 10 + self.duration_video = 15 + self.t_compute_max_allowed = 15 self.dict_multi_trans = {} self.dict_multi_trans_include = {} self.multi_trans_currently_shown = [] @@ -79,114 +88,87 @@ class BlendingFrontend(): self.width = 768 # make dummy image + def save_empty_image(self): self.fp_img_empty = 'empty.jpg' Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5) - def change_depth_strength(self, value): - self.depth_strength = value - print(f"changed depth_strength to {value}") - - def change_num_inference_steps(self, value): - self.num_inference_steps = value - print(f"changed num_inference_steps to {value}") - - def change_guidance_scale(self, value): - self.guidance_scale = value - self.lb.set_guidance_scale(value) - print(f"changed guidance_scale to {value}") - - def change_guidance_scale_mid_damper(self, value): - self.guidance_scale_mid_damper = value - print(f"changed guidance_scale_mid_damper to {value}") - - def change_mid_compression_scaler(self, value): - self.mid_compression_scaler = value - print(f"changed mid_compression_scaler to {value}") - - def change_branch1_influence(self, value): - self.branch1_influence = value - print(f"changed branch1_influence to {value}") - - def change_height(self, value): - self.height = value - print(f"changed height to {value}") - - def change_width(self, value): - self.width = value - print(f"changed width to {value}") - - def change_nmb_branches_final(self, value): - self.nmb_branches_final = value - print(f"changed nmb_branches_final to {value}") - - def change_duration(self, value): - self.duration = value - print(f"changed duration to {value}") - - def change_fps(self, value): - self.fps = value - print(f"changed fps to {value}") - - def change_negative_prompt(self, value): - self.negative_prompt = value - - def change_seed1(self, value): - self.seed1 = int(value) - - def change_seed2(self, value): - self.seed2 = int(value) def randomize_seed1(self): seed = np.random.randint(0, 10000000) - self.change_seed1(seed) + self.seed1 = int(seed) print(f"randomize_seed1: new seed = {self.seed1}") return seed def randomize_seed2(self): seed = np.random.randint(0, 10000000) - self.change_seed2(seed) + self.seed2 = int(seed) print(f"randomize_seed2: new seed = {self.seed2}") return seed - def compute_transition(self, prompt1, prompt2): - self.prompt1 = prompt1 - self.prompt2 = prompt2 - print("STARTING DIFFUSION!") + def setup_lb(self, list_ui_elem): + # Collect latent blending variables self.state_current = self.get_state_dict() + self.lb.set_width(list_ui_elem[list_ui_keys.index('width')]) + self.lb.set_height(list_ui_elem[list_ui_keys.index('height')]) + self.lb.set_prompt1(list_ui_elem[list_ui_keys.index('prompt1')]) + self.lb.set_prompt2(list_ui_elem[list_ui_keys.index('prompt2')]) + self.lb.set_negative_prompt(list_ui_elem[list_ui_keys.index('negative_prompt')]) + self.lb.guidance_scale = list_ui_elem[list_ui_keys.index('guidance_scale')] + self.lb.guidance_scale_mid_damper = list_ui_elem[list_ui_keys.index('guidance_scale_mid_damper')] + self.lb.branch1_influence = list_ui_elem[list_ui_keys.index('branch1_influence')] + self.lb.branch1_mixing_depth = list_ui_elem[list_ui_keys.index('branch1_mixing_depth')] + self.lb.t_compute_max_allowed = list_ui_elem[list_ui_keys.index('duration_compute')] + self.lb.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')] + self.lb.sdh.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')] + self.duration_video = list_ui_elem[list_ui_keys.index('duration_video')] + self.lb.seed1 = list_ui_elem[list_ui_keys.index('seed1')] + self.lb.seed2 = list_ui_elem[list_ui_keys.index('seed2')] + + + def compute_img1(self, *args): + list_ui_elem = args + self.setup_lb(list_ui_elem) + fp_img1 = f"img1_{get_time('second')}.jpg" + img1 = Image.fromarray(self.lb.compute_latents1(return_image=True)) + img1.save(fp_img1) + self.save_empty_image() + return [fp_img1, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty] + + def compute_img2(self, *args): + list_ui_elem = args + self.setup_lb(list_ui_elem) + fp_img2 = f"img2_{get_time('second')}.jpg" + img2 = Image.fromarray(self.lb.compute_latents2(return_image=True)) + img2.save(fp_img2) + return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2] + + def compute_transition(self, *args): + list_ui_elem = args + self.setup_lb(list_ui_elem) + print("STARTING DIFFUSION!") if self.use_debug: list_imgs = [(255*np.random.rand(self.height,self.width,3)).astype(np.uint8) for l in range(5)] list_imgs = [Image.fromarray(l) for l in list_imgs] print("DONE! SENDING BACK RESULTS") return list_imgs - # Collect latent blending variables - self.lb.set_width(self.width) - self.lb.set_height(self.height) - self.lb.autosetup_branching( - depth_strength = self.depth_strength, - num_inference_steps = self.num_inference_steps, - nmb_branches_final = self.nmb_branches_final, - nmb_mindist = 3) - self.lb.set_prompt1(self.prompt1) - self.lb.set_prompt2(self.prompt2) - self.lb.set_negative_prompt(self.negative_prompt) - - self.lb.guidance_scale = self.guidance_scale - self.lb.guidance_scale_mid_damper = self.guidance_scale_mid_damper - self.lb.mid_compression_scaler = self.mid_compression_scaler - self.lb.branch1_influence = self.branch1_influence fixed_seeds = [self.seed1, self.seed2] # Run Latent Blending - imgs_transition = self.lb.run_transition(fixed_seeds=fixed_seeds) + imgs_transition = self.lb.run_transition( + recycle_img1=True, + recycle_img2=True, + num_inference_steps=self.num_inference_steps, + depth_strength=self.depth_strength, + fixed_seeds=fixed_seeds + ) print(f"Latent Blending pass finished. Resulted in {len(imgs_transition)} images") - # Subselect the preview images (hard fixed to self.nmb_imgs_show=5) - assert np.mod((self.nmb_branches_final-self.nmb_imgs_show)/4, 1)==0, 'self.nmb_branches_final illegal value!' - idx_list = np.linspace(0, self.nmb_branches_final-1, self.nmb_imgs_show).astype(np.int32) + # Subselect three preview images + idx_img_prev = np.round(np.linspace(0, len(imgs_transition)-1, 5)[1:-1]).astype(np.int32) list_imgs_preview = [] - for j in idx_list: + for j in idx_img_prev: list_imgs_preview.append(Image.fromarray(imgs_transition[j])) # Save the preview imgs as jpgs on disk so we are not sending umcompressed data around @@ -198,7 +180,7 @@ class BlendingFrontend(): self.list_fp_imgs_current.append(fp_img) # Insert cheap frames for the movie - imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration, self.fps) + imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration_video, self.fps) # Save as movie fp_movie = self.get_fp_movie(self.current_timestamp) @@ -330,35 +312,44 @@ if __name__ == "__main__": sdh = StableDiffusionHolder(fp_ckpt) self = BlendingFrontend(sdh) # Yes this is possible in python and yes it is an awesome trick + # self = BlendingFrontend(None) # Yes this is possible in python and yes it is an awesome trick + + dict_ui_elem = {} with gr.Blocks() as demo: with gr.Row(): prompt1 = gr.Textbox(label="prompt 1") - prompt2 = gr.Textbox(label="prompt 2") negative_prompt = gr.Textbox(label="negative prompt") - + prompt2 = gr.Textbox(label="prompt 2") + with gr.Row(): - nmb_branches_final = gr.Slider(5, 125, self.nmb_branches_final, step=4, label='nmb trans images', interactive=True) + duration_compute = gr.Slider(10, 40, self.duration_video, step=1, label='compute budget for transition (seconds)', interactive=True) + duration_video = gr.Slider(0.1, 30, self.duration_video, step=0.1, label='result video duration (seconds)', interactive=True) height = gr.Slider(256, 2048, self.height, step=128, label='height', interactive=True) width = gr.Slider(256, 2048, self.width, step=128, label='width', interactive=True) + with gr.Accordion("Advanced Settings (click to expand)", open=False): + + with gr.Row(): + depth_strength = gr.Slider(0.01, 0.99, self.depth_strength, step=0.01, label='depth_strength', interactive=True) + branch1_influence = gr.Slider(0.0, 1.0, self.branch1_influence, step=0.01, label='branch1_influence', interactive=True) + branch1_mixing_depth = gr.Slider(0.0, 1.0, self.branch1_mixing_depth, step=0.01, label='branch1_mixing_depth', interactive=True) + + with gr.Row(): + num_inference_steps = gr.Slider(5, 100, self.num_inference_steps, step=1, label='num_inference_steps', interactive=True) + guidance_scale = gr.Slider(1, 25, self.guidance_scale, step=0.1, label='guidance_scale', interactive=True) + guidance_scale_mid_damper = gr.Slider(0.01, 2.0, self.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True) + + with gr.Row(): + seed1 = gr.Number(420, label="seed 1", interactive=True) + b_newseed1 = gr.Button("randomize seed 1", variant='secondary') + seed2 = gr.Number(420, label="seed 2", interactive=True) + b_newseed2 = gr.Button("randomize seed 2", variant='secondary') + with gr.Row(): - num_inference_steps = gr.Slider(5, 100, self.num_inference_steps, step=1, label='num_inference_steps', interactive=True) - branch1_influence = gr.Slider(0.0, 1.0, self.branch1_influence, step=0.01, label='branch1_influence', interactive=True) - guidance_scale = gr.Slider(1, 25, self.guidance_scale, step=0.1, label='guidance_scale', interactive=True) - - with gr.Row(): - depth_strength = gr.Slider(0.01, 0.99, self.depth_strength, step=0.01, label='depth_strength', interactive=True) - duration = gr.Slider(0.1, 30, self.duration, step=0.1, label='video duration', interactive=True) - guidance_scale_mid_damper = gr.Slider(0.01, 2.0, self.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True) - - with gr.Row(): - seed1 = gr.Number(42, label="seed 1", interactive=True) - b_newseed1 = gr.Button("randomize seed 1", variant='secondary') - seed2 = gr.Number(420, label="seed 2", interactive=True) - b_newseed2 = gr.Button("randomize seed 2", variant='secondary') - with gr.Row(): + b_compute1 = gr.Button('compute first image', variant='primary') b_compute_transition = gr.Button('compute transition', variant='primary') + b_compute2 = gr.Button('compute last image', variant='primary') with gr.Row(): img1 = gr.Image(label="1/5") @@ -370,31 +361,40 @@ if __name__ == "__main__": with gr.Row(): vid_transition = gr.Video() - # Bind the on-change methods - depth_strength.change(fn=self.change_depth_strength, inputs=depth_strength) - num_inference_steps.change(fn=self.change_num_inference_steps, inputs=num_inference_steps) - nmb_branches_final.change(fn=self.change_nmb_branches_final, inputs=nmb_branches_final) + # Collect all UI elemts in list to easily pass as inputs + dict_ui_elem["prompt1"] = prompt1 + dict_ui_elem["negative_prompt"] = negative_prompt + dict_ui_elem["prompt2"] = prompt2 + + dict_ui_elem["duration_compute"] = duration_compute + dict_ui_elem["duration_video"] = duration_video + dict_ui_elem["height"] = height + dict_ui_elem["width"] = width + + dict_ui_elem["depth_strength"] = depth_strength + dict_ui_elem["branch1_influence"] = branch1_influence + dict_ui_elem["branch1_mixing_depth"] = branch1_mixing_depth - guidance_scale.change(fn=self.change_guidance_scale, inputs=guidance_scale) - guidance_scale_mid_damper.change(fn=self.change_guidance_scale_mid_damper, inputs=guidance_scale_mid_damper) + dict_ui_elem["num_inference_steps"] = num_inference_steps + dict_ui_elem["guidance_scale"] = guidance_scale + dict_ui_elem["guidance_scale_mid_damper"] = guidance_scale_mid_damper + dict_ui_elem["seed1"] = seed1 + dict_ui_elem["seed2"] = seed2 + + # Convert to list, as gradio doesn't seem to accept dicts + list_ui_elem = [] + list_ui_keys = [] + for k in dict_ui_elem.keys(): + list_ui_elem.append(dict_ui_elem[k]) + list_ui_keys.append(k) + self.list_ui_keys = list_ui_keys - height.change(fn=self.change_height, inputs=height) - width.change(fn=self.change_width, inputs=width) - negative_prompt.change(fn=self.change_negative_prompt, inputs=negative_prompt) - seed1.change(fn=self.change_seed1, inputs=seed1) - seed2.change(fn=self.change_seed2, inputs=seed2) - duration.change(fn=self.change_duration, inputs=duration) - branch1_influence.change(fn=self.change_branch1_influence, inputs=branch1_influence) - b_newseed1.click(self.randomize_seed1, outputs=seed1) b_newseed2.click(self.randomize_seed2, outputs=seed2) - # b_stackforward.click(self.stack_forward, - # inputs=[prompt2, seed2], - # outputs=[img1, img2, img3, img4, img5, prompt1, seed1, prompt2]) + b_compute1.click(self.compute_img1, inputs=list_ui_elem, outputs=[img1, img2, img3, img4, img5]) + b_compute2.click(self.compute_img2, inputs=list_ui_elem, outputs=[img2, img3, img4, img5]) b_compute_transition.click(self.compute_transition, - inputs=[prompt1, prompt2], - outputs=[img1, img2, img3, img4, img5, vid_transition]) + inputs=list_ui_elem, + outputs=[img2, img3, img4, vid_transition]) - - demo.launch(share=self.share, inbrowser=True, inline=False) diff --git a/latent_blending.py b/latent_blending.py index b053024..6364a11 100644 --- a/latent_blending.py +++ b/latent_blending.py @@ -45,7 +45,7 @@ from ldm.util import instantiate_from_config from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentInpaintDiffusion from stable_diffusion_holder import StableDiffusionHolder import yaml - +import lpips #%% class LatentBlending(): def __init__( @@ -88,10 +88,13 @@ class LatentBlending(): self.prompt1 = "" self.prompt2 = "" self.negative_prompt = "" - self.tree_latents = None + + self.tree_latents = [None, None] self.tree_fracts = None + self.idx_injection = [] self.tree_status = None self.tree_final_imgs = [] + self.list_nmb_branches_prev = [] self.list_injection_idx_prev = [] self.text_embedding1 = None @@ -105,12 +108,15 @@ class LatentBlending(): self.list_injection_idx = None self.list_nmb_branches = None self.branch1_influence = 0.0 - self.branch1_fract_crossfeed = 0.65 + self.branch1_mixing_depth = 0.65 self.branch1_insertion_completed = False self.set_guidance_scale(guidance_scale) self.init_mode() self.multi_transition_img_first = None self.multi_transition_img_last = None + self.dt_per_diff = 0 + + self.lpips = lpips.LPIPS(net='alex').cuda(self.device) def init_mode(self): @@ -375,6 +381,187 @@ class LatentBlending(): self.tree_status.append(['untouched']*nmb_branches) def run_transition( + self, + recycle_img1: Optional[bool] = False, + recycle_img2: Optional[bool] = False, + num_inference_steps: Optional[int] = 30, + depth_strength: Optional[float] = 0.3, + fixed_seeds: Optional[List[int]] = None, + ): + + # # FIXME: deal with these tree args later + # self.num_inference_steps = 30 + # self.t_compute_max_allowed = 60 + + # Sanity checks first + assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before' + assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before' + + # Random seeds + if fixed_seeds is not None: + if fixed_seeds == 'randomize': + fixed_seeds = list(np.random.randint(0, 1000000, 2).astype(np.int32)) + else: + assert len(fixed_seeds)==2, "Supply a list with len = 2" + + self.seed1 = fixed_seeds[0] + self.seed2 = fixed_seeds[1] + + # Ensure correct num_inference_steps in holder + self.sdh.num_inference_steps = self.num_inference_steps + + # Compute / Recycle first image + if not recycle_img1: + list_latents1 = self.compute_latents1() + else: + # FIXME: check if latents there... + list_latents1 = self.tree_latents[0] + + # Compute / Recycle first image + if not recycle_img2: + list_latents2 = self.compute_latents2() + else: + # FIXME: check if latents there... + list_latents2 = self.tree_latents[-1] + + # Reset the tree, injecting the edge latents1/2 we just generated/recycled + self.tree_latents = [list_latents1, list_latents2] + self.tree_fracts = [0.0, 1.0] + self.tree_final_imgs = [self.sdh.latent2image((self.tree_latents[0][-1])), self.sdh.latent2image((self.tree_latents[-1][-1]))] + self.tree_idx_injection = [0, 0] + + # Set up branching scheme (dependent on provided compute time) + idx_injection_base = int(round(self.num_inference_steps*depth_strength)) + list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps-1, 3) + list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32) + t_compute = 0 + while t_compute < self.t_compute_max_allowed: + list_compute_steps = self.num_inference_steps - list_idx_injection + list_compute_steps *= list_nmb_stems + t_compute = np.sum(list_compute_steps) * self.dt_per_diff + increase_done = False + for s_idx in range(len(list_nmb_stems)-1): + if list_nmb_stems[s_idx+1] / list_nmb_stems[s_idx] >= 2: + list_nmb_stems[s_idx] += 1 + increase_done = True + break + if not increase_done: + list_nmb_stems[-1] += 1 + # print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}") + + # Run iteratively + for s_idx in tqdm(range(len(list_idx_injection))): + nmb_stems = list_nmb_stems[s_idx] + idx_injection = list_idx_injection[s_idx] + + for i in range(nmb_stems): + fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection) + # print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}") + list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection) + self.insert_into_tree(fract_mixing, idx_injection, list_latents) + + return self.tree_final_imgs + + + + def get_mixing_parameters(self, idx_injection): + # get_lpips_similarity + similarities = [] + for i in range(len(self.tree_final_imgs)-1): + similarities.append(self.get_lpips_similarity(self.tree_final_imgs[i], self.tree_final_imgs[i+1])) + b_closest1 = np.argmax(similarities) + b_closest2 = b_closest1+1 + fract_closest1 = self.tree_fracts[b_closest1] + fract_closest2 = self.tree_fracts[b_closest2] + + # Ensure that the parents are indeed older! + b_parent1 = b_closest1 + while True: + if self.tree_idx_injection[b_parent1] < idx_injection: + break + else: + b_parent1 -= 1 + + b_parent2 = b_closest2 + while True: + if self.tree_idx_injection[b_parent2] < idx_injection: + break + else: + b_parent2 += 1 + + # print(f"\n\nb_closest: {b_closest1} {b_closest2} fract_closest1 {fract_closest1} fract_closest2 {fract_closest2}") + # print(f"b_parent: {b_parent1} {b_parent2}") + # print(f"similarities {similarities}") + # print(f"idx_injection {idx_injection} tree_idx_injection {self.tree_idx_injection}") + + fract_mixing = (fract_closest1 + fract_closest2) /2 + return fract_mixing, b_parent1, b_parent2 + + + def insert_into_tree(self, fract_mixing, idx_injection, list_latents): + # FIXME + b_parent1, b_parent2 = get_closest_idx(fract_mixing, self.tree_fracts) + self.tree_latents.insert(b_parent1+1, list_latents) + self.tree_final_imgs.insert(b_parent1+1, self.sdh.latent2image(list_latents[-1])) + self.tree_fracts.insert(b_parent1+1, fract_mixing) + self.tree_idx_injection.insert(b_parent1+1, idx_injection) + + + def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection): + # FIXME + list_conditionings = self.get_mixed_conditioning(fract_mixing) + fract_mixing_parental = (fract_mixing - self.tree_fracts[b_parent1]) / (self.tree_fracts[b_parent2] - self.tree_fracts[b_parent1]) + idx_reversed = self.num_inference_steps - idx_injection + latents_for_injection = interpolate_spherical( + self.tree_latents[b_parent1][-idx_reversed-1], + self.tree_latents[b_parent2][-idx_reversed-1], + fract_mixing_parental) + list_latents = self.run_diffusion(list_conditionings, latents_for_injection=latents_for_injection, idx_start=idx_injection) + return list_latents + + + def compute_latents1(self, return_image=False): + print("starting compute_latents1") + list_conditionings = [self.text_embedding1] + t0 = time.time() + list_latents1 = self.run_diffusion(list_conditionings, seed_source=self.seed1) + t1 = time.time() + self.dt_per_diff = (t1-t0) / self.num_inference_steps + self.tree_latents[0] = list_latents1 + if return_image: + return self.sdh.latent2image(list_latents1[-1]) + else: + return list_latents1 + + + def compute_latents2(self, return_image=False): + print("starting compute_latents2") + list_conditionings = [self.text_embedding2 + ] + # Influence from branch1 + if self.branch1_influence > 0.0: + self.branch1_influence = np.clip(self.branch1_influence, 0, 1) + self.branch1_mixing_depth = np.clip(self.branch1_mixing_depth, 0, 1) + idx_crossfeed = int(round(self.num_inference_steps*self.branch1_mixing_depth)) + list_latents2 = self.run_diffusion( + list_conditionings, + idx_start=idx_crossfeed, + latents_for_injection=self.tree_latents[0], + seed_source=self.seed2, + seed_mixing_target=self.seed1, + mixing_coeff=self.branch1_influence) + else: + list_latents2 = self.run_diffusion(list_conditionings) + self.tree_latents[-1] = list_latents2 + + if return_image: + return self.sdh.latent2image(list_latents2[-1]) + else: + return list_latents2 + + + + def run_transition_OLD( self, recycle_img1: Optional[bool] = False, recycle_img2: Optional[bool] = False, @@ -423,9 +610,9 @@ class LatentBlending(): if self.branch1_influence > 0.0 and not self.branch1_insertion_completed: assert self.list_nmb_branches[0]==2, 'branch1 influnce currently requires the self.list_nmb_branches[0] = 0' self.branch1_influence = np.clip(self.branch1_influence, 0, 1) - self.branch1_fract_crossfeed = np.clip(self.branch1_fract_crossfeed, 0, 1) + self.branch1_mixing_depth = np.clip(self.branch1_mixing_depth, 0, 1) self.list_nmb_branches.insert(1, 2) - idx_crossfeed = int(round(self.list_injection_idx[1]*self.branch1_fract_crossfeed)) + idx_crossfeed = int(round(self.list_injection_idx[1]*self.branch1_mixing_depth)) self.list_injection_idx_ext.insert(1, idx_crossfeed) self.tree_fracts.insert(1, self.tree_fracts[0]) self.tree_status.insert(1, self.tree_status[0]) @@ -606,6 +793,9 @@ class LatentBlending(): latents_for_injection: torch.FloatTensor = None, idx_start: int = -1, idx_stop: int = -1, + seed_source: int = -1, + seed_mixing_target: int = -1, + mixing_coeff: float = 0.0, return_image: Optional[bool] = False ): r""" @@ -620,6 +810,7 @@ class LatentBlending(): Index of the diffusion process start and where the latents_for_injection are injected idx_stop: int Index of the diffusion process end. + FIXME ARGS return_image: Optional[bool] Optionally return image directly """ @@ -630,14 +821,23 @@ class LatentBlending(): if self.mode == 'standard': text_embeddings = list_conditionings[0] - return self.sdh.run_diffusion_standard(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image) + return self.sdh.run_diffusion_standard( + text_embeddings, + latents_for_injection=latents_for_injection, + idx_start=idx_start, + idx_stop=idx_stop, + seed_source=seed_source, + seed_mixing_target=seed_mixing_target, + mixing_coeff=mixing_coeff, + return_image=return_image, + ) elif self.mode == 'inpaint': text_embeddings = list_conditionings[0] assert self.sdh.image_source is not None, "image_source is None. Please run init_inpainting first." assert self.sdh.mask_image is not None, "image_source is None. Please run init_inpainting first." return self.sdh.run_diffusion_inpaint(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image) - + # FIXME LONG LINE elif self.mode == 'upscale': cond = list_conditionings[0] uc_full = list_conditionings[1] @@ -881,8 +1081,6 @@ class LatentBlending(): if inject_img2: self.tree_latents[t_block][-1] = list_latents[self.list_injection_idx_ext[t_block]:self.list_injection_idx_ext[t_block+1]] - - def swap_forward(self): r""" @@ -901,6 +1099,21 @@ class LatentBlending(): self.tree_final_imgs = [] + def get_lpips_similarity(self, imgA, imgB): + # FIXME + tensorA = torch.from_numpy(imgA).float().cuda(self.device) + tensorA = 2*tensorA/255.0 - 1 + tensorA = tensorA.permute([2,0,1]).unsqueeze(0) + + tensorB = torch.from_numpy(imgB).float().cuda(self.device) + tensorB = 2*tensorB/255.0 - 1 + tensorB = tensorB.permute([2,0,1]).unsqueeze(0) + lploss = self.lpips(tensorA, tensorB) + lploss = float(lploss[0][0][0][0]) + + return lploss + + # Auxiliary functions def get_closest_idx( fract_mixing: float, @@ -1169,31 +1382,79 @@ if __name__ == "__main__": #%% First let us spawn a stable diffusion holder device = "cuda" fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt" - fp_config = 'configs/v2-inference.yaml' - sdh = StableDiffusionHolder(fp_ckpt, fp_config, device) + sdh = StableDiffusionHolder(fp_ckpt) xxx #%% Next let's set up all parameters - quality = 'medium' - depth_strength = 0.65 # Specifies how deep (in terms of diffusion iterations the first branching happens) - fixed_seeds = [69731932, 504430820] + depth_strength = 0.3 # Specifies how deep (in terms of diffusion iterations the first branching happens) + fixed_seeds = [697164, 430214] - prompt1 = "photo of a beautiful cherry forest covered in white flowers, ambient light, very detailed, magic" - prompt2 = "photo of an golden statue with a funny hat, surrounded by ferns and vines, grainy analog photograph, mystical ambience, incredible detail" + prompt1 = "photo of a desert and a sky" + prompt2 = "photo of a tree with a lake" duration_transition = 12 # In seconds fps = 30 # Spawn latent blending self = LatentBlending(sdh) - self.branch1_influence = 0.8 - self.load_branching_profile(quality=quality, depth_strength=0.3) + self.set_prompt1(prompt1) self.set_prompt2(prompt2) # Run latent blending - imgs_transition = self.run_transition(fixed_seeds=fixed_seeds) - + self.branch1_influence = 0.3 + self.branch1_mixing_depth = 0.4 + self.run_transition(depth_strength=depth_strength, fixed_seeds=fixed_seeds) + #%% + self.branch1_influence = 0.3 + self.branch1_mixing_depth = 0.5 + img2 = self.compute_latents2(return_image=True) + Image.fromarray(img2) + + #%% + idx_injection = 15 + fract_mixing = 0.5 + list_conditionings = self.get_mixed_conditioning(fract_mixing) + latents_for_injection = interpolate_spherical(self.tree_latents[0][idx_injection], self.tree_latents[-1][idx_injection], fract_mixing) + list_latents = self.run_diffusion(list_conditionings, latents_for_injection=latents_for_injection, idx_start=idx_injection) + img_mix = self.sdh.latent2image((list_latents[-1])) + + Image.fromarray(np.concatenate((img1,img_mix,img2), axis=1)).resize((800,800//3)) + + #%% scheme + # init scheme + list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps-1, 2) + list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32) + + #%% + list_compute_steps = self.num_inference_steps - list_idx_injection + list_compute_steps *= list_nmb_stems + t_compute = np.sum(list_compute_steps) * self.dt_per_diff + increase_done = False + for s_idx in range(len(list_nmb_stems)-1): + if list_nmb_stems[s_idx+1] / list_nmb_stems[s_idx] >= 3: + list_nmb_stems[s_idx] += 1 + increase_done = True + break + if not increase_done: + list_nmb_stems[-1] += 1 + print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}") + + + #%% + + imgs_transition = self.tree_final_imgs + # Let's get more cheap frames via linear interpolation (duration_transition*fps frames) + imgs_transition_ext = add_frames_linear_interp(imgs_transition, 15, fps) + + # Save as MP4 + fp_movie = "test.mp4" + if os.path.isfile(fp_movie): + os.remove(fp_movie) + ms = MovieSaver(fp_movie, fps=fps, shape_hw=[sdh.height, sdh.width]) + for img in tqdm(imgs_transition_ext): + ms.write_frame(img) + ms.finalize() diff --git a/stable_diffusion_holder.py b/stable_diffusion_holder.py index f6f5da9..8cae407 100644 --- a/stable_diffusion_holder.py +++ b/stable_diffusion_holder.py @@ -42,7 +42,6 @@ from contextlib import nullcontext from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from einops import repeat, rearrange - #%% @@ -279,24 +278,33 @@ class StableDiffusionHolder: def run_diffusion_standard( self, text_embeddings: torch.FloatTensor, - latents_for_injection: torch.FloatTensor = None, + latents_for_injection = None, idx_start: int = -1, idx_stop: int = -1, - return_image: Optional[bool] = False + seed_source: int = -1, + seed_mixing_target: int = -1, + mixing_coeff: float = 0.0, + return_image: Optional[bool] = False, ): r""" Wrapper function for run_diffusion_standard and run_diffusion_inpaint. Depending on the mode, the correct one will be executed. Args: - text_embeddings: torch.FloatTensor + text_embeddings: torch.FloatTensor Text embeddings used for diffusion - latents_for_injection: torch.FloatTensor + latents_for_injection: torch.FloatTensor or list Latents that are used for injection idx_start: int Index of the diffusion process start and where the latents_for_injection are injected idx_stop: int Index of the diffusion process end. + mixing_coeff: + # FIXME + seed_source: + # FIXME + seed_mixing: + # FIXME return_image: Optional[bool] Optionally return image directly """ @@ -304,12 +312,19 @@ class StableDiffusionHolder: if latents_for_injection is None: do_inject_latents = False + do_mix_latents = False else: - do_inject_latents = True + if mixing_coeff > 0.0: + do_inject_latents = False + do_mix_latents = True + assert seed_mixing_target != -1, "Set to correct seed for mixing" + else: + do_inject_latents = True + do_mix_latents = False precision_scope = autocast if self.precision == "autocast" else nullcontext - generator = torch.Generator(device=self.device).manual_seed(int(self.seed)) + generator = torch.Generator(device=self.device).manual_seed(int(seed_source)) with precision_scope("cuda"): with self.model.ema_scope(): @@ -340,6 +355,16 @@ class StableDiffusionHolder: continue elif i == idx_start: latents = latents_for_injection.clone() + if do_mix_latents: + if i == 0: + generator = torch.Generator(device=self.device).manual_seed(int(seed_mixing_target)) + latents_mixtarget = torch.randn(size, generator=generator, device=self.device) + if i < idx_start: + latents_mixtarget = latents_for_injection[i-1].clone() + latents = interpolate_spherical(latents, latents_mixtarget, mixing_coeff) + + if i == idx_start: + do_mix_latents = False if i == idx_stop: return list_latents_out @@ -576,6 +601,50 @@ class StableDiffusionHolder: image = x_sample.astype(np.uint8) return image +@torch.no_grad() +def interpolate_spherical(p0, p1, fract_mixing: float): + r""" + Helper function to correctly mix two random variables using spherical interpolation. + See https://en.wikipedia.org/wiki/Slerp + The function will always cast up to float64 for sake of extra 4. + Args: + p0: + First tensor for interpolation + p1: + Second tensor for interpolation + fract_mixing: float + Mixing coefficient of interval [0, 1]. + 0 will return in p0 + 1 will return in p1 + 0.x will return a mix between both preserving angular velocity. + """ + + if p0.dtype == torch.float16: + recast_to = 'fp16' + else: + recast_to = 'fp32' + + p0 = p0.double() + p1 = p1.double() + norm = torch.linalg.norm(p0) * torch.linalg.norm(p1) + epsilon = 1e-7 + dot = torch.sum(p0 * p1) / norm + dot = dot.clamp(-1+epsilon, 1-epsilon) + + theta_0 = torch.arccos(dot) + sin_theta_0 = torch.sin(theta_0) + theta_t = theta_0 * fract_mixing + s0 = torch.sin(theta_0 - theta_t) / sin_theta_0 + s1 = torch.sin(theta_t) / sin_theta_0 + interp = p0*s0 + p1*s1 + + if recast_to == 'fp16': + interp = interp.half() + elif recast_to == 'fp32': + interp = interp.float() + + return interp + if __name__ == "__main__":