This commit is contained in:
Johannes Stelzer 2023-11-16 15:37:02 +01:00
parent 2d4570a228
commit ddd6fdee21
4 changed files with 98 additions and 117 deletions

View File

@ -13,20 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import torch import torch
torch.backends.cudnn.benchmark = False
torch.set_grad_enabled(False)
import numpy as np import numpy as np
import warnings import warnings
warnings.filterwarnings('ignore')
import warnings
import torch
from PIL import Image
import torch
from typing import Optional from typing import Optional
from torch import autocast
from contextlib import nullcontext
from utils import interpolate_spherical from utils import interpolate_spherical
from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
@ -35,6 +26,9 @@ from diffusers.models.attention_processor import (
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
warnings.filterwarnings('ignore')
torch.backends.cudnn.benchmark = False
torch.set_grad_enabled(False)
class DiffusersHolder(): class DiffusersHolder():
@ -71,13 +65,11 @@ class DiffusersHolder():
def set_dimensions(self, size_output): def set_dimensions(self, size_output):
s = self.pipe.vae_scale_factor s = self.pipe.vae_scale_factor
if size_output is None: if size_output is None:
width = self.pipe.unet.config.sample_size width = self.pipe.unet.config.sample_size
height = self.pipe.unet.config.sample_size height = self.pipe.unet.config.sample_size
else: else:
width, height = size_output width, height = size_output
self.width_img = int(round(width / s) * s) self.width_img = int(round(width / s) * s)
self.width_latent = int(self.width_img / s) self.width_latent = int(self.width_img / s)
self.height_img = int(round(height / s) * s) self.height_img = int(round(height / s) * s)
@ -95,7 +87,6 @@ class DiffusersHolder():
if len(self.negative_prompt) > 1: if len(self.negative_prompt) > 1:
self.negative_prompt = [self.negative_prompt[0]] self.negative_prompt = [self.negative_prompt[0]]
def get_text_embedding(self, prompt, do_classifier_free_guidance=True): def get_text_embedding(self, prompt, do_classifier_free_guidance=True):
if self.use_sd_xl: if self.use_sd_xl:
pr_encoder = self.pipe.encode_prompt pr_encoder = self.pipe.encode_prompt
@ -114,7 +105,7 @@ class DiffusersHolder():
) )
return prompt_embeds return prompt_embeds
def get_noise(self, seed=420, mode=None): def get_noise(self, seed=420):
H = self.height_latent H = self.height_latent
W = self.width_latent W = self.width_latent
C = self.pipe.unet.config.in_channels C = self.pipe.unet.config.in_channels
@ -164,7 +155,6 @@ class DiffusersHolder():
return np.asarray(image) return np.asarray(image)
else: else:
return image return image
def prepare_mixing(self, mixing_coeffs, list_latents_mixing): def prepare_mixing(self, mixing_coeffs, list_latents_mixing):
if type(mixing_coeffs) == float: if type(mixing_coeffs) == float:
@ -265,10 +255,10 @@ class DiffusersHolder():
list_latents_mixing=None, list_latents_mixing=None,
mixing_coeffs=0.0, mixing_coeffs=0.0,
return_image: Optional[bool] = False): return_image: Optional[bool] = False):
# 0. Default height and width to unet # 0. Default height and width to unet
original_size = (self.width_img, self.height_img) # FIXME original_size = (self.width_img, self.height_img)
crops_coords_top_left = (0, 0) # FIXME crops_coords_top_left = (0, 0)
target_size = original_size target_size = original_size
batch_size = 1 batch_size = 1
eta = 0.0 eta = 0.0
@ -276,10 +266,10 @@ class DiffusersHolder():
cross_attention_kwargs = None cross_attention_kwargs = None
generator = torch.Generator(device=self.device) # dummy generator generator = torch.Generator(device=self.device) # dummy generator
do_classifier_free_guidance = self.guidance_scale > 1.0 do_classifier_free_guidance = self.guidance_scale > 1.0
# 1. Check inputs. Raise error if not correct & 2. Define call parameters # 1. Check inputs. Raise error if not correct & 2. Define call parameters
list_mixing_coeffs = self.prepare_mixing(mixing_coeffs, list_latents_mixing) list_mixing_coeffs = self.prepare_mixing(mixing_coeffs, list_latents_mixing)
# 3. Encode input prompt (already encoded outside bc of mixing, just split here) # 3. Encode input prompt (already encoded outside bc of mixing, just split here)
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = text_embeddings prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = text_embeddings
@ -294,28 +284,13 @@ class DiffusersHolder():
# 6. Prepare extra step kwargs. usedummy generator # 6. Prepare extra step kwargs. usedummy generator
extra_step_kwargs = self.pipe.prepare_extra_step_kwargs(generator, eta) # dummy extra_step_kwargs = self.pipe.prepare_extra_step_kwargs(generator, eta) # dummy
# 7. Prepare added time ids & embeddings
# add_text_embeds = pooled_prompt_embeds
# add_time_ids = self.pipe._get_add_time_ids(
# original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
# )
# if do_classifier_free_guidance:
# prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
# add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
# prompt_embeds = prompt_embeds.to(self.device)
# add_text_embeds = add_text_embeds.to(self.device)
# add_time_ids = add_time_ids.to(self.device).repeat(batch_size * num_images_per_prompt, 1)
# 7. Prepare added time ids & embeddings # 7. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds add_text_embeds = pooled_prompt_embeds
if self.pipe.text_encoder_2 is None: if self.pipe.text_encoder_2 is None:
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
else: else:
text_encoder_projection_dim = self.pipe.text_encoder_2.config.projection_dim text_encoder_projection_dim = self.pipe.text_encoder_2.config.projection_dim
add_time_ids = self.pipe._get_add_time_ids( add_time_ids = self.pipe._get_add_time_ids(
original_size, original_size,
crops_coords_top_left, crops_coords_top_left,
@ -323,26 +298,16 @@ class DiffusersHolder():
dtype=prompt_embeds.dtype, dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim, text_encoder_projection_dim=text_encoder_projection_dim,
) )
# if negative_original_size is not None and negative_target_size is not None:
# negative_add_time_ids = self.pipe._get_add_time_ids(
# negative_original_size,
# negative_crops_coords_top_left,
# negative_target_size,
# dtype=prompt_embeds.dtype,
# text_encoder_projection_dim=text_encoder_projection_dim,
# )
# else:
negative_add_time_ids = add_time_ids negative_add_time_ids = add_time_ids
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds.to(self.device) prompt_embeds = prompt_embeds.to(self.device)
add_text_embeds = add_text_embeds.to(self.device) add_text_embeds = add_text_embeds.to(self.device)
add_time_ids = add_time_ids.to(self.device).repeat(batch_size * num_images_per_prompt, 1) add_time_ids = add_time_ids.to(self.device).repeat(batch_size * num_images_per_prompt, 1)
# 8. Denoising loop # 8. Denoising loop
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
@ -358,7 +323,6 @@ class DiffusersHolder():
latents_mixtarget = list_latents_mixing[i - 1].clone() latents_mixtarget = list_latents_mixing[i - 1].clone()
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i]) latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# Always scale latents # Always scale latents
@ -380,14 +344,12 @@ class DiffusersHolder():
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# FIXME guidance_rescale disabled
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# Append latents # Append latents
list_latents_out.append(latents.clone()) list_latents_out.append(latents.clone())
if return_image: if return_image:
return self.latent2image(latents) return self.latent2image(latents)
else: else:
@ -415,7 +377,7 @@ class DiffusersHolder():
batch_size = 1 batch_size = 1
eta = 0.0 eta = 0.0
controlnet_conditioning_scale = 1.0 controlnet_conditioning_scale = 1.0
# align format for control guidance # align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start] control_guidance_start = len(control_guidance_end) * [control_guidance_start]
@ -527,19 +489,16 @@ class DiffusersHolder():
# Append latents # Append latents
list_latents_out.append(latents.clone()) list_latents_out.append(latents.clone())
if return_image: if return_image:
return self.latent2image(latents) return self.latent2image(latents)
else: else:
return list_latents_out return list_latents_out
#%% #%%
if __name__ == "__main__": if __name__ == "__main__":
from PIL import Image
#%% #%%
pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0" pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16) pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16)

