From e889c2a0ccf5c5106fd98edb60a3ba4470ddf17d Mon Sep 17 00:00:00 2001 From: Johannes Stelzer Date: Tue, 9 Jan 2024 15:31:17 +0100 Subject: [PATCH] better branch handling --- diffusers_holder.py | 2 +- latent_blending.py | 147 ++++++++++++++++++++++++++++---------------- 2 files changed, 96 insertions(+), 53 deletions(-) diff --git a/diffusers_holder.py b/diffusers_holder.py index a1b17ab..bb8f615 100644 --- a/diffusers_holder.py +++ b/diffusers_holder.py @@ -415,7 +415,7 @@ if __name__ == "__main__": # img_refx = self.pipe(prompt=prompt1, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale)[0] - img_refx = self.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=True) + img_refx = self.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=False) dt_ref = time.time() - t0 img_refx.save(f"x_{prefix}_{i}.jpg") diff --git a/latent_blending.py b/latent_blending.py index ad94ea0..e298b10 100644 --- a/latent_blending.py +++ b/latent_blending.py @@ -76,17 +76,11 @@ class LatentBlending(): self.tree_status = None self.tree_final_imgs = [] - self.list_nmb_branches_prev = [] - self.list_injection_idx_prev = [] self.text_embedding1 = None self.text_embedding2 = None self.image1_lowres = None self.image2_lowres = None self.negative_prompt = None - self.num_inference_steps = self.dh.num_inference_steps - self.noise_level_upscaling = 20 - self.list_injection_idx = None - self.list_nmb_branches = None # Mixing parameters self.branch1_crossfeed_power = 0.0 @@ -100,11 +94,34 @@ class LatentBlending(): self.set_guidance_scale(guidance_scale) self.multi_transition_img_first = None self.multi_transition_img_last = None - self.dt_per_diff = 0 + self.dt_unet_step = 0 self.lpips = lpips.LPIPS(net='alex').cuda(self.device) self.set_prompt1("") self.set_prompt2("") + + self.set_num_inference_steps() + self.benchmark_speed() + self.set_branching() + + def benchmark_speed(self): + """ + Measures the time per diffusion step and for the vae decoding + """ + + text_embeddings = self.dh.get_text_embedding("test") + latents_start = self.dh.get_noise(np.random.randint(111111)) + # warmup + list_latents = self.dh.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=False, idx_start=self.num_inference_steps-1) + # bench unet + t0 = time.time() + list_latents = self.dh.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=False, idx_start=self.num_inference_steps-1) + self.dt_unet_step = time.time() - t0 + + # bench vae + t0 = time.time() + img = self.dh.latent2image(list_latents[-1]) + self.dt_vae = time.time() - t0 def set_dimensions(self, size_output=None): r""" @@ -208,28 +225,21 @@ class LatentBlending(): image: Image """ self.image2_lowres = image - - def run_transition( - self, - recycle_img1: Optional[bool] = False, - recycle_img2: Optional[bool] = False, - num_inference_steps: Optional[int] = 30, - list_idx_injection: Optional[int] = None, - list_nmb_stems: Optional[int] = None, - depth_strength: Optional[float] = 0.3, - t_compute_max_allowed: Optional[float] = None, - nmb_max_branches: Optional[int] = None, - fixed_seeds: Optional[List[int]] = None): - r""" - Function for computing transitions. - Returns a list of transition images using spherical latent blending. - Args: - recycle_img1: Optional[bool]: - Don't recompute the latents for the first keyframe (purely prompt1). Saves compute. - recycle_img2: Optional[bool]: - Don't recompute the latents for the second keyframe (purely prompt2). Saves compute. - num_inference_steps: - Number of diffusion steps. Higher values will take more compute time. + + def set_num_inference_steps(self, num_inference_steps=None): + if self.dh.is_sdxl_turbo: + if num_inference_steps is None: + num_inference_steps = 4 + else: + if num_inference_steps is None: + num_inference_steps = 30 + + self.num_inference_steps = num_inference_steps + self.dh.set_num_inference_steps(num_inference_steps) + + def set_branching(self, depth_strength=None, t_compute_max_allowed=None, nmb_max_branches=None): + """ + Sets the branching structure of the blending tree. Default arguments depend on pipe! depth_strength: Determines how deep the first injection will happen. Deeper injections will cause (unwanted) formation of new structures, @@ -241,6 +251,45 @@ class LatentBlending(): Either provide t_compute_max_allowed or nmb_max_branches. The maximum number of branches to be computed. Higher values give better results. Use this if you want to have controllable results independent of your computer. + """ + if self.dh.is_sdxl_turbo: + assert t_compute_max_allowed is None, "time-based branching not supported for SDXL Turbo" + if depth_strength is not None: + idx_inject = int(round(self.num_inference_steps*depth_strength)) + else: + idx_inject = 2 + if nmb_max_branches is None: + nmb_max_branches = 10 + + self.list_idx_injection = [idx_inject] + self.list_nmb_stems = [nmb_max_branches] + + else: + if depth_strength is None: + depth_strength = 0.5 + if t_compute_max_allowed is None and nmb_max_branches is None: + t_compute_max_allowed = 20 + elif t_compute_max_allowed is not None and nmb_max_branches is not None: + raise ValueErorr("Either specify t_compute_max_allowed or nmb_max_branches") + + self.list_idx_injection, self.list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches) + + def run_transition( + self, + recycle_img1: Optional[bool] = False, + recycle_img2: Optional[bool] = False, + fixed_seeds: Optional[List[int]] = None): + r""" + Function for computing transitions. + Returns a list of transition images using spherical latent blending. + Args: + recycle_img1: Optional[bool]: + Don't recompute the latents for the first keyframe (purely prompt1). Saves compute. + recycle_img2: Optional[bool]: + Don't recompute the latents for the second keyframe (purely prompt2). Saves compute. + num_inference_steps: + Number of diffusion steps. Higher values will take more compute time. + fixed_seeds: Optional[List[int)]: You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2). Otherwise random seeds will be taken. @@ -261,12 +310,7 @@ class LatentBlending(): self.seed1 = fixed_seeds[0] self.seed2 = fixed_seeds[1] - # Ensure correct num_inference_steps in holder - if self.dh.is_sdxl_turbo: - num_inference_steps = 4 #ideal results - self.num_inference_steps = num_inference_steps - self.dh.set_num_inference_steps(num_inference_steps) - + # Compute / Recycle first image if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps: list_latents1 = self.compute_latents1() @@ -291,16 +335,13 @@ class LatentBlending(): self.parental_crossfeed_power = 1.0 self.parental_crossfeed_power_decay = 1.0 self.parental_crossfeed_range = 1.0 - list_idx_injection = [2] - list_nmb_stems = [10] - else: - list_idx_injection, list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches) + # Run iteratively, starting with the longest trajectory. # Always inserting new branches where they are needed most according to image similarity - for s_idx in tqdm(range(len(list_idx_injection))): - nmb_stems = list_nmb_stems[s_idx] - idx_injection = list_idx_injection[s_idx] + for s_idx in tqdm(range(len(self.list_idx_injection))): + nmb_stems = self.list_nmb_stems[s_idx] + idx_injection = self.list_idx_injection[s_idx] for i in range(nmb_stems): fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection) @@ -310,6 +351,9 @@ class LatentBlending(): # print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection} bp1 {b_parent1} bp2 {b_parent2}") return self.tree_final_imgs + + + def compute_latents1(self, return_image=False): r""" @@ -327,7 +371,7 @@ class LatentBlending(): latents_start=latents_start, idx_start=0) t1 = time.time() - self.dt_per_diff = (t1 - t0) / self.num_inference_steps + self.dt_unet_step = (t1 - t0) / self.num_inference_steps self.tree_latents[0] = list_latents1 if return_image: return self.dh.latent2image(list_latents1[-1]) @@ -447,8 +491,8 @@ class LatentBlending(): while not stop_criterion_reached: list_compute_steps = self.num_inference_steps - list_idx_injection list_compute_steps *= list_nmb_stems - t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15 * np.sum(list_nmb_stems) - t_compute += 2 * self.num_inference_steps * self.dt_per_diff # outer branches + t_compute = np.sum(list_compute_steps) * self.dt_unet_step + self.dt_vae * np.sum(list_nmb_stems) + t_compute += 2 * (self.num_inference_steps * self.dt_unet_step + self.dt_vae) # outer branches increase_done = False for s_idx in range(len(list_nmb_stems) - 1): if list_nmb_stems[s_idx + 1] / list_nmb_stems[s_idx] >= 1: @@ -765,8 +809,8 @@ if __name__ == "__main__": from diffusers_holder import DiffusersHolder from diffusers import DiffusionPipeline from diffusers import AutoencoderTiny - # pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0" - pretrained_model_name_or_path = "stabilityai/sdxl-turbo" + pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0" + # pretrained_model_name_or_path = "stabilityai/sdxl-turbo" pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16, variant="fp16") @@ -776,8 +820,8 @@ if __name__ == "__main__": dh = DiffusersHolder(pipe) # %% Next let's set up all parameters - size_output = (512, 512) - # size_output = (1024, 1024) + # size_output = (512, 512) + size_output = (1024, 1024) prompt1 = "photo of underwater landscape, fish, und the sea, incredible detail, high resolution" prompt2 = "rendering of an alien planet, strange plants, strange creatures, surreal" negative_prompt = "blurry, ugly, pale" # Optional @@ -793,10 +837,9 @@ if __name__ == "__main__": lb.set_prompt2(prompt2) lb.set_dimensions(size_output) lb.set_negative_prompt(negative_prompt) - # Run latent blending - lb.run_transition(fixed_seeds=[420, 421], t_compute_max_allowed=15) + lb.run_transition(fixed_seeds=[420, 421]) # Save movie fp_movie = f'test.mp4' @@ -804,4 +847,4 @@ if __name__ == "__main__": - #%% +