diff --git a/diffusers_holder.py b/diffusers_holder.py index 06c90cb..b56d2ef 100644 --- a/diffusers_holder.py +++ b/diffusers_holder.py @@ -56,17 +56,17 @@ class DiffusersHolder(): assert hasattr(self.pipe.__class__, "__name__"), "No valid diffusers pipeline found." if self.pipe.__class__.__name__ == 'StableDiffusionXLPipeline': self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device) - self.use_sd_xl = True prompt_embeds, _, _, _ = self.pipe.encode_prompt("test") else: - self.use_sd_xl = False prompt_embeds = self.pipe._encode_prompt("test", self.device, 1, True) self.dtype = prompt_embeds.dtype + + self.is_sdxl_turbo = 'turbo' in self.pipe._name_or_path + def set_num_inference_steps(self, num_inference_steps): self.num_inference_steps = num_inference_steps - if self.use_sd_xl: - self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device) + self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device) def set_dimensions(self, size_output): s = self.pipe.vae_scale_factor @@ -181,401 +181,8 @@ class DiffusersHolder(): mixing_coeffs=0.0, return_image: Optional[bool] = False): - if self.pipe.__class__.__name__ == 'StableDiffusionXLPipeline': - return self.run_diffusion_sd_xl(text_embeddings, latents_start, idx_start, list_latents_mixing, mixing_coeffs, return_image) - elif self.pipe.__class__.__name__ == 'StableDiffusionPipeline': - return self.run_diffusion_sd12x(text_embeddings, latents_start, idx_start, list_latents_mixing, mixing_coeffs, return_image) - elif self.pipe.__class__.__name__ == 'StableDiffusionControlNetPipeline': - pass + return self.run_diffusion_sd_xl(text_embeddings, latents_start, idx_start, list_latents_mixing, mixing_coeffs, return_image) - @torch.no_grad() - def run_diffusion_sd12x( - self, - text_embeddings: torch.FloatTensor, - latents_start: torch.FloatTensor, - idx_start: int = 0, - list_latents_mixing=None, - mixing_coeffs=0.0, - return_image: Optional[bool] = False): - - list_mixing_coeffs = self.prepare_mixing() - - do_classifier_free_guidance = self.guidance_scale > 1.0 - - # accomodate different sd model types - self.pipe.scheduler.set_timesteps(self.num_inference_steps - 1, device=self.device) - timesteps = self.pipe.scheduler.timesteps - - if len(timesteps) != self.num_inference_steps: - self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device) - timesteps = self.pipe.scheduler.timesteps - - latents = latents_start.clone() - list_latents_out = [] - - for i, t in enumerate(timesteps): - # Set the right starting latents - if i == idx_start: - latents = latents_start.clone() - # Mix latents - if i > 0 and list_mixing_coeffs[i] > 0: - latents_mixtarget = list_latents_mixing[i - 1].clone() - latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i]) - - if i < idx_start: - list_latents_out.append(latents) - - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.pipe.unet( - latent_model_input, - t, - encoder_hidden_states=text_embeddings, - return_dict=False, - )[0] - - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.pipe.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - list_latents_out.append(latents.clone()) - - if return_image: - return self.latent2image(latents) - else: - return list_latents_out - - - @torch.no_grad() - def run_diffusion_controlnet( - self, - conditioning: list, - latents_start: torch.FloatTensor, - idx_start: int = 0, - list_latents_mixing=None, - mixing_coeffs=0.0, - return_image: Optional[bool] = False): - - prompt_embeds = conditioning[0] - image = conditioning[1] - list_mixing_coeffs = self.prepare_mixing() - - controlnet = self.pipe.controlnet - control_guidance_start = [0.0] - control_guidance_end = [1.0] - guess_mode = False - num_images_per_prompt = 1 - batch_size = 1 - eta = 0.0 - controlnet_conditioning_scale = 1.0 - - # align format for control guidance - if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): - control_guidance_start = len(control_guidance_end) * [control_guidance_start] - elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): - control_guidance_end = len(control_guidance_start) * [control_guidance_end] - - # 2. Define call parameters - device = self.pipe._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = self.guidance_scale > 1.0 - - # 4. Prepare image - image = self.pipe.prepare_image( - image=image, - width=None, - height=None, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=self.device, - dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - guess_mode=guess_mode, - ) - height, width = image.shape[-2:] - - # 5. Prepare timesteps - self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device) - timesteps = self.pipe.scheduler.timesteps - - # 6. Prepare latent variables - generator = torch.Generator(device=self.device).manual_seed(int(420)) - latents = latents_start.clone() - list_latents_out = [] - - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.pipe.prepare_extra_step_kwargs(generator, eta) - - # 7.1 Create tensor stating which controlnets to keep - controlnet_keep = [] - for i in range(len(timesteps)): - keeps = [ - 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) - for s, e in zip(control_guidance_start, control_guidance_end) - ] - controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) - - # 8. Denoising loop - for i, t in enumerate(timesteps): - if i < idx_start: - list_latents_out.append(None) - continue - elif i == idx_start: - latents = latents_start.clone() - - # Mix latents for crossfeeding - if i > 0 and list_mixing_coeffs[i] > 0: - latents_mixtarget = list_latents_mixing[i - 1].clone() - latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i]) - - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t) - - control_model_input = latent_model_input - controlnet_prompt_embeds = prompt_embeds - - if isinstance(controlnet_keep[i], list): - cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] - else: - cond_scale = controlnet_conditioning_scale * controlnet_keep[i] - - down_block_res_samples, mid_block_res_sample = self.pipe.controlnet( - control_model_input, - t, - encoder_hidden_states=controlnet_prompt_embeds, - controlnet_cond=image, - conditioning_scale=cond_scale, - guess_mode=guess_mode, - return_dict=False, - ) - - if guess_mode and do_classifier_free_guidance: - # Infered ControlNet only for the conditional batch. - # To apply the output of ControlNet to both the unconditional and conditional batches, - # add 0 to the unconditional batch to keep it unchanged. - down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] - mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) - - # predict the noise residual - noise_pred = self.pipe.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=None, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - return_dict=False, - )[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - # Append latents - list_latents_out.append(latents.clone()) - - if return_image: - return self.latent2image(latents) - else: - return list_latents_out - - - @torch.no_grad() - def run_diffusion_sd_xl_turbo( - self, - text_embeddings: list, - latents_start: torch.FloatTensor, - idx_start: int = 0, - list_latents_mixing=None, - mixing_coeffs=0.0, - return_image: Optional[bool] = False, - seed=420, - **kwargs, - ): - - timesteps = None - denoising_end = None - guidance_scale = 0.0 - negative_prompt = None - negative_prompt_2 = None - num_images_per_prompt = 1 - eta = 0.0 - latents = None - prompt_embeds = None - negative_prompt_embeds = None - pooled_prompt_embeds = None - negative_pooled_prompt_embeds = None - ip_adapter_image = None - return_dict = True - cross_attention_kwargs = None - guidance_rescale = 0.0 - original_size = None - crops_coords_top_left = (0, 0) - target_size = None - negative_original_size = None - negative_crops_coords_top_left = (0, 0) - negative_target_size = None - clip_skip = None - - - # 0. Default height and width to unet - height = self.pipe.default_sample_size * self.pipe.vae_scale_factor - width = self.pipe.default_sample_size * self.pipe.vae_scale_factor - list_mixing_coeffs = self.prepare_mixing(mixing_coeffs, list_latents_mixing) - - original_size = (height, width) - target_size = (height, width) - - # 1. (skipped) Check inputs. Raise error if not correct - self.pipe._guidance_scale = guidance_scale - self.pipe._guidance_rescale = guidance_rescale - self.pipe._clip_skip = clip_skip - self.pipe._cross_attention_kwargs = cross_attention_kwargs - self.pipe._denoising_end = denoising_end - self.pipe._interrupt = False - - # 2. Define call parameters - batch_size = 1 - - device = self.pipe._execution_device - - # 3. Encode input prompt - prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = text_embeddings - - # 4. Prepare timesteps - timesteps, self.num_inference_steps = retrieve_timesteps(self.pipe.scheduler, self.num_inference_steps, device, timesteps) - - # 5. Prepare latent variables - latents = latents_start.clone() - list_latents_out = [] - - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - - extra_step_kwargs = self.pipe.prepare_extra_step_kwargs(torch.Generator(device=self.device).manual_seed(int(0)), eta) - - # 7. Prepare added time ids & embeddings - add_text_embeds = pooled_prompt_embeds - if self.pipe.text_encoder_2 is None: - text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) - else: - text_encoder_projection_dim = self.pipe.text_encoder_2.config.projection_dim - - add_time_ids = self.pipe._get_add_time_ids( - original_size, - crops_coords_top_left, - target_size, - dtype=prompt_embeds.dtype, - text_encoder_projection_dim=text_encoder_projection_dim, - ) - if negative_original_size is not None and negative_target_size is not None: - negative_add_time_ids = self.pipe._get_add_time_ids( - negative_original_size, - negative_crops_coords_top_left, - negative_target_size, - dtype=prompt_embeds.dtype, - text_encoder_projection_dim=text_encoder_projection_dim, - ) - else: - negative_add_time_ids = add_time_ids - - if self.pipe.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) - add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) - - prompt_embeds = prompt_embeds.to(device) - add_text_embeds = add_text_embeds.to(device) - add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) - - if ip_adapter_image is not None: - output_hidden_state = False if isinstance(self.pipe.unet.encoder_hid_proj, ImageProjection) else True - image_embeds, negative_image_embeds = self.pipe.encode_image( - ip_adapter_image, device, num_images_per_prompt, output_hidden_state - ) - if self.pipe.do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) - image_embeds = image_embeds.to(device) - - # 8. Denoising loop - num_warmup_steps = max(len(timesteps) - self.num_inference_steps * self.pipe.scheduler.order, 0) - - - # 9. Optionally get Guidance Scale Embedding - timestep_cond = None - if self.pipe.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.pipe.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) - timestep_cond = self.pipe.get_guidance_scale_embedding( - guidance_scale_tensor, embedding_dim=self.pipe.unet.config.time_cond_proj_dim - ).to(device=device, dtype=latents.dtype) - - self.pipe._num_timesteps = len(timesteps) - - - for i, t in enumerate(timesteps): - # Set the right starting latents - # Write latents out and skip - if i < idx_start: - list_latents_out.append(None) - continue - elif i == idx_start: - latents = latents_start.clone() - - # Mix latents for crossfeeding - if i > 0 and list_mixing_coeffs[i] > 0: - latents_mixtarget = list_latents_mixing[i - 1].clone() - latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i]) - - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.pipe.do_classifier_free_guidance else latents - latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - if ip_adapter_image is not None: - added_cond_kwargs["image_embeds"] = image_embeds - - noise_pred = self.pipe.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - timestep_cond=timestep_cond, - cross_attention_kwargs=self.pipe.cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - return_dict=False, - )[0] - - # perform guidance - if self.pipe.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.pipe.guidance_scale * (noise_pred_text - noise_pred_uncond) - - if self.pipe.do_classifier_free_guidance and self.pipe.guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.pipe.guidance_rescale) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - # Append latents - list_latents_out.append(latents.clone()) - - if return_image: - return self.latent2image(latents) - else: - return list_latents_out - @torch.no_grad() diff --git a/latent_blending.py b/latent_blending.py index 08a71c5..de6d6be 100644 --- a/latent_blending.py +++ b/latent_blending.py @@ -262,7 +262,7 @@ class LatentBlending(): self.seed2 = fixed_seeds[1] # Ensure correct num_inference_steps in holder - if 'turbo' in self.dh.pipe._name_or_path: + if self.dh.is_sdxl_turbo: num_inference_steps = 4 #ideal results self.num_inference_steps = num_inference_steps self.dh.set_num_inference_steps(num_inference_steps) @@ -286,16 +286,14 @@ class LatentBlending(): self.tree_idx_injection = [0, 0] # Set up branching scheme (dependent on provided compute time) - if 'turbo' in self.dh.pipe._name_or_path: + if self.dh.is_sdxl_turbo: self.guidance_scale = 0.0 - self.parental_crossfeed_power = 1.0 self.parental_crossfeed_power_decay = 1.0 self.parental_crossfeed_range = 1.0 list_idx_injection = [2] list_nmb_stems = [10] else: - list_idx_injection, list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches) # Run iteratively, starting with the longest trajectory.