View File

@ -0,0 +1,57 @@
# Copyright 2022 Lunar Ring. All rights reserved.
# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import warnings
from latent_blending import LatentBlending
from diffusers_holder import DiffusersHolder
from diffusers import DiffusionPipeline
warnings.filterwarnings('ignore')
torch.set_grad_enabled(False)
torch.backends.cudnn.benchmark = False
# %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16)
pipe.to('cuda')
dh = DiffusersHolder(pipe)
# %% Next let's set up all parameters
depth_strength = 0.55 # Specifies how deep (in terms of diffusion iterations the first branching happens)
t_compute_max_allowed = 60 # Determines the quality of the transition in terms of compute time you grant it
num_inference_steps = 30
size_output = (1024, 1024)
prompt1 = "underwater landscape, fish, und the sea, incredible detail, high resolution"
prompt2 = "rendering of an alien planet, strange plants, strange creatures, surreal"
negative_prompt = "blurry, ugly, pale" # Optional
fp_movie = 'movie_example1.mp4'
duration_transition = 12 # In seconds
# Spawn latent blending
lb = LatentBlending(dh)
lb.set_prompt1(prompt1)
lb.set_prompt2(prompt2)
lb.set_dimensions(size_output)
lb.set_negative_prompt(negative_prompt)
# Run latent blending
lb.run_transition(
depth_strength=depth_strength,
num_inference_steps=num_inference_steps,
t_compute_max_allowed=t_compute_max_allowed)
# Save movie
lb.write_movie_transition(fp_movie, duration_transition)

