From 704433e267cccbed45b9c6ff32c31e0797c9a939 Mon Sep 17 00:00:00 2001 From: Johannes Stelzer Date: Fri, 21 Jul 2023 14:03:02 +0200 Subject: [PATCH] controlnet upd --- diffusers_holder.py | 91 ++++++++++++++++++++++++--------------------- 1 file changed, 48 insertions(+), 43 deletions(-) diff --git a/diffusers_holder.py b/diffusers_holder.py index 7158c25..a732c6b 100644 --- a/diffusers_holder.py +++ b/diffusers_holder.py @@ -163,9 +163,20 @@ class DiffusersHolder(): image = self.pipe.vae.decode(latents / self.pipe.vae.config.scaling_factor, return_dict=False)[0] image = self.pipe.image_processor.postprocess(image, output_type="pil", do_denormalize=[True] * image.shape[0]) - return np.asarray(image[0]) + def prepare_mixing(self, mixing_coeffs, list_latents_mixing): + if type(mixing_coeffs) == float: + list_mixing_coeffs = (1 + self.num_inference_steps) * [mixing_coeffs] + elif type(mixing_coeffs) == list: + assert len(mixing_coeffs) == self.num_inference_steps, f"len(mixing_coeffs) {len(mixing_coeffs)} != self.num_inference_steps {self.num_inference_steps}" + list_mixing_coeffs = mixing_coeffs + else: + raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps") + if np.sum(list_mixing_coeffs) > 0: + assert len(list_latents_mixing) == self.num_inference_steps, f"len(list_latents_mixing) {len(list_latents_mixing)} != self.num_inference_steps {self.num_inference_steps}" + return list_mixing_coeffs + @torch.no_grad() def run_diffusion( self, @@ -175,14 +186,13 @@ class DiffusersHolder(): list_latents_mixing=None, mixing_coeffs=0.0, return_image: Optional[bool] = False): - + if self.pipe.__class__.__name__ == 'StableDiffusionXLPipeline': return self.run_diffusion_sd_xl(text_embeddings, latents_start, idx_start, list_latents_mixing, mixing_coeffs, return_image) elif self.pipe.__class__.__name__ == 'StableDiffusionPipeline': return self.run_diffusion_sd12x(text_embeddings, latents_start, idx_start, list_latents_mixing, mixing_coeffs, return_image) elif self.pipe.__class__.__name__ == 'StableDiffusionControlNetPipeline': pass - @torch.no_grad() def run_diffusion_sd12x( @@ -193,28 +203,19 @@ class DiffusersHolder(): list_latents_mixing=None, mixing_coeffs=0.0, return_image: Optional[bool] = False): - - if type(mixing_coeffs) == float: - list_mixing_coeffs = (1+self.num_inference_steps) * [mixing_coeffs] - elif type(mixing_coeffs) == list: - assert len(mixing_coeffs) == self.num_inference_steps, f"len(mixing_coeffs) {len(mixing_coeffs)} != self.num_inference_steps {self.num_inference_steps}" - list_mixing_coeffs = mixing_coeffs - else: - raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps") - if np.sum(list_mixing_coeffs) > 0: - assert len(list_latents_mixing) == self.num_inference_steps, f"len(list_latents_mixing) {len(list_latents_mixing)} != self.num_inference_steps {self.num_inference_steps}" - + list_mixing_coeffs = self.prepare_mixing() + do_classifier_free_guidance = self.guidance_scale > 1.0 - - # diffusers bit wiggly - self.pipe.scheduler.set_timesteps(self.num_inference_steps-1, device=self.device) + + # accomodate different sd model types + self.pipe.scheduler.set_timesteps(self.num_inference_steps - 1, device=self.device) timesteps = self.pipe.scheduler.timesteps - + if len(timesteps) != self.num_inference_steps: self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device) timesteps = self.pipe.scheduler.timesteps - + latents = latents_start.clone() list_latents_out = [] @@ -229,11 +230,11 @@ class DiffusersHolder(): if i > 0 and list_mixing_coeffs[i] > 0: latents_mixtarget = list_latents_mixing[i - 1].clone() latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i]) - + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t) - + # predict the noise residual noise_pred = self.pipe.unet( latent_model_input, @@ -248,7 +249,7 @@ class DiffusersHolder(): # compute the previous noisy sample x_t -> x_t-1 latents = self.pipe.scheduler.step(noise_pred, t, latents, return_dict=False)[0] list_latents_out.append(latents.clone()) - + if return_image: return self.latent2image(latents) else: @@ -276,17 +277,8 @@ class DiffusersHolder(): do_classifier_free_guidance = self.guidance_scale > 1.0 # 1. Check inputs. Raise error if not correct & 2. Define call parameters - if type(mixing_coeffs) == float: - list_mixing_coeffs = (1+self.num_inference_steps) * [mixing_coeffs] - elif type(mixing_coeffs) == list: - assert len(mixing_coeffs) == self.num_inference_steps, f"len(mixing_coeffs) {len(mixing_coeffs)} != self.num_inference_steps {self.num_inference_steps}" - list_mixing_coeffs = mixing_coeffs - else: - raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps") - - if np.sum(list_mixing_coeffs) > 0: - assert len(list_latents_mixing) == self.num_inference_steps, f"len(list_latents_mixing) {len(list_latents_mixing)} != self.num_inference_steps {self.num_inference_steps}" - + list_mixing_coeffs = self.prepare_mixing() + # 3. Encode input prompt (already encoded outside bc of mixing, just split here) prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = text_embeddings @@ -374,10 +366,11 @@ class DiffusersHolder(): list_latents_mixing=None, mixing_coeffs=0.0, return_image: Optional[bool] = False): - - + prompt_embeds = conditioning[0] - + image = conditioning[1] + list_mixing_coeffs = self.prepare_mixing() + controlnet = self.pipe.controlnet control_guidance_start = [0.0] control_guidance_end = [1.0] @@ -386,17 +379,14 @@ class DiffusersHolder(): batch_size = 1 eta = 0.0 controlnet_conditioning_scale = 1.0 - image = Image.open("/home/lugo/glif/lora_models/pretrained_model_name_or_path/value_runwayml_stable-diffusion-v1-5_fabian/fabian_in_the_desert/img_001.jpg") # align format for control guidance - if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] # 2. Define call parameters - device = self.pipe._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -424,6 +414,7 @@ class DiffusersHolder(): # 6. Prepare latent variables generator = torch.Generator(device=self.device).manual_seed(int(420)) latents = latents_start.clone() + list_latents_out = [] # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.pipe.prepare_extra_step_kwargs(generator, eta) @@ -439,6 +430,17 @@ class DiffusersHolder(): # 8. Denoising loop for i, t in enumerate(timesteps): + if i < idx_start: + list_latents_out.append(None) + continue + elif i == idx_start: + latents = latents_start.clone() + + # Mix latents for crossfeeding + if i > 0 and list_mixing_coeffs[i] > 0: + latents_mixtarget = list_latents_mixing[i - 1].clone() + latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i]) + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t) @@ -487,10 +489,13 @@ class DiffusersHolder(): # compute the previous noisy sample x_t -> x_t-1 latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - image = self.pipe.vae.decode(latents / self.pipe.vae.config.scaling_factor, return_dict=False)[0] - image, has_nsfw_concept = self.pipe.run_safety_checker(image, device, prompt_embeds.dtype) - image = self.pipe.image_processor.postprocess(image, output_type="pil") - return image + # Append latents + list_latents_out.append(latents.clone()) + + if return_image: + return self.latent2image(latents) + else: + return list_latents_out #%%