diff --git a/diffusers_holder.py b/diffusers_holder.py index 59a2824..bb8f615 100644 --- a/diffusers_holder.py +++ b/diffusers_holder.py @@ -17,7 +17,7 @@ import torch import numpy as np import warnings -from typing import Optional +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from utils import interpolate_spherical from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel from diffusers.models.attention_processor import ( @@ -26,6 +26,7 @@ from diffusers.models.attention_processor import ( LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import retrieve_timesteps warnings.filterwarnings('ignore') torch.backends.cudnn.benchmark = False torch.set_grad_enabled(False) @@ -45,23 +46,26 @@ class DiffusersHolder(): self.width_latent = self.pipe.unet.config.sample_size self.height_latent = self.pipe.unet.config.sample_size + self.width_img = self.width_latent * self.pipe.vae_scale_factor + self.height_img = self.height_latent * self.pipe.vae_scale_factor + def init_types(self): assert hasattr(self.pipe, "__class__"), "No valid diffusers pipeline found." assert hasattr(self.pipe.__class__, "__name__"), "No valid diffusers pipeline found." if self.pipe.__class__.__name__ == 'StableDiffusionXLPipeline': self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device) - self.use_sd_xl = True prompt_embeds, _, _, _ = self.pipe.encode_prompt("test") else: - self.use_sd_xl = False prompt_embeds = self.pipe._encode_prompt("test", self.device, 1, True) self.dtype = prompt_embeds.dtype + + self.is_sdxl_turbo = 'turbo' in self.pipe._name_or_path + def set_num_inference_steps(self, num_inference_steps): self.num_inference_steps = num_inference_steps - if self.use_sd_xl: - self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device) + self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device) def set_dimensions(self, size_output): s = self.pipe.vae_scale_factor @@ -87,74 +91,72 @@ class DiffusersHolder(): if len(self.negative_prompt) > 1: self.negative_prompt = [self.negative_prompt[0]] - def get_text_embedding(self, prompt, do_classifier_free_guidance=True): - if self.use_sd_xl: - pr_encoder = self.pipe.encode_prompt - else: - pr_encoder = self.pipe._encode_prompt - - prompt_embeds = pr_encoder( + def get_text_embedding(self, prompt): + do_classifier_free_guidance = self.guidance_scale > 1 and self.pipe.unet.config.time_cond_proj_dim is None + text_embeddings = self.pipe.encode_prompt( prompt=prompt, - device=self.device, + prompt_2=prompt, + device=self.pipe._execution_device, num_images_per_prompt=1, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=self.negative_prompt, + negative_prompt_2=self.negative_prompt, prompt_embeds=None, negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, lora_scale=None, + clip_skip=None,#self.pipe._clip_skip, ) - return prompt_embeds + return text_embeddings def get_noise(self, seed=420): - H = self.height_latent - W = self.width_latent - C = self.pipe.unet.config.in_channels - generator = torch.Generator(device=self.device).manual_seed(int(seed)) - latents = torch.randn((1, C, H, W), generator=generator, dtype=self.dtype, device=self.device) - if self.use_sd_xl: - latents = latents * self.pipe.scheduler.init_noise_sigma + + latents = self.pipe.prepare_latents( + 1, + self.pipe.unet.config.in_channels, + self.height_img, + self.width_img, + torch.float16, + self.pipe._execution_device, + torch.Generator(device=self.device).manual_seed(int(seed)), + None, + ) + return latents + @torch.no_grad() def latent2image( self, latents: torch.FloatTensor, - convert_numpy=True): + output_type="pil"): r""" Returns an image provided a latent representation from diffusion. Args: latents: torch.FloatTensor Result of the diffusion process. - convert_numpy: if converting to numpy + output_type: "pil" or "np" """ - if self.use_sd_xl: - # make sure the VAE is in float32 mode, as it overflows in float16 - self.pipe.vae.to(dtype=torch.float32) - - use_torch_2_0_or_xformers = isinstance( - self.pipe.vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - LoRAXFormersAttnProcessor, - LoRAAttnProcessor2_0, - ), - ) - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - self.pipe.vae.post_quant_conv.to(latents.dtype) - self.pipe.vae.decoder.conv_in.to(latents.dtype) - self.pipe.vae.decoder.mid_block.to(latents.dtype) - else: - latents = latents.float() - + assert output_type in ["pil", "np"] + + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.pipe.vae.dtype == torch.float16 and self.pipe.vae.config.force_upcast + + if needs_upcasting: + self.pipe.upcast_vae() + latents = latents.to(next(iter(self.pipe.vae.post_quant_conv.parameters())).dtype) + image = self.pipe.vae.decode(latents / self.pipe.vae.config.scaling_factor, return_dict=False)[0] - image = self.pipe.image_processor.postprocess(image, output_type="pil", do_denormalize=[True] * image.shape[0])[0] - if convert_numpy: - return np.asarray(image) - else: - return image + + # cast back to fp16 if needed + if needs_upcasting: + self.pipe.vae.to(dtype=torch.float16) + + image = self.pipe.image_processor.postprocess(image, output_type=output_type)[0] + + return image + def prepare_mixing(self, mixing_coeffs, list_latents_mixing): if type(mixing_coeffs) == float: @@ -178,111 +180,94 @@ class DiffusersHolder(): mixing_coeffs=0.0, return_image: Optional[bool] = False): - if self.pipe.__class__.__name__ == 'StableDiffusionXLPipeline': - return self.run_diffusion_sd_xl(text_embeddings, latents_start, idx_start, list_latents_mixing, mixing_coeffs, return_image) - elif self.pipe.__class__.__name__ == 'StableDiffusionPipeline': - return self.run_diffusion_sd12x(text_embeddings, latents_start, idx_start, list_latents_mixing, mixing_coeffs, return_image) - elif self.pipe.__class__.__name__ == 'StableDiffusionControlNetPipeline': - pass + return self.run_diffusion_sd_xl(text_embeddings, latents_start, idx_start, list_latents_mixing, mixing_coeffs, return_image) - @torch.no_grad() - def run_diffusion_sd12x( - self, - text_embeddings: torch.FloatTensor, - latents_start: torch.FloatTensor, - idx_start: int = 0, - list_latents_mixing=None, - mixing_coeffs=0.0, - return_image: Optional[bool] = False): - list_mixing_coeffs = self.prepare_mixing() - - do_classifier_free_guidance = self.guidance_scale > 1.0 - - # accomodate different sd model types - self.pipe.scheduler.set_timesteps(self.num_inference_steps - 1, device=self.device) - timesteps = self.pipe.scheduler.timesteps - - if len(timesteps) != self.num_inference_steps: - self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device) - timesteps = self.pipe.scheduler.timesteps - - latents = latents_start.clone() - list_latents_out = [] - - for i, t in enumerate(timesteps): - # Set the right starting latents - if i < idx_start: - list_latents_out.append(None) - continue - elif i == idx_start: - latents = latents_start.clone() - # Mix latents - if i > 0 and list_mixing_coeffs[i] > 0: - latents_mixtarget = list_latents_mixing[i - 1].clone() - latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i]) - - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.pipe.unet( - latent_model_input, - t, - encoder_hidden_states=text_embeddings, - return_dict=False, - )[0] - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.pipe.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - list_latents_out.append(latents.clone()) - - if return_image: - return self.latent2image(latents) - else: - return list_latents_out @torch.no_grad() def run_diffusion_sd_xl( - self, - text_embeddings: list, - latents_start: torch.FloatTensor, - idx_start: int = 0, - list_latents_mixing=None, - mixing_coeffs=0.0, - return_image: Optional[bool] = False): + self, + text_embeddings: tuple, + latents_start: torch.FloatTensor, + idx_start: int = 0, + list_latents_mixing=None, + mixing_coeffs=0.0, + return_image: Optional[bool] = False, + ): + + + prompt_2 = None + height = None + width = None + timesteps = None + denoising_end = None + negative_prompt_2 = None + num_images_per_prompt = 1 + eta = 0.0 + generator = None + latents = None + prompt_embeds = None + negative_prompt_embeds = None + pooled_prompt_embeds = None + negative_pooled_prompt_embeds = None + ip_adapter_image = None + output_type = "pil" + return_dict = True + cross_attention_kwargs = None + guidance_rescale = 0.0 + original_size = None + crops_coords_top_left = (0, 0) + target_size = None + negative_original_size = None + negative_crops_coords_top_left = (0, 0) + negative_target_size = None + clip_skip = None + callback = None + callback_on_step_end = None + callback_on_step_end_tensor_inputs = ["latents"] + # kwargs are additional keyword arguments and don't need a default value set here. # 0. Default height and width to unet - original_size = (self.width_img, self.height_img) - crops_coords_top_left = (0, 0) - target_size = original_size - batch_size = 1 - eta = 0.0 - num_images_per_prompt = 1 - cross_attention_kwargs = None - generator = torch.Generator(device=self.device) # dummy generator - do_classifier_free_guidance = self.guidance_scale > 1.0 + height = height or self.pipe.default_sample_size * self.pipe.vae_scale_factor + width = width or self.pipe.default_sample_size * self.pipe.vae_scale_factor - # 1. Check inputs. Raise error if not correct & 2. Define call parameters + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. skipped. + + self.pipe._guidance_scale = self.guidance_scale + self.pipe._guidance_rescale = guidance_rescale + self.pipe._clip_skip = clip_skip + self.pipe._cross_attention_kwargs = cross_attention_kwargs + self.pipe._denoising_end = denoising_end + self.pipe._interrupt = False + + # 2. Define call parameters list_mixing_coeffs = self.prepare_mixing(mixing_coeffs, list_latents_mixing) + batch_size = 1 - # 3. Encode input prompt (already encoded outside bc of mixing, just split here) - prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = text_embeddings + device = self.pipe._execution_device + + # 3. Encode input prompt + lora_scale = None + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = text_embeddings # 4. Prepare timesteps - self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device) - timesteps = self.pipe.scheduler.timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.pipe.scheduler, self.num_inference_steps, device, timesteps) # 5. Prepare latent variables + num_channels_latents = self.pipe.unet.config.in_channels latents = latents_start.clone() list_latents_out = [] - # 6. Prepare extra step kwargs. usedummy generator - extra_step_kwargs = self.pipe.prepare_extra_step_kwargs(generator, eta) # dummy + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.pipe.prepare_extra_step_kwargs(generator, eta) # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds @@ -298,225 +283,207 @@ class DiffusersHolder(): dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self.pipe._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids - negative_add_time_ids = add_time_ids + if self.pipe.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) - add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.to(self.device) - add_text_embeds = add_text_embeds.to(self.device) - add_time_ids = add_time_ids.to(self.device).repeat(batch_size * num_images_per_prompt, 1) + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.pipe.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.pipe.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) + if self.pipe.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + image_embeds = image_embeds.to(device) # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.pipe.scheduler.order, 0) + + # 9. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.pipe.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.pipe.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.pipe.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.pipe.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + self.pipe._num_timesteps = len(timesteps) for i, t in enumerate(timesteps): # Set the right starting latents + # Write latents out and skip if i < idx_start: list_latents_out.append(None) continue elif i == idx_start: latents = latents_start.clone() - + # Mix latents for crossfeeding if i > 0 and list_mixing_coeffs[i] > 0: latents_mixtarget = list_latents_mixing[i - 1].clone() latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i]) + # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - # Always scale latents + latent_model_input = torch.cat([latents] * 2) if self.pipe.do_classifier_free_guidance else latents + latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None: + added_cond_kwargs["image_embeds"] = image_embeds noise_pred = self.pipe.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.pipe.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance - if do_classifier_free_guidance: + if self.pipe.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - # Append latents - list_latents_out.append(latents.clone()) - - if return_image: - return self.latent2image(latents) - else: - return list_latents_out - - @torch.no_grad() - def run_diffusion_controlnet( - self, - conditioning: list, - latents_start: torch.FloatTensor, - idx_start: int = 0, - list_latents_mixing=None, - mixing_coeffs=0.0, - return_image: Optional[bool] = False): - - prompt_embeds = conditioning[0] - image = conditioning[1] - list_mixing_coeffs = self.prepare_mixing() - - controlnet = self.pipe.controlnet - control_guidance_start = [0.0] - control_guidance_end = [1.0] - guess_mode = False - num_images_per_prompt = 1 - batch_size = 1 - eta = 0.0 - controlnet_conditioning_scale = 1.0 - - # align format for control guidance - if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): - control_guidance_start = len(control_guidance_end) * [control_guidance_start] - elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): - control_guidance_end = len(control_guidance_start) * [control_guidance_end] - - # 2. Define call parameters - device = self.pipe._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = self.guidance_scale > 1.0 - - # 4. Prepare image - image = self.pipe.prepare_image( - image=image, - width=None, - height=None, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=self.device, - dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - guess_mode=guess_mode, - ) - height, width = image.shape[-2:] - - # 5. Prepare timesteps - self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device) - timesteps = self.pipe.scheduler.timesteps - - # 6. Prepare latent variables - generator = torch.Generator(device=self.device).manual_seed(int(420)) - latents = latents_start.clone() - list_latents_out = [] - - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.pipe.prepare_extra_step_kwargs(generator, eta) - - # 7.1 Create tensor stating which controlnets to keep - controlnet_keep = [] - for i in range(len(timesteps)): - keeps = [ - 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) - for s, e in zip(control_guidance_start, control_guidance_end) - ] - controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) - - # 8. Denoising loop - for i, t in enumerate(timesteps): - if i < idx_start: - list_latents_out.append(None) - continue - elif i == idx_start: - latents = latents_start.clone() - - # Mix latents for crossfeeding - if i > 0 and list_mixing_coeffs[i] > 0: - latents_mixtarget = list_latents_mixing[i - 1].clone() - latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i]) - - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t) - - control_model_input = latent_model_input - controlnet_prompt_embeds = prompt_embeds - - if isinstance(controlnet_keep[i], list): - cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] - else: - cond_scale = controlnet_conditioning_scale * controlnet_keep[i] - - down_block_res_samples, mid_block_res_sample = self.pipe.controlnet( - control_model_input, - t, - encoder_hidden_states=controlnet_prompt_embeds, - controlnet_cond=image, - conditioning_scale=cond_scale, - guess_mode=guess_mode, - return_dict=False, - ) - - if guess_mode and do_classifier_free_guidance: - # Infered ControlNet only for the conditional batch. - # To apply the output of ControlNet to both the unconditional and conditional batches, - # add 0 to the unconditional batch to keep it unchanged. - down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] - mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) - - # predict the noise residual - noise_pred = self.pipe.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=None, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - return_dict=False, - )[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self.pipe.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.pipe.do_classifier_free_guidance and self.pipe.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.pipe.guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # Append latents list_latents_out.append(latents.clone()) + + if return_image: return self.latent2image(latents) else: return list_latents_out + #%% if __name__ == "__main__": from PIL import Image - #%% - pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0" - pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16) - pipe.to('cuda') # xxx - - #%% + from diffusers import AutoencoderTiny + # pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0" + pretrained_model_name_or_path = "stabilityai/sdxl-turbo" + pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16, variant="fp16") + pipe.to("cuda") + #% + # pipe.vae = AutoencoderTiny.from_pretrained('madebyollin/taesdxl', torch_device='cuda', torch_dtype=torch.float16) + # pipe.vae = pipe.vae.cuda() + #%% resanity + import time self = DiffusersHolder(pipe) + prompt1 = "photo of underwater landscape, fish, und the sea, incredible detail, high resolution" + negative_prompt = "blurry, ugly, pale" + num_inference_steps = 4 + guidance_scale = 0 + + self.set_num_inference_steps(num_inference_steps) + self.guidance_scale = guidance_scale + + prefix='turbo' + for i in range(10): + self.set_negative_prompt(negative_prompt) + + text_embeddings = self.get_text_embedding(prompt1) + latents_start = self.get_noise(np.random.randint(111111)) + + t0 = time.time() + + # img_refx = self.pipe(prompt=prompt1, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale)[0] + + img_refx = self.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=False) + + dt_ref = time.time() - t0 + img_refx.save(f"x_{prefix}_{i}.jpg") + + + + + # xxx - self.set_dimensions((1024, 704)) - self.set_num_inference_steps(40) - # self.set_dimensions(1536, 1024) - prompt = "Surreal painting of eerie, nebulous glow of an indigo moon, a spine-chilling spectacle unfolds; a baroque, marbled hand reaches out from a viscous, purple lake clutching a melting clock, its face distorted in a never-ending scream of hysteria, while a cluster of laughing orchids, their petals morphed into grotesque human lips, festoon a crimson tree weeping blood instead of sap, a psychedelic cat with an unnaturally playful grin and mismatched eyes lounges atop a floating vintage television showing static, an albino peacock with iridescent, crystalline feathers dances around a towering, inverted pyramid on top of which a humanoid figure with an octopus head lounges seductively, all against the backdrop of a sprawling cityscape where buildings are inverted and writhing as if alive, and the sky is punctuated by floating aquatic creatures glowing neon, adding a touch of haunting beauty to this otherwise deeply unsettling tableau" - text_embeddings = self.get_text_embedding(prompt) - generator = torch.Generator(device=self.device).manual_seed(int(420)) - latents_start = self.get_noise() - list_latents_1 = self.run_diffusion(text_embeddings, latents_start) - img_orig = self.latent2image(list_latents_1[-1]) + + # self.set_negative_prompt(negative_prompt) + # self.set_num_inference_steps(num_inference_steps) + # text_embeddings1 = self.get_text_embedding(prompt1) + # prompt_embeds1, negative_prompt_embeds1, pooled_prompt_embeds1, negative_pooled_prompt_embeds1 = text_embeddings1 + # latents_start = self.get_noise(420) + # t0 = time.time() + # img_dh = self.run_diffusion_sd_xl_resanity(text_embeddings1, latents_start, idx_start=0, return_image=True) + # dt_dh = time.time() - t0 - + + # xxxx + # #%% + + # self = DiffusersHolder(pipe) + # num_inference_steps = 4 + # self.set_num_inference_steps(num_inference_steps) + # latents_start = self.get_noise(420) + # guidance_scale = 0 + # self.guidance_scale = 0 + + # #% get embeddings1 + # prompt1 = "Photo of a colorful landscape with a blue sky with clouds" + # text_embeddings1 = self.get_text_embedding(prompt1) + # prompt_embeds1, negative_prompt_embeds1, pooled_prompt_embeds1, negative_pooled_prompt_embeds1 = text_embeddings1 + + # #% get embeddings2 + # prompt2 = "Photo of a tree" + # text_embeddings2 = self.get_text_embedding(prompt2) + # prompt_embeds2, negative_prompt_embeds2, pooled_prompt_embeds2, negative_pooled_prompt_embeds2 = text_embeddings2 + + # latents1 = self.run_diffusion_sd_xl(text_embeddings1, latents_start, idx_start=0, return_image=False) + + # img1 = self.run_diffusion_sd_xl(text_embeddings1, latents_start, idx_start=0, return_image=True) + # img1B = self.run_diffusion_sd_xl(text_embeddings1, latents_start, idx_start=0, return_image=True) + + + + # # latents2 = self.run_diffusion_sd_xl(text_embeddings2, latents_start, idx_start=0, return_image=False) + + + # # # check if brings same image if restarted + # # img1_return = self.run_diffusion_sd_xl(text_embeddings1, latents1[idx_mix-1], idx_start=idx_start, return_image=True) + + # # mix latents + # #%% + # idx_mix = 2 + # fract=0.8 + # latents_start_mixed = interpolate_spherical(latents1[idx_mix-1], latents2[idx_mix-1], fract) + # prompt_embeds = interpolate_spherical(prompt_embeds1, prompt_embeds2, fract) + # pooled_prompt_embeds = interpolate_spherical(pooled_prompt_embeds1, pooled_prompt_embeds2, fract) + # negative_prompt_embeds = negative_prompt_embeds1 + # negative_pooled_prompt_embeds = negative_pooled_prompt_embeds1 + # text_embeddings_mix = [prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds] + + # self.run_diffusion_sd_xl(text_embeddings_mix, latents_start_mixed, idx_start=idx_start, return_image=True) + + + + diff --git a/example1_standard.py b/example1_standard.py index 125ce61..db82235 100644 --- a/example1_standard.py +++ b/example1_standard.py @@ -17,41 +17,25 @@ import torch import warnings from latent_blending import LatentBlending from diffusers_holder import DiffusersHolder -from diffusers import DiffusionPipeline +from diffusers import AutoPipelineForText2Image + warnings.filterwarnings('ignore') torch.set_grad_enabled(False) torch.backends.cudnn.benchmark = False # %% First let us spawn a stable diffusion holder. Uncomment your version of choice. -pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0" -pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16) -pipe.to('cuda') +pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16") +pipe.to("cuda") + dh = DiffusersHolder(pipe) -# %% Next let's set up all parameters -depth_strength = 0.55 # Specifies how deep (in terms of diffusion iterations the first branching happens) -t_compute_max_allowed = 60 # Determines the quality of the transition in terms of compute time you grant it -num_inference_steps = 30 -size_output = (1024, 1024) -prompt1 = "underwater landscape, fish, und the sea, incredible detail, high resolution" -prompt2 = "rendering of an alien planet, strange plants, strange creatures, surreal" -negative_prompt = "blurry, ugly, pale" # Optional - -fp_movie = 'movie_example1.mp4' -duration_transition = 12 # In seconds - -# Spawn latent blending lb = LatentBlending(dh) -lb.set_prompt1(prompt1) -lb.set_prompt2(prompt2) -lb.set_dimensions(size_output) -lb.set_negative_prompt(negative_prompt) +lb.set_prompt1("photo of underwater landscape, fish, und the sea, incredible detail, high resolution") +lb.set_prompt2("rendering of an alien planet, strange plants, strange creatures, surreal") +lb.set_negative_prompt("blurry, ugly, pale") # Run latent blending -lb.run_transition( - depth_strength=depth_strength, - num_inference_steps=num_inference_steps, - t_compute_max_allowed=t_compute_max_allowed) +lb.run_transition() # Save movie -lb.write_movie_transition(fp_movie, duration_transition) +lb.write_movie_transition('movie_example1.mp4', duration_transition=12) diff --git a/example2_multitrans.py b/example2_multitrans.py index cc5c65b..320923f 100644 --- a/example2_multitrans.py +++ b/example2_multitrans.py @@ -17,24 +17,20 @@ import torch import warnings from latent_blending import LatentBlending from diffusers_holder import DiffusersHolder -from diffusers import DiffusionPipeline +from diffusers import AutoPipelineForText2Image from movie_util import concatenate_movies torch.set_grad_enabled(False) torch.backends.cudnn.benchmark = False warnings.filterwarnings('ignore') # %% First let us spawn a stable diffusion holder. Uncomment your version of choice. -pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0" -pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16) +pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16") pipe.to('cuda') dh = DiffusersHolder(pipe) # %% Let's setup the multi transition fps = 30 -duration_single_trans = 20 -depth_strength = 0.25 # Specifies how deep (in terms of diffusion iterations the first branching happens) -size_output = (1280, 768) -num_inference_steps = 30 +duration_single_trans = 10 # Specify a list of prompts below list_prompts = [] @@ -45,12 +41,8 @@ list_prompts.append("photo of a house, high detail") # You can optionally specify the seeds list_seeds = [95437579, 33259350, 956051013] -t_compute_max_allowed = 20 # per segment fp_movie = 'movie_example2.mp4' lb = LatentBlending(dh) -lb.set_dimensions(size_output) -lb.dh.set_num_inference_steps(num_inference_steps) - list_movie_parts = [] for i in range(len(list_prompts) - 1): @@ -69,8 +61,6 @@ for i in range(len(list_prompts) - 1): # Run latent blending lb.run_transition( recycle_img1=recycle_img1, - depth_strength=depth_strength, - t_compute_max_allowed=t_compute_max_allowed, fixed_seeds=fixed_seeds) # Save movie diff --git a/latent_blending.py b/latent_blending.py index 21846e1..6a58e37 100644 --- a/latent_blending.py +++ b/latent_blending.py @@ -33,18 +33,11 @@ class LatentBlending(): def __init__( self, dh: None, - guidance_scale: float = 4, guidance_scale_mid_damper: float = 0.5, mid_compression_scaler: float = 1.2): r""" Initializes the latent blending class. Args: - guidance_scale: float - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. guidance_scale_mid_damper: float = 0.5 Reduces the guidance scale towards the middle of the transition. A value of 0.5 would decrease the guidance_scale towards the middle linearly by 0.5. @@ -76,36 +69,48 @@ class LatentBlending(): self.tree_status = None self.tree_final_imgs = [] - self.list_nmb_branches_prev = [] - self.list_injection_idx_prev = [] self.text_embedding1 = None self.text_embedding2 = None self.image1_lowres = None self.image2_lowres = None self.negative_prompt = None - self.num_inference_steps = self.dh.num_inference_steps - self.noise_level_upscaling = 20 - self.list_injection_idx = None - self.list_nmb_branches = None - # Mixing parameters - self.branch1_crossfeed_power = 0.3 - self.branch1_crossfeed_range = 0.3 - self.branch1_crossfeed_decay = 0.99 - - self.parental_crossfeed_power = 0.3 - self.parental_crossfeed_range = 0.6 - self.parental_crossfeed_power_decay = 0.9 - - self.set_guidance_scale(guidance_scale) + self.set_guidance_scale() self.multi_transition_img_first = None self.multi_transition_img_last = None - self.dt_per_diff = 0 - self.spatial_mask = None + self.dt_unet_step = 0 self.lpips = lpips.LPIPS(net='alex').cuda(self.device) self.set_prompt1("") self.set_prompt2("") + + self.set_branch1_crossfeed() + self.set_parental_crossfeed() + + self.set_num_inference_steps() + self.benchmark_speed() + self.set_branching() + + + + def benchmark_speed(self): + """ + Measures the time per diffusion step and for the vae decoding + """ + + text_embeddings = self.dh.get_text_embedding("test") + latents_start = self.dh.get_noise(np.random.randint(111111)) + # warmup + list_latents = self.dh.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=False, idx_start=self.num_inference_steps-1) + # bench unet + t0 = time.time() + list_latents = self.dh.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=False, idx_start=self.num_inference_steps-1) + self.dt_unet_step = time.time() - t0 + + # bench vae + t0 = time.time() + img = self.dh.latent2image(list_latents[-1]) + self.dt_vae = time.time() - t0 def set_dimensions(self, size_output=None): r""" @@ -115,12 +120,23 @@ class LatentBlending(): width x height Note: the size will get automatically adjusted to be divisable by 32. """ + if size_output is None: + if self.dh.is_sdxl_turbo: + size_output = (512, 512) + else: + size_output = (1024, 1024) self.dh.set_dimensions(size_output) - def set_guidance_scale(self, guidance_scale): + def set_guidance_scale(self, guidance_scale=None): r""" sets the guidance scale. """ + if guidance_scale is None: + if self.dh.is_sdxl_turbo: + guidance_scale = 0.0 + else: + guidance_scale = 4.0 + self.guidance_scale_base = guidance_scale self.guidance_scale = guidance_scale self.dh.guidance_scale = guidance_scale @@ -142,7 +158,7 @@ class LatentBlending(): self.guidance_scale = guidance_scale_effective self.dh.guidance_scale = guidance_scale_effective - def set_branch1_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay): + def set_branch1_crossfeed(self, crossfeed_power=0, crossfeed_range=0, crossfeed_decay=0): r""" Sets the crossfeed parameters for the first branch to the last branch. Args: @@ -157,7 +173,7 @@ class LatentBlending(): self.branch1_crossfeed_range = np.clip(crossfeed_range, 0, 1) self.branch1_crossfeed_decay = np.clip(crossfeed_decay, 0, 1) - def set_parental_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay): + def set_parental_crossfeed(self, crossfeed_power=None, crossfeed_range=None, crossfeed_decay=None): r""" Sets the crossfeed parameters for all transition images (within the first and last branch). Args: @@ -168,9 +184,22 @@ class LatentBlending(): crossfeed_decay: float [0,1] Sets decay for branch1_crossfeed_power. Lower values make the decay stronger across the range. """ + + if self.dh.is_sdxl_turbo: + if crossfeed_power is None: + crossfeed_power = 1.0 + if crossfeed_range is None: + crossfeed_range = 1.0 + if crossfeed_decay is None: + crossfeed_decay = 1.0 + else: + crossfeed_power = 0.3 + crossfeed_range = 0.6 + crossfeed_decay = 0.9 + self.parental_crossfeed_power = np.clip(crossfeed_power, 0, 1) self.parental_crossfeed_range = np.clip(crossfeed_range, 0, 1) - self.parental_crossfeed_power_decay = np.clip(crossfeed_decay, 0, 1) + self.parental_crossfeed_decay = np.clip(crossfeed_decay, 0, 1) def set_prompt1(self, prompt: str): r""" @@ -209,26 +238,21 @@ class LatentBlending(): image: Image """ self.image2_lowres = image - - def run_transition( - self, - recycle_img1: Optional[bool] = False, - recycle_img2: Optional[bool] = False, - num_inference_steps: Optional[int] = 30, - depth_strength: Optional[float] = 0.3, - t_compute_max_allowed: Optional[float] = None, - nmb_max_branches: Optional[int] = None, - fixed_seeds: Optional[List[int]] = None): - r""" - Function for computing transitions. - Returns a list of transition images using spherical latent blending. - Args: - recycle_img1: Optional[bool]: - Don't recompute the latents for the first keyframe (purely prompt1). Saves compute. - recycle_img2: Optional[bool]: - Don't recompute the latents for the second keyframe (purely prompt2). Saves compute. - num_inference_steps: - Number of diffusion steps. Higher values will take more compute time. + + def set_num_inference_steps(self, num_inference_steps=None): + if self.dh.is_sdxl_turbo: + if num_inference_steps is None: + num_inference_steps = 4 + else: + if num_inference_steps is None: + num_inference_steps = 30 + + self.num_inference_steps = num_inference_steps + self.dh.set_num_inference_steps(num_inference_steps) + + def set_branching(self, depth_strength=None, t_compute_max_allowed=None, nmb_max_branches=None): + """ + Sets the branching structure of the blending tree. Default arguments depend on pipe! depth_strength: Determines how deep the first injection will happen. Deeper injections will cause (unwanted) formation of new structures, @@ -240,6 +264,45 @@ class LatentBlending(): Either provide t_compute_max_allowed or nmb_max_branches. The maximum number of branches to be computed. Higher values give better results. Use this if you want to have controllable results independent of your computer. + """ + if self.dh.is_sdxl_turbo: + assert t_compute_max_allowed is None, "time-based branching not supported for SDXL Turbo" + if depth_strength is not None: + idx_inject = int(round(self.num_inference_steps*depth_strength)) + else: + idx_inject = 2 + if nmb_max_branches is None: + nmb_max_branches = 10 + + self.list_idx_injection = [idx_inject] + self.list_nmb_stems = [nmb_max_branches] + + else: + if depth_strength is None: + depth_strength = 0.5 + if t_compute_max_allowed is None and nmb_max_branches is None: + t_compute_max_allowed = 20 + elif t_compute_max_allowed is not None and nmb_max_branches is not None: + raise ValueErorr("Either specify t_compute_max_allowed or nmb_max_branches") + + self.list_idx_injection, self.list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches) + + def run_transition( + self, + recycle_img1: Optional[bool] = False, + recycle_img2: Optional[bool] = False, + fixed_seeds: Optional[List[int]] = None): + r""" + Function for computing transitions. + Returns a list of transition images using spherical latent blending. + Args: + recycle_img1: Optional[bool]: + Don't recompute the latents for the first keyframe (purely prompt1). Saves compute. + recycle_img2: Optional[bool]: + Don't recompute the latents for the second keyframe (purely prompt2). Saves compute. + num_inference_steps: + Number of diffusion steps. Higher values will take more compute time. + fixed_seeds: Optional[List[int)]: You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2). Otherwise random seeds will be taken. @@ -248,6 +311,7 @@ class LatentBlending(): # Sanity checks first assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before' assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before' + # Random seeds if fixed_seeds is not None: @@ -259,10 +323,7 @@ class LatentBlending(): self.seed1 = fixed_seeds[0] self.seed2 = fixed_seeds[1] - # Ensure correct num_inference_steps in holder - self.num_inference_steps = num_inference_steps - self.dh.set_num_inference_steps(num_inference_steps) - + # Compute / Recycle first image if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps: list_latents1 = self.compute_latents1() @@ -280,27 +341,26 @@ class LatentBlending(): self.tree_fracts = [0.0, 1.0] self.tree_final_imgs = [self.dh.latent2image((self.tree_latents[0][-1])), self.dh.latent2image((self.tree_latents[-1][-1]))] self.tree_idx_injection = [0, 0] + self.tree_similarities = [self.get_tree_similarities] - # Hard-fix. Apply spatial mask only for list_latents2 but not for transition. WIP... - self.spatial_mask = None - - # Set up branching scheme (dependent on provided compute time) - list_idx_injection, list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches) # Run iteratively, starting with the longest trajectory. # Always inserting new branches where they are needed most according to image similarity - for s_idx in tqdm(range(len(list_idx_injection))): - nmb_stems = list_nmb_stems[s_idx] - idx_injection = list_idx_injection[s_idx] + for s_idx in tqdm(range(len(self.list_idx_injection))): + nmb_stems = self.list_nmb_stems[s_idx] + idx_injection = self.list_idx_injection[s_idx] for i in range(nmb_stems): fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection) self.set_guidance_mid_dampening(fract_mixing) list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection) self.insert_into_tree(fract_mixing, idx_injection, list_latents) - # print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}") + # print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection} bp1 {b_parent1} bp2 {b_parent2}") return self.tree_final_imgs + + + def compute_latents1(self, return_image=False): r""" @@ -318,7 +378,7 @@ class LatentBlending(): latents_start=latents_start, idx_start=0) t1 = time.time() - self.dt_per_diff = (t1 - t0) / self.num_inference_steps + self.dt_unet_step = (t1 - t0) / self.num_inference_steps self.tree_latents[0] = list_latents1 if return_image: return self.dh.latent2image(list_latents1[-1]) @@ -388,7 +448,7 @@ class LatentBlending(): mixing_coeffs = idx_injection * [self.parental_crossfeed_power] nmb_mixing = idx_mixing_stop - idx_injection if nmb_mixing > 0: - mixing_coeffs.extend(list(np.linspace(self.parental_crossfeed_power, self.parental_crossfeed_power * self.parental_crossfeed_power_decay, nmb_mixing))) + mixing_coeffs.extend(list(np.linspace(self.parental_crossfeed_power, self.parental_crossfeed_power * self.parental_crossfeed_decay, nmb_mixing))) mixing_coeffs.extend((self.num_inference_steps - len(mixing_coeffs)) * [0]) latents_start = list_latents_parental_mix[idx_injection - 1] list_latents = self.run_diffusion( @@ -417,8 +477,10 @@ class LatentBlending(): results. Use this if you want to have controllable results independent of your computer. """ - idx_injection_base = int(round(self.num_inference_steps * depth_strength)) - list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps - 1, 3) + idx_injection_base = int(np.floor(self.num_inference_steps * depth_strength)) + + steps = int(np.ceil(self.num_inference_steps/10)) + list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps, steps) list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32) t_compute = 0 @@ -436,11 +498,11 @@ class LatentBlending(): while not stop_criterion_reached: list_compute_steps = self.num_inference_steps - list_idx_injection list_compute_steps *= list_nmb_stems - t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15 * np.sum(list_nmb_stems) - t_compute += 2 * self.num_inference_steps * self.dt_per_diff # outer branches + t_compute = np.sum(list_compute_steps) * self.dt_unet_step + self.dt_vae * np.sum(list_nmb_stems) + t_compute += 2 * (self.num_inference_steps * self.dt_unet_step + self.dt_vae) # outer branches increase_done = False for s_idx in range(len(list_nmb_stems) - 1): - if list_nmb_stems[s_idx + 1] / list_nmb_stems[s_idx] >= 2: + if list_nmb_stems[s_idx + 1] / list_nmb_stems[s_idx] >= 1: list_nmb_stems[s_idx] += 1 increase_done = True break @@ -471,15 +533,15 @@ class LatentBlending(): the index in terms of diffusion steps, where the next insertion will start. """ # get_lpips_similarity - similarities = [] - for i in range(len(self.tree_final_imgs) - 1): - similarities.append(self.get_lpips_similarity(self.tree_final_imgs[i], self.tree_final_imgs[i + 1])) + similarities = self.tree_similarities + # similarities = self.get_tree_similarities() b_closest1 = np.argmax(similarities) b_closest2 = b_closest1 + 1 fract_closest1 = self.tree_fracts[b_closest1] fract_closest2 = self.tree_fracts[b_closest2] + fract_mixing = (fract_closest1 + fract_closest2) / 2 - # Ensure that the parents are indeed older! + # Ensure that the parents are indeed older b_parent1 = b_closest1 while True: if self.tree_idx_injection[b_parent1] < idx_injection: @@ -492,7 +554,6 @@ class LatentBlending(): break else: b_parent2 += 1 - fract_mixing = (fract_closest1 + fract_closest2) / 2 return fract_mixing, b_parent1, b_parent2 def insert_into_tree(self, fract_mixing, idx_injection, list_latents): @@ -506,11 +567,21 @@ class LatentBlending(): list_latents: list list of the latents to be inserted """ + img_insert = self.dh.latent2image(list_latents[-1]) + b_parent1, b_parent2 = self.get_closest_idx(fract_mixing) - self.tree_latents.insert(b_parent1 + 1, list_latents) - self.tree_final_imgs.insert(b_parent1 + 1, self.dh.latent2image(list_latents[-1])) - self.tree_fracts.insert(b_parent1 + 1, fract_mixing) - self.tree_idx_injection.insert(b_parent1 + 1, idx_injection) + left_sim = self.get_lpips_similarity(img_insert, self.tree_final_imgs[b_parent1]) + right_sim = self.get_lpips_similarity(img_insert, self.tree_final_imgs[b_parent2]) + idx_insert = b_parent1 + 1 + self.tree_latents.insert(idx_insert, list_latents) + self.tree_final_imgs.insert(idx_insert, img_insert) + self.tree_fracts.insert(idx_insert, fract_mixing) + self.tree_idx_injection.insert(idx_insert, idx_injection) + + # update similarities + self.tree_similarities[b_parent1] = left_sim + self.tree_similarities.insert(idx_insert, right_sim) + def get_noise(self, seed): r""" @@ -552,119 +623,29 @@ class LatentBlending(): self.dh.set_num_inference_steps(self.num_inference_steps) assert type(list_conditionings) is list, "list_conditionings need to be a list" - if self.dh.use_sd_xl: - text_embeddings = list_conditionings[0] - return self.dh.run_diffusion_sd_xl( - text_embeddings=text_embeddings, - latents_start=latents_start, - idx_start=idx_start, - list_latents_mixing=list_latents_mixing, - mixing_coeffs=mixing_coeffs, - return_image=return_image) + text_embeddings = list_conditionings[0] + return self.dh.run_diffusion_sd_xl( + text_embeddings=text_embeddings, + latents_start=latents_start, + idx_start=idx_start, + list_latents_mixing=list_latents_mixing, + mixing_coeffs=mixing_coeffs, + return_image=return_image) - else: - text_embeddings = list_conditionings[0] - return self.dh.run_diffusion_standard( - text_embeddings=text_embeddings, - latents_start=latents_start, - idx_start=idx_start, - list_latents_mixing=list_latents_mixing, - mixing_coeffs=mixing_coeffs, - return_image=return_image) - def run_upscaling( - self, - dp_img: str, - depth_strength: float = 0.65, - num_inference_steps: int = 100, - nmb_max_branches_highres: int = 5, - nmb_max_branches_lowres: int = 6, - duration_single_segment=3, - fps=24, - fixed_seeds: Optional[List[int]] = None): - r""" - Runs upscaling with the x4 model. Requires that you run a transition before with a low-res model and save the results using write_imgs_transition. - Args: - dp_img: str - Path to the low-res transition path (as saved in write_imgs_transition) - depth_strength: - Determines how deep the first injection will happen. - Deeper injections will cause (unwanted) formation of new structures, - more shallow values will go into alpha-blendy land. - num_inference_steps: - Number of diffusion steps. Higher values will take more compute time. - nmb_max_branches_highres: int - Number of final branches of the upscaling transition pass. Note this is the number - of branches between each pair of low-res images. - nmb_max_branches_lowres: int - Number of input low-res images, subsampling all transition images written in the low-res pass. - Setting this number lower (e.g. 6) will decrease the compute time but not affect the results too much. - duration_single_segment: float - The duration of each high-res movie segment. You will have nmb_max_branches_lowres-1 segments in total. - fps: float - frames per second of movie - fixed_seeds: Optional[List[int)]: - You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2). - Otherwise random seeds will be taken. - """ - fp_yml = os.path.join(dp_img, "lowres.yaml") - fp_movie = os.path.join(dp_img, "movie_highres.mp4") - ms = MovieSaver(fp_movie, fps=fps) - assert os.path.isfile(fp_yml), "lowres.yaml does not exist. did you forget run_upscaling_step1?" - dict_stuff = yml_load(fp_yml) - - # load lowres images - nmb_images_lowres = dict_stuff['nmb_images'] - prompt1 = dict_stuff['prompt1'] - prompt2 = dict_stuff['prompt2'] - idx_img_lowres = np.round(np.linspace(0, nmb_images_lowres - 1, nmb_max_branches_lowres)).astype(np.int32) - imgs_lowres = [] - for i in idx_img_lowres: - fp_img_lowres = os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg") - assert os.path.isfile(fp_img_lowres), f"{fp_img_lowres} does not exist. did you forget run_upscaling_step1?" - imgs_lowres.append(Image.open(fp_img_lowres)) - - # set up upscaling - text_embeddingA = self.dh.get_text_embedding(prompt1) - text_embeddingB = self.dh.get_text_embedding(prompt2) - list_fract_mixing = np.linspace(0, 1, nmb_max_branches_lowres - 1) - for i in range(nmb_max_branches_lowres - 1): - print(f"Starting movie segment {i+1}/{nmb_max_branches_lowres-1}") - self.text_embedding1 = interpolate_linear(text_embeddingA, text_embeddingB, list_fract_mixing[i]) - self.text_embedding2 = interpolate_linear(text_embeddingA, text_embeddingB, 1 - list_fract_mixing[i]) - if i == 0: - recycle_img1 = False - else: - self.swap_forward() - recycle_img1 = True - - self.set_image1(imgs_lowres[i]) - self.set_image2(imgs_lowres[i + 1]) - - list_imgs = self.run_transition( - recycle_img1=recycle_img1, - recycle_img2=False, - num_inference_steps=num_inference_steps, - depth_strength=depth_strength, - nmb_max_branches=nmb_max_branches_highres) - list_imgs_interp = add_frames_linear_interp(list_imgs, fps, duration_single_segment) - - # Save movie frame - for img in list_imgs_interp: - ms.write_frame(img) - ms.finalize() @torch.no_grad() def get_mixed_conditioning(self, fract_mixing): - if self.dh.use_sd_xl: - text_embeddings_mix = [] - for i in range(len(self.text_embedding1)): - text_embeddings_mix.append(interpolate_linear(self.text_embedding1[i], self.text_embedding2[i], fract_mixing)) - list_conditionings = [text_embeddings_mix] - else: - text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing) - list_conditionings = [text_embeddings_mix] + text_embeddings_mix = [] + for i in range(len(self.text_embedding1)): + if self.text_embedding1[i] is None: + mix = None + else: + mix = interpolate_linear(self.text_embedding1[i], self.text_embedding2[i], fract_mixing) + text_embeddings_mix.append(mix) + list_conditionings = [text_embeddings_mix] + return list_conditionings @torch.no_grad() @@ -733,7 +714,7 @@ class LatentBlending(): 'num_inference_steps', 'depth_strength', 'guidance_scale', 'guidance_scale_mid_damper', 'mid_compression_scaler', 'negative_prompt', 'branch1_crossfeed_power', 'branch1_crossfeed_range', 'branch1_crossfeed_decay' - 'parental_crossfeed_power', 'parental_crossfeed_range', 'parental_crossfeed_power_decay'] + 'parental_crossfeed_power', 'parental_crossfeed_range', 'parental_crossfeed_decay'] for v in grab_vars: if hasattr(self, v): if v == 'seed1' or v == 'seed2': @@ -797,16 +778,22 @@ class LatentBlending(): Used to determine the optimal point of insertion to create smooth transitions. High values indicate low similarity. """ - tensorA = torch.from_numpy(imgA).float().cuda(self.device) + tensorA = torch.from_numpy(np.asarray(imgA)).float().cuda(self.device) tensorA = 2 * tensorA / 255.0 - 1 tensorA = tensorA.permute([2, 0, 1]).unsqueeze(0) - tensorB = torch.from_numpy(imgB).float().cuda(self.device) + tensorB = torch.from_numpy(np.asarray(imgB)).float().cuda(self.device) tensorB = 2 * tensorB / 255.0 - 1 tensorB = tensorB.permute([2, 0, 1]).unsqueeze(0) lploss = self.lpips(tensorA, tensorB) lploss = float(lploss[0][0][0][0]) return lploss + def get_tree_similarities(self): + similarities = [] + for i in range(len(self.tree_final_imgs) - 1): + similarities.append(self.get_lpips_similarity(self.tree_final_imgs[i], self.tree_final_imgs[i + 1])) + return similarities + # Auxiliary functions def get_closest_idx( self, @@ -831,3 +818,46 @@ class LatentBlending(): b_parent1 = tmp return b_parent1, b_parent2 + +#%% +if __name__ == "__main__": + + # %% First let us spawn a stable diffusion holder. Uncomment your version of choice. + from diffusers_holder import DiffusersHolder + from diffusers import DiffusionPipeline + from diffusers import AutoencoderTiny + # pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0" + pretrained_model_name_or_path = "stabilityai/sdxl-turbo" + + + pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16, variant="fp16") + pipe.to("cuda") + pipe.vae = AutoencoderTiny.from_pretrained('madebyollin/taesdxl', torch_device='cuda', torch_dtype=torch.float16) + pipe.vae = pipe.vae.cuda() + + dh = DiffusersHolder(pipe) + # %% Next let's set up all parameters + prompt1 = "photo of underwater landscape, fish, und the sea, incredible detail, high resolution" + prompt2 = "rendering of an alien planet, strange plants, strange creatures, surreal" + negative_prompt = "blurry, ugly, pale" # Optional + + duration_transition = 12 # In seconds + + # Spawn latent blending + lb = LatentBlending(dh) + lb.set_prompt1(prompt1) + lb.set_prompt2(prompt2) + lb.set_negative_prompt(negative_prompt) + + # Run latent blending + t0 = time.time() + lb.run_transition(fixed_seeds=[420, 421]) + dt = time.time() - t0 + + # Save movie + fp_movie = f'test.mp4' + lb.write_movie_transition(fp_movie, duration_transition) + + + + diff --git a/movie_util.py b/movie_util.py index e6e0c6a..eb7e157 100644 --- a/movie_util.py +++ b/movie_util.py @@ -262,7 +262,6 @@ def add_subtitles_to_video( - class MovieReader(): r""" Class to read in a movie. diff --git a/utils.py b/utils.py index d4424ea..d89af4a 100644 --- a/utils.py +++ b/utils.py @@ -24,7 +24,7 @@ import datetime from typing import List, Union torch.set_grad_enabled(False) import yaml - +import PIL @torch.no_grad() def interpolate_spherical(p0, p1, fract_mixing: float): @@ -142,6 +142,8 @@ def add_frames_linear_interp( if nmb_frames_missing < 1: return list_imgs + if type(list_imgs[0]) == PIL.Image.Image: + list_imgs = [np.asarray(l) for l in list_imgs] list_imgs_float = [img.astype(np.float32) for img in list_imgs] # Distribute missing frames, append nmb_frames_to_insert(i) frames for each frame mean_nmb_frames_insert = nmb_frames_missing / nmb_frames_diff