From 994c77e92465845a87e407421771b93ec97ff73b Mon Sep 17 00:00:00 2001 From: lugo Date: Tue, 29 Nov 2022 18:03:08 +0100 Subject: [PATCH] latent injection for inpainting --- example2_inpaint.py | 12 ++-- latent_blending.py | 133 ++++++++++++++++++++++++++------------------ 2 files changed, 86 insertions(+), 59 deletions(-) diff --git a/example2_inpaint.py b/example2_inpaint.py index 8457bef..5ef84c4 100644 --- a/example2_inpaint.py +++ b/example2_inpaint.py @@ -52,21 +52,25 @@ prompt1 = "photo of a futuristic alien temple in a desert, mystic, glowing, orga lb.set_prompt1(prompt1) lb.init_inpainting(init_empty=True) lb.set_seed(seed0) -image_source = lb.run_diffusion(lb.text_embedding1, return_image=True) +list_latents = lb.run_diffusion(lb.text_embedding1) +image_source = lb.sdh.latent2image(list_latents[-1]) + mask_image = 255*np.ones([512,512], dtype=np.uint8) mask_image[340:420, 170:280, ] = 0 mask_image = Image.fromarray(mask_image) #%% Next let's set up all parameters -fixed_seeds = [seed0, 280335986] +lb.inject_latents(list_latents, inject_img1=True) + +fixed_seeds = [seed0, 6579436] prompt1 = "photo of a futuristic alien temple in a desert, mystic, glowing, organic, intricate, sci-fi movie, mesmerizing, scary" -prompt2 = "aerial photo of a futuristic alien temple in a coastal area, waves clashing" +prompt2 = "aerial photo of a futuristic alien temple in a blue coastal area, the sun is shining with a bright light" lb.set_prompt1(prompt1) lb.set_prompt2(prompt2) lb.init_inpainting(image_source, mask_image) -imgs_transition = lb.run_transition(fixed_seeds=fixed_seeds) +imgs_transition = lb.run_transition(recycle_img1=True, fixed_seeds=fixed_seeds) # let's get more cheap frames via linear interpolation duration_transition = 12 diff --git a/latent_blending.py b/latent_blending.py index faf126a..1a28358 100644 --- a/latent_blending.py +++ b/latent_blending.py @@ -69,11 +69,12 @@ class LatentBlending(): thus high values can give rough transitions. Values around 2 should be fine. """ + assert guidance_scale_mid_damper>0 and guidance_scale_mid_damper<=1.0, f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}" + self.sdh = sdh self.device = self.sdh.device self.width = self.sdh.width self.height = self.sdh.height - assert guidance_scale_mid_damper>0 and guidance_scale_mid_damper<=1.0, f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}" self.guidance_scale_mid_damper = guidance_scale_mid_damper self.mid_compression_scaler = mid_compression_scaler self.seed = 420 # Run self.set_seed or fixed_seeds argument in run_transition @@ -81,9 +82,9 @@ class LatentBlending(): # Initialize vars self.prompt1 = "" self.prompt2 = "" - self.tree_latents = [] - self.tree_fracts = [] - self.tree_status = [] + self.tree_latents = None + self.tree_fracts = None + self.tree_status = None self.tree_final_imgs = [] self.list_nmb_branches_prev = [] self.list_injection_idx_prev = [] @@ -100,8 +101,7 @@ class LatentBlending(): def init_mode(self, mode='standard'): r""" - Automatically sets the mode of this class, depending on the supplied pipeline. - FIXME XXX + Sets the mode of this class, either inpaint of standard. """ if mode == 'inpaint': self.sdh.image_source = None @@ -268,16 +268,46 @@ class LatentBlending(): assert max(list_injection_idx) < num_inference_steps, "Decrease the injection index or strength" assert len(list_injection_idx) == len(list_nmb_branches), "Need to have same length" assert max(list_injection_idx) < num_inference_steps,"Injection index cannot happen after last diffusion step! Decrease list_injection_idx or list_injection_strength[-1]" - + + # Auto inits + list_injection_idx_ext = list_injection_idx[:] + list_injection_idx_ext.append(num_inference_steps) + + # If injection at depth 0 not specified, we will start out with 2 branches + if list_injection_idx_ext[0] != 0: + list_injection_idx_ext.insert(0,0) + list_nmb_branches.insert(0,2) + assert list_nmb_branches[0] == 2, "Need to start with 2 branches. set list_nmb_branches[0]=2" + + # Set attributes self.num_inference_steps = num_inference_steps self.sdh.num_inference_steps = num_inference_steps self.list_nmb_branches = list_nmb_branches self.list_injection_idx = list_injection_idx + self.list_injection_idx_ext = list_injection_idx_ext - - + self.init_tree_struct() + + def init_tree_struct(self): + r""" + Initializes tree variables for holding latents etc. + """ + + self.tree_latents = [] + self.tree_fracts = [] + self.tree_status = [] + self.tree_final_imgs_timing = [0]*self.list_nmb_branches[-1] + + nmb_blocks_time = len(self.list_injection_idx_ext)-1 + for t_block in range(nmb_blocks_time): + nmb_branches = self.list_nmb_branches[t_block] + list_fract_mixing_current = get_spacing(nmb_branches, self.mid_compression_scaler) + self.tree_fracts.append(list_fract_mixing_current) + self.tree_latents.append([None]*nmb_branches) + self.tree_status.append(['untouched']*nmb_branches) + def run_transition( self, recycle_img1: Optional[bool] = False, @@ -313,57 +343,34 @@ class LatentBlending(): # Ensure correct num_inference_steps in holder self.sdh.num_inference_steps = self.num_inference_steps - # Recycling? There are requirements - if recycle_img1 or recycle_img2: - if self.list_nmb_branches_prev == []: - print("Warning. You want to recycle but there is nothing here. Disabling recycling.") - recycle_img1 = False - recycle_img2 = False - elif self.list_nmb_branches_prev != self.list_nmb_branches: - print("Warning. Cannot change list_nmb_branches if recycling latent. Disabling recycling.") - recycle_img1 = False - recycle_img2 = False - elif self.list_injection_idx_prev != self.list_injection_idx: - print("Warning. Cannot change list_nmb_branches if recycling latent. Disabling recycling.") - recycle_img1 = False - recycle_img2 = False + # # Recycling? There are requirements + # if recycle_img1 or recycle_img2: + # # if self.list_nmb_branches_prev == []: + # # print("Warning. You want to recycle but there is nothing here. Disabling recycling.") + # # recycle_img1 = False + # # recycle_img2 = False + # if self.list_nmb_branches_prev != self.list_nmb_branches: + # print("Warning. Cannot change list_nmb_branches if recycling latent. Disabling recycling.") + # recycle_img1 = False + # recycle_img2 = False + # elif self.list_injection_idx_prev != self.list_injection_idx: + # print("Warning. Cannot change list_nmb_branches if recycling latent. Disabling recycling.") + # recycle_img1 = False + # recycle_img2 = False # Make a backup for future reference self.list_nmb_branches_prev = self.list_nmb_branches[:] self.list_injection_idx_prev = self.list_injection_idx[:] - # Auto inits - list_injection_idx_ext = self.list_injection_idx[:] - list_nmb_branches = self.list_nmb_branches[:] - list_injection_idx_ext.append(self.num_inference_steps) - - # If injection at depth 0 not specified, we will start out with 2 branches - if list_injection_idx_ext[0] != 0: - list_injection_idx_ext.insert(0,0) - list_nmb_branches.insert(0,2) - assert list_nmb_branches[0] == 2, "Need to start with 2 branches. set list_nmb_branches[0]=2" - # Pre-define entire branching tree structures + self.tree_final_imgs = [None]*self.list_nmb_branches[-1] + nmb_blocks_time = len(self.list_injection_idx_ext)-1 if not recycle_img1 and not recycle_img2: - self.tree_latents = [] - self.tree_fracts = [] - self.tree_status = [] - self.tree_final_imgs = [None]*list_nmb_branches[-1] - self.tree_final_imgs_timing = [0]*list_nmb_branches[-1] - - nmb_blocks_time = len(list_injection_idx_ext)-1 - for t_block in range(nmb_blocks_time): - nmb_branches = list_nmb_branches[t_block] - # list_fract_mixing_current = np.linspace(0, 1, nmb_branches) - list_fract_mixing_current = get_spacing(nmb_branches, self.mid_compression_scaler) - self.tree_fracts.append(list_fract_mixing_current) - self.tree_latents.append([None]*nmb_branches) - self.tree_status.append(['untouched']*nmb_branches) + self.init_tree_struct() else: - self.tree_final_imgs = [None]*list_nmb_branches[-1] - nmb_blocks_time = len(list_injection_idx_ext)-1 + self.tree_final_imgs = [None]*self.list_nmb_branches[-1] for t_block in range(nmb_blocks_time): - nmb_branches = list_nmb_branches[t_block] + nmb_branches = self.list_nmb_branches[t_block] for idx_branch in range(nmb_branches): self.tree_status[t_block][idx_branch] = 'untouched' if recycle_img1: @@ -386,7 +393,7 @@ class LatentBlending(): list_compute.extend(list_local_stem[::-1]) # setup compute order: start from last leafs (the final transition images) and work way down. what parents do they need? - for idx_leaf in range(1, list_nmb_branches[-1]): + for idx_leaf in range(1, self.list_nmb_branches[-1]): list_local_stem = [] t_block = nmb_blocks_time - 1 t_block_prev = t_block - 1 @@ -414,7 +421,7 @@ class LatentBlending(): return self.tree_final_imgs # print(f"computing t_block {t_block} idx_branch {idx_branch}") - idx_stop = list_injection_idx_ext[t_block+1] + idx_stop = self.list_injection_idx_ext[t_block+1] fract_mixing = self.tree_fracts[t_block][idx_branch] text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing) self.set_guidance_mid_dampening(fract_mixing) @@ -423,7 +430,7 @@ class LatentBlending(): if fixed_seeds is not None: if idx_branch == 0: self.set_seed(fixed_seeds[0]) - elif idx_branch == list_nmb_branches[0] -1: + elif idx_branch == self.list_nmb_branches[0] -1: self.set_seed(fixed_seeds[1]) list_latents = self.run_diffusion(text_embeddings_mix, idx_stop=idx_stop) else: @@ -434,7 +441,7 @@ class LatentBlending(): latents2 = latents1 else: latents2 = self.tree_latents[t_block-1][b_parent2][-1] - idx_start = list_injection_idx_ext[t_block] + idx_start = self.list_injection_idx_ext[t_block] fract_mixing_parental = (fract_mixing - self.tree_fracts[t_block-1][b_parent1]) / (self.tree_fracts[t_block-1][b_parent2] - self.tree_fracts[t_block-1][b_parent1]) latents_for_injection = interpolate_spherical(latents1, latents2, fract_mixing_parental) list_latents = self.run_diffusion(text_embeddings_mix, latents_for_injection, idx_start=idx_start, idx_stop=idx_stop) @@ -594,6 +601,22 @@ class LatentBlending(): self.seed = seed self.sdh.seed = seed + def inject_latents(self, list_latents, inject_img1=True, inject_img2=False): + r""" + Injects list of latents into tree structure. + + """ + assert inject_img1 != inject_img2, "Either inject into img1 or img2" + assert self.tree_latents is not None, "You need to setup the branching beforehand, run autosetup_branching() or setup_branching() before" + + for t_block in range(len(self.list_injection_idx)): + if inject_img1: + self.tree_latents[t_block][0] = list_latents[self.list_injection_idx_ext[t_block]:self.list_injection_idx_ext[t_block+1]] + if inject_img2: + self.tree_latents[t_block][-1] = list_latents[self.list_injection_idx_ext[t_block]:self.list_injection_idx_ext[t_block+1]] + + + def swap_forward(self): r"""