updated latent blending for stable diffusion holder

This commit is contained in:
lugo 2022-11-25 15:34:41 +01:00
parent 3113bfaed4
commit 9fe1559b1d
3 changed files with 114 additions and 421 deletions

View File

@ -29,44 +29,33 @@ import torch
from movie_util import MovieSaver from movie_util import MovieSaver
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
from latent_blending import LatentBlending, add_frames_linear_interp from latent_blending import LatentBlending, add_frames_linear_interp
from stable_diffusion_holder import StableDiffusionHolder
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
#%% First let us spawn a diffusers pipe using DDIMScheduler #%% First let us spawn a stable diffusion holder
device = "cuda:0" device = "cuda:0"
model_path = "../stable_diffusion_models/stable-diffusion-v1-5" num_inference_steps = 20 # Number of diffusion interations
fp_ckpt = "../stable_diffusion_models/ckpt/768-v-ema.ckpt"
fp_config = '../stablediffusion/configs/stable-diffusion/v2-inference-v.yaml'
scheduler = DDIMScheduler(beta_start=0.00085, sdh = StableDiffusionHolder(fp_ckpt, fp_config, device, num_inference_steps=num_inference_steps)
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False)
pipe = StableDiffusionPipeline.from_pretrained(
model_path,
revision="fp16",
torch_dtype=torch.float16,
scheduler=scheduler,
use_auth_token=True
)
pipe = pipe.to(device)
#%% Next let's set up all parameters #%% Next let's set up all parameters
# FIXME below fix numbers # FIXME below fix numbers
# We want 20 diffusion steps in total, begin with 2 branches, have 3 branches at step 12 (=0.6*20) # We want 20 diffusion steps in total, begin with 2 branches, have 3 branches at step 12 (=0.6*20)
# 10 branches at step 16 (=0.8*20) and 24 branches at step 18 (=0.9*20) # 10 branches at step 16 (=0.8*20) and 24 branches at step 18 (=0.9*20)
# Furthermore we want seed 993621550 for keyframeA and seed 54878562 for keyframeB () # Furthermore we want seed 993621550 for keyframeA and seed 54878562 for keyframeB ()
num_inference_steps = 20 # Number of diffusion interations
list_nmb_branches = [2, 3, 10, 24] # Branching structure: how many branches list_nmb_branches = [2, 3, 10, 24] # Branching structure: how many branches
list_injection_strength = [0.0, 0.6, 0.8, 0.9] # Branching structure: how deep is the blending list_injection_strength = [0.0, 0.6, 0.8, 0.9] # Branching structure: how deep is the blending
width = 512 width = 768
height = 512 height = 768
guidance_scale = 5 guidance_scale = 5
fixed_seeds = [993621550, 280335986] fixed_seeds = [993621550, 280335986]
lb = LatentBlending(pipe, device, height, width, num_inference_steps, guidance_scale) lb = LatentBlending(sdh, num_inference_steps, guidance_scale)
prompt1 = "photo of a beautiful forest covered in white flowers, ambient light, very detailed, magic" prompt1 = "photo of a beautiful forest covered in white flowers, ambient light, very detailed, magic"
prompt2 = "photo of an eerie statue surrounded by ferns and vines, analog photograph kodak portra, mystical ambience, incredible detail" prompt2 = "photo of an golden statue with a funny hat, surrounded by ferns and vines, grainy analog photograph,, mystical ambience, incredible detail"
lb.set_prompt1(prompt1) lb.set_prompt1(prompt1)
lb.set_prompt2(prompt2) lb.set_prompt2(prompt2)
@ -78,7 +67,7 @@ fps = 60
imgs_transition_ext = add_frames_linear_interp(imgs_transition, duration_transition, fps) imgs_transition_ext = add_frames_linear_interp(imgs_transition, duration_transition, fps)
# movie saving # movie saving
fp_movie = f"/home/lugo/tmp/latentblending/bobo_incoming.mp4" fp_movie = "/home/lugo/tmp/latentblending/bobo_incoming.mp4"
if os.path.isfile(fp_movie): if os.path.isfile(fp_movie):
os.remove(fp_movie) os.remove(fp_movie)
ms = MovieSaver(fp_movie, fps=fps) ms = MovieSaver(fp_movie, fps=fps)

