Merge pull request #12 from lunarring/sdxl_turbo

Sdxl turbo
This commit is contained in:
Johannes Stelzer 2024-01-09 17:06:51 +01:00 committed by GitHub
commit 321d083c7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 508 additions and 536 deletions

View File

@ -17,7 +17,7 @@ import torch
import numpy as np import numpy as np
import warnings import warnings
from typing import Optional from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from utils import interpolate_spherical from utils import interpolate_spherical
from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
@ -26,6 +26,7 @@ from diffusers.models.attention_processor import (
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import retrieve_timesteps
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
@ -45,23 +46,26 @@ class DiffusersHolder():
self.width_latent = self.pipe.unet.config.sample_size self.width_latent = self.pipe.unet.config.sample_size
self.height_latent = self.pipe.unet.config.sample_size self.height_latent = self.pipe.unet.config.sample_size
self.width_img = self.width_latent * self.pipe.vae_scale_factor
self.height_img = self.height_latent * self.pipe.vae_scale_factor
def init_types(self): def init_types(self):
assert hasattr(self.pipe, "__class__"), "No valid diffusers pipeline found." assert hasattr(self.pipe, "__class__"), "No valid diffusers pipeline found."
assert hasattr(self.pipe.__class__, "__name__"), "No valid diffusers pipeline found." assert hasattr(self.pipe.__class__, "__name__"), "No valid diffusers pipeline found."
if self.pipe.__class__.__name__ == 'StableDiffusionXLPipeline': if self.pipe.__class__.__name__ == 'StableDiffusionXLPipeline':
self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device) self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device)
self.use_sd_xl = True
prompt_embeds, _, _, _ = self.pipe.encode_prompt("test") prompt_embeds, _, _, _ = self.pipe.encode_prompt("test")
else: else:
self.use_sd_xl = False
prompt_embeds = self.pipe._encode_prompt("test", self.device, 1, True) prompt_embeds = self.pipe._encode_prompt("test", self.device, 1, True)
self.dtype = prompt_embeds.dtype self.dtype = prompt_embeds.dtype
self.is_sdxl_turbo = 'turbo' in self.pipe._name_or_path
def set_num_inference_steps(self, num_inference_steps): def set_num_inference_steps(self, num_inference_steps):
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
if self.use_sd_xl: self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device)
self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device)
def set_dimensions(self, size_output): def set_dimensions(self, size_output):
s = self.pipe.vae_scale_factor s = self.pipe.vae_scale_factor
@ -87,74 +91,72 @@ class DiffusersHolder():
if len(self.negative_prompt) > 1: if len(self.negative_prompt) > 1:
self.negative_prompt = [self.negative_prompt[0]] self.negative_prompt = [self.negative_prompt[0]]
def get_text_embedding(self, prompt, do_classifier_free_guidance=True): def get_text_embedding(self, prompt):
if self.use_sd_xl: do_classifier_free_guidance = self.guidance_scale > 1 and self.pipe.unet.config.time_cond_proj_dim is None
pr_encoder = self.pipe.encode_prompt text_embeddings = self.pipe.encode_prompt(
else:
pr_encoder = self.pipe._encode_prompt
prompt_embeds = pr_encoder(
prompt=prompt, prompt=prompt,
device=self.device, prompt_2=prompt,
device=self.pipe._execution_device,
num_images_per_prompt=1, num_images_per_prompt=1,
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=self.negative_prompt, negative_prompt=self.negative_prompt,
negative_prompt_2=self.negative_prompt,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
lora_scale=None, lora_scale=None,
clip_skip=None,#self.pipe._clip_skip,
) )
return prompt_embeds return text_embeddings
def get_noise(self, seed=420): def get_noise(self, seed=420):
H = self.height_latent
W = self.width_latent latents = self.pipe.prepare_latents(
C = self.pipe.unet.config.in_channels 1,
generator = torch.Generator(device=self.device).manual_seed(int(seed)) self.pipe.unet.config.in_channels,
latents = torch.randn((1, C, H, W), generator=generator, dtype=self.dtype, device=self.device) self.height_img,
if self.use_sd_xl: self.width_img,
latents = latents * self.pipe.scheduler.init_noise_sigma torch.float16,
self.pipe._execution_device,
torch.Generator(device=self.device).manual_seed(int(seed)),
None,
)
return latents return latents
@torch.no_grad() @torch.no_grad()
def latent2image( def latent2image(
self, self,
latents: torch.FloatTensor, latents: torch.FloatTensor,
convert_numpy=True): output_type="pil"):
r""" r"""
Returns an image provided a latent representation from diffusion. Returns an image provided a latent representation from diffusion.
Args: Args:
latents: torch.FloatTensor latents: torch.FloatTensor
Result of the diffusion process. Result of the diffusion process.
convert_numpy: if converting to numpy output_type: "pil" or "np"
""" """
if self.use_sd_xl: assert output_type in ["pil", "np"]
# make sure the VAE is in float32 mode, as it overflows in float16
self.pipe.vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance( # make sure the VAE is in float32 mode, as it overflows in float16
self.pipe.vae.decoder.mid_block.attentions[0].processor, needs_upcasting = self.pipe.vae.dtype == torch.float16 and self.pipe.vae.config.force_upcast
(
AttnProcessor2_0, if needs_upcasting:
XFormersAttnProcessor, self.pipe.upcast_vae()
LoRAXFormersAttnProcessor, latents = latents.to(next(iter(self.pipe.vae.post_quant_conv.parameters())).dtype)
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
self.pipe.vae.post_quant_conv.to(latents.dtype)
self.pipe.vae.decoder.conv_in.to(latents.dtype)
self.pipe.vae.decoder.mid_block.to(latents.dtype)
else:
latents = latents.float()
image = self.pipe.vae.decode(latents / self.pipe.vae.config.scaling_factor, return_dict=False)[0] image = self.pipe.vae.decode(latents / self.pipe.vae.config.scaling_factor, return_dict=False)[0]
image = self.pipe.image_processor.postprocess(image, output_type="pil", do_denormalize=[True] * image.shape[0])[0]
if convert_numpy: # cast back to fp16 if needed
return np.asarray(image) if needs_upcasting:
else: self.pipe.vae.to(dtype=torch.float16)
return image
image = self.pipe.image_processor.postprocess(image, output_type=output_type)[0]
return image
def prepare_mixing(self, mixing_coeffs, list_latents_mixing): def prepare_mixing(self, mixing_coeffs, list_latents_mixing):
if type(mixing_coeffs) == float: if type(mixing_coeffs) == float:
@ -178,111 +180,94 @@ class DiffusersHolder():
mixing_coeffs=0.0, mixing_coeffs=0.0,
return_image: Optional[bool] = False): return_image: Optional[bool] = False):
if self.pipe.__class__.__name__ == 'StableDiffusionXLPipeline': return self.run_diffusion_sd_xl(text_embeddings, latents_start, idx_start, list_latents_mixing, mixing_coeffs, return_image)
return self.run_diffusion_sd_xl(text_embeddings, latents_start, idx_start, list_latents_mixing, mixing_coeffs, return_image)
elif self.pipe.__class__.__name__ == 'StableDiffusionPipeline':
return self.run_diffusion_sd12x(text_embeddings, latents_start, idx_start, list_latents_mixing, mixing_coeffs, return_image)
elif self.pipe.__class__.__name__ == 'StableDiffusionControlNetPipeline':
pass
@torch.no_grad()
def run_diffusion_sd12x(
self,
text_embeddings: torch.FloatTensor,
latents_start: torch.FloatTensor,
idx_start: int = 0,
list_latents_mixing=None,
mixing_coeffs=0.0,
return_image: Optional[bool] = False):
list_mixing_coeffs = self.prepare_mixing()
do_classifier_free_guidance = self.guidance_scale > 1.0
# accomodate different sd model types
self.pipe.scheduler.set_timesteps(self.num_inference_steps - 1, device=self.device)
timesteps = self.pipe.scheduler.timesteps
if len(timesteps) != self.num_inference_steps:
self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device)
timesteps = self.pipe.scheduler.timesteps
latents = latents_start.clone()
list_latents_out = []
for i, t in enumerate(timesteps):
# Set the right starting latents
if i < idx_start:
list_latents_out.append(None)
continue
elif i == idx_start:
latents = latents_start.clone()
# Mix latents
if i > 0 and list_mixing_coeffs[i] > 0:
latents_mixtarget = list_latents_mixing[i - 1].clone()
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.pipe.unet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
return_dict=False,
)[0]
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.pipe.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
list_latents_out.append(latents.clone())
if return_image:
return self.latent2image(latents)
else:
return list_latents_out
@torch.no_grad() @torch.no_grad()
def run_diffusion_sd_xl( def run_diffusion_sd_xl(
self, self,
text_embeddings: list, text_embeddings: tuple,
latents_start: torch.FloatTensor, latents_start: torch.FloatTensor,
idx_start: int = 0, idx_start: int = 0,
list_latents_mixing=None, list_latents_mixing=None,
mixing_coeffs=0.0, mixing_coeffs=0.0,
return_image: Optional[bool] = False): return_image: Optional[bool] = False,
):
prompt_2 = None
height = None
width = None
timesteps = None
denoising_end = None
negative_prompt_2 = None
num_images_per_prompt = 1
eta = 0.0
generator = None
latents = None
prompt_embeds = None
negative_prompt_embeds = None
pooled_prompt_embeds = None
negative_pooled_prompt_embeds = None
ip_adapter_image = None
output_type = "pil"
return_dict = True
cross_attention_kwargs = None
guidance_rescale = 0.0
original_size = None
crops_coords_top_left = (0, 0)
target_size = None
negative_original_size = None
negative_crops_coords_top_left = (0, 0)
negative_target_size = None
clip_skip = None
callback = None
callback_on_step_end = None
callback_on_step_end_tensor_inputs = ["latents"]
# kwargs are additional keyword arguments and don't need a default value set here.
# 0. Default height and width to unet # 0. Default height and width to unet
original_size = (self.width_img, self.height_img) height = height or self.pipe.default_sample_size * self.pipe.vae_scale_factor
crops_coords_top_left = (0, 0) width = width or self.pipe.default_sample_size * self.pipe.vae_scale_factor
target_size = original_size
batch_size = 1
eta = 0.0
num_images_per_prompt = 1
cross_attention_kwargs = None
generator = torch.Generator(device=self.device) # dummy generator
do_classifier_free_guidance = self.guidance_scale > 1.0
# 1. Check inputs. Raise error if not correct & 2. Define call parameters original_size = original_size or (height, width)
target_size = target_size or (height, width)
# 1. Check inputs. skipped.
self.pipe._guidance_scale = self.guidance_scale
self.pipe._guidance_rescale = guidance_rescale
self.pipe._clip_skip = clip_skip
self.pipe._cross_attention_kwargs = cross_attention_kwargs
self.pipe._denoising_end = denoising_end
self.pipe._interrupt = False
# 2. Define call parameters
list_mixing_coeffs = self.prepare_mixing(mixing_coeffs, list_latents_mixing) list_mixing_coeffs = self.prepare_mixing(mixing_coeffs, list_latents_mixing)
batch_size = 1
# 3. Encode input prompt (already encoded outside bc of mixing, just split here) device = self.pipe._execution_device
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = text_embeddings
# 3. Encode input prompt
lora_scale = None
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = text_embeddings
# 4. Prepare timesteps # 4. Prepare timesteps
self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device) timesteps, num_inference_steps = retrieve_timesteps(self.pipe.scheduler, self.num_inference_steps, device, timesteps)
timesteps = self.pipe.scheduler.timesteps
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.pipe.unet.config.in_channels
latents = latents_start.clone() latents = latents_start.clone()
list_latents_out = [] list_latents_out = []
# 6. Prepare extra step kwargs. usedummy generator # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.pipe.prepare_extra_step_kwargs(generator, eta) # dummy extra_step_kwargs = self.pipe.prepare_extra_step_kwargs(generator, eta)
# 7. Prepare added time ids & embeddings # 7. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds add_text_embeds = pooled_prompt_embeds
@ -298,20 +283,50 @@ class DiffusersHolder():
dtype=prompt_embeds.dtype, dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim, text_encoder_projection_dim=text_encoder_projection_dim,
) )
if negative_original_size is not None and negative_target_size is not None:
negative_add_time_ids = self.pipe._get_add_time_ids(
negative_original_size,
negative_crops_coords_top_left,
negative_target_size,
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
else:
negative_add_time_ids = add_time_ids
negative_add_time_ids = add_time_ids if self.pipe.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_embeds = prompt_embeds.to(device)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_text_embeds = add_text_embeds.to(device)
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.to(self.device) if ip_adapter_image is not None:
add_text_embeds = add_text_embeds.to(self.device) output_hidden_state = False if isinstance(self.pipe.unet.encoder_hid_proj, ImageProjection) else True
add_time_ids = add_time_ids.to(self.device).repeat(batch_size * num_images_per_prompt, 1) image_embeds, negative_image_embeds = self.pipe.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
if self.pipe.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
image_embeds = image_embeds.to(device)
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.pipe.scheduler.order, 0)
# 9. Optionally get Guidance Scale Embedding
timestep_cond = None
if self.pipe.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.pipe.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
timestep_cond = self.pipe.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.pipe.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)
self.pipe._num_timesteps = len(timesteps)
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# Set the right starting latents # Set the right starting latents
# Write latents out and skip
if i < idx_start: if i < idx_start:
list_latents_out.append(None) list_latents_out.append(None)
continue continue
@ -323,26 +338,34 @@ class DiffusersHolder():
latents_mixtarget = list_latents_mixing[i - 1].clone() latents_mixtarget = list_latents_mixing[i - 1].clone()
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i]) latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.pipe.do_classifier_free_guidance else latents
# Always scale latents
latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual # predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
if ip_adapter_image is not None:
added_cond_kwargs["image_embeds"] = image_embeds
noise_pred = self.pipe.unet( noise_pred = self.pipe.unet(
latent_model_input, latent_model_input,
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, timestep_cond=timestep_cond,
cross_attention_kwargs=self.pipe.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if self.pipe.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + self.pipe.guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.pipe.do_classifier_free_guidance and self.pipe.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.pipe.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
@ -350,145 +373,7 @@ class DiffusersHolder():
# Append latents # Append latents
list_latents_out.append(latents.clone()) list_latents_out.append(latents.clone())
if return_image:
return self.latent2image(latents)
else:
return list_latents_out
@torch.no_grad()
def run_diffusion_controlnet(
self,
conditioning: list,
latents_start: torch.FloatTensor,
idx_start: int = 0,
list_latents_mixing=None,
mixing_coeffs=0.0,
return_image: Optional[bool] = False):
prompt_embeds = conditioning[0]
image = conditioning[1]
list_mixing_coeffs = self.prepare_mixing()
controlnet = self.pipe.controlnet
control_guidance_start = [0.0]
control_guidance_end = [1.0]
guess_mode = False
num_images_per_prompt = 1
batch_size = 1
eta = 0.0
controlnet_conditioning_scale = 1.0
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
# 2. Define call parameters
device = self.pipe._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = self.guidance_scale > 1.0
# 4. Prepare image
image = self.pipe.prepare_image(
image=image,
width=None,
height=None,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=self.device,
dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = image.shape[-2:]
# 5. Prepare timesteps
self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device)
timesteps = self.pipe.scheduler.timesteps
# 6. Prepare latent variables
generator = torch.Generator(device=self.device).manual_seed(int(420))
latents = latents_start.clone()
list_latents_out = []
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.pipe.prepare_extra_step_kwargs(generator, eta)
# 7.1 Create tensor stating which controlnets to keep
controlnet_keep = []
for i in range(len(timesteps)):
keeps = [
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
# 8. Denoising loop
for i, t in enumerate(timesteps):
if i < idx_start:
list_latents_out.append(None)
continue
elif i == idx_start:
latents = latents_start.clone()
# Mix latents for crossfeeding
if i > 0 and list_mixing_coeffs[i] > 0:
latents_mixtarget = list_latents_mixing[i - 1].clone()
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t)
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds
if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
else:
cond_scale = controlnet_conditioning_scale * controlnet_keep[i]
down_block_res_samples, mid_block_res_sample = self.pipe.controlnet(
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=image,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
return_dict=False,
)
if guess_mode and do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
# predict the noise residual
noise_pred = self.pipe.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=None,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# Append latents
list_latents_out.append(latents.clone())
if return_image: if return_image:
return self.latent2image(latents) return self.latent2image(latents)
@ -496,26 +381,108 @@ class DiffusersHolder():
return list_latents_out return list_latents_out
#%% #%%
if __name__ == "__main__": if __name__ == "__main__":
from PIL import Image from PIL import Image
#%% from diffusers import AutoencoderTiny
pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0" # pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16) pretrained_model_name_or_path = "stabilityai/sdxl-turbo"
pipe.to('cuda') # xxx pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16, variant="fp16")
pipe.to("cuda")
#%% #%
# pipe.vae = AutoencoderTiny.from_pretrained('madebyollin/taesdxl', torch_device='cuda', torch_dtype=torch.float16)
# pipe.vae = pipe.vae.cuda()
#%% resanity
import time
self = DiffusersHolder(pipe) self = DiffusersHolder(pipe)
prompt1 = "photo of underwater landscape, fish, und the sea, incredible detail, high resolution"
negative_prompt = "blurry, ugly, pale"
num_inference_steps = 4
guidance_scale = 0
self.set_num_inference_steps(num_inference_steps)
self.guidance_scale = guidance_scale
prefix='turbo'
for i in range(10):
self.set_negative_prompt(negative_prompt)
text_embeddings = self.get_text_embedding(prompt1)
latents_start = self.get_noise(np.random.randint(111111))
t0 = time.time()
# img_refx = self.pipe(prompt=prompt1, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale)[0]
img_refx = self.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=False)
dt_ref = time.time() - t0
img_refx.save(f"x_{prefix}_{i}.jpg")
# xxx # xxx
self.set_dimensions((1024, 704))
self.set_num_inference_steps(40) # self.set_negative_prompt(negative_prompt)
# self.set_dimensions(1536, 1024) # self.set_num_inference_steps(num_inference_steps)
prompt = "Surreal painting of eerie, nebulous glow of an indigo moon, a spine-chilling spectacle unfolds; a baroque, marbled hand reaches out from a viscous, purple lake clutching a melting clock, its face distorted in a never-ending scream of hysteria, while a cluster of laughing orchids, their petals morphed into grotesque human lips, festoon a crimson tree weeping blood instead of sap, a psychedelic cat with an unnaturally playful grin and mismatched eyes lounges atop a floating vintage television showing static, an albino peacock with iridescent, crystalline feathers dances around a towering, inverted pyramid on top of which a humanoid figure with an octopus head lounges seductively, all against the backdrop of a sprawling cityscape where buildings are inverted and writhing as if alive, and the sky is punctuated by floating aquatic creatures glowing neon, adding a touch of haunting beauty to this otherwise deeply unsettling tableau" # text_embeddings1 = self.get_text_embedding(prompt1)
text_embeddings = self.get_text_embedding(prompt) # prompt_embeds1, negative_prompt_embeds1, pooled_prompt_embeds1, negative_pooled_prompt_embeds1 = text_embeddings1
generator = torch.Generator(device=self.device).manual_seed(int(420)) # latents_start = self.get_noise(420)
latents_start = self.get_noise() # t0 = time.time()
list_latents_1 = self.run_diffusion(text_embeddings, latents_start) # img_dh = self.run_diffusion_sd_xl_resanity(text_embeddings1, latents_start, idx_start=0, return_image=True)
img_orig = self.latent2image(list_latents_1[-1]) # dt_dh = time.time() - t0
# xxxx
# #%%
# self = DiffusersHolder(pipe)
# num_inference_steps = 4
# self.set_num_inference_steps(num_inference_steps)
# latents_start = self.get_noise(420)
# guidance_scale = 0
# self.guidance_scale = 0
# #% get embeddings1
# prompt1 = "Photo of a colorful landscape with a blue sky with clouds"
# text_embeddings1 = self.get_text_embedding(prompt1)
# prompt_embeds1, negative_prompt_embeds1, pooled_prompt_embeds1, negative_pooled_prompt_embeds1 = text_embeddings1
# #% get embeddings2
# prompt2 = "Photo of a tree"
# text_embeddings2 = self.get_text_embedding(prompt2)
# prompt_embeds2, negative_prompt_embeds2, pooled_prompt_embeds2, negative_pooled_prompt_embeds2 = text_embeddings2
# latents1 = self.run_diffusion_sd_xl(text_embeddings1, latents_start, idx_start=0, return_image=False)
# img1 = self.run_diffusion_sd_xl(text_embeddings1, latents_start, idx_start=0, return_image=True)
# img1B = self.run_diffusion_sd_xl(text_embeddings1, latents_start, idx_start=0, return_image=True)
# # latents2 = self.run_diffusion_sd_xl(text_embeddings2, latents_start, idx_start=0, return_image=False)
# # # check if brings same image if restarted
# # img1_return = self.run_diffusion_sd_xl(text_embeddings1, latents1[idx_mix-1], idx_start=idx_start, return_image=True)
# # mix latents
# #%%
# idx_mix = 2
# fract=0.8
# latents_start_mixed = interpolate_spherical(latents1[idx_mix-1], latents2[idx_mix-1], fract)
# prompt_embeds = interpolate_spherical(prompt_embeds1, prompt_embeds2, fract)
# pooled_prompt_embeds = interpolate_spherical(pooled_prompt_embeds1, pooled_prompt_embeds2, fract)
# negative_prompt_embeds = negative_prompt_embeds1
# negative_pooled_prompt_embeds = negative_pooled_prompt_embeds1
# text_embeddings_mix = [prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds]
# self.run_diffusion_sd_xl(text_embeddings_mix, latents_start_mixed, idx_start=idx_start, return_image=True)

View File

@ -17,41 +17,25 @@ import torch
import warnings import warnings
from latent_blending import LatentBlending from latent_blending import LatentBlending
from diffusers_holder import DiffusersHolder from diffusers_holder import DiffusersHolder
from diffusers import DiffusionPipeline from diffusers import AutoPipelineForText2Image
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
# %% First let us spawn a stable diffusion holder. Uncomment your version of choice. # %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0" pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16) pipe.to("cuda")
pipe.to('cuda')
dh = DiffusersHolder(pipe) dh = DiffusersHolder(pipe)
# %% Next let's set up all parameters
depth_strength = 0.55 # Specifies how deep (in terms of diffusion iterations the first branching happens)
t_compute_max_allowed = 60 # Determines the quality of the transition in terms of compute time you grant it
num_inference_steps = 30
size_output = (1024, 1024)
prompt1 = "underwater landscape, fish, und the sea, incredible detail, high resolution"
prompt2 = "rendering of an alien planet, strange plants, strange creatures, surreal"
negative_prompt = "blurry, ugly, pale" # Optional
fp_movie = 'movie_example1.mp4'
duration_transition = 12 # In seconds
# Spawn latent blending
lb = LatentBlending(dh) lb = LatentBlending(dh)
lb.set_prompt1(prompt1) lb.set_prompt1("photo of underwater landscape, fish, und the sea, incredible detail, high resolution")
lb.set_prompt2(prompt2) lb.set_prompt2("rendering of an alien planet, strange plants, strange creatures, surreal")
lb.set_dimensions(size_output) lb.set_negative_prompt("blurry, ugly, pale")
lb.set_negative_prompt(negative_prompt)
# Run latent blending # Run latent blending
lb.run_transition( lb.run_transition()
depth_strength=depth_strength,
num_inference_steps=num_inference_steps,
t_compute_max_allowed=t_compute_max_allowed)
# Save movie # Save movie
lb.write_movie_transition(fp_movie, duration_transition) lb.write_movie_transition('movie_example1.mp4', duration_transition=12)

View File

@ -17,24 +17,20 @@ import torch
import warnings import warnings
from latent_blending import LatentBlending from latent_blending import LatentBlending
from diffusers_holder import DiffusersHolder from diffusers_holder import DiffusersHolder
from diffusers import DiffusionPipeline from diffusers import AutoPipelineForText2Image
from movie_util import concatenate_movies from movie_util import concatenate_movies
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
# %% First let us spawn a stable diffusion holder. Uncomment your version of choice. # %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0" pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16)
pipe.to('cuda') pipe.to('cuda')
dh = DiffusersHolder(pipe) dh = DiffusersHolder(pipe)
# %% Let's setup the multi transition # %% Let's setup the multi transition
fps = 30 fps = 30
duration_single_trans = 20 duration_single_trans = 10
depth_strength = 0.25 # Specifies how deep (in terms of diffusion iterations the first branching happens)
size_output = (1280, 768)
num_inference_steps = 30
# Specify a list of prompts below # Specify a list of prompts below
list_prompts = [] list_prompts = []
@ -45,12 +41,8 @@ list_prompts.append("photo of a house, high detail")
# You can optionally specify the seeds # You can optionally specify the seeds
list_seeds = [95437579, 33259350, 956051013] list_seeds = [95437579, 33259350, 956051013]
t_compute_max_allowed = 20 # per segment
fp_movie = 'movie_example2.mp4' fp_movie = 'movie_example2.mp4'
lb = LatentBlending(dh) lb = LatentBlending(dh)
lb.set_dimensions(size_output)
lb.dh.set_num_inference_steps(num_inference_steps)
list_movie_parts = [] list_movie_parts = []
for i in range(len(list_prompts) - 1): for i in range(len(list_prompts) - 1):
@ -69,8 +61,6 @@ for i in range(len(list_prompts) - 1):
# Run latent blending # Run latent blending
lb.run_transition( lb.run_transition(
recycle_img1=recycle_img1, recycle_img1=recycle_img1,
depth_strength=depth_strength,
t_compute_max_allowed=t_compute_max_allowed,
fixed_seeds=fixed_seeds) fixed_seeds=fixed_seeds)
# Save movie # Save movie

View File

@ -33,18 +33,11 @@ class LatentBlending():
def __init__( def __init__(
self, self,
dh: None, dh: None,
guidance_scale: float = 4,
guidance_scale_mid_damper: float = 0.5, guidance_scale_mid_damper: float = 0.5,
mid_compression_scaler: float = 1.2): mid_compression_scaler: float = 1.2):
r""" r"""
Initializes the latent blending class. Initializes the latent blending class.
Args: Args:
guidance_scale: float
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
guidance_scale_mid_damper: float = 0.5 guidance_scale_mid_damper: float = 0.5
Reduces the guidance scale towards the middle of the transition. Reduces the guidance scale towards the middle of the transition.
A value of 0.5 would decrease the guidance_scale towards the middle linearly by 0.5. A value of 0.5 would decrease the guidance_scale towards the middle linearly by 0.5.
@ -76,37 +69,49 @@ class LatentBlending():
self.tree_status = None self.tree_status = None
self.tree_final_imgs = [] self.tree_final_imgs = []
self.list_nmb_branches_prev = []
self.list_injection_idx_prev = []
self.text_embedding1 = None self.text_embedding1 = None
self.text_embedding2 = None self.text_embedding2 = None
self.image1_lowres = None self.image1_lowres = None
self.image2_lowres = None self.image2_lowres = None
self.negative_prompt = None self.negative_prompt = None
self.num_inference_steps = self.dh.num_inference_steps
self.noise_level_upscaling = 20
self.list_injection_idx = None
self.list_nmb_branches = None
# Mixing parameters self.set_guidance_scale()
self.branch1_crossfeed_power = 0.3
self.branch1_crossfeed_range = 0.3
self.branch1_crossfeed_decay = 0.99
self.parental_crossfeed_power = 0.3
self.parental_crossfeed_range = 0.6
self.parental_crossfeed_power_decay = 0.9
self.set_guidance_scale(guidance_scale)
self.multi_transition_img_first = None self.multi_transition_img_first = None
self.multi_transition_img_last = None self.multi_transition_img_last = None
self.dt_per_diff = 0 self.dt_unet_step = 0
self.spatial_mask = None
self.lpips = lpips.LPIPS(net='alex').cuda(self.device) self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
self.set_prompt1("") self.set_prompt1("")
self.set_prompt2("") self.set_prompt2("")
self.set_branch1_crossfeed()
self.set_parental_crossfeed()
self.set_num_inference_steps()
self.benchmark_speed()
self.set_branching()
def benchmark_speed(self):
"""
Measures the time per diffusion step and for the vae decoding
"""
text_embeddings = self.dh.get_text_embedding("test")
latents_start = self.dh.get_noise(np.random.randint(111111))
# warmup
list_latents = self.dh.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=False, idx_start=self.num_inference_steps-1)
# bench unet
t0 = time.time()
list_latents = self.dh.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=False, idx_start=self.num_inference_steps-1)
self.dt_unet_step = time.time() - t0
# bench vae
t0 = time.time()
img = self.dh.latent2image(list_latents[-1])
self.dt_vae = time.time() - t0
def set_dimensions(self, size_output=None): def set_dimensions(self, size_output=None):
r""" r"""
sets the size of the output video. sets the size of the output video.
@ -115,12 +120,23 @@ class LatentBlending():
width x height width x height
Note: the size will get automatically adjusted to be divisable by 32. Note: the size will get automatically adjusted to be divisable by 32.
""" """
if size_output is None:
if self.dh.is_sdxl_turbo:
size_output = (512, 512)
else:
size_output = (1024, 1024)
self.dh.set_dimensions(size_output) self.dh.set_dimensions(size_output)
def set_guidance_scale(self, guidance_scale): def set_guidance_scale(self, guidance_scale=None):
r""" r"""
sets the guidance scale. sets the guidance scale.
""" """
if guidance_scale is None:
if self.dh.is_sdxl_turbo:
guidance_scale = 0.0
else:
guidance_scale = 4.0
self.guidance_scale_base = guidance_scale self.guidance_scale_base = guidance_scale
self.guidance_scale = guidance_scale self.guidance_scale = guidance_scale
self.dh.guidance_scale = guidance_scale self.dh.guidance_scale = guidance_scale
@ -142,7 +158,7 @@ class LatentBlending():
self.guidance_scale = guidance_scale_effective self.guidance_scale = guidance_scale_effective
self.dh.guidance_scale = guidance_scale_effective self.dh.guidance_scale = guidance_scale_effective
def set_branch1_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay): def set_branch1_crossfeed(self, crossfeed_power=0, crossfeed_range=0, crossfeed_decay=0):
r""" r"""
Sets the crossfeed parameters for the first branch to the last branch. Sets the crossfeed parameters for the first branch to the last branch.
Args: Args:
@ -157,7 +173,7 @@ class LatentBlending():
self.branch1_crossfeed_range = np.clip(crossfeed_range, 0, 1) self.branch1_crossfeed_range = np.clip(crossfeed_range, 0, 1)
self.branch1_crossfeed_decay = np.clip(crossfeed_decay, 0, 1) self.branch1_crossfeed_decay = np.clip(crossfeed_decay, 0, 1)
def set_parental_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay): def set_parental_crossfeed(self, crossfeed_power=None, crossfeed_range=None, crossfeed_decay=None):
r""" r"""
Sets the crossfeed parameters for all transition images (within the first and last branch). Sets the crossfeed parameters for all transition images (within the first and last branch).
Args: Args:
@ -168,9 +184,22 @@ class LatentBlending():
crossfeed_decay: float [0,1] crossfeed_decay: float [0,1]
Sets decay for branch1_crossfeed_power. Lower values make the decay stronger across the range. Sets decay for branch1_crossfeed_power. Lower values make the decay stronger across the range.
""" """
if self.dh.is_sdxl_turbo:
if crossfeed_power is None:
crossfeed_power = 1.0
if crossfeed_range is None:
crossfeed_range = 1.0
if crossfeed_decay is None:
crossfeed_decay = 1.0
else:
crossfeed_power = 0.3
crossfeed_range = 0.6
crossfeed_decay = 0.9
self.parental_crossfeed_power = np.clip(crossfeed_power, 0, 1) self.parental_crossfeed_power = np.clip(crossfeed_power, 0, 1)
self.parental_crossfeed_range = np.clip(crossfeed_range, 0, 1) self.parental_crossfeed_range = np.clip(crossfeed_range, 0, 1)
self.parental_crossfeed_power_decay = np.clip(crossfeed_decay, 0, 1) self.parental_crossfeed_decay = np.clip(crossfeed_decay, 0, 1)
def set_prompt1(self, prompt: str): def set_prompt1(self, prompt: str):
r""" r"""
@ -210,25 +239,20 @@ class LatentBlending():
""" """
self.image2_lowres = image self.image2_lowres = image
def run_transition( def set_num_inference_steps(self, num_inference_steps=None):
self, if self.dh.is_sdxl_turbo:
recycle_img1: Optional[bool] = False, if num_inference_steps is None:
recycle_img2: Optional[bool] = False, num_inference_steps = 4
num_inference_steps: Optional[int] = 30, else:
depth_strength: Optional[float] = 0.3, if num_inference_steps is None:
t_compute_max_allowed: Optional[float] = None, num_inference_steps = 30
nmb_max_branches: Optional[int] = None,
fixed_seeds: Optional[List[int]] = None): self.num_inference_steps = num_inference_steps
r""" self.dh.set_num_inference_steps(num_inference_steps)
Function for computing transitions.
Returns a list of transition images using spherical latent blending. def set_branching(self, depth_strength=None, t_compute_max_allowed=None, nmb_max_branches=None):
Args: """
recycle_img1: Optional[bool]: Sets the branching structure of the blending tree. Default arguments depend on pipe!
Don't recompute the latents for the first keyframe (purely prompt1). Saves compute.
recycle_img2: Optional[bool]:
Don't recompute the latents for the second keyframe (purely prompt2). Saves compute.
num_inference_steps:
Number of diffusion steps. Higher values will take more compute time.
depth_strength: depth_strength:
Determines how deep the first injection will happen. Determines how deep the first injection will happen.
Deeper injections will cause (unwanted) formation of new structures, Deeper injections will cause (unwanted) formation of new structures,
@ -240,6 +264,45 @@ class LatentBlending():
Either provide t_compute_max_allowed or nmb_max_branches. The maximum number of branches to be computed. Higher values give better Either provide t_compute_max_allowed or nmb_max_branches. The maximum number of branches to be computed. Higher values give better
results. Use this if you want to have controllable results independent results. Use this if you want to have controllable results independent
of your computer. of your computer.
"""
if self.dh.is_sdxl_turbo:
assert t_compute_max_allowed is None, "time-based branching not supported for SDXL Turbo"
if depth_strength is not None:
idx_inject = int(round(self.num_inference_steps*depth_strength))
else:
idx_inject = 2
if nmb_max_branches is None:
nmb_max_branches = 10
self.list_idx_injection = [idx_inject]
self.list_nmb_stems = [nmb_max_branches]
else:
if depth_strength is None:
depth_strength = 0.5
if t_compute_max_allowed is None and nmb_max_branches is None:
t_compute_max_allowed = 20
elif t_compute_max_allowed is not None and nmb_max_branches is not None:
raise ValueErorr("Either specify t_compute_max_allowed or nmb_max_branches")
self.list_idx_injection, self.list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches)
def run_transition(
self,
recycle_img1: Optional[bool] = False,
recycle_img2: Optional[bool] = False,
fixed_seeds: Optional[List[int]] = None):
r"""
Function for computing transitions.
Returns a list of transition images using spherical latent blending.
Args:
recycle_img1: Optional[bool]:
Don't recompute the latents for the first keyframe (purely prompt1). Saves compute.
recycle_img2: Optional[bool]:
Don't recompute the latents for the second keyframe (purely prompt2). Saves compute.
num_inference_steps:
Number of diffusion steps. Higher values will take more compute time.
fixed_seeds: Optional[List[int)]: fixed_seeds: Optional[List[int)]:
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2). You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
Otherwise random seeds will be taken. Otherwise random seeds will be taken.
@ -249,6 +312,7 @@ class LatentBlending():
assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before' assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before'
assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before' assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before'
# Random seeds # Random seeds
if fixed_seeds is not None: if fixed_seeds is not None:
if fixed_seeds == 'randomize': if fixed_seeds == 'randomize':
@ -259,9 +323,6 @@ class LatentBlending():
self.seed1 = fixed_seeds[0] self.seed1 = fixed_seeds[0]
self.seed2 = fixed_seeds[1] self.seed2 = fixed_seeds[1]
# Ensure correct num_inference_steps in holder
self.num_inference_steps = num_inference_steps
self.dh.set_num_inference_steps(num_inference_steps)
# Compute / Recycle first image # Compute / Recycle first image
if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps: if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps:
@ -280,28 +341,27 @@ class LatentBlending():
self.tree_fracts = [0.0, 1.0] self.tree_fracts = [0.0, 1.0]
self.tree_final_imgs = [self.dh.latent2image((self.tree_latents[0][-1])), self.dh.latent2image((self.tree_latents[-1][-1]))] self.tree_final_imgs = [self.dh.latent2image((self.tree_latents[0][-1])), self.dh.latent2image((self.tree_latents[-1][-1]))]
self.tree_idx_injection = [0, 0] self.tree_idx_injection = [0, 0]
self.tree_similarities = [self.get_tree_similarities]
# Hard-fix. Apply spatial mask only for list_latents2 but not for transition. WIP...
self.spatial_mask = None
# Set up branching scheme (dependent on provided compute time)
list_idx_injection, list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches)
# Run iteratively, starting with the longest trajectory. # Run iteratively, starting with the longest trajectory.
# Always inserting new branches where they are needed most according to image similarity # Always inserting new branches where they are needed most according to image similarity
for s_idx in tqdm(range(len(list_idx_injection))): for s_idx in tqdm(range(len(self.list_idx_injection))):
nmb_stems = list_nmb_stems[s_idx] nmb_stems = self.list_nmb_stems[s_idx]
idx_injection = list_idx_injection[s_idx] idx_injection = self.list_idx_injection[s_idx]
for i in range(nmb_stems): for i in range(nmb_stems):
fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection) fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection)
self.set_guidance_mid_dampening(fract_mixing) self.set_guidance_mid_dampening(fract_mixing)
list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection) list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection)
self.insert_into_tree(fract_mixing, idx_injection, list_latents) self.insert_into_tree(fract_mixing, idx_injection, list_latents)
# print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}") # print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection} bp1 {b_parent1} bp2 {b_parent2}")
return self.tree_final_imgs return self.tree_final_imgs
def compute_latents1(self, return_image=False): def compute_latents1(self, return_image=False):
r""" r"""
Runs a diffusion trajectory for the first image Runs a diffusion trajectory for the first image
@ -318,7 +378,7 @@ class LatentBlending():
latents_start=latents_start, latents_start=latents_start,
idx_start=0) idx_start=0)
t1 = time.time() t1 = time.time()
self.dt_per_diff = (t1 - t0) / self.num_inference_steps self.dt_unet_step = (t1 - t0) / self.num_inference_steps
self.tree_latents[0] = list_latents1 self.tree_latents[0] = list_latents1
if return_image: if return_image:
return self.dh.latent2image(list_latents1[-1]) return self.dh.latent2image(list_latents1[-1])
@ -388,7 +448,7 @@ class LatentBlending():
mixing_coeffs = idx_injection * [self.parental_crossfeed_power] mixing_coeffs = idx_injection * [self.parental_crossfeed_power]
nmb_mixing = idx_mixing_stop - idx_injection nmb_mixing = idx_mixing_stop - idx_injection
if nmb_mixing > 0: if nmb_mixing > 0:
mixing_coeffs.extend(list(np.linspace(self.parental_crossfeed_power, self.parental_crossfeed_power * self.parental_crossfeed_power_decay, nmb_mixing))) mixing_coeffs.extend(list(np.linspace(self.parental_crossfeed_power, self.parental_crossfeed_power * self.parental_crossfeed_decay, nmb_mixing)))
mixing_coeffs.extend((self.num_inference_steps - len(mixing_coeffs)) * [0]) mixing_coeffs.extend((self.num_inference_steps - len(mixing_coeffs)) * [0])
latents_start = list_latents_parental_mix[idx_injection - 1] latents_start = list_latents_parental_mix[idx_injection - 1]
list_latents = self.run_diffusion( list_latents = self.run_diffusion(
@ -417,8 +477,10 @@ class LatentBlending():
results. Use this if you want to have controllable results independent results. Use this if you want to have controllable results independent
of your computer. of your computer.
""" """
idx_injection_base = int(round(self.num_inference_steps * depth_strength)) idx_injection_base = int(np.floor(self.num_inference_steps * depth_strength))
list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps - 1, 3)
steps = int(np.ceil(self.num_inference_steps/10))
list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps, steps)
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32) list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
t_compute = 0 t_compute = 0
@ -436,11 +498,11 @@ class LatentBlending():
while not stop_criterion_reached: while not stop_criterion_reached:
list_compute_steps = self.num_inference_steps - list_idx_injection list_compute_steps = self.num_inference_steps - list_idx_injection
list_compute_steps *= list_nmb_stems list_compute_steps *= list_nmb_stems
t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15 * np.sum(list_nmb_stems) t_compute = np.sum(list_compute_steps) * self.dt_unet_step + self.dt_vae * np.sum(list_nmb_stems)
t_compute += 2 * self.num_inference_steps * self.dt_per_diff # outer branches t_compute += 2 * (self.num_inference_steps * self.dt_unet_step + self.dt_vae) # outer branches
increase_done = False increase_done = False
for s_idx in range(len(list_nmb_stems) - 1): for s_idx in range(len(list_nmb_stems) - 1):
if list_nmb_stems[s_idx + 1] / list_nmb_stems[s_idx] >= 2: if list_nmb_stems[s_idx + 1] / list_nmb_stems[s_idx] >= 1:
list_nmb_stems[s_idx] += 1 list_nmb_stems[s_idx] += 1
increase_done = True increase_done = True
break break
@ -471,15 +533,15 @@ class LatentBlending():
the index in terms of diffusion steps, where the next insertion will start. the index in terms of diffusion steps, where the next insertion will start.
""" """
# get_lpips_similarity # get_lpips_similarity
similarities = [] similarities = self.tree_similarities
for i in range(len(self.tree_final_imgs) - 1): # similarities = self.get_tree_similarities()
similarities.append(self.get_lpips_similarity(self.tree_final_imgs[i], self.tree_final_imgs[i + 1]))
b_closest1 = np.argmax(similarities) b_closest1 = np.argmax(similarities)
b_closest2 = b_closest1 + 1 b_closest2 = b_closest1 + 1
fract_closest1 = self.tree_fracts[b_closest1] fract_closest1 = self.tree_fracts[b_closest1]
fract_closest2 = self.tree_fracts[b_closest2] fract_closest2 = self.tree_fracts[b_closest2]
fract_mixing = (fract_closest1 + fract_closest2) / 2
# Ensure that the parents are indeed older! # Ensure that the parents are indeed older
b_parent1 = b_closest1 b_parent1 = b_closest1
while True: while True:
if self.tree_idx_injection[b_parent1] < idx_injection: if self.tree_idx_injection[b_parent1] < idx_injection:
@ -492,7 +554,6 @@ class LatentBlending():
break break
else: else:
b_parent2 += 1 b_parent2 += 1
fract_mixing = (fract_closest1 + fract_closest2) / 2
return fract_mixing, b_parent1, b_parent2 return fract_mixing, b_parent1, b_parent2
def insert_into_tree(self, fract_mixing, idx_injection, list_latents): def insert_into_tree(self, fract_mixing, idx_injection, list_latents):
@ -506,11 +567,21 @@ class LatentBlending():
list_latents: list list_latents: list
list of the latents to be inserted list of the latents to be inserted
""" """
img_insert = self.dh.latent2image(list_latents[-1])
b_parent1, b_parent2 = self.get_closest_idx(fract_mixing) b_parent1, b_parent2 = self.get_closest_idx(fract_mixing)
self.tree_latents.insert(b_parent1 + 1, list_latents) left_sim = self.get_lpips_similarity(img_insert, self.tree_final_imgs[b_parent1])
self.tree_final_imgs.insert(b_parent1 + 1, self.dh.latent2image(list_latents[-1])) right_sim = self.get_lpips_similarity(img_insert, self.tree_final_imgs[b_parent2])
self.tree_fracts.insert(b_parent1 + 1, fract_mixing) idx_insert = b_parent1 + 1
self.tree_idx_injection.insert(b_parent1 + 1, idx_injection) self.tree_latents.insert(idx_insert, list_latents)
self.tree_final_imgs.insert(idx_insert, img_insert)
self.tree_fracts.insert(idx_insert, fract_mixing)
self.tree_idx_injection.insert(idx_insert, idx_injection)
# update similarities
self.tree_similarities[b_parent1] = left_sim
self.tree_similarities.insert(idx_insert, right_sim)
def get_noise(self, seed): def get_noise(self, seed):
r""" r"""
@ -552,119 +623,29 @@ class LatentBlending():
self.dh.set_num_inference_steps(self.num_inference_steps) self.dh.set_num_inference_steps(self.num_inference_steps)
assert type(list_conditionings) is list, "list_conditionings need to be a list" assert type(list_conditionings) is list, "list_conditionings need to be a list"
if self.dh.use_sd_xl: text_embeddings = list_conditionings[0]
text_embeddings = list_conditionings[0] return self.dh.run_diffusion_sd_xl(
return self.dh.run_diffusion_sd_xl( text_embeddings=text_embeddings,
text_embeddings=text_embeddings, latents_start=latents_start,
latents_start=latents_start, idx_start=idx_start,
idx_start=idx_start, list_latents_mixing=list_latents_mixing,
list_latents_mixing=list_latents_mixing, mixing_coeffs=mixing_coeffs,
mixing_coeffs=mixing_coeffs, return_image=return_image)
return_image=return_image)
else:
text_embeddings = list_conditionings[0]
return self.dh.run_diffusion_standard(
text_embeddings=text_embeddings,
latents_start=latents_start,
idx_start=idx_start,
list_latents_mixing=list_latents_mixing,
mixing_coeffs=mixing_coeffs,
return_image=return_image)
def run_upscaling(
self,
dp_img: str,
depth_strength: float = 0.65,
num_inference_steps: int = 100,
nmb_max_branches_highres: int = 5,
nmb_max_branches_lowres: int = 6,
duration_single_segment=3,
fps=24,
fixed_seeds: Optional[List[int]] = None):
r"""
Runs upscaling with the x4 model. Requires that you run a transition before with a low-res model and save the results using write_imgs_transition.
Args:
dp_img: str
Path to the low-res transition path (as saved in write_imgs_transition)
depth_strength:
Determines how deep the first injection will happen.
Deeper injections will cause (unwanted) formation of new structures,
more shallow values will go into alpha-blendy land.
num_inference_steps:
Number of diffusion steps. Higher values will take more compute time.
nmb_max_branches_highres: int
Number of final branches of the upscaling transition pass. Note this is the number
of branches between each pair of low-res images.
nmb_max_branches_lowres: int
Number of input low-res images, subsampling all transition images written in the low-res pass.
Setting this number lower (e.g. 6) will decrease the compute time but not affect the results too much.
duration_single_segment: float
The duration of each high-res movie segment. You will have nmb_max_branches_lowres-1 segments in total.
fps: float
frames per second of movie
fixed_seeds: Optional[List[int)]:
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
Otherwise random seeds will be taken.
"""
fp_yml = os.path.join(dp_img, "lowres.yaml")
fp_movie = os.path.join(dp_img, "movie_highres.mp4")
ms = MovieSaver(fp_movie, fps=fps)
assert os.path.isfile(fp_yml), "lowres.yaml does not exist. did you forget run_upscaling_step1?"
dict_stuff = yml_load(fp_yml)
# load lowres images
nmb_images_lowres = dict_stuff['nmb_images']
prompt1 = dict_stuff['prompt1']
prompt2 = dict_stuff['prompt2']
idx_img_lowres = np.round(np.linspace(0, nmb_images_lowres - 1, nmb_max_branches_lowres)).astype(np.int32)
imgs_lowres = []
for i in idx_img_lowres:
fp_img_lowres = os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg")
assert os.path.isfile(fp_img_lowres), f"{fp_img_lowres} does not exist. did you forget run_upscaling_step1?"
imgs_lowres.append(Image.open(fp_img_lowres))
# set up upscaling
text_embeddingA = self.dh.get_text_embedding(prompt1)
text_embeddingB = self.dh.get_text_embedding(prompt2)
list_fract_mixing = np.linspace(0, 1, nmb_max_branches_lowres - 1)
for i in range(nmb_max_branches_lowres - 1):
print(f"Starting movie segment {i+1}/{nmb_max_branches_lowres-1}")
self.text_embedding1 = interpolate_linear(text_embeddingA, text_embeddingB, list_fract_mixing[i])
self.text_embedding2 = interpolate_linear(text_embeddingA, text_embeddingB, 1 - list_fract_mixing[i])
if i == 0:
recycle_img1 = False
else:
self.swap_forward()
recycle_img1 = True
self.set_image1(imgs_lowres[i])
self.set_image2(imgs_lowres[i + 1])
list_imgs = self.run_transition(
recycle_img1=recycle_img1,
recycle_img2=False,
num_inference_steps=num_inference_steps,
depth_strength=depth_strength,
nmb_max_branches=nmb_max_branches_highres)
list_imgs_interp = add_frames_linear_interp(list_imgs, fps, duration_single_segment)
# Save movie frame
for img in list_imgs_interp:
ms.write_frame(img)
ms.finalize()
@torch.no_grad() @torch.no_grad()
def get_mixed_conditioning(self, fract_mixing): def get_mixed_conditioning(self, fract_mixing):
if self.dh.use_sd_xl: text_embeddings_mix = []
text_embeddings_mix = [] for i in range(len(self.text_embedding1)):
for i in range(len(self.text_embedding1)): if self.text_embedding1[i] is None:
text_embeddings_mix.append(interpolate_linear(self.text_embedding1[i], self.text_embedding2[i], fract_mixing)) mix = None
list_conditionings = [text_embeddings_mix] else:
else: mix = interpolate_linear(self.text_embedding1[i], self.text_embedding2[i], fract_mixing)
text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing) text_embeddings_mix.append(mix)
list_conditionings = [text_embeddings_mix] list_conditionings = [text_embeddings_mix]
return list_conditionings return list_conditionings
@torch.no_grad() @torch.no_grad()
@ -733,7 +714,7 @@ class LatentBlending():
'num_inference_steps', 'depth_strength', 'guidance_scale', 'num_inference_steps', 'depth_strength', 'guidance_scale',
'guidance_scale_mid_damper', 'mid_compression_scaler', 'negative_prompt', 'guidance_scale_mid_damper', 'mid_compression_scaler', 'negative_prompt',
'branch1_crossfeed_power', 'branch1_crossfeed_range', 'branch1_crossfeed_decay' 'branch1_crossfeed_power', 'branch1_crossfeed_range', 'branch1_crossfeed_decay'
'parental_crossfeed_power', 'parental_crossfeed_range', 'parental_crossfeed_power_decay'] 'parental_crossfeed_power', 'parental_crossfeed_range', 'parental_crossfeed_decay']
for v in grab_vars: for v in grab_vars:
if hasattr(self, v): if hasattr(self, v):
if v == 'seed1' or v == 'seed2': if v == 'seed1' or v == 'seed2':
@ -797,16 +778,22 @@ class LatentBlending():
Used to determine the optimal point of insertion to create smooth transitions. Used to determine the optimal point of insertion to create smooth transitions.
High values indicate low similarity. High values indicate low similarity.
""" """
tensorA = torch.from_numpy(imgA).float().cuda(self.device) tensorA = torch.from_numpy(np.asarray(imgA)).float().cuda(self.device)
tensorA = 2 * tensorA / 255.0 - 1 tensorA = 2 * tensorA / 255.0 - 1
tensorA = tensorA.permute([2, 0, 1]).unsqueeze(0) tensorA = tensorA.permute([2, 0, 1]).unsqueeze(0)
tensorB = torch.from_numpy(imgB).float().cuda(self.device) tensorB = torch.from_numpy(np.asarray(imgB)).float().cuda(self.device)
tensorB = 2 * tensorB / 255.0 - 1 tensorB = 2 * tensorB / 255.0 - 1
tensorB = tensorB.permute([2, 0, 1]).unsqueeze(0) tensorB = tensorB.permute([2, 0, 1]).unsqueeze(0)
lploss = self.lpips(tensorA, tensorB) lploss = self.lpips(tensorA, tensorB)
lploss = float(lploss[0][0][0][0]) lploss = float(lploss[0][0][0][0])
return lploss return lploss
def get_tree_similarities(self):
similarities = []
for i in range(len(self.tree_final_imgs) - 1):
similarities.append(self.get_lpips_similarity(self.tree_final_imgs[i], self.tree_final_imgs[i + 1]))
return similarities
# Auxiliary functions # Auxiliary functions
def get_closest_idx( def get_closest_idx(
self, self,
@ -831,3 +818,46 @@ class LatentBlending():
b_parent1 = tmp b_parent1 = tmp
return b_parent1, b_parent2 return b_parent1, b_parent2
#%%
if __name__ == "__main__":
# %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
from diffusers_holder import DiffusersHolder
from diffusers import DiffusionPipeline
from diffusers import AutoencoderTiny
# pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
pretrained_model_name_or_path = "stabilityai/sdxl-turbo"
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16, variant="fp16")
pipe.to("cuda")
pipe.vae = AutoencoderTiny.from_pretrained('madebyollin/taesdxl', torch_device='cuda', torch_dtype=torch.float16)
pipe.vae = pipe.vae.cuda()
dh = DiffusersHolder(pipe)
# %% Next let's set up all parameters
prompt1 = "photo of underwater landscape, fish, und the sea, incredible detail, high resolution"
prompt2 = "rendering of an alien planet, strange plants, strange creatures, surreal"
negative_prompt = "blurry, ugly, pale" # Optional
duration_transition = 12 # In seconds
# Spawn latent blending
lb = LatentBlending(dh)
lb.set_prompt1(prompt1)
lb.set_prompt2(prompt2)
lb.set_negative_prompt(negative_prompt)
# Run latent blending
t0 = time.time()
lb.run_transition(fixed_seeds=[420, 421])
dt = time.time() - t0
# Save movie
fp_movie = f'test.mp4'
lb.write_movie_transition(fp_movie, duration_transition)

View File

@ -262,7 +262,6 @@ def add_subtitles_to_video(
class MovieReader(): class MovieReader():
r""" r"""
Class to read in a movie. Class to read in a movie.

View File

@ -24,7 +24,7 @@ import datetime
from typing import List, Union from typing import List, Union
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
import yaml import yaml
import PIL
@torch.no_grad() @torch.no_grad()
def interpolate_spherical(p0, p1, fract_mixing: float): def interpolate_spherical(p0, p1, fract_mixing: float):
@ -142,6 +142,8 @@ def add_frames_linear_interp(
if nmb_frames_missing < 1: if nmb_frames_missing < 1:
return list_imgs return list_imgs
if type(list_imgs[0]) == PIL.Image.Image:
list_imgs = [np.asarray(l) for l in list_imgs]
list_imgs_float = [img.astype(np.float32) for img in list_imgs] list_imgs_float = [img.astype(np.float32) for img in list_imgs]
# Distribute missing frames, append nmb_frames_to_insert(i) frames for each frame # Distribute missing frames, append nmb_frames_to_insert(i) frames for each frame
mean_nmb_frames_insert = nmb_frames_missing / nmb_frames_diff mean_nmb_frames_insert = nmb_frames_missing / nmb_frames_diff