diffusers, forced sd xl

This commit is contained in:
Johannes Stelzer 2023-07-20 13:49:19 +02:00
parent 1c60f7df4f
commit 76f89cb836
4 changed files with 541 additions and 112 deletions

416
diffusers_holder.py Normal file
View File

@ -0,0 +1,416 @@
# Copyright 2022 Lunar Ring. All rights reserved.
# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
torch.backends.cudnn.benchmark = False
torch.set_grad_enabled(False)
import numpy as np
import warnings
warnings.filterwarnings('ignore')
import warnings
import torch
from PIL import Image
import torch
from typing import Optional
from torch import autocast
from contextlib import nullcontext
from utils import interpolate_spherical
from diffusers import DiffusionPipeline
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
class DiffusersHolder():
def __init__(self, pipe):
# Base settings
self.negative_prompt = ""
self.guidance_scale = 5.0
self.num_inference_steps = 30
# Check if valid pipe
self.pipe = pipe
self.device = str(pipe._execution_device)
self.init_type_pipe()
self.init_dtype()
self.width_latent = self.pipe.unet.config.sample_size
self.height_latent = self.pipe.unet.config.sample_size
def init_type_pipe(self):
self.type_pipe = "StableDiffusionXLPipeline"
if self.type_pipe == "StableDiffusionXLPipeline":
self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device)
self.use_sd_xl = True
else:
self.use_sd_xl = False
def init_dtype(self):
if self.type_pipe == "StableDiffusionXLPipeline":
prompt_embeds, _, _, _ = self.pipe.encode_prompt("test")
self.dtype = prompt_embeds.dtype
def set_num_inference_steps(self, num_inference_steps):
self.num_inference_steps = num_inference_steps
if self.use_sd_xl:
self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device)
def set_dimensions(self, width, height):
s = self.pipe.vae_scale_factor
if width is None:
self.width_latent = self.pipe.unet.config.sample_size
self.width_img = self.width_latent * self.pipe.vae_scale_factor
else:
self.width_img = int(round(width / s) * s)
self.width_latent = int(self.width_img / s)
if height is None:
self.height_latent = self.pipe.unet.config.sample_size
self.height_img = self.width_latent * self.pipe.vae_scale_factor
else:
self.height_img = int(round(height / s) * s)
self.height_latent = int(self.height_img / s)
def set_negative_prompt(self, negative_prompt):
r"""Set the negative prompt. Currenty only one negative prompt is supported
"""
if isinstance(negative_prompt, str):
self.negative_prompt = [negative_prompt]
else:
self.negative_prompt = negative_prompt
if len(self.negative_prompt) > 1:
self.negative_prompt = [self.negative_prompt[0]]
def get_text_embedding(self, prompt, do_classifier_free_guidance=True):
if self.use_sd_xl:
pr_encoder = self.pipe.encode_prompt
else:
pr_encoder = self.pipe._encode_prompt
prompt_embeds = pr_encoder(
prompt,
self.device,
1,
do_classifier_free_guidance,
negative_prompt=self.negative_prompt,
prompt_embeds=None,
negative_prompt_embeds=None,
lora_scale=None,
)
return prompt_embeds
def get_noise(self, seed=420, mode=None):
H = self.height_latent
W = self.width_latent
C = self.pipe.unet.config.in_channels
generator = torch.Generator(device=self.device).manual_seed(int(seed))
latents = torch.randn((1, C, H, W), generator=generator, dtype=self.dtype, device=self.device)
if self.use_sd_xl:
latents = latents * self.pipe.scheduler.init_noise_sigma
return latents
@torch.no_grad()
def latent2image(
self,
latents: torch.FloatTensor):
r"""
Returns an image provided a latent representation from diffusion.
Args:
latents: torch.FloatTensor
Result of the diffusion process.
"""
if self.use_sd_xl:
# make sure the VAE is in float32 mode, as it overflows in float16
self.pipe.vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance(
self.pipe.vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
self.pipe.vae.post_quant_conv.to(latents.dtype)
self.pipe.vae.decoder.conv_in.to(latents.dtype)
self.pipe.vae.decoder.mid_block.to(latents.dtype)
else:
latents = latents.float()
image = self.pipe.vae.decode(latents / self.pipe.vae.config.scaling_factor, return_dict=False)[0]
image = self.pipe.image_processor.postprocess(image, output_type="pil", do_denormalize=[True] * image.shape[0])
return np.asarray(image[0])
@torch.no_grad()
def run_diffusion_standard(
self,
text_embeddings: torch.FloatTensor,
latents_start: torch.FloatTensor,
idx_start: int = 0,
list_latents_mixing=None,
mixing_coeffs=0.0,
return_image: Optional[bool] = False):
if type(mixing_coeffs) == float:
list_mixing_coeffs = (1+self.num_inference_steps) * [mixing_coeffs]
elif type(mixing_coeffs) == list:
assert len(mixing_coeffs) == self.num_inference_steps, f"len(mixing_coeffs) {len(mixing_coeffs)} != self.num_inference_steps {self.num_inference_steps}"
list_mixing_coeffs = mixing_coeffs
else:
raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps")
if np.sum(list_mixing_coeffs) > 0:
assert len(list_latents_mixing) == self.num_inference_steps, f"len(list_latents_mixing) {len(list_latents_mixing)} != self.num_inference_steps {self.num_inference_steps}"
do_classifier_free_guidance = self.guidance_scale > 1.0
# diffusers bit wiggly
self.pipe.scheduler.set_timesteps(self.num_inference_steps-1, device=self.device)
timesteps = self.pipe.scheduler.timesteps
if len(timesteps) != self.num_inference_steps:
self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device)
timesteps = self.pipe.scheduler.timesteps
latents = latents_start.clone()
list_latents_out = []
num_warmup_steps = len(timesteps) - self.num_inference_steps * self.pipe.scheduler.order
for i, t in enumerate(timesteps):
# Set the right starting latents
if i < idx_start:
list_latents_out.append(None)
continue
elif i == idx_start:
latents = latents_start.clone()
# Mix latents
if i > 0 and list_mixing_coeffs[i] > 0:
latents_mixtarget = list_latents_mixing[i - 1].clone()
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.pipe.unet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
return_dict=False,
)[0]
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.pipe.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
list_latents_out.append(latents.clone())
if return_image:
return self.latent2image(latents)
else:
return list_latents_out
@torch.no_grad()
def run_diffusion_sd_xl(
self,
text_embeddings: list,
latents_start: torch.FloatTensor,
idx_start: int = 0,
list_latents_mixing=None,
mixing_coeffs=0.0,
return_image: Optional[bool] = False):
# prompt = "photo of a house"
# self.num_inference_steps = 50
# mixing_coeffs= 0.0
# idx_start= 0
# latents_start = self.get_noise()
# text_embeddings = self.pipe.encode_prompt(
# prompt,
# self.device,
# num_images_per_prompt=1,
# do_classifier_free_guidance=True,
# negative_prompt="",
# prompt_embeds=None,
# negative_prompt_embeds=None,
# pooled_prompt_embeds=None,
# negative_pooled_prompt_embeds=None,
# lora_scale=None,
# )
# 0. Default height and width to unet
original_size = (1024, 1024) # FIXME
crops_coords_top_left = (0, 0) # FIXME
target_size = original_size
batch_size = 1
eta = 0.0
num_images_per_prompt = 1
cross_attention_kwargs = None
generator = torch.Generator(device=self.device) # dummy generator
do_classifier_free_guidance = self.guidance_scale > 1.0
# 1. Check inputs. Raise error if not correct & 2. Define call parameters
# FIXME see if check_inputs use
if type(mixing_coeffs) == float:
list_mixing_coeffs = (1+self.num_inference_steps) * [mixing_coeffs]
elif type(mixing_coeffs) == list:
assert len(mixing_coeffs) == self.num_inference_steps, f"len(mixing_coeffs) {len(mixing_coeffs)} != self.num_inference_steps {self.num_inference_steps}"
list_mixing_coeffs = mixing_coeffs
else:
raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps")
if np.sum(list_mixing_coeffs) > 0:
assert len(list_latents_mixing) == self.num_inference_steps, f"len(list_latents_mixing) {len(list_latents_mixing)} != self.num_inference_steps {self.num_inference_steps}"
# 3. Encode input prompt (already encoded outside bc of mixing, just split here)
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = text_embeddings
# 4. Prepare timesteps
self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device)
timesteps = self.pipe.scheduler.timesteps
# 5. Prepare latent variables
latents = latents_start.clone()
list_latents_out = []
# 6. Prepare extra step kwargs. usedummy generator
extra_step_kwargs = self.pipe.prepare_extra_step_kwargs(generator, eta) # dummy
# 7. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds
add_time_ids = self.pipe._get_add_time_ids(
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds.to(self.device)
add_text_embeds = add_text_embeds.to(self.device)
add_time_ids = add_time_ids.to(self.device).repeat(batch_size * num_images_per_prompt, 1)
# 8. Denoising loop
for i, t in enumerate(timesteps):
# Set the right starting latents
if i < idx_start:
list_latents_out.append(None)
continue
elif i == idx_start:
latents = latents_start.clone()
# Mix latents for crossfeeding
if i > 0 and list_mixing_coeffs[i] > 0:
latents_mixtarget = list_latents_mixing[i - 1].clone()
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# Always scale latents
latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
noise_pred = self.pipe.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# FIXME guidance_rescale disabled
# compute the previous noisy sample x_t -> x_t-1
latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# Append latents
list_latents_out.append(latents.clone())
if return_image:
return self.latent2image(latents)
else:
return list_latents_out
#%%
if __name__ == "__main__":
pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-0.9"
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16)
pipe.to('cuda')
# xxx
self = DiffusersHolder(pipe)
# xxx
self.set_num_inference_steps(50)
self.set_dimensions(1536, 1024)
prompt = "photo of a beautiful cherry forest covered in white flowers, ambient light, very detailed, magic"
text_embeddings = self.get_text_embedding(prompt)
generator = torch.Generator(device=self.device).manual_seed(int(420))
latents_start = self.get_noise()
list_latents_1 = self.run_diffusion_sd_xl(text_embeddings, latents_start)
img_orig = self.latent2image(list_latents_1[-1])
# %%
"""
OPEN
- other examples
- kill upscaling? or keep?
- cleanup
- ldh
- sdh class
- diffusion holder
- check linting
- check docstrings
- fix readme
"""

