From bc34a83008a17a23b141fcbd3a568f16ccafac6d Mon Sep 17 00:00:00 2001 From: Johannes Stelzer Date: Sun, 19 Feb 2023 15:32:37 +0100 Subject: [PATCH] masked --- latent_blending.py | 236 +++++++++++++++++++++---------------- stable_diffusion_holder.py | 39 ++---- 2 files changed, 141 insertions(+), 134 deletions(-) diff --git a/latent_blending.py b/latent_blending.py index 12ea8ea..4a5fbf1 100644 --- a/latent_blending.py +++ b/latent_blending.py @@ -123,6 +123,7 @@ class LatentBlending(): self.multi_transition_img_first = None self.multi_transition_img_last = None self.dt_per_diff = 0 + self.spatial_mask = None self.lpips = lpips.LPIPS(net='alex').cuda(self.device) @@ -277,6 +278,9 @@ class LatentBlending(): self.tree_final_imgs = [self.sdh.latent2image((self.tree_latents[0][-1])), self.sdh.latent2image((self.tree_latents[-1][-1]))] self.tree_idx_injection = [0, 0] + # Hard-fix. Apply spatial mask only for list_latents2 but not for transition. WIP... + self.spatial_mask = None + # Set up branching scheme (dependent on provided compute time) list_idx_injection, list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches) @@ -296,6 +300,109 @@ class LatentBlending(): return self.tree_final_imgs + def compute_latents1(self, return_image=False): + r""" + Runs a diffusion trajectory for the first image + Args: + return_image: bool + whether to return an image or the list of latents + """ + print("starting compute_latents1") + list_conditionings = self.get_mixed_conditioning(0) + t0 = time.time() + latents_start = self.get_noise(self.seed1) + list_latents1 = self.run_diffusion( + list_conditionings, + latents_start = latents_start, + idx_start = 0 + ) + t1 = time.time() + self.dt_per_diff = (t1-t0) / self.num_inference_steps + self.tree_latents[0] = list_latents1 + if return_image: + return self.sdh.latent2image(list_latents1[-1]) + else: + return list_latents1 + + def compute_latents2(self, return_image=False): + r""" + Runs a diffusion trajectory for the last image, which may be affected by the first image's trajectory. + Args: + return_image: bool + whether to return an image or the list of latents + """ + print("starting compute_latents2") + list_conditionings = self.get_mixed_conditioning(1) + latents_start = self.get_noise(self.seed2) + # Influence from branch1 + if self.branch1_influence > 0.0: + # Set up the mixing_coeffs + idx_mixing_stop = int(round(self.num_inference_steps*self.branch1_max_depth_influence)) + mixing_coeffs = list(np.linspace(self.branch1_influence, self.branch1_influence*self.branch1_influence_decay, idx_mixing_stop)) + mixing_coeffs.extend((self.num_inference_steps-idx_mixing_stop)*[0]) + list_latents_mixing = self.tree_latents[0] + list_latents2 = self.run_diffusion( + list_conditionings, + latents_start = latents_start, + idx_start = 0, + list_latents_mixing = list_latents_mixing, + mixing_coeffs = mixing_coeffs + ) + else: + list_latents2 = self.run_diffusion(list_conditionings, latents_start) + self.tree_latents[-1] = list_latents2 + + if return_image: + return self.sdh.latent2image(list_latents2[-1]) + else: + return list_latents2 + + + def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection): + r""" + Runs a diffusion trajectory, using the latents from the respective parents + Args: + fract_mixing: float + the fraction along the transition axis [0, 1] + b_parent1: int + index of parent1 to be used + b_parent2: int + index of parent2 to be used + idx_injection: int + the index in terms of diffusion steps, where the next insertion will start. + """ + list_conditionings = self.get_mixed_conditioning(fract_mixing) + fract_mixing_parental = (fract_mixing - self.tree_fracts[b_parent1]) / (self.tree_fracts[b_parent2] - self.tree_fracts[b_parent1]) + # idx_reversed = self.num_inference_steps - idx_injection + + list_latents_parental_mix = [] + for i in range(self.num_inference_steps): + latents_p1 = self.tree_latents[b_parent1][i] + latents_p2 = self.tree_latents[b_parent2][i] + if latents_p1 is None or latents_p2 is None: + latents_parental = None + else: + latents_parental = interpolate_spherical(latents_p1, latents_p2, fract_mixing_parental) + list_latents_parental_mix.append(latents_parental) + + idx_mixing_stop = int(round(self.num_inference_steps*self.parental_max_depth_influence)) + mixing_coeffs = idx_injection*[self.parental_influence] + nmb_mixing = idx_mixing_stop - idx_injection + if nmb_mixing > 0: + mixing_coeffs.extend(list(np.linspace(self.parental_influence, self.parental_influence*self.parental_influence_decay, nmb_mixing))) + mixing_coeffs.extend((self.num_inference_steps-len(mixing_coeffs))*[0]) + + latents_start = list_latents_parental_mix[idx_injection-1] + list_latents = self.run_diffusion( + list_conditionings, + latents_start = latents_start, + idx_start = idx_injection, + list_latents_mixing = list_latents_parental_mix, + mixing_coeffs = mixing_coeffs + ) + + return list_latents + def get_time_based_branching(self, depth_strength, t_compute_max_allowed=None, nmb_max_branches=None): r""" Sets up the branching scheme dependent on the time that is granted for compute. @@ -419,110 +526,34 @@ class LatentBlending(): self.tree_fracts.insert(b_parent1+1, fract_mixing) self.tree_idx_injection.insert(b_parent1+1, idx_injection) - - def compute_latents1(self, return_image=False): - r""" - Runs a diffusion trajectory for the first image - Args: - return_image: bool - whether to return an image or the list of latents - """ - print("starting compute_latents1") - list_conditionings = self.get_mixed_conditioning(0) - t0 = time.time() - latents_start = self.get_noise(self.seed1) - list_latents1 = self.run_diffusion( - list_conditionings, - latents_start = latents_start, - idx_start = 0 - ) - t1 = time.time() - self.dt_per_diff = (t1-t0) / self.num_inference_steps - self.tree_latents[0] = list_latents1 - if return_image: - return self.sdh.latent2image(list_latents1[-1]) - else: - return list_latents1 - def compute_latents2(self, return_image=False): - r""" - Runs a diffusion trajectory for the last image, which may be affected by the first image's trajectory. - Args: - return_image: bool - whether to return an image or the list of latents - """ - print("starting compute_latents2") - list_conditionings = self.get_mixed_conditioning(1) - latents_start = self.get_noise(self.seed2) - # Influence from branch1 - if self.branch1_influence > 0.0: - # Set up the mixing_coeffs - idx_mixing_stop = int(round(self.num_inference_steps*self.branch1_max_depth_influence)) - mixing_coeffs = list(np.linspace(self.branch1_influence, self.branch1_influence*self.branch1_influence_decay, idx_mixing_stop)) - mixing_coeffs.extend((self.num_inference_steps-idx_mixing_stop)*[0]) - list_latents_mixing = self.tree_latents[0] - list_latents2 = self.run_diffusion( - list_conditionings, - latents_start = latents_start, - idx_start = 0, - list_latents_mixing = list_latents_mixing, - mixing_coeffs = mixing_coeffs - ) - else: - list_latents2 = self.run_diffusion(list_conditionings, latents_start) - self.tree_latents[-1] = list_latents2 - - if return_image: - return self.sdh.latent2image(list_latents2[-1]) - else: - return list_latents2 - - - def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection): - r""" - Runs a diffusion trajectory, using the latents from the respective parents - Args: - fract_mixing: float - the fraction along the transition axis [0, 1] - b_parent1: int - index of parent1 to be used - b_parent2: int - index of parent2 to be used - idx_injection: int - the index in terms of diffusion steps, where the next insertion will start. - """ - list_conditionings = self.get_mixed_conditioning(fract_mixing) - fract_mixing_parental = (fract_mixing - self.tree_fracts[b_parent1]) / (self.tree_fracts[b_parent2] - self.tree_fracts[b_parent1]) - # idx_reversed = self.num_inference_steps - idx_injection - - list_latents_parental_mix = [] - for i in range(self.num_inference_steps): - latents_p1 = self.tree_latents[b_parent1][i] - latents_p2 = self.tree_latents[b_parent2][i] - if latents_p1 is None or latents_p2 is None: - latents_parental = None - else: - latents_parental = interpolate_spherical(latents_p1, latents_p2, fract_mixing_parental) - list_latents_parental_mix.append(latents_parental) - - idx_mixing_stop = int(round(self.num_inference_steps*self.parental_max_depth_influence)) - mixing_coeffs = idx_injection*[self.parental_influence] - nmb_mixing = idx_mixing_stop - idx_injection - if nmb_mixing > 0: - mixing_coeffs.extend(list(np.linspace(self.parental_influence, self.parental_influence*self.parental_influence_decay, nmb_mixing))) - mixing_coeffs.extend((self.num_inference_steps-len(mixing_coeffs))*[0]) - - latents_start = list_latents_parental_mix[idx_injection-1] - list_latents = self.run_diffusion( - list_conditionings, - latents_start = latents_start, - idx_start = idx_injection, - list_latents_mixing = list_latents_parental_mix, - mixing_coeffs = mixing_coeffs - ) - - return list_latents + def get_spatial_mask_template(self): + shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f] + C, H, W = shape_latents + return np.ones((H, W)) + def set_spatial_mask(self, img_mask): + r""" + Helper function to #FIXME + Args: + seed: int + + """ + + shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f] + C, H, W = shape_latents + img_mask = np.asarray(img_mask) + assert len(img_mask.shape) == 2, "Currently, only 2D images are supported as mask" + img_mask = np.clip(img_mask, 0, 1) + assert img_mask.shape[0] == H, f"Your mask needs to be of dimension {H} x {W}" + assert img_mask.shape[1] == W, f"Your mask needs to be of dimension {H} x {W}" + spatial_mask = torch.from_numpy(img_mask).to(device=self.device) + spatial_mask = torch.unsqueeze(spatial_mask, 0) + spatial_mask = spatial_mask.repeat((C,1,1)) + spatial_mask = torch.unsqueeze(spatial_mask, 0) + + self.spatial_mask = spatial_mask + def get_noise(self, seed): r""" @@ -585,6 +616,7 @@ class LatentBlending(): idx_start = idx_start, list_latents_mixing = list_latents_mixing, mixing_coeffs = mixing_coeffs, + spatial_mask = self.spatial_mask, return_image = return_image, ) diff --git a/stable_diffusion_holder.py b/stable_diffusion_holder.py index 8bcf28e..fcd720d 100644 --- a/stable_diffusion_holder.py +++ b/stable_diffusion_holder.py @@ -218,37 +218,6 @@ class StableDiffusionHolder: if len(self.negative_prompt) > 1: self.negative_prompt = [self.negative_prompt[0]] - def init_inpainting( - self, - image_source: Union[Image.Image, np.ndarray] = None, - mask_image: Union[Image.Image, np.ndarray] = None, - init_empty: Optional[bool] = False, - ): - r""" - Initializes inpainting with a source and maks image. - Args: - image_source: Union[Image.Image, np.ndarray] - Source image onto which the mask will be applied. - mask_image: Union[Image.Image, np.ndarray] - Mask image, value = 0 will stay untouched, value = 255 subjet to diffusion - init_empty: Optional[bool]: - Initialize inpainting with an empty image and mask, effectively disabling inpainting, - useful for generating a first image for transitions using diffusion. - """ - if not init_empty: - assert image_source is not None, "init_inpainting: you need to provide image_source" - assert mask_image is not None, "init_inpainting: you need to provide mask_image" - if type(image_source) == np.ndarray: - image_source = Image.fromarray(image_source) - self.image_source = image_source - - if type(mask_image) == np.ndarray: - mask_image = Image.fromarray(mask_image) - self.mask_image = mask_image - else: - self.mask_image = self.mask_empty - self.image_source = self.image_empty - def get_text_embedding(self, prompt): c = self.model.get_learned_conditioning(prompt) @@ -282,6 +251,7 @@ class StableDiffusionHolder: idx_start: int = 0, list_latents_mixing = None, mixing_coeffs = 0.0, + spatial_mask = None, return_image: Optional[bool] = False, ): r""" @@ -295,7 +265,7 @@ class StableDiffusionHolder: idx_start: int Index of the diffusion process start and where the latents_for_injection are injected mixing_coeff: - # FIXME + # FIXME spatial_mask return_image: Optional[bool] Optionally return image directly @@ -313,6 +283,7 @@ class StableDiffusionHolder: if np.sum(list_mixing_coeffs) > 0: assert len(list_latents_mixing) == self.num_inference_steps + precision_scope = autocast if self.precision == "autocast" else nullcontext with precision_scope("cuda"): @@ -345,6 +316,10 @@ class StableDiffusionHolder: if i > 0 and list_mixing_coeffs[i]>0: latents_mixtarget = list_latents_mixing[i-1].clone() latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i]) + + if spatial_mask is not None and list_latents_mixing is not None: + latents = interpolate_spherical(latents, list_latents_mixing[i-1], 1-spatial_mask) + # latents[:,:,-15:,:] = latents_mixtarget[:,:,-15:,:] index = total_steps - i - 1 ts = torch.full((1,), step, device=self.device, dtype=torch.long)