From 76f89cb8367503dffbf88ec8f1cbca02e256f81c Mon Sep 17 00:00:00 2001 From: Johannes Stelzer Date: Thu, 20 Jul 2023 13:49:19 +0200 Subject: [PATCH] diffusers, forced sd xl --- diffusers_holder.py | 416 +++++++++++++++++++++++++++++++++++++ example1_standard.py | 28 +-- latent_blending.py | 191 ++++++++--------- stable_diffusion_holder.py | 18 ++ 4 files changed, 541 insertions(+), 112 deletions(-) create mode 100644 diffusers_holder.py diff --git a/diffusers_holder.py b/diffusers_holder.py new file mode 100644 index 0000000..5a97642 --- /dev/null +++ b/diffusers_holder.py @@ -0,0 +1,416 @@ +# Copyright 2022 Lunar Ring. All rights reserved. +# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch +torch.backends.cudnn.benchmark = False +torch.set_grad_enabled(False) +import numpy as np +import warnings +warnings.filterwarnings('ignore') +import warnings +import torch +from PIL import Image +import torch +from typing import Optional +from torch import autocast +from contextlib import nullcontext +from utils import interpolate_spherical +from diffusers import DiffusionPipeline +from diffusers.models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) + + +class DiffusersHolder(): + def __init__(self, pipe): + # Base settings + self.negative_prompt = "" + self.guidance_scale = 5.0 + self.num_inference_steps = 30 + + # Check if valid pipe + self.pipe = pipe + self.device = str(pipe._execution_device) + self.init_type_pipe() + self.init_dtype() + + self.width_latent = self.pipe.unet.config.sample_size + self.height_latent = self.pipe.unet.config.sample_size + + + + def init_type_pipe(self): + self.type_pipe = "StableDiffusionXLPipeline" + if self.type_pipe == "StableDiffusionXLPipeline": + self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device) + self.use_sd_xl = True + else: + self.use_sd_xl = False + + def init_dtype(self): + if self.type_pipe == "StableDiffusionXLPipeline": + prompt_embeds, _, _, _ = self.pipe.encode_prompt("test") + self.dtype = prompt_embeds.dtype + + def set_num_inference_steps(self, num_inference_steps): + self.num_inference_steps = num_inference_steps + if self.use_sd_xl: + self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device) + + + def set_dimensions(self, width, height): + s = self.pipe.vae_scale_factor + if width is None: + self.width_latent = self.pipe.unet.config.sample_size + self.width_img = self.width_latent * self.pipe.vae_scale_factor + else: + self.width_img = int(round(width / s) * s) + self.width_latent = int(self.width_img / s) + + if height is None: + self.height_latent = self.pipe.unet.config.sample_size + self.height_img = self.width_latent * self.pipe.vae_scale_factor + else: + self.height_img = int(round(height / s) * s) + self.height_latent = int(self.height_img / s) + + + def set_negative_prompt(self, negative_prompt): + r"""Set the negative prompt. Currenty only one negative prompt is supported + """ + if isinstance(negative_prompt, str): + self.negative_prompt = [negative_prompt] + else: + self.negative_prompt = negative_prompt + + if len(self.negative_prompt) > 1: + self.negative_prompt = [self.negative_prompt[0]] + + def get_text_embedding(self, prompt, do_classifier_free_guidance=True): + if self.use_sd_xl: + pr_encoder = self.pipe.encode_prompt + else: + pr_encoder = self.pipe._encode_prompt + + prompt_embeds = pr_encoder( + prompt, + self.device, + 1, + do_classifier_free_guidance, + negative_prompt=self.negative_prompt, + prompt_embeds=None, + negative_prompt_embeds=None, + lora_scale=None, + ) + return prompt_embeds + + + def get_noise(self, seed=420, mode=None): + H = self.height_latent + W = self.width_latent + C = self.pipe.unet.config.in_channels + generator = torch.Generator(device=self.device).manual_seed(int(seed)) + latents = torch.randn((1, C, H, W), generator=generator, dtype=self.dtype, device=self.device) + if self.use_sd_xl: + latents = latents * self.pipe.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def latent2image( + self, + latents: torch.FloatTensor): + r""" + Returns an image provided a latent representation from diffusion. + Args: + latents: torch.FloatTensor + Result of the diffusion process. + """ + if self.use_sd_xl: + # make sure the VAE is in float32 mode, as it overflows in float16 + self.pipe.vae.to(dtype=torch.float32) + + use_torch_2_0_or_xformers = isinstance( + self.pipe.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.pipe.vae.post_quant_conv.to(latents.dtype) + self.pipe.vae.decoder.conv_in.to(latents.dtype) + self.pipe.vae.decoder.mid_block.to(latents.dtype) + else: + latents = latents.float() + + image = self.pipe.vae.decode(latents / self.pipe.vae.config.scaling_factor, return_dict=False)[0] + image = self.pipe.image_processor.postprocess(image, output_type="pil", do_denormalize=[True] * image.shape[0]) + + + + return np.asarray(image[0]) + + @torch.no_grad() + def run_diffusion_standard( + self, + text_embeddings: torch.FloatTensor, + latents_start: torch.FloatTensor, + idx_start: int = 0, + list_latents_mixing=None, + mixing_coeffs=0.0, + return_image: Optional[bool] = False): + + if type(mixing_coeffs) == float: + list_mixing_coeffs = (1+self.num_inference_steps) * [mixing_coeffs] + elif type(mixing_coeffs) == list: + assert len(mixing_coeffs) == self.num_inference_steps, f"len(mixing_coeffs) {len(mixing_coeffs)} != self.num_inference_steps {self.num_inference_steps}" + list_mixing_coeffs = mixing_coeffs + else: + raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps") + + if np.sum(list_mixing_coeffs) > 0: + assert len(list_latents_mixing) == self.num_inference_steps, f"len(list_latents_mixing) {len(list_latents_mixing)} != self.num_inference_steps {self.num_inference_steps}" + + do_classifier_free_guidance = self.guidance_scale > 1.0 + + # diffusers bit wiggly + self.pipe.scheduler.set_timesteps(self.num_inference_steps-1, device=self.device) + timesteps = self.pipe.scheduler.timesteps + + if len(timesteps) != self.num_inference_steps: + self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device) + timesteps = self.pipe.scheduler.timesteps + + latents = latents_start.clone() + list_latents_out = [] + + num_warmup_steps = len(timesteps) - self.num_inference_steps * self.pipe.scheduler.order + for i, t in enumerate(timesteps): + # Set the right starting latents + if i < idx_start: + list_latents_out.append(None) + continue + elif i == idx_start: + latents = latents_start.clone() + # Mix latents + if i > 0 and list_mixing_coeffs[i] > 0: + latents_mixtarget = list_latents_mixing[i - 1].clone() + latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i]) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.pipe.unet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings, + return_dict=False, + )[0] + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.pipe.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + list_latents_out.append(latents.clone()) + + if return_image: + return self.latent2image(latents) + else: + return list_latents_out + + @torch.no_grad() + def run_diffusion_sd_xl( + self, + text_embeddings: list, + latents_start: torch.FloatTensor, + idx_start: int = 0, + list_latents_mixing=None, + mixing_coeffs=0.0, + return_image: Optional[bool] = False): + + # prompt = "photo of a house" + # self.num_inference_steps = 50 + # mixing_coeffs= 0.0 + # idx_start= 0 + # latents_start = self.get_noise() + # text_embeddings = self.pipe.encode_prompt( + # prompt, + # self.device, + # num_images_per_prompt=1, + # do_classifier_free_guidance=True, + # negative_prompt="", + # prompt_embeds=None, + # negative_prompt_embeds=None, + # pooled_prompt_embeds=None, + # negative_pooled_prompt_embeds=None, + # lora_scale=None, + # ) + + + # 0. Default height and width to unet + original_size = (1024, 1024) # FIXME + crops_coords_top_left = (0, 0) # FIXME + target_size = original_size + batch_size = 1 + eta = 0.0 + num_images_per_prompt = 1 + cross_attention_kwargs = None + generator = torch.Generator(device=self.device) # dummy generator + do_classifier_free_guidance = self.guidance_scale > 1.0 + + # 1. Check inputs. Raise error if not correct & 2. Define call parameters + # FIXME see if check_inputs use + if type(mixing_coeffs) == float: + list_mixing_coeffs = (1+self.num_inference_steps) * [mixing_coeffs] + elif type(mixing_coeffs) == list: + assert len(mixing_coeffs) == self.num_inference_steps, f"len(mixing_coeffs) {len(mixing_coeffs)} != self.num_inference_steps {self.num_inference_steps}" + list_mixing_coeffs = mixing_coeffs + else: + raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps") + + if np.sum(list_mixing_coeffs) > 0: + assert len(list_latents_mixing) == self.num_inference_steps, f"len(list_latents_mixing) {len(list_latents_mixing)} != self.num_inference_steps {self.num_inference_steps}" + + # 3. Encode input prompt (already encoded outside bc of mixing, just split here) + prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = text_embeddings + + # 4. Prepare timesteps + self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device) + timesteps = self.pipe.scheduler.timesteps + + # 5. Prepare latent variables + latents = latents_start.clone() + list_latents_out = [] + + # 6. Prepare extra step kwargs. usedummy generator + extra_step_kwargs = self.pipe.prepare_extra_step_kwargs(generator, eta) # dummy + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self.pipe._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(self.device) + add_text_embeds = add_text_embeds.to(self.device) + add_time_ids = add_time_ids.to(self.device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + for i, t in enumerate(timesteps): + # Set the right starting latents + if i < idx_start: + list_latents_out.append(None) + continue + elif i == idx_start: + latents = latents_start.clone() + + + + # Mix latents for crossfeeding + if i > 0 and list_mixing_coeffs[i] > 0: + latents_mixtarget = list_latents_mixing[i - 1].clone() + latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i]) + + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + # Always scale latents + latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.pipe.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # FIXME guidance_rescale disabled + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # Append latents + list_latents_out.append(latents.clone()) + + if return_image: + return self.latent2image(latents) + else: + return list_latents_out + + + + +#%% + +if __name__ == "__main__": + pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-0.9" + pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16) + pipe.to('cuda') + # xxx + self = DiffusersHolder(pipe) + # xxx + self.set_num_inference_steps(50) + self.set_dimensions(1536, 1024) + prompt = "photo of a beautiful cherry forest covered in white flowers, ambient light, very detailed, magic" + text_embeddings = self.get_text_embedding(prompt) + generator = torch.Generator(device=self.device).manual_seed(int(420)) + latents_start = self.get_noise() + list_latents_1 = self.run_diffusion_sd_xl(text_embeddings, latents_start) + img_orig = self.latent2image(list_latents_1[-1]) + + + + # %% + + """ + OPEN + - other examples + - kill upscaling? or keep? + - cleanup + - ldh + - sdh class + - diffusion holder + - check linting + - check docstrings + - fix readme + """ + + + diff --git a/example1_standard.py b/example1_standard.py index 704f93a..fae0a18 100644 --- a/example1_standard.py +++ b/example1_standard.py @@ -20,33 +20,37 @@ import warnings warnings.filterwarnings('ignore') import warnings from latent_blending import LatentBlending -from stable_diffusion_holder import StableDiffusionHolder -from huggingface_hub import hf_hub_download - +from diffusers_holder import DiffusersHolder +from diffusers import DiffusionPipeline # %% First let us spawn a stable diffusion holder. Uncomment your version of choice. -# fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt") -fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.ckpt") -sdh = StableDiffusionHolder(fp_ckpt) +# dh = DiffusersHolder("stabilityai/stable-diffusion-xl-base-0.9") +pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-0.9" +pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16) +pipe.to('cuda') +dh = DiffusersHolder(pipe) # %% Next let's set up all parameters -depth_strength = 0.65 # Specifies how deep (in terms of diffusion iterations the first branching happens) -t_compute_max_allowed = 15 # Determines the quality of the transition in terms of compute time you grant it -fixed_seeds = [69731932, 504430820] +depth_strength = 0.55 # Specifies how deep (in terms of diffusion iterations the first branching happens) +t_compute_max_allowed = 60 # Determines the quality of the transition in terms of compute time you grant it +fixed_seeds = [6913192, 504443080] +num_inference_steps = 50 -prompt1 = "photo of a beautiful cherry forest covered in white flowers, ambient light, very detailed, magic" -prompt2 = "photo of an golden statue with a funny hat, surrounded by ferns and vines, grainy analog photograph, mystical ambience, incredible detail" +prompt1 = "underwater landscape, fish, und the sea, incredible detail, high resolution" +prompt2 = "rendering of an alien planet, strange plants, strange creatures, surreal" fp_movie = 'movie_example1.mp4' duration_transition = 12 # In seconds # Spawn latent blending -lb = LatentBlending(sdh) +lb = LatentBlending(dh) lb.set_prompt1(prompt1) lb.set_prompt2(prompt2) +lb.set_dimensions(1536, 1024) # Run latent blending lb.run_transition( depth_strength=depth_strength, + num_inference_steps=num_inference_steps, t_compute_max_allowed=t_compute_max_allowed, fixed_seeds=fixed_seeds) diff --git a/latent_blending.py b/latent_blending.py index 3899074..3645c0a 100644 --- a/latent_blending.py +++ b/latent_blending.py @@ -26,7 +26,6 @@ from tqdm.auto import tqdm from PIL import Image from movie_util import MovieSaver from typing import List, Optional -from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentInpaintDiffusion import lpips from utils import interpolate_spherical, interpolate_linear, add_frames_linear_interp, yml_load, yml_save @@ -34,7 +33,7 @@ from utils import interpolate_spherical, interpolate_linear, add_frames_linear_i class LatentBlending(): def __init__( self, - sdh: None, + dh: None, guidance_scale: float = 4, guidance_scale_mid_damper: float = 0.5, mid_compression_scaler: float = 1.2): @@ -59,10 +58,10 @@ class LatentBlending(): and guidance_scale_mid_damper <= 1.0, \ f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}" - self.sdh = sdh - self.device = self.sdh.device - self.width = self.sdh.width - self.height = self.sdh.height + self.dh = dh + self.device = self.dh.device + self.set_dimensions() + self.guidance_scale_mid_damper = guidance_scale_mid_damper self.mid_compression_scaler = mid_compression_scaler self.seed1 = 0 @@ -86,40 +85,49 @@ class LatentBlending(): self.image1_lowres = None self.image2_lowres = None self.negative_prompt = None - self.num_inference_steps = self.sdh.num_inference_steps + self.num_inference_steps = self.dh.num_inference_steps self.noise_level_upscaling = 20 self.list_injection_idx = None self.list_nmb_branches = None # Mixing parameters - self.branch1_crossfeed_power = 0.1 - self.branch1_crossfeed_range = 0.6 - self.branch1_crossfeed_decay = 0.8 + self.branch1_crossfeed_power = 0.05 + self.branch1_crossfeed_range = 0.4 + self.branch1_crossfeed_decay = 0.9 self.parental_crossfeed_power = 0.1 self.parental_crossfeed_range = 0.8 self.parental_crossfeed_power_decay = 0.8 self.set_guidance_scale(guidance_scale) - self.init_mode() + self.mode = 'standard' + # self.init_mode() self.multi_transition_img_first = None self.multi_transition_img_last = None self.dt_per_diff = 0 self.spatial_mask = None self.lpips = lpips.LPIPS(net='alex').cuda(self.device) + + self.set_prompt1("") + self.set_prompt2("") - def init_mode(self): - r""" - Sets the operational mode. Currently supported are standard, inpainting and x4 upscaling. - """ - if isinstance(self.sdh.model, LatentUpscaleDiffusion): - self.mode = 'upscale' - elif isinstance(self.sdh.model, LatentInpaintDiffusion): - self.sdh.image_source = None - self.sdh.mask_image = None - self.mode = 'inpaint' - else: - self.mode = 'standard' + # def init_mode(self): + # r""" + # Sets the operational mode. Currently supported are standard, inpainting and x4 upscaling. + # """ + # if isinstance(self.dh.model, LatentUpscaleDiffusion): + # self.mode = 'upscale' + # elif isinstance(self.dh.model, LatentInpaintDiffusion): + # self.dh.image_source = None + # self.dh.mask_image = None + # self.mode = 'inpaint' + # else: + # self.mode = 'standard' + + def set_dimensions(self, width=None, height=None): + self.dh.set_dimensions(width, height) + + def set_guidance_scale(self, guidance_scale): r""" @@ -127,13 +135,13 @@ class LatentBlending(): """ self.guidance_scale_base = guidance_scale self.guidance_scale = guidance_scale - self.sdh.guidance_scale = guidance_scale + self.dh.guidance_scale = guidance_scale def set_negative_prompt(self, negative_prompt): r"""Set the negative prompt. Currenty only one negative prompt is supported """ self.negative_prompt = negative_prompt - self.sdh.set_negative_prompt(negative_prompt) + self.dh.set_negative_prompt(negative_prompt) def set_guidance_mid_dampening(self, fract_mixing): r""" @@ -144,7 +152,7 @@ class LatentBlending(): max_guidance_reduction = self.guidance_scale_base * (1 - self.guidance_scale_mid_damper) - 1 guidance_scale_effective = self.guidance_scale_base - max_guidance_reduction * mid_factor self.guidance_scale = guidance_scale_effective - self.sdh.guidance_scale = guidance_scale_effective + self.dh.guidance_scale = guidance_scale_effective def set_branch1_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay): r""" @@ -265,7 +273,7 @@ class LatentBlending(): # Ensure correct num_inference_steps in holder self.num_inference_steps = num_inference_steps - self.sdh.num_inference_steps = num_inference_steps + self.dh.set_num_inference_steps(num_inference_steps) # Compute / Recycle first image if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps: @@ -282,7 +290,7 @@ class LatentBlending(): # Reset the tree, injecting the edge latents1/2 we just generated/recycled self.tree_latents = [list_latents1, list_latents2] self.tree_fracts = [0.0, 1.0] - self.tree_final_imgs = [self.sdh.latent2image((self.tree_latents[0][-1])), self.sdh.latent2image((self.tree_latents[-1][-1]))] + self.tree_final_imgs = [self.dh.latent2image((self.tree_latents[0][-1])), self.dh.latent2image((self.tree_latents[-1][-1]))] self.tree_idx_injection = [0, 0] # Hard-fix. Apply spatial mask only for list_latents2 but not for transition. WIP... @@ -325,7 +333,7 @@ class LatentBlending(): self.dt_per_diff = (t1 - t0) / self.num_inference_steps self.tree_latents[0] = list_latents1 if return_image: - return self.sdh.latent2image(list_latents1[-1]) + return self.dh.latent2image(list_latents1[-1]) else: return list_latents1 @@ -357,7 +365,7 @@ class LatentBlending(): self.tree_latents[-1] = list_latents2 if return_image: - return self.sdh.latent2image(list_latents2[-1]) + return self.dh.latent2image(list_latents2[-1]) else: return list_latents2 @@ -511,55 +519,17 @@ class LatentBlending(): """ b_parent1, b_parent2 = self.get_closest_idx(fract_mixing) self.tree_latents.insert(b_parent1 + 1, list_latents) - self.tree_final_imgs.insert(b_parent1 + 1, self.sdh.latent2image(list_latents[-1])) + self.tree_final_imgs.insert(b_parent1 + 1, self.dh.latent2image(list_latents[-1])) self.tree_fracts.insert(b_parent1 + 1, fract_mixing) self.tree_idx_injection.insert(b_parent1 + 1, idx_injection) - def get_spatial_mask_template(self): - r""" - Experimental helper function to get a spatial mask template. - """ - shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f] - C, H, W = shape_latents - return np.ones((H, W)) - - def set_spatial_mask(self, img_mask): - r""" - Experimental helper function to set a spatial mask. - The mask forces latents to be overwritten. - Args: - img_mask: - mask image [0,1]. You can get a template using get_spatial_mask_template - """ - shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f] - C, H, W = shape_latents - img_mask = np.asarray(img_mask) - assert len(img_mask.shape) == 2, "Currently, only 2D images are supported as mask" - img_mask = np.clip(img_mask, 0, 1) - assert img_mask.shape[0] == H, f"Your mask needs to be of dimension {H} x {W}" - assert img_mask.shape[1] == W, f"Your mask needs to be of dimension {H} x {W}" - spatial_mask = torch.from_numpy(img_mask).to(device=self.device) - spatial_mask = torch.unsqueeze(spatial_mask, 0) - spatial_mask = spatial_mask.repeat((C, 1, 1)) - spatial_mask = torch.unsqueeze(spatial_mask, 0) - self.spatial_mask = spatial_mask - def get_noise(self, seed): r""" Helper function to get noise given seed. Args: seed: int """ - generator = torch.Generator(device=self.sdh.device).manual_seed(int(seed)) - if self.mode == 'standard': - shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f] - C, H, W = shape_latents - elif self.mode == 'upscale': - w = self.image1_lowres.size[0] - h = self.image1_lowres.size[1] - shape_latents = [self.sdh.model.channels, h, w] - C, H, W = shape_latents - return torch.randn((1, C, H, W), generator=generator, device=self.sdh.device) + return self.dh.get_noise(seed, self.mode) @torch.no_grad() def run_diffusion( @@ -590,32 +560,41 @@ class LatentBlending(): """ # Ensure correct num_inference_steps in Holder - self.sdh.num_inference_steps = self.num_inference_steps + self.dh.set_num_inference_steps(self.num_inference_steps) assert type(list_conditionings) is list, "list_conditionings need to be a list" - if self.mode == 'standard': + if self.dh.use_sd_xl: text_embeddings = list_conditionings[0] - return self.sdh.run_diffusion_standard( + return self.dh.run_diffusion_sd_xl( text_embeddings=text_embeddings, latents_start=latents_start, idx_start=idx_start, list_latents_mixing=list_latents_mixing, mixing_coeffs=mixing_coeffs, - spatial_mask=self.spatial_mask, return_image=return_image) - elif self.mode == 'upscale': - cond = list_conditionings[0] - uc_full = list_conditionings[1] - return self.sdh.run_diffusion_upscaling( - cond, - uc_full, + else: + text_embeddings = list_conditionings[0] + return self.dh.run_diffusion_standard( + text_embeddings=text_embeddings, latents_start=latents_start, idx_start=idx_start, list_latents_mixing=list_latents_mixing, mixing_coeffs=mixing_coeffs, return_image=return_image) + # elif self.mode == 'upscale': + # cond = list_conditionings[0] + # uc_full = list_conditionings[1] + # return self.dh.run_diffusion_upscaling( + # cond, + # uc_full, + # latents_start=latents_start, + # idx_start=idx_start, + # list_latents_mixing=list_latents_mixing, + # mixing_coeffs=mixing_coeffs, + # return_image=return_image) + def run_upscaling( self, dp_img: str, @@ -670,8 +649,8 @@ class LatentBlending(): imgs_lowres.append(Image.open(fp_img_lowres)) # set up upscaling - text_embeddingA = self.sdh.get_text_embedding(prompt1) - text_embeddingB = self.sdh.get_text_embedding(prompt2) + text_embeddingA = self.dh.get_text_embedding(prompt1) + text_embeddingB = self.dh.get_text_embedding(prompt2) list_fract_mixing = np.linspace(0, 1, nmb_max_branches_lowres - 1) for i in range(nmb_max_branches_lowres - 1): print(f"Starting movie segment {i+1}/{nmb_max_branches_lowres-1}") @@ -701,23 +680,35 @@ class LatentBlending(): @torch.no_grad() def get_mixed_conditioning(self, fract_mixing): - if self.mode == 'standard': - text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing) + if self.dh.use_sd_xl: + text_embeddings_mix = [] + for i in range(len(self.text_embedding1)): + text_embeddings_mix.append(interpolate_linear(self.text_embedding1[i], self.text_embedding2[i], fract_mixing)) list_conditionings = [text_embeddings_mix] - elif self.mode == 'inpaint': - text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing) - list_conditionings = [text_embeddings_mix] - elif self.mode == 'upscale': - text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing) - cond, uc_full = self.sdh.get_cond_upscaling(self.image1_lowres, text_embeddings_mix, self.noise_level_upscaling) - condB, uc_fullB = self.sdh.get_cond_upscaling(self.image2_lowres, text_embeddings_mix, self.noise_level_upscaling) - cond['c_concat'][0] = interpolate_spherical(cond['c_concat'][0], condB['c_concat'][0], fract_mixing) - uc_full['c_concat'][0] = interpolate_spherical(uc_full['c_concat'][0], uc_fullB['c_concat'][0], fract_mixing) - list_conditionings = [cond, uc_full] else: - raise ValueError(f"mix_conditioning: unknown mode {self.mode}") + text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing) + list_conditionings = [text_embeddings_mix] return list_conditionings + # @torch.no_grad() + # def get_mixed_conditioning(self, fract_mixing): + # if self.mode == 'standard': + # text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing) + # list_conditionings = [text_embeddings_mix] + # elif self.mode == 'inpaint': + # text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing) + # list_conditionings = [text_embeddings_mix] + # elif self.mode == 'upscale': + # text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing) + # cond, uc_full = self.dh.get_cond_upscaling(self.image1_lowres, text_embeddings_mix, self.noise_level_upscaling) + # condB, uc_fullB = self.dh.get_cond_upscaling(self.image2_lowres, text_embeddings_mix, self.noise_level_upscaling) + # cond['c_concat'][0] = interpolate_spherical(cond['c_concat'][0], condB['c_concat'][0], fract_mixing) + # uc_full['c_concat'][0] = interpolate_spherical(uc_full['c_concat'][0], uc_fullB['c_concat'][0], fract_mixing) + # list_conditionings = [cond, uc_full] + # else: + # raise ValueError(f"mix_conditioning: unknown mode {self.mode}") + # return list_conditionings + @torch.no_grad() def get_text_embeddings( self, @@ -729,7 +720,7 @@ class LatentBlending(): prompt: str ABC trending on artstation painted by Old Greg. """ - return self.sdh.get_text_embedding(prompt) + return self.dh.get_text_embedding(prompt) def write_imgs_transition(self, dp_img): r""" @@ -766,7 +757,7 @@ class LatentBlending(): # Save as MP4 if os.path.isfile(fp_movie): os.remove(fp_movie) - ms = MovieSaver(fp_movie, fps=fps, shape_hw=[self.sdh.height, self.sdh.width]) + ms = MovieSaver(fp_movie, fps=fps, shape_hw=[self.dh.height_img, self.dh.width_img]) for img in tqdm(imgs_transition_ext): ms.write_frame(img) ms.finalize() @@ -811,7 +802,7 @@ class LatentBlending(): Set a the seed for a fresh start. """ self.seed = seed - self.sdh.seed = seed + self.dh.seed = seed def set_width(self, width): r""" @@ -819,7 +810,7 @@ class LatentBlending(): """ assert np.mod(width, 64) == 0, "set_width: value needs to be divisible by 64" self.width = width - self.sdh.width = width + self.dh.width = width def set_height(self, height): r""" @@ -827,7 +818,7 @@ class LatentBlending(): """ assert np.mod(height, 64) == 0, "set_height: value needs to be divisible by 64" self.height = height - self.sdh.height = height + self.dh.height = height def swap_forward(self): r""" diff --git a/stable_diffusion_holder.py b/stable_diffusion_holder.py index 5fd7b20..03e426f 100644 --- a/stable_diffusion_holder.py +++ b/stable_diffusion_holder.py @@ -155,6 +155,24 @@ class StableDiffusionHolder: else: self.height = 512 self.width = 512 + + def get_noise(self, seed, mode='standard'): + r""" + Helper function to get noise given seed. + Args: + seed: int + """ + + generator = torch.Generator(device=self.device).manual_seed(int(seed)) + if mode == 'standard': + shape_latents = [self.C, self.height // self.f, self.width // self.f] + C, H, W = shape_latents + elif mode == 'upscale': + w = self.image1_lowres.size[0] + h = self.image1_lowres.size[1] + shape_latents = [self.model.channels, h, w] + C, H, W = shape_latents + return torch.randn((1, C, H, W), generator=generator, device=self.device) def set_negative_prompt(self, negative_prompt): r"""Set the negative prompt. Currenty only one negative prompt is supported