View File

@ -27,22 +27,19 @@ from diffusers import StableDiffusionInpaintPipeline
from PIL import Image from PIL import Image
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch import torch
from movie_man import MovieSaver from movie_util import MovieSaver
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
from latent_blending import LatentBlending, add_frames_linear_interp from latent_blending import LatentBlending, add_frames_linear_interp
from stable_diffusion_holder import StableDiffusionHolder
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
#%% First let us spawn a diffusers pipe using DDIMScheduler #%% First let us spawn a stable diffusion holder
device = "cuda:0" device = "cuda:0"
model_path = "../stable_diffusion_models/stable-diffusion-inpainting" num_inference_steps = 20 # Number of diffusion interations
fp_ckpt= "../stable_diffusion_models/ckpt/512-inpainting-ema.ckpt"
fp_config = '../stablediffusion/configs//stable-diffusion/v2-inpainting-inference.yaml'
pipe = StableDiffusionInpaintPipeline.from_pretrained( sdh = StableDiffusionHolder(fp_ckpt, fp_config, device, num_inference_steps=num_inference_steps)
model_path,
revision="fp16",
torch_dtype=torch.float16,
safety_checker=None
)
pipe = pipe.to(device)
#%% Let's make a source image and mask. #%% Let's make a source image and mask.
@ -52,7 +49,7 @@ num_inference_steps = 30
guidance_scale = 5 guidance_scale = 5
fixed_seeds = [629575320, 670154945] fixed_seeds = [629575320, 670154945]
lb = LatentBlending(pipe, device, height, width, num_inference_steps, guidance_scale) lb = LatentBlending(sdh, num_inference_steps, guidance_scale)
prompt1 = "photo of a futuristic alien temple in a desert, mystic, glowing, organic, intricate, sci-fi movie, mesmerizing, scary" prompt1 = "photo of a futuristic alien temple in a desert, mystic, glowing, organic, intricate, sci-fi movie, mesmerizing, scary"
lb.set_prompt1(prompt1) lb.set_prompt1(prompt1)
lb.init_inpainting(init_empty=True) lb.init_inpainting(init_empty=True)
@ -77,7 +74,6 @@ height = 512
guidance_scale = 5 guidance_scale = 5
fixed_seeds = [993621550, 280335986] fixed_seeds = [993621550, 280335986]
lb = LatentBlending(pipe, device, height, width, num_inference_steps, guidance_scale)
prompt1 = "photo of a futuristic alien temple in a desert, mystic, glowing, organic, intricate, sci-fi movie, mesmerizing, scary" prompt1 = "photo of a futuristic alien temple in a desert, mystic, glowing, organic, intricate, sci-fi movie, mesmerizing, scary"
prompt2 = "aerial photo of a futuristic alien temple in a coastal area, waves clashing" prompt2 = "aerial photo of a futuristic alien temple in a coastal area, waves clashing"
lb.set_prompt1(prompt1) lb.set_prompt1(prompt1)
@ -92,12 +88,11 @@ fps = 60
imgs_transition_ext = add_frames_linear_interp(imgs_transition, duration_transition, fps) imgs_transition_ext = add_frames_linear_interp(imgs_transition, duration_transition, fps)
# movie saving # movie saving
fp_movie = f"/home/lugo/tmp/latentblending/bobo_incoming.mp4" fp_movie = "/home/lugo/tmp/latentblending/bobo_incoming.mp4"
if os.path.isfile(fp_movie): if os.path.isfile(fp_movie):
os.remove(fp_movie) os.remove(fp_movie)
ms = MovieSaver(fp_movie, fps=fps, profile='save') ms = MovieSaver(fp_movie, fps=fps, shape_hw=[lb.height, lb.width])
for img in tqdm(imgs_transition_ext): for img in tqdm(imgs_transition_ext):
ms.write_frame(img) ms.write_frame(img)
ms.finalize() ms.finalize()

View File