View File

@ -20,33 +20,37 @@ import warnings
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
import warnings import warnings
from latent_blending import LatentBlending from latent_blending import LatentBlending
from stable_diffusion_holder import StableDiffusionHolder from diffusers_holder import DiffusersHolder
from huggingface_hub import hf_hub_download from diffusers import DiffusionPipeline
# %% First let us spawn a stable diffusion holder. Uncomment your version of choice. # %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
# fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt") # dh = DiffusersHolder("stabilityai/stable-diffusion-xl-base-0.9")
fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.ckpt") pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-0.9"
sdh = StableDiffusionHolder(fp_ckpt) pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16)
pipe.to('cuda')
dh = DiffusersHolder(pipe)
# %% Next let's set up all parameters # %% Next let's set up all parameters
depth_strength = 0.65 # Specifies how deep (in terms of diffusion iterations the first branching happens) depth_strength = 0.55 # Specifies how deep (in terms of diffusion iterations the first branching happens)
t_compute_max_allowed = 15 # Determines the quality of the transition in terms of compute time you grant it t_compute_max_allowed = 60 # Determines the quality of the transition in terms of compute time you grant it
fixed_seeds = [69731932, 504430820] fixed_seeds = [6913192, 504443080]
num_inference_steps = 50
prompt1 = "photo of a beautiful cherry forest covered in white flowers, ambient light, very detailed, magic" prompt1 = "underwater landscape, fish, und the sea, incredible detail, high resolution"
prompt2 = "photo of an golden statue with a funny hat, surrounded by ferns and vines, grainy analog photograph, mystical ambience, incredible detail" prompt2 = "rendering of an alien planet, strange plants, strange creatures, surreal"
fp_movie = 'movie_example1.mp4' fp_movie = 'movie_example1.mp4'
duration_transition = 12 # In seconds duration_transition = 12 # In seconds
# Spawn latent blending # Spawn latent blending
lb = LatentBlending(sdh) lb = LatentBlending(dh)
lb.set_prompt1(prompt1) lb.set_prompt1(prompt1)
lb.set_prompt2(prompt2) lb.set_prompt2(prompt2)
lb.set_dimensions(1536, 1024)
# Run latent blending # Run latent blending
lb.run_transition( lb.run_transition(
depth_strength=depth_strength, depth_strength=depth_strength,
num_inference_steps=num_inference_steps,
t_compute_max_allowed=t_compute_max_allowed, t_compute_max_allowed=t_compute_max_allowed,
fixed_seeds=fixed_seeds) fixed_seeds=fixed_seeds)

