From bc5713241fe9fdbf8097a8f69e5b1ced9598cfe4 Mon Sep 17 00:00:00 2001 From: Johannes Stelzer Date: Thu, 20 Jul 2023 15:45:06 +0200 Subject: [PATCH] controlnet first steps --- diffusers_holder.py | 251 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 200 insertions(+), 51 deletions(-) diff --git a/diffusers_holder.py b/diffusers_holder.py index 5a97642..7158c25 100644 --- a/diffusers_holder.py +++ b/diffusers_holder.py @@ -28,7 +28,7 @@ from typing import Optional from torch import autocast from contextlib import nullcontext from utils import interpolate_spherical -from diffusers import DiffusionPipeline +from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel from diffusers.models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -47,26 +47,24 @@ class DiffusersHolder(): # Check if valid pipe self.pipe = pipe self.device = str(pipe._execution_device) - self.init_type_pipe() - self.init_dtype() + self.init_types() self.width_latent = self.pipe.unet.config.sample_size self.height_latent = self.pipe.unet.config.sample_size - - def init_type_pipe(self): - self.type_pipe = "StableDiffusionXLPipeline" - if self.type_pipe == "StableDiffusionXLPipeline": + def init_types(self): + assert hasattr(self.pipe, "__class__"), "No valid diffusers pipeline found." + assert hasattr(self.pipe.__class__, "__name__"), "No valid diffusers pipeline found." + if self.pipe.__class__.__name__ == 'StableDiffusionXLPipeline': self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device) self.use_sd_xl = True + prompt_embeds, _, _, _ = self.pipe.encode_prompt("test") else: self.use_sd_xl = False + prompt_embeds = self.pipe._encode_prompt("test", self.device, 1, True) + self.dtype = prompt_embeds.dtype - def init_dtype(self): - if self.type_pipe == "StableDiffusionXLPipeline": - prompt_embeds, _, _, _ = self.pipe.encode_prompt("test") - self.dtype = prompt_embeds.dtype def set_num_inference_steps(self, num_inference_steps): self.num_inference_steps = num_inference_steps @@ -102,6 +100,7 @@ class DiffusersHolder(): if len(self.negative_prompt) > 1: self.negative_prompt = [self.negative_prompt[0]] + def get_text_embedding(self, prompt, do_classifier_free_guidance=True): if self.use_sd_xl: pr_encoder = self.pipe.encode_prompt @@ -120,7 +119,6 @@ class DiffusersHolder(): ) return prompt_embeds - def get_noise(self, seed=420, mode=None): H = self.height_latent W = self.width_latent @@ -166,12 +164,28 @@ class DiffusersHolder(): image = self.pipe.vae.decode(latents / self.pipe.vae.config.scaling_factor, return_dict=False)[0] image = self.pipe.image_processor.postprocess(image, output_type="pil", do_denormalize=[True] * image.shape[0]) - - return np.asarray(image[0]) @torch.no_grad() - def run_diffusion_standard( + def run_diffusion( + self, + text_embeddings: torch.FloatTensor, + latents_start: torch.FloatTensor, + idx_start: int = 0, + list_latents_mixing=None, + mixing_coeffs=0.0, + return_image: Optional[bool] = False): + + if self.pipe.__class__.__name__ == 'StableDiffusionXLPipeline': + return self.run_diffusion_sd_xl(text_embeddings, latents_start, idx_start, list_latents_mixing, mixing_coeffs, return_image) + elif self.pipe.__class__.__name__ == 'StableDiffusionPipeline': + return self.run_diffusion_sd12x(text_embeddings, latents_start, idx_start, list_latents_mixing, mixing_coeffs, return_image) + elif self.pipe.__class__.__name__ == 'StableDiffusionControlNetPipeline': + pass + + + @torch.no_grad() + def run_diffusion_sd12x( self, text_embeddings: torch.FloatTensor, latents_start: torch.FloatTensor, @@ -204,7 +218,6 @@ class DiffusersHolder(): latents = latents_start.clone() list_latents_out = [] - num_warmup_steps = len(timesteps) - self.num_inference_steps * self.pipe.scheduler.order for i, t in enumerate(timesteps): # Set the right starting latents if i < idx_start: @@ -251,25 +264,6 @@ class DiffusersHolder(): mixing_coeffs=0.0, return_image: Optional[bool] = False): - # prompt = "photo of a house" - # self.num_inference_steps = 50 - # mixing_coeffs= 0.0 - # idx_start= 0 - # latents_start = self.get_noise() - # text_embeddings = self.pipe.encode_prompt( - # prompt, - # self.device, - # num_images_per_prompt=1, - # do_classifier_free_guidance=True, - # negative_prompt="", - # prompt_embeds=None, - # negative_prompt_embeds=None, - # pooled_prompt_embeds=None, - # negative_pooled_prompt_embeds=None, - # lora_scale=None, - # ) - - # 0. Default height and width to unet original_size = (1024, 1024) # FIXME crops_coords_top_left = (0, 0) # FIXME @@ -282,7 +276,6 @@ class DiffusersHolder(): do_classifier_free_guidance = self.guidance_scale > 1.0 # 1. Check inputs. Raise error if not correct & 2. Define call parameters - # FIXME see if check_inputs use if type(mixing_coeffs) == float: list_mixing_coeffs = (1+self.num_inference_steps) * [mixing_coeffs] elif type(mixing_coeffs) == list: @@ -332,8 +325,6 @@ class DiffusersHolder(): elif i == idx_start: latents = latents_start.clone() - - # Mix latents for crossfeeding if i > 0 and list_mixing_coeffs[i] > 0: latents_mixtarget = list_latents_mixing[i - 1].clone() @@ -374,26 +365,183 @@ class DiffusersHolder(): else: return list_latents_out + @torch.no_grad() + def run_diffusion_controlnet( + self, + conditioning: list, + latents_start: torch.FloatTensor, + idx_start: int = 0, + list_latents_mixing=None, + mixing_coeffs=0.0, + return_image: Optional[bool] = False): + + + prompt_embeds = conditioning[0] + + controlnet = self.pipe.controlnet + control_guidance_start = [0.0] + control_guidance_end = [1.0] + guess_mode = False + num_images_per_prompt = 1 + batch_size = 1 + eta = 0.0 + controlnet_conditioning_scale = 1.0 + image = Image.open("/home/lugo/glif/lora_models/pretrained_model_name_or_path/value_runwayml_stable-diffusion-v1-5_fabian/fabian_in_the_desert/img_001.jpg") + + # align format for control guidance + + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + + # 2. Define call parameters + + device = self.pipe._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = self.guidance_scale > 1.0 + + # 4. Prepare image + image = self.pipe.prepare_image( + image=image, + width=None, + height=None, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=self.device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + + # 5. Prepare timesteps + self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device) + timesteps = self.pipe.scheduler.timesteps + + # 6. Prepare latent variables + generator = torch.Generator(device=self.device).manual_seed(int(420)) + latents = latents_start.clone() + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.pipe.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) + + # 8. Denoising loop + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t) + + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + cond_scale = controlnet_conditioning_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.pipe.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.pipe.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=None, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + image = self.pipe.vae.decode(latents / self.pipe.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.pipe.run_safety_checker(image, device, prompt_embeds.dtype) + image = self.pipe.image_processor.postprocess(image, output_type="pil") + return image + +#%% + +""" +steps: + x get controlnet vanilla running. + - externalize conditions + - have conditions as input (use one list) + - include latent blending + - test latent blending + - have lora and latent blending + +""" #%% if __name__ == "__main__": - pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-0.9" - pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16) - pipe.to('cuda') - # xxx + + + controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-scribble", torch_dtype=torch.float16) + pipe = StableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 + ).to("cuda") + self = DiffusersHolder(pipe) - # xxx - self.set_num_inference_steps(50) - self.set_dimensions(1536, 1024) - prompt = "photo of a beautiful cherry forest covered in white flowers, ambient light, very detailed, magic" - text_embeddings = self.get_text_embedding(prompt) - generator = torch.Generator(device=self.device).manual_seed(int(420)) - latents_start = self.get_noise() - list_latents_1 = self.run_diffusion_sd_xl(text_embeddings, latents_start) - img_orig = self.latent2image(list_latents_1[-1]) + + # get text encoding + + # get image encoding + + + + + #%% + # # pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-0.9" + # pretrained_model_name_or_path = "stabilityai/stable-diffusion-2-1" + # pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16) + # pipe.to('cuda') + # # xxx + # self = DiffusersHolder(pipe) + # # xxx + # self.set_num_inference_steps(50) + # # self.set_dimensions(1536, 1024) + # prompt = "photo of a beautiful cherry forest covered in white flowers, ambient light, very detailed, magic" + # text_embeddings = self.get_text_embedding(prompt) + # generator = torch.Generator(device=self.device).manual_seed(int(420)) + # latents_start = self.get_noise() + # list_latents_1 = self.run_diffusion(text_embeddings, latents_start) + # img_orig = self.latent2image(list_latents_1[-1]) @@ -401,6 +549,7 @@ if __name__ == "__main__": """ OPEN + - rename text encodings to conditionings - other examples - kill upscaling? or keep? - cleanup