From 14bc3323b5cec9d633a72ee8f72795d2f76d54d6 Mon Sep 17 00:00:00 2001 From: lugo Date: Mon, 28 Nov 2022 15:34:18 +0100 Subject: [PATCH] mid scaling --- example1_standard.py | 8 ++--- latent_blending.py | 85 +++++++++++++++++++++++++++++++++----------- 2 files changed, 68 insertions(+), 25 deletions(-) diff --git a/example1_standard.py b/example1_standard.py index eb9bd64..5a6cec1 100644 --- a/example1_standard.py +++ b/example1_standard.py @@ -39,15 +39,15 @@ sdh = StableDiffusionHolder(fp_ckpt, fp_config, device) #%% Next let's set up all parameters -guidance_scale = 5 -quality = 'high' +quality = 'medium' fixed_seeds = [69731932, 504430820] -lb = LatentBlending(sdh, guidance_scale) +lb = LatentBlending(sdh) prompt1 = "photo of a beautiful forest covered in white flowers, ambient light, very detailed, magic" prompt2 = "photo of an golden statue with a funny hat, surrounded by ferns and vines, grainy analog photograph, mystical ambience, incredible detail" lb.set_prompt1(prompt1) lb.set_prompt2(prompt2) + lb.autosetup_branching(quality=quality) imgs_transition = lb.run_transition(fixed_seeds=fixed_seeds) @@ -58,7 +58,7 @@ fps = 60 imgs_transition_ext = add_frames_linear_interp(imgs_transition, duration_transition, fps) # movie saving -fp_movie = f"movie_example1_{quality}.mp4" +fp_movie = f"movie_example1.mp4" if os.path.isfile(fp_movie): os.remove(fp_movie) ms = MovieSaver(fp_movie, fps=fps, shape_hw=[sdh.height, sdh.width]) diff --git a/latent_blending.py b/latent_blending.py index 19e4ee3..55c440f 100644 --- a/latent_blending.py +++ b/latent_blending.py @@ -47,31 +47,36 @@ class LatentBlending(): def __init__( self, sdh: None, - guidance_scale: float = 7.5, + guidance_scale: float = 4, + guidance_scale_mid_damper: float = 0.5, + mid_compression_scaler: float = 2.0, ): r""" Initializes the latent blending class. Args: - FIXME XXX - height: int - Height of the desired output image. The model was trained on 512. - width: int - Width of the desired output image. The model was trained on 512. guidance_scale: float Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - seed: int - Random seed. + guidance_scale_mid_damper: float = 0.5 + Reduces the guidance scale towards the middle of the transition. + A value of 0.5 would decrease the guidance_scale towards the middle linearly by 0.5. + mid_compression_scaler: float = 2.0 + Increases the sampling density in the middle (where most changes happen). Higher value + imply more values in the middle. However the inflection point can occur outside the middle, + thus high values can give rough transitions. Values around 2 should be fine. """ self.sdh = sdh self.device = self.sdh.device self.width = self.sdh.width self.height = self.sdh.height - self.seed = 420 #use self.set_seed or fixed_seeds argument in run_transition + assert guidance_scale_mid_damper>0 and guidance_scale_mid_damper<=1.0, f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}" + self.guidance_scale_mid_damper = guidance_scale_mid_damper + self.mid_compression_scaler = mid_compression_scaler + self.seed = 420 # Run self.set_seed or fixed_seeds argument in run_transition # Initialize vars self.prompt1 = "" @@ -109,8 +114,20 @@ class LatentBlending(): r""" sets the guidance scale. """ + self.guidance_scale_base = guidance_scale self.guidance_scale = guidance_scale self.sdh.guidance_scale = guidance_scale + + def set_guidance_mid_dampening(self, fract_mixing): + r""" + Tunes the guidance scale down as a linear function of fract_mixing, + towards 0.5 the minimum will be reached. + """ + mid_factor = 1 - np.abs(fract_mixing - 0.5)/ 0.5 + max_guidance_reduction = self.guidance_scale_base * (1-self.guidance_scale_mid_damper) + guidance_scale_effective = self.guidance_scale_base - max_guidance_reduction*mid_factor + self.guidance_scale = guidance_scale_effective + self.sdh.guidance_scale = guidance_scale_effective def set_prompt1(self, prompt: str): r""" @@ -158,6 +175,7 @@ class LatentBlending(): total number of frames nmb_mindist: int = 3 minimum distance in terms of diffusion iteratinos between subsequent injections + """ if quality == 'lowest': @@ -201,13 +219,13 @@ class LatentBlending(): list_injection_idx = list_injection_idx_clean list_nmb_branches = list_nmb_branches_clean - print(f"num_inference_steps: {num_inference_steps}") - print(f"list_injection_idx: {list_injection_idx}") - print(f"list_nmb_branches: {list_nmb_branches}") + # print(f"num_inference_steps: {num_inference_steps}") + # print(f"list_injection_idx: {list_injection_idx}") + # print(f"list_nmb_branches: {list_nmb_branches}") - self.num_inference_steps = num_inference_steps - self.list_injection_idx = list_injection_idx - self.list_nmb_branches = list_nmb_branches + list_nmb_branches = list_nmb_branches + list_injection_idx = list_injection_idx + self.setup_branching(num_inference_steps, list_nmb_branches=list_nmb_branches, list_injection_idx=list_injection_idx) def setup_branching(self, @@ -215,7 +233,7 @@ class LatentBlending(): list_nmb_branches: List[int] = None, list_injection_strength: List[float] = None, list_injection_idx: List[int] = None, - guidance_downscale: float = 1.0, + ): r""" Sets the branching structure for making transitions. @@ -229,13 +247,9 @@ class LatentBlending(): list_injection_idx: List[int]: list of injection strengths within interval [0, 1), values need to be increasing. Alternatively you can specify the list_injection_strength. - guidance_downscale: float = 1.0 - reduces the guidance scale towards the middle of the transition - """ # Assert - assert guidance_downscale>0 and guidance_downscale<=1.0, "guidance_downscale neees to be in interval (0,1]" assert not((list_injection_strength is not None) and (list_injection_idx is not None)), "suppyl either list_injection_strength or list_injection_idx" if list_injection_strength is None: @@ -262,6 +276,7 @@ class LatentBlending(): self.sdh.num_inference_steps = num_inference_steps self.list_nmb_branches = list_nmb_branches self.list_injection_idx = list_injection_idx + self.guidance_scale_mid_damper = guidance_scale_mid_damper @@ -341,7 +356,8 @@ class LatentBlending(): nmb_blocks_time = len(list_injection_idx_ext)-1 for t_block in range(nmb_blocks_time): nmb_branches = list_nmb_branches[t_block] - list_fract_mixing_current = np.linspace(0, 1, nmb_branches) + # list_fract_mixing_current = np.linspace(0, 1, nmb_branches) + list_fract_mixing_current = get_spacing(nmb_branches, self.mid_compression_scaler) self.tree_fracts.append(list_fract_mixing_current) self.tree_latents.append([None]*nmb_branches) self.tree_status.append(['untouched']*nmb_branches) @@ -403,6 +419,8 @@ class LatentBlending(): idx_stop = list_injection_idx_ext[t_block+1] fract_mixing = self.tree_fracts[t_block][idx_branch] text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing) + self.set_guidance_mid_dampening(fract_mixing) + # print(f"fract_mixing {fract_mixing} guid {self.sdh.guidance_scale}") if t_block == 0: if fixed_seeds is not None: if idx_branch == 0: @@ -787,6 +805,31 @@ def add_frames_linear_interp( return list_imgs_interp +def get_spacing(nmb_points:int, scaling: float): + """ + Helper function for getting nonlinear spacing between 0 and 1, symmetric around 0.5 + Args: + nmb_points: int + Number of points between [0, 1] + scaling: float + Higher values will return higher sampling density around 0.5 + + """ + if scaling < 1.7: + return np.linspace(0, 1, nmb_points) + nmb_points_per_side = nmb_points//2 + 1 + if np.mod(nmb_points, 2) != 0: # uneven case + left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5) + right_side = 1-left_side[::-1][1:] + else: + left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)[0:-1] + right_side = 1-left_side[::-1] + + all_fracts = np.hstack([left_side, right_side]) + + return all_fracts + + def get_time(resolution=None): """ Helper function returning an nicely formatted time string, e.g. 221117_1620