View File

@ -26,7 +26,6 @@ from tqdm.auto import tqdm
from PIL import Image from PIL import Image
from movie_util import MovieSaver from movie_util import MovieSaver
from typing import List, Optional from typing import List, Optional
from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentInpaintDiffusion
import lpips import lpips
from utils import interpolate_spherical, interpolate_linear, add_frames_linear_interp, yml_load, yml_save from utils import interpolate_spherical, interpolate_linear, add_frames_linear_interp, yml_load, yml_save
@ -34,7 +33,7 @@ from utils import interpolate_spherical, interpolate_linear, add_frames_linear_i
class LatentBlending(): class LatentBlending():
def __init__( def __init__(
self, self,
sdh: None, dh: None,
guidance_scale: float = 4, guidance_scale: float = 4,
guidance_scale_mid_damper: float = 0.5, guidance_scale_mid_damper: float = 0.5,
mid_compression_scaler: float = 1.2): mid_compression_scaler: float = 1.2):
@ -59,10 +58,10 @@ class LatentBlending():
and guidance_scale_mid_damper <= 1.0, \ and guidance_scale_mid_damper <= 1.0, \
f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}" f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}"
self.sdh = sdh self.dh = dh
self.device = self.sdh.device self.device = self.dh.device
self.width = self.sdh.width self.set_dimensions()
self.height = self.sdh.height
self.guidance_scale_mid_damper = guidance_scale_mid_damper self.guidance_scale_mid_damper = guidance_scale_mid_damper
self.mid_compression_scaler = mid_compression_scaler self.mid_compression_scaler = mid_compression_scaler
self.seed1 = 0 self.seed1 = 0
@ -86,40 +85,49 @@ class LatentBlending():
self.image1_lowres = None self.image1_lowres = None
self.image2_lowres = None self.image2_lowres = None
self.negative_prompt = None self.negative_prompt = None
self.num_inference_steps = self.sdh.num_inference_steps self.num_inference_steps = self.dh.num_inference_steps
self.noise_level_upscaling = 20 self.noise_level_upscaling = 20
self.list_injection_idx = None self.list_injection_idx = None
self.list_nmb_branches = None self.list_nmb_branches = None
# Mixing parameters # Mixing parameters
self.branch1_crossfeed_power = 0.1 self.branch1_crossfeed_power = 0.05
self.branch1_crossfeed_range = 0.6 self.branch1_crossfeed_range = 0.4
self.branch1_crossfeed_decay = 0.8 self.branch1_crossfeed_decay = 0.9
self.parental_crossfeed_power = 0.1 self.parental_crossfeed_power = 0.1
self.parental_crossfeed_range = 0.8 self.parental_crossfeed_range = 0.8
self.parental_crossfeed_power_decay = 0.8 self.parental_crossfeed_power_decay = 0.8
self.set_guidance_scale(guidance_scale) self.set_guidance_scale(guidance_scale)
self.init_mode() self.mode = 'standard'
# self.init_mode()
self.multi_transition_img_first = None self.multi_transition_img_first = None
self.multi_transition_img_last = None self.multi_transition_img_last = None
self.dt_per_diff = 0 self.dt_per_diff = 0
self.spatial_mask = None self.spatial_mask = None
self.lpips = lpips.LPIPS(net='alex').cuda(self.device) self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
def init_mode(self): self.set_prompt1("")
r""" self.set_prompt2("")
Sets the operational mode. Currently supported are standard, inpainting and x4 upscaling.
""" # def init_mode(self):
if isinstance(self.sdh.model, LatentUpscaleDiffusion): # r"""
self.mode = 'upscale' # Sets the operational mode. Currently supported are standard, inpainting and x4 upscaling.
elif isinstance(self.sdh.model, LatentInpaintDiffusion): # """
self.sdh.image_source = None # if isinstance(self.dh.model, LatentUpscaleDiffusion):
self.sdh.mask_image = None # self.mode = 'upscale'
self.mode = 'inpaint' # elif isinstance(self.dh.model, LatentInpaintDiffusion):
else: # self.dh.image_source = None
self.mode = 'standard' # self.dh.mask_image = None
# self.mode = 'inpaint'
# else:
# self.mode = 'standard'
def set_dimensions(self, width=None, height=None):
self.dh.set_dimensions(width, height)
def set_guidance_scale(self, guidance_scale): def set_guidance_scale(self, guidance_scale):
r""" r"""
@ -127,13 +135,13 @@ class LatentBlending():
""" """
self.guidance_scale_base = guidance_scale self.guidance_scale_base = guidance_scale
self.guidance_scale = guidance_scale self.guidance_scale = guidance_scale
self.sdh.guidance_scale = guidance_scale self.dh.guidance_scale = guidance_scale
def set_negative_prompt(self, negative_prompt): def set_negative_prompt(self, negative_prompt):
r"""Set the negative prompt. Currenty only one negative prompt is supported r"""Set the negative prompt. Currenty only one negative prompt is supported
""" """
self.negative_prompt = negative_prompt self.negative_prompt = negative_prompt
self.sdh.set_negative_prompt(negative_prompt) self.dh.set_negative_prompt(negative_prompt)
def set_guidance_mid_dampening(self, fract_mixing): def set_guidance_mid_dampening(self, fract_mixing):
r""" r"""
@ -144,7 +152,7 @@ class LatentBlending():
max_guidance_reduction = self.guidance_scale_base * (1 - self.guidance_scale_mid_damper) - 1 max_guidance_reduction = self.guidance_scale_base * (1 - self.guidance_scale_mid_damper) - 1
guidance_scale_effective = self.guidance_scale_base - max_guidance_reduction * mid_factor guidance_scale_effective = self.guidance_scale_base - max_guidance_reduction * mid_factor
self.guidance_scale = guidance_scale_effective self.guidance_scale = guidance_scale_effective
self.sdh.guidance_scale = guidance_scale_effective self.dh.guidance_scale = guidance_scale_effective
def set_branch1_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay): def set_branch1_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay):
r""" r"""
@ -265,7 +273,7 @@ class LatentBlending():
# Ensure correct num_inference_steps in holder # Ensure correct num_inference_steps in holder
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self.sdh.num_inference_steps = num_inference_steps self.dh.set_num_inference_steps(num_inference_steps)
# Compute / Recycle first image # Compute / Recycle first image
if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps: if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps:
@ -282,7 +290,7 @@ class LatentBlending():
# Reset the tree, injecting the edge latents1/2 we just generated/recycled # Reset the tree, injecting the edge latents1/2 we just generated/recycled
self.tree_latents = [list_latents1, list_latents2] self.tree_latents = [list_latents1, list_latents2]
self.tree_fracts = [0.0, 1.0] self.tree_fracts = [0.0, 1.0]
self.tree_final_imgs = [self.sdh.latent2image((self.tree_latents[0][-1])), self.sdh.latent2image((self.tree_latents[-1][-1]))] self.tree_final_imgs = [self.dh.latent2image((self.tree_latents[0][-1])), self.dh.latent2image((self.tree_latents[-1][-1]))]
self.tree_idx_injection = [0, 0] self.tree_idx_injection = [0, 0]
# Hard-fix. Apply spatial mask only for list_latents2 but not for transition. WIP... # Hard-fix. Apply spatial mask only for list_latents2 but not for transition. WIP...
@ -325,7 +333,7 @@ class LatentBlending():
self.dt_per_diff = (t1 - t0) / self.num_inference_steps self.dt_per_diff = (t1 - t0) / self.num_inference_steps
self.tree_latents[0] = list_latents1 self.tree_latents[0] = list_latents1
if return_image: if return_image:
return self.sdh.latent2image(list_latents1[-1]) return self.dh.latent2image(list_latents1[-1])
else: else:
return list_latents1 return list_latents1
@ -357,7 +365,7 @@ class LatentBlending():
self.tree_latents[-1] = list_latents2 self.tree_latents[-1] = list_latents2
if return_image: if return_image:
return self.sdh.latent2image(list_latents2[-1]) return self.dh.latent2image(list_latents2[-1])
else: else:
return list_latents2 return list_latents2
@ -511,55 +519,17 @@ class LatentBlending():
""" """
b_parent1, b_parent2 = self.get_closest_idx(fract_mixing) b_parent1, b_parent2 = self.get_closest_idx(fract_mixing)
self.tree_latents.insert(b_parent1 + 1, list_latents) self.tree_latents.insert(b_parent1 + 1, list_latents)
self.tree_final_imgs.insert(b_parent1 + 1, self.sdh.latent2image(list_latents[-1])) self.tree_final_imgs.insert(b_parent1 + 1, self.dh.latent2image(list_latents[-1]))
self.tree_fracts.insert(b_parent1 + 1, fract_mixing) self.tree_fracts.insert(b_parent1 + 1, fract_mixing)
self.tree_idx_injection.insert(b_parent1 + 1, idx_injection) self.tree_idx_injection.insert(b_parent1 + 1, idx_injection)
def get_spatial_mask_template(self):
r"""
Experimental helper function to get a spatial mask template.
"""
shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
C, H, W = shape_latents
return np.ones((H, W))
def set_spatial_mask(self, img_mask):
r"""
Experimental helper function to set a spatial mask.
The mask forces latents to be overwritten.
Args:
img_mask:
mask image [0,1]. You can get a template using get_spatial_mask_template
"""
shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
C, H, W = shape_latents
img_mask = np.asarray(img_mask)
assert len(img_mask.shape) == 2, "Currently, only 2D images are supported as mask"
img_mask = np.clip(img_mask, 0, 1)
assert img_mask.shape[0] == H, f"Your mask needs to be of dimension {H} x {W}"
assert img_mask.shape[1] == W, f"Your mask needs to be of dimension {H} x {W}"
spatial_mask = torch.from_numpy(img_mask).to(device=self.device)
spatial_mask = torch.unsqueeze(spatial_mask, 0)
spatial_mask = spatial_mask.repeat((C, 1, 1))
spatial_mask = torch.unsqueeze(spatial_mask, 0)
self.spatial_mask = spatial_mask
def get_noise(self, seed): def get_noise(self, seed):
r""" r"""
Helper function to get noise given seed. Helper function to get noise given seed.
Args: Args:
seed: int seed: int
""" """
generator = torch.Generator(device=self.sdh.device).manual_seed(int(seed)) return self.dh.get_noise(seed, self.mode)
if self.mode == 'standard':
shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
C, H, W = shape_latents
elif self.mode == 'upscale':
w = self.image1_lowres.size[0]
h = self.image1_lowres.size[1]
shape_latents = [self.sdh.model.channels, h, w]
C, H, W = shape_latents
return torch.randn((1, C, H, W), generator=generator, device=self.sdh.device)
@torch.no_grad() @torch.no_grad()
def run_diffusion( def run_diffusion(
@ -590,32 +560,41 @@ class LatentBlending():
""" """
# Ensure correct num_inference_steps in Holder # Ensure correct num_inference_steps in Holder
self.sdh.num_inference_steps = self.num_inference_steps self.dh.set_num_inference_steps(self.num_inference_steps)
assert type(list_conditionings) is list, "list_conditionings need to be a list" assert type(list_conditionings) is list, "list_conditionings need to be a list"
if self.mode == 'standard': if self.dh.use_sd_xl:
text_embeddings = list_conditionings[0] text_embeddings = list_conditionings[0]
return self.sdh.run_diffusion_standard( return self.dh.run_diffusion_sd_xl(
text_embeddings=text_embeddings, text_embeddings=text_embeddings,
latents_start=latents_start, latents_start=latents_start,
idx_start=idx_start, idx_start=idx_start,
list_latents_mixing=list_latents_mixing, list_latents_mixing=list_latents_mixing,
mixing_coeffs=mixing_coeffs, mixing_coeffs=mixing_coeffs,
spatial_mask=self.spatial_mask,
return_image=return_image) return_image=return_image)
elif self.mode == 'upscale': else:
cond = list_conditionings[0] text_embeddings = list_conditionings[0]
uc_full = list_conditionings[1] return self.dh.run_diffusion_standard(
return self.sdh.run_diffusion_upscaling( text_embeddings=text_embeddings,
cond,
uc_full,
latents_start=latents_start, latents_start=latents_start,
idx_start=idx_start, idx_start=idx_start,
list_latents_mixing=list_latents_mixing, list_latents_mixing=list_latents_mixing,
mixing_coeffs=mixing_coeffs, mixing_coeffs=mixing_coeffs,
return_image=return_image) return_image=return_image)
# elif self.mode == 'upscale':
# cond = list_conditionings[0]
# uc_full = list_conditionings[1]
# return self.dh.run_diffusion_upscaling(
# cond,
# uc_full,
# latents_start=latents_start,
# idx_start=idx_start,
# list_latents_mixing=list_latents_mixing,
# mixing_coeffs=mixing_coeffs,
# return_image=return_image)
def run_upscaling( def run_upscaling(
self, self,
dp_img: str, dp_img: str,
@ -670,8 +649,8 @@ class LatentBlending():
imgs_lowres.append(Image.open(fp_img_lowres)) imgs_lowres.append(Image.open(fp_img_lowres))
# set up upscaling # set up upscaling
text_embeddingA = self.sdh.get_text_embedding(prompt1) text_embeddingA = self.dh.get_text_embedding(prompt1)
text_embeddingB = self.sdh.get_text_embedding(prompt2) text_embeddingB = self.dh.get_text_embedding(prompt2)
list_fract_mixing = np.linspace(0, 1, nmb_max_branches_lowres - 1) list_fract_mixing = np.linspace(0, 1, nmb_max_branches_lowres - 1)
for i in range(nmb_max_branches_lowres - 1): for i in range(nmb_max_branches_lowres - 1):
print(f"Starting movie segment {i+1}/{nmb_max_branches_lowres-1}") print(f"Starting movie segment {i+1}/{nmb_max_branches_lowres-1}")
@ -701,23 +680,35 @@ class LatentBlending():
@torch.no_grad() @torch.no_grad()
def get_mixed_conditioning(self, fract_mixing): def get_mixed_conditioning(self, fract_mixing):
if self.mode == 'standard': if self.dh.use_sd_xl:
text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing) text_embeddings_mix = []
for i in range(len(self.text_embedding1)):
text_embeddings_mix.append(interpolate_linear(self.text_embedding1[i], self.text_embedding2[i], fract_mixing))
list_conditionings = [text_embeddings_mix] list_conditionings = [text_embeddings_mix]
elif self.mode == 'inpaint':
text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing)
list_conditionings = [text_embeddings_mix]
elif self.mode == 'upscale':
text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing)
cond, uc_full = self.sdh.get_cond_upscaling(self.image1_lowres, text_embeddings_mix, self.noise_level_upscaling)
condB, uc_fullB = self.sdh.get_cond_upscaling(self.image2_lowres, text_embeddings_mix, self.noise_level_upscaling)
cond['c_concat'][0] = interpolate_spherical(cond['c_concat'][0], condB['c_concat'][0], fract_mixing)
uc_full['c_concat'][0] = interpolate_spherical(uc_full['c_concat'][0], uc_fullB['c_concat'][0], fract_mixing)
list_conditionings = [cond, uc_full]
else: else:
raise ValueError(f"mix_conditioning: unknown mode {self.mode}") text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing)
list_conditionings = [text_embeddings_mix]
return list_conditionings return list_conditionings
# @torch.no_grad()
# def get_mixed_conditioning(self, fract_mixing):
# if self.mode == 'standard':
# text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing)
# list_conditionings = [text_embeddings_mix]
# elif self.mode == 'inpaint':
# text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing)
# list_conditionings = [text_embeddings_mix]
# elif self.mode == 'upscale':
# text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing)
# cond, uc_full = self.dh.get_cond_upscaling(self.image1_lowres, text_embeddings_mix, self.noise_level_upscaling)
# condB, uc_fullB = self.dh.get_cond_upscaling(self.image2_lowres, text_embeddings_mix, self.noise_level_upscaling)
# cond['c_concat'][0] = interpolate_spherical(cond['c_concat'][0], condB['c_concat'][0], fract_mixing)
# uc_full['c_concat'][0] = interpolate_spherical(uc_full['c_concat'][0], uc_fullB['c_concat'][0], fract_mixing)
# list_conditionings = [cond, uc_full]
# else:
# raise ValueError(f"mix_conditioning: unknown mode {self.mode}")
# return list_conditionings
@torch.no_grad() @torch.no_grad()
def get_text_embeddings( def get_text_embeddings(
self, self,
@ -729,7 +720,7 @@ class LatentBlending():
prompt: str prompt: str
ABC trending on artstation painted by Old Greg. ABC trending on artstation painted by Old Greg.
""" """
return self.sdh.get_text_embedding(prompt) return self.dh.get_text_embedding(prompt)
def write_imgs_transition(self, dp_img): def write_imgs_transition(self, dp_img):
r""" r"""
@ -766,7 +757,7 @@ class LatentBlending():
# Save as MP4 # Save as MP4
if os.path.isfile(fp_movie): if os.path.isfile(fp_movie):
os.remove(fp_movie) os.remove(fp_movie)
ms = MovieSaver(fp_movie, fps=fps, shape_hw=[self.sdh.height, self.sdh.width]) ms = MovieSaver(fp_movie, fps=fps, shape_hw=[self.dh.height_img, self.dh.width_img])
for img in tqdm(imgs_transition_ext): for img in tqdm(imgs_transition_ext):
ms.write_frame(img) ms.write_frame(img)
ms.finalize() ms.finalize()
@ -811,7 +802,7 @@ class LatentBlending():
Set a the seed for a fresh start. Set a the seed for a fresh start.
""" """
self.seed = seed self.seed = seed
self.sdh.seed = seed self.dh.seed = seed
def set_width(self, width): def set_width(self, width):
r""" r"""
@ -819,7 +810,7 @@ class LatentBlending():
""" """
assert np.mod(width, 64) == 0, "set_width: value needs to be divisible by 64" assert np.mod(width, 64) == 0, "set_width: value needs to be divisible by 64"
self.width = width self.width = width
self.sdh.width = width self.dh.width = width
def set_height(self, height): def set_height(self, height):
r""" r"""
@ -827,7 +818,7 @@ class LatentBlending():
""" """
assert np.mod(height, 64) == 0, "set_height: value needs to be divisible by 64" assert np.mod(height, 64) == 0, "set_height: value needs to be divisible by 64"
self.height = height self.height = height
self.sdh.height = height self.dh.height = height
def swap_forward(self): def swap_forward(self):
r""" r"""

View File

@ -156,6 +156,24 @@ class StableDiffusionHolder:
self.height = 512 self.height = 512
self.width = 512 self.width = 512
def get_noise(self, seed, mode='standard'):
r"""
Helper function to get noise given seed.
Args:
seed: int
"""
generator = torch.Generator(device=self.device).manual_seed(int(seed))
if mode == 'standard':
shape_latents = [self.C, self.height // self.f, self.width // self.f]
C, H, W = shape_latents
elif mode == 'upscale':
w = self.image1_lowres.size[0]
h = self.image1_lowres.size[1]
shape_latents = [self.model.channels, h, w]
C, H, W = shape_latents
return torch.randn((1, C, H, W), generator=generator, device=self.device)
def set_negative_prompt(self, negative_prompt): def set_negative_prompt(self, negative_prompt):
r"""Set the negative prompt. Currenty only one negative prompt is supported r"""Set the negative prompt. Currenty only one negative prompt is supported
""" """