@ -38,15 +38,14 @@ from typing import Callable, List, Optional, Union
import inspect import inspect
from threading import Thread from threading import Thread
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
from omegaconf import OmegaConf
from torch import autocast
from contextlib import nullcontext
#%% #%%
class LatentBlending(): class LatentBlending():
def __init__( def __init__(
self, self,
pipe: Union[StableDiffusionInpaintPipeline, StableDiffusionPipeline], sdh: None,
device: str,
height: int = 512,
width: int = 512,
num_inference_steps: int = 30, num_inference_steps: int = 30,
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
seed: int = 420, seed: int = 420,
@ -54,8 +53,7 @@ class LatentBlending():
r""" r"""
Initializes the latent blending class. Initializes the latent blending class.
Args: Args:
device: str FIXME XXX
Compute device, e.g. cuda:0
height: int height: int
Height of the desired output image. The model was trained on 512. Height of the desired output image. The model was trained on 512.
width: int width: int
@ -72,19 +70,15 @@ class LatentBlending():
Random seed. Random seed.
""" """
self.sdh = sdh
self.pipe = pipe
self.device = device
self.guidance_scale = guidance_scale
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self.width = width self.sdh.num_inference_steps = num_inference_steps
self.height = height self.device = self.sdh.device
self.guidance_scale = guidance_scale
self.width = self.sdh.width
self.height = self.sdh.height
self.seed = seed self.seed = seed
# Inits
self.check_asserts()
self.init_mode()
# Initialize vars # Initialize vars
self.prompt1 = "" self.prompt1 = ""
self.prompt2 = "" self.prompt2 = ""
@ -99,60 +93,22 @@ class LatentBlending():
self.stop_diffusion = False self.stop_diffusion = False
self.negative_prompt = None self.negative_prompt = None
self.init_mode()
def check_asserts(self):
r"""
Runs Minimal set of sanity checks.
"""
assert self.pipe.scheduler._class_name == 'DDIMScheduler', 'Currently only the DDIMScheduler is supported.'
def init_mode(self): def init_mode(self, mode='standard'):
r""" r"""
Automatically sets the mode of this class, depending on the supplied pipeline. Automatically sets the mode of this class, depending on the supplied pipeline.
FIXME XXX
""" """
if self.pipe._class_name == 'StableDiffusionInpaintPipeline': if mode == 'inpaint':
self.mask_empty = Image.fromarray(255*np.ones([self.width, self.height], dtype=np.uint8)) self.sdh.image_source = None
self.image_empty = Image.fromarray(np.zeros([self.width, self.height, 3], dtype=np.uint8)) self.sdh.mask_image = None
self.image_source = None
self.mask_image = None
self.mode = 'inpaint' self.mode = 'inpaint'
else: else:
self.mode = 'standard' self.mode = 'standard'
def init_inpainting(
self,
image_source: Union[Image.Image, np.ndarray] = None,
mask_image: Union[Image.Image, np.ndarray] = None,
init_empty: Optional[bool] = False,
):
r"""
Initializes inpainting with a source and maks image.
Args:
image_source: Union[Image.Image, np.ndarray]
Source image onto which the mask will be applied.
mask_image: Union[Image.Image, np.ndarray]
Mask image, value = 0 will stay untouched, value = 255 subjet to diffusion
init_empty: Optional[bool]:
Initialize inpainting with an empty image and mask, effectively disabling inpainting.
"""
assert self.mode == 'inpaint', 'Initialize class with an inpainting pipeline!'
if not init_empty:
assert image_source is not None, "init_inpainting: you need to provide image_source"
assert mask_image is not None, "init_inpainting: you need to provide mask_image"
if type(image_source) == np.ndarray:
image_source = Image.fromarray(image_source)
self.image_source = image_source
if type(mask_image) == np.ndarray:
mask_image = Image.fromarray(mask_image)
self.mask_image = mask_image
else:
self.mask_image = self.mask_empty
self.image_source = self.image_empty
def set_prompt1(self, prompt: str): def set_prompt1(self, prompt: str):
r""" r"""
Sets the first prompt (for the first keyframe) including text embeddings. Sets the first prompt (for the first keyframe) including text embeddings.
@ -238,6 +194,9 @@ class LatentBlending():
# Process interruption variable # Process interruption variable
self.stop_diffusion = False self.stop_diffusion = False
# Ensure correct num_inference_steps in holder
self.sdh.num_inference_steps = self.num_inference_steps
# Recycling? There are requirements # Recycling? There are requirements
if recycle_img1 or recycle_img2: if recycle_img1 or recycle_img2:
if self.list_nmb_branches_prev == []: if self.list_nmb_branches_prev == []:
@ -291,11 +250,11 @@ class LatentBlending():
self.tree_status[t_block][idx_branch] = 'untouched' self.tree_status[t_block][idx_branch] = 'untouched'
if recycle_img1: if recycle_img1:
self.tree_status[t_block][0] = 'computed' self.tree_status[t_block][0] = 'computed'
self.tree_final_imgs[0] = self.latent2image(self.tree_latents[-1][0][-1]) self.tree_final_imgs[0] = self.sdh.latent2image(self.tree_latents[-1][0][-1])
self.tree_final_imgs_timing[0] = 0 self.tree_final_imgs_timing[0] = 0
if recycle_img2: if recycle_img2:
self.tree_status[t_block][-1] = 'computed' self.tree_status[t_block][-1] = 'computed'
self.tree_final_imgs[-1] = self.latent2image(self.tree_latents[-1][-1][-1]) self.tree_final_imgs[-1] = self.sdh.latent2image(self.tree_latents[-1][-1][-1])
self.tree_final_imgs_timing[-1] = 0 self.tree_final_imgs_timing[-1] = 0
# setup compute order: goal: try to get last branch computed asap. # setup compute order: goal: try to get last branch computed asap.
@ -365,7 +324,7 @@ class LatentBlending():
# Convert latents to image directly for the last t_block # Convert latents to image directly for the last t_block
if t_block == nmb_blocks_time-1: if t_block == nmb_blocks_time-1:
self.tree_final_imgs[idx_branch] = self.latent2image(list_latents[-1]) self.tree_final_imgs[idx_branch] = self.sdh.latent2image(list_latents[-1])
self.tree_final_imgs_timing[idx_branch] = time.time() - time_start self.tree_final_imgs_timing[idx_branch] = time.time() - time_start
return self.tree_final_imgs return self.tree_final_imgs
@ -406,6 +365,8 @@ class LatentBlending():
The duration of your movie will be duration_single_trans * len(list_prompts) The duration of your movie will be duration_single_trans * len(list_prompts)
""" """
# Ensure correct
if list_seeds is None: if list_seeds is None:
list_seeds = list(np.random.randint(0, 10e10, len(list_prompts))) list_seeds = list(np.random.randint(0, 10e10, len(list_prompts)))
@ -424,7 +385,7 @@ class LatentBlending():
recycle_img1 = True recycle_img1 = True
local_seeds = [list_seeds[i], list_seeds[i+1]] local_seeds = [list_seeds[i], list_seeds[i+1]]
list_imgs = lb.run_transition(list_nmb_branches, list_injection_strength=list_injection_strength, list_injection_idx=list_injection_idx, recycle_img1=recycle_img1, fixed_seeds=local_seeds) list_imgs = self.run_transition(list_nmb_branches, list_injection_strength=list_injection_strength, list_injection_idx=list_injection_idx, recycle_img1=recycle_img1, fixed_seeds=local_seeds)
list_imgs_interp = add_frames_linear_interp(list_imgs, fps, duration_single_trans) list_imgs_interp = add_frames_linear_interp(list_imgs, fps, duration_single_trans)
# Save movie frame # Save movie frame
@ -462,257 +423,37 @@ class LatentBlending():
Optionally return image directly Optionally return image directly
""" """
# Ensure correct num_inference_steps in Holder
self.sdh.num_inference_steps = self.num_inference_steps
if self.mode == 'standard': if self.mode == 'standard':
return self.run_diffusion_standard(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image) return self.sdh.run_diffusion_standard(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
elif self.mode == 'inpaint': elif self.mode == 'inpaint':
assert self.image_source is not None, "image_source is None. Please run init_inpainting first." assert self.sdh.image_source is not None, "image_source is None. Please run init_inpainting first."
assert self.mask_image is not None, "image_source is None. Please run init_inpainting first." assert self.sdh.mask_image is not None, "image_source is None. Please run init_inpainting first."
return self.run_diffusion_inpaint(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image) return self.sdh.run_diffusion_inpaint(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
def init_inpainting(
@torch.no_grad()
def run_diffusion_standard(
self, self,
text_embeddings: torch.FloatTensor, image_source: Union[Image.Image, np.ndarray] = None,
latents_for_injection: torch.FloatTensor = None, mask_image: Union[Image.Image, np.ndarray] = None,
idx_start: int = -1, init_empty: Optional[bool] = False,
idx_stop: int = -1,
return_image: Optional[bool] = False
): ):
r""" r"""
Runs regular diffusion. Returns a list of latents that were computed. Initializes inpainting with a source and maks image.
Adaptations allow to supply
a) starting index for diffusion
b) stopping index for diffusion
c) latent representations that are injected at the starting index
Furthermore the intermittent latents are collected and returned.
Adapted from diffusers (https://github.com/huggingface/diffusers)
Args: Args:
text_embeddings: torch.FloatTensor image_source: Union[Image.Image, np.ndarray]
Text embeddings used for diffusion Source image onto which the mask will be applied.
latents_for_injection: torch.FloatTensor mask_image: Union[Image.Image, np.ndarray]
Latents that are used for injection Mask image, value = 0 will stay untouched, value = 255 subjet to diffusion
idx_start: int init_empty: Optional[bool]:
Index of the diffusion process start and where the latents_for_injection are injected Initialize inpainting with an empty image and mask, effectively disabling inpainting,
idx_stop: int useful for generating a first image for transitions using diffusion.
Index of the diffusion process end.
return_image: Optional[bool]
Optionally return image directly
""" """
if latents_for_injection is None: self.init_mode('inpaint')
do_inject_latents = False self.sdh.init_inpainting(image_source, mask_image, init_empty)
else:
do_inject_latents = True
generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
batch_size = 1
height = self.height
width = self.width
num_inference_steps = self.num_inference_steps
num_images_per_prompt = 1
do_classifier_free_guidance = True
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# set timesteps
self.pipe.scheduler.set_timesteps(num_inference_steps)
# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps_tensor = self.pipe.scheduler.timesteps.to(self.pipe.device)
if not do_inject_latents:
# get the initial random noise unless the user supplied it
latents_shape = (batch_size * num_images_per_prompt, self.pipe.unet.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
latents = torch.randn(latents_shape, generator=generator, device=self.pipe.device, dtype=latents_dtype)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.pipe.scheduler.init_noise_sigma
extra_step_kwargs = {}
# collect latents
list_latents_out = []
for i, t in enumerate(timesteps_tensor):
if do_inject_latents:
# Inject latent at right place
if i < idx_start:
continue
elif i == idx_start:
latents = latents_for_injection.clone()
if i == idx_stop:
return list_latents_out
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
list_latents_out.append(latents.clone())
if return_image:
return self.latent2image(latents)
else:
return list_latents_out
@torch.no_grad()
def run_diffusion_inpaint(
self,
text_embeddings: torch.FloatTensor,
latents_for_injection: torch.FloatTensor = None,
idx_start: int = -1,
idx_stop: int = -1,
return_image: Optional[bool] = False
):
r"""
Runs inpaint-based diffusion. Returns a list of latents that were computed.
Adaptations allow to supply
a) starting index for diffusion
b) stopping index for diffusion
c) latent representations that are injected at the starting index
Furthermore the intermittent latents are collected and returned.
Adapted from diffusers (https://github.com/huggingface/diffusers)
Args:
text_embeddings: torch.FloatTensor
Text embeddings used for diffusion
latents_for_injection: torch.FloatTensor
Latents that are used for injection
idx_start: int
Index of the diffusion process start and where the latents_for_injection are injected
idx_stop: int
Index of the diffusion process end.
return_image: Optional[bool]
Optionally return image directly
"""
if latents_for_injection is None:
do_inject_latents = False
else:
do_inject_latents = True
generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
batch_size = 1
height = self.height
width = self.width
num_inference_steps = self.num_inference_steps
num_images_per_prompt = 1
do_classifier_free_guidance = True
# prepare mask and masked_image
mask, masked_image = self.prepare_mask_and_masked_image(self.image_source, self.mask_image)
mask = mask.to(device=self.pipe.device, dtype=text_embeddings.dtype)
masked_image = masked_image.to(device=self.pipe.device, dtype=text_embeddings.dtype)
# resize the mask to latents shape as we concatenate the mask to the latents
mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
# encode the mask image into latents space so we can concatenate it to the latents
masked_image_latents = self.pipe.vae.encode(masked_image).latent_dist.sample(generator=generator)
masked_image_latents = 0.18215 * masked_image_latents
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
mask = mask.repeat(num_images_per_prompt, 1, 1, 1)
masked_image_latents = masked_image_latents.repeat(num_images_per_prompt, 1, 1, 1)
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
masked_image_latents = (
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
)
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
num_channels_latents = self.pipe.vae.config.latent_channels
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
latents = torch.randn(latents_shape, generator=generator, device=self.pipe.device, dtype=latents_dtype)
latents = latents.to(self.pipe.device)
# set timesteps
self.pipe.scheduler.set_timesteps(num_inference_steps)
timesteps_tensor = self.pipe.scheduler.timesteps.to(self.pipe.device)
latents = latents * self.pipe.scheduler.init_noise_sigma
extra_step_kwargs = {}
# collect latents
list_latents_out = []
for i, t in enumerate(timesteps_tensor):
if do_inject_latents:
# Inject latent at right place
if i < idx_start:
continue
elif i == idx_start:
latents = latents_for_injection.clone()
if i == idx_stop:
return list_latents_out
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
list_latents_out.append(latents.clone())
if return_image:
return self.latent2image(latents)
else:
return list_latents_out
@torch.no_grad()
def latent2image(
self,
latents: torch.FloatTensor
):
r"""
Returns an image provided a latent representation from diffusion.
Args:
latents: torch.FloatTensor
Result of the diffusion process.
"""
latents = 1 / 0.18215 * latents
image = self.pipe.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
image = (image[0,:,:,:] * 255).astype(np.uint8)
return image
@torch.no_grad() @torch.no_grad()
def get_text_embeddings( def get_text_embeddings(
@ -721,82 +462,14 @@ class LatentBlending():
): ):
r""" r"""
Computes the text embeddings provided a string with a prompts. Computes the text embeddings provided a string with a prompts.
Adapted from diffusers (https://github.com/huggingface/diffusers) Adapted from stable diffusion repo
Args: Args:
prompt: str prompt: str
ABC trending on artstation painted by Old Greg. ABC trending on artstation painted by Old Greg.
""" """
if self.negative_prompt is None:
uncond_tokens = [""]
else:
if isinstance(self.negative_prompt, str):
uncond_tokens = [self.negative_prompt]
batch_size = 1 return self.sdh.get_text_embedding(prompt)
num_images_per_prompt = 1
do_classifier_free_guidance = True
# get prompt text embeddings
text_inputs = self.pipe.tokenizer(
prompt,
padding="max_length",
max_length=self.pipe.tokenizer.model_max_length,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
# if text_input_ids.shape[-1] > self.pipe.tokenizer.modeLatentBlendingl_max_length:
# removed_text = self.pipe.tokenizer.batch_decode(text_input_ids[:, self.pipe.tokenizer.model_max_length :])
# text_input_ids = text_input_ids[:, : self.pipe.tokenizer.model_max_length]
text_embeddings = self.pipe.text_encoder(text_input_ids.to(self.pipe.device))[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
max_length = text_input_ids.shape[-1]
uncond_input = self.pipe.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.pipe.text_encoder(uncond_input.input_ids.to(self.pipe.device))[0]
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings
def prepare_mask_and_masked_image(self, image, mask):
r"""
Mask and image preparation for inpainting.
Adapted from diffusers (https://github.com/huggingface/diffusers)
Args:
image:
Source image
mask:
Mask image
"""
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
return mask, masked_image
def randomize_seed(self): def randomize_seed(self):
r""" r"""
@ -858,7 +531,7 @@ def interpolate_spherical(p0, p1, fract_mixing: float):
r""" r"""
Helper function to correctly mix two random variables using spherical interpolation. Helper function to correctly mix two random variables using spherical interpolation.
See https://en.wikipedia.org/wiki/Slerp See https://en.wikipedia.org/wiki/Slerp
The function will always cast up to float64 for sake of extra precision. The function will always cast up to float64 for sake of extra 4.
Args: Args:
p0: p0:
First tensor for interpolation First tensor for interpolation
@ -992,7 +665,7 @@ def add_frames_linear_interp(
nmb_frames_to_insert = nmb_frames_to_insert.astype(np.int32) nmb_frames_to_insert = nmb_frames_to_insert.astype(np.int32)
list_imgs_interp = [] list_imgs_interp = []
for i in tqdm(range(len(list_imgs_float)-1), desc="STAGE linear interp"): for i in range(len(list_imgs_float)-1):#, desc="STAGE linear interp"):
img0 = list_imgs_float[i] img0 = list_imgs_float[i]
img1 = list_imgs_float[i+1] img1 = list_imgs_float[i+1]
list_imgs_interp.append(img0.astype(np.uint8)) list_imgs_interp.append(img0.astype(np.uint8))
@ -1051,8 +724,7 @@ def get_branching(
""" """
#%%
if quality == 'lowest': if quality == 'lowest':
num_inference_steps = 12 num_inference_steps = 12
nmb_branches_final = 5 nmb_branches_final = 5
@ -1095,12 +767,49 @@ def get_branching(
print(f"list_injection_idx: {list_injection_idx_clean}") print(f"list_injection_idx: {list_injection_idx_clean}")
print(f"list_nmb_branches: {list_nmb_branches_clean}") print(f"list_nmb_branches: {list_nmb_branches_clean}")
return num_inference_steps, list_injection_idx_clean, list_nmb_branches_clean # return num_inference_steps, list_injection_idx_clean, list_nmb_branches_clean
#%% le main #%% le main
if __name__ == "__main__": if __name__ == "__main__":
sys.path.append('../stablediffusion/ldm')
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
num_inference_steps = 20 # Number of diffusion interations
sdh = StableDiffusionHolder(num_inference_steps)
# fp_ckpt = "../stable_diffusion_models/ckpt/768-v-ema.ckpt"
# fp_config = '../stablediffusion/configs/stable-diffusion/v2-inference-v.yaml'
fp_ckpt= "../stable_diffusion_models/ckpt/512-base-ema.ckpt"
fp_config = '../stablediffusion/configs//stable-diffusion/v2-inference.yaml'
sdh.init_model(fp_ckpt, fp_config)
#%%
list_nmb_branches = [2, 3, 10, 24] # Branching structure: how many branches
list_injection_strength = [0.0, 0.6, 0.8, 0.9] # Branching structure: how deep is the blending
width = 512
height = 512
guidance_scale = 5
fixed_seeds = [993621550, 280335986]
device = "cuda:0"
lb = LatentBlending(sdh, device, height, width, num_inference_steps, guidance_scale)
prompt1 = "photo of a forest covered in white flowers, ambient light, very detailed, magic"
prompt2 = "photo of an eerie statue surrounded by ferns and vines, analog photograph kodak portra, mystical ambience, incredible detail"
lb.set_prompt1(prompt1)
lb.set_prompt2(prompt2)
lx = lb.run_transition(list_nmb_branches, list_injection_strength)
#%%
xxx
device = "cuda:0" device = "cuda:0"
model_path = "../stable_diffusion_models/stable-diffusion-v1-5" model_path = "../stable_diffusion_models/stable-diffusion-v1-5"
@ -1110,7 +819,7 @@ if __name__ == "__main__":
clip_sample=False, clip_sample=False,
set_alpha_to_one=False) set_alpha_to_one=False)
pipe = StableDiffusionPipeline.from_pretrained( pipe = StableDiffusionPipeline.from_Union[StableDiffusionInpaintPipeline, StableDiffusionPipeline],pretrained(
model_path, model_path,
revision="fp16", revision="fp16",
torch_dtype=torch.float16, torch_dtype=torch.float16,