View File

@ -14,16 +14,14 @@
# limitations under the License. # limitations under the License.
import torch import torch
torch.backends.cudnn.benchmark = False
torch.set_grad_enabled(False)
import warnings
warnings.filterwarnings('ignore')
import warnings import warnings
from latent_blending import LatentBlending from latent_blending import LatentBlending
from diffusers_holder import DiffusersHolder from diffusers_holder import DiffusersHolder
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from movie_util import concatenate_movies from movie_util import concatenate_movies
from huggingface_hub import hf_hub_download torch.set_grad_enabled(False)
torch.backends.cudnn.benchmark = False
warnings.filterwarnings('ignore')
# %% First let us spawn a stable diffusion holder. Uncomment your version of choice. # %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0" pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
@ -35,21 +33,23 @@ dh = DiffusersHolder(pipe)
fps = 30 fps = 30
duration_single_trans = 20 duration_single_trans = 20
depth_strength = 0.25 # Specifies how deep (in terms of diffusion iterations the first branching happens) depth_strength = 0.25 # Specifies how deep (in terms of diffusion iterations the first branching happens)
size_output = (1280, 768)
num_inference_steps = 30
# Specify a list of prompts below # Specify a list of prompts below
list_prompts = [] list_prompts = []
list_prompts.append("A panoramic photo of a sentient mirror maze amidst a neon-lit forest, where bioluminescent mushrooms glow eerily, reflecting off the mirrors, and cybernetic crows, with silver wings and ruby eyes, perch ominously, David Lynch, Gaspar Noé, Photograph.") list_prompts.append("A beautiful astronomic photo of a nebula, with intricate microscopic structures, mitochondria")
list_prompts.append("An unsettling tableau of spectral butterflies with clockwork wings, swirling around an antique typewriter perched precariously atop a floating, gnarled tree trunk, a stormy twilight sky, David Lynch's dreamscape, meticulously crafted.") list_prompts.append("Microscope fluorescence photo, cell filaments, intricate galaxy, astronomic nebula")
# list_prompts.append("A haunting tableau of an antique dollhouse swallowed by a giant venus flytrap under the neon glow of an alien moon, its uncanny light reflecting from shattered porcelain faces and marbles, in a quiet, abandoned amusement park.") list_prompts.append("telescope photo starry sky, nebula, cell core, dna, stunning")
# You can optionally specify the seeds # You can optionally specify the seeds
list_seeds = [95437579, 33259350, 956051013, 408831845, 250009012, 675588737] list_seeds = [95437579, 33259350, 956051013]
t_compute_max_allowed = 20 # per segment t_compute_max_allowed = 20 # per segment
fp_movie = 'movie_example2.mp4' fp_movie = 'movie_example2.mp4'
lb = LatentBlending(dh) lb = LatentBlending(dh)
lb.dh.set_dimensions(1024, 704) lb.set_dimensions(size_output)
lb.dh.set_num_inference_steps(40) lb.dh.set_num_inference_steps(num_inference_steps)
list_movie_parts = [] list_movie_parts = []
@ -68,7 +68,7 @@ for i in range(len(list_prompts) - 1):
fixed_seeds = list_seeds[i:i + 2] fixed_seeds = list_seeds[i:i + 2]
# Run latent blending # Run latent blending
lb.run_transition( lb.run_transition(
recycle_img1 = recycle_img1, recycle_img1=recycle_img1,
depth_strength=depth_strength, depth_strength=depth_strength,
t_compute_max_allowed=t_compute_max_allowed, t_compute_max_allowed=t_compute_max_allowed,
fixed_seeds=fixed_seeds) fixed_seeds=fixed_seeds)

View File

@ -15,19 +15,18 @@
import os import os
import torch import torch
torch.backends.cudnn.benchmark = False
torch.set_grad_enabled(False)
import numpy as np import numpy as np
import warnings import warnings
warnings.filterwarnings('ignore')
import time import time
import warnings
from tqdm.auto import tqdm from tqdm.auto import tqdm
from PIL import Image from PIL import Image
from movie_util import MovieSaver from movie_util import MovieSaver
from typing import List, Optional from typing import List, Optional
import lpips import lpips
from utils import interpolate_spherical, interpolate_linear, add_frames_linear_interp, yml_load, yml_save from utils import interpolate_spherical, interpolate_linear, add_frames_linear_interp, yml_load, yml_save
warnings.filterwarnings('ignore')
torch.backends.cudnn.benchmark = False
torch.set_grad_enabled(False)
class LatentBlending(): class LatentBlending():
@ -70,7 +69,6 @@ class LatentBlending():
# Initialize vars # Initialize vars
self.prompt1 = "" self.prompt1 = ""
self.prompt2 = "" self.prompt2 = ""
self.negative_prompt = ""
self.tree_latents = [None, None] self.tree_latents = [None, None]
self.tree_fracts = None self.tree_fracts = None
@ -91,17 +89,15 @@ class LatentBlending():
self.list_nmb_branches = None self.list_nmb_branches = None
# Mixing parameters # Mixing parameters
self.branch1_crossfeed_power = 0.05 self.branch1_crossfeed_power = 0.3
self.branch1_crossfeed_range = 0.4 self.branch1_crossfeed_range = 0.3
self.branch1_crossfeed_decay = 0.9 self.branch1_crossfeed_decay = 0.99
self.parental_crossfeed_power = 0.1 self.parental_crossfeed_power = 0.3
self.parental_crossfeed_range = 0.8 self.parental_crossfeed_range = 0.6
self.parental_crossfeed_power_decay = 0.8 self.parental_crossfeed_power_decay = 0.9
self.set_guidance_scale(guidance_scale) self.set_guidance_scale(guidance_scale)
self.mode = 'standard'
# self.init_mode()
self.multi_transition_img_first = None self.multi_transition_img_first = None
self.multi_transition_img_last = None self.multi_transition_img_last = None
self.dt_per_diff = 0 self.dt_per_diff = 0
@ -441,7 +437,7 @@ class LatentBlending():
list_compute_steps = self.num_inference_steps - list_idx_injection list_compute_steps = self.num_inference_steps - list_idx_injection
list_compute_steps *= list_nmb_stems list_compute_steps *= list_nmb_stems
t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15 * np.sum(list_nmb_stems) t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15 * np.sum(list_nmb_stems)
t_compute += 2*self.num_inference_steps*self.dt_per_diff # outer branches t_compute += 2 * self.num_inference_steps * self.dt_per_diff # outer branches
increase_done = False increase_done = False
for s_idx in range(len(list_nmb_stems) - 1): for s_idx in range(len(list_nmb_stems) - 1):
if list_nmb_stems[s_idx + 1] / list_nmb_stems[s_idx] >= 2: if list_nmb_stems[s_idx + 1] / list_nmb_stems[s_idx] >= 2:
@ -522,7 +518,7 @@ class LatentBlending():
Args: Args:
seed: int seed: int
""" """
return self.dh.get_noise(seed, self.mode) return self.dh.get_noise(seed)
@torch.no_grad() @torch.no_grad()
def run_diffusion( def run_diffusion(
@ -576,18 +572,6 @@ class LatentBlending():
mixing_coeffs=mixing_coeffs, mixing_coeffs=mixing_coeffs,
return_image=return_image) return_image=return_image)
# elif self.mode == 'upscale':
# cond = list_conditionings[0]
# uc_full = list_conditionings[1]
# return self.dh.run_diffusion_upscaling(
# cond,
# uc_full,
# latents_start=latents_start,
# idx_start=idx_start,
# list_latents_mixing=list_latents_mixing,
# mixing_coeffs=mixing_coeffs,
# return_image=return_image)
def run_upscaling( def run_upscaling(
self, self,
dp_img: str, dp_img: str,
@ -683,25 +667,6 @@ class LatentBlending():
list_conditionings = [text_embeddings_mix] list_conditionings = [text_embeddings_mix]
return list_conditionings return list_conditionings
# @torch.no_grad()
# def get_mixed_conditioning(self, fract_mixing):
# if self.mode == 'standard':
# text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing)
# list_conditionings = [text_embeddings_mix]
# elif self.mode == 'inpaint':
# text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing)
# list_conditionings = [text_embeddings_mix]
# elif self.mode == 'upscale':
# text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing)
# cond, uc_full = self.dh.get_cond_upscaling(self.image1_lowres, text_embeddings_mix, self.noise_level_upscaling)
# condB, uc_fullB = self.dh.get_cond_upscaling(self.image2_lowres, text_embeddings_mix, self.noise_level_upscaling)
# cond['c_concat'][0] = interpolate_spherical(cond['c_concat'][0], condB['c_concat'][0], fract_mixing)
# uc_full['c_concat'][0] = interpolate_spherical(uc_full['c_concat'][0], uc_fullB['c_concat'][0], fract_mixing)
# list_conditionings = [cond, uc_full]
# else:
# raise ValueError(f"mix_conditioning: unknown mode {self.mode}")
# return list_conditionings
@torch.no_grad() @torch.no_grad()
def get_text_embeddings( def get_text_embeddings(
self, self,