updated latent blending for stable diffusion holder
This commit is contained in:
parent
3113bfaed4
commit
9fe1559b1d
|
@ -29,44 +29,33 @@ import torch
|
||||||
from movie_util import MovieSaver
|
from movie_util import MovieSaver
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
from latent_blending import LatentBlending, add_frames_linear_interp
|
from latent_blending import LatentBlending, add_frames_linear_interp
|
||||||
|
from stable_diffusion_holder import StableDiffusionHolder
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
#%% First let us spawn a diffusers pipe using DDIMScheduler
|
#%% First let us spawn a stable diffusion holder
|
||||||
device = "cuda:0"
|
device = "cuda:0"
|
||||||
model_path = "../stable_diffusion_models/stable-diffusion-v1-5"
|
num_inference_steps = 20 # Number of diffusion interations
|
||||||
|
fp_ckpt = "../stable_diffusion_models/ckpt/768-v-ema.ckpt"
|
||||||
|
fp_config = '../stablediffusion/configs/stable-diffusion/v2-inference-v.yaml'
|
||||||
|
|
||||||
scheduler = DDIMScheduler(beta_start=0.00085,
|
sdh = StableDiffusionHolder(fp_ckpt, fp_config, device, num_inference_steps=num_inference_steps)
|
||||||
beta_end=0.012,
|
|
||||||
beta_schedule="scaled_linear",
|
|
||||||
clip_sample=False,
|
|
||||||
set_alpha_to_one=False)
|
|
||||||
|
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
|
||||||
model_path,
|
|
||||||
revision="fp16",
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
scheduler=scheduler,
|
|
||||||
use_auth_token=True
|
|
||||||
)
|
|
||||||
pipe = pipe.to(device)
|
|
||||||
|
|
||||||
#%% Next let's set up all parameters
|
#%% Next let's set up all parameters
|
||||||
# FIXME below fix numbers
|
# FIXME below fix numbers
|
||||||
# We want 20 diffusion steps in total, begin with 2 branches, have 3 branches at step 12 (=0.6*20)
|
# We want 20 diffusion steps in total, begin with 2 branches, have 3 branches at step 12 (=0.6*20)
|
||||||
# 10 branches at step 16 (=0.8*20) and 24 branches at step 18 (=0.9*20)
|
# 10 branches at step 16 (=0.8*20) and 24 branches at step 18 (=0.9*20)
|
||||||
# Furthermore we want seed 993621550 for keyframeA and seed 54878562 for keyframeB ()
|
# Furthermore we want seed 993621550 for keyframeA and seed 54878562 for keyframeB ()
|
||||||
|
|
||||||
num_inference_steps = 20 # Number of diffusion interations
|
|
||||||
list_nmb_branches = [2, 3, 10, 24] # Branching structure: how many branches
|
list_nmb_branches = [2, 3, 10, 24] # Branching structure: how many branches
|
||||||
list_injection_strength = [0.0, 0.6, 0.8, 0.9] # Branching structure: how deep is the blending
|
list_injection_strength = [0.0, 0.6, 0.8, 0.9] # Branching structure: how deep is the blending
|
||||||
width = 512
|
width = 768
|
||||||
height = 512
|
height = 768
|
||||||
guidance_scale = 5
|
guidance_scale = 5
|
||||||
fixed_seeds = [993621550, 280335986]
|
fixed_seeds = [993621550, 280335986]
|
||||||
|
|
||||||
lb = LatentBlending(pipe, device, height, width, num_inference_steps, guidance_scale)
|
lb = LatentBlending(sdh, num_inference_steps, guidance_scale)
|
||||||
prompt1 = "photo of a beautiful forest covered in white flowers, ambient light, very detailed, magic"
|
prompt1 = "photo of a beautiful forest covered in white flowers, ambient light, very detailed, magic"
|
||||||
prompt2 = "photo of an eerie statue surrounded by ferns and vines, analog photograph kodak portra, mystical ambience, incredible detail"
|
prompt2 = "photo of an golden statue with a funny hat, surrounded by ferns and vines, grainy analog photograph,, mystical ambience, incredible detail"
|
||||||
lb.set_prompt1(prompt1)
|
lb.set_prompt1(prompt1)
|
||||||
lb.set_prompt2(prompt2)
|
lb.set_prompt2(prompt2)
|
||||||
|
|
||||||
|
@ -78,7 +67,7 @@ fps = 60
|
||||||
imgs_transition_ext = add_frames_linear_interp(imgs_transition, duration_transition, fps)
|
imgs_transition_ext = add_frames_linear_interp(imgs_transition, duration_transition, fps)
|
||||||
|
|
||||||
# movie saving
|
# movie saving
|
||||||
fp_movie = f"/home/lugo/tmp/latentblending/bobo_incoming.mp4"
|
fp_movie = "/home/lugo/tmp/latentblending/bobo_incoming.mp4"
|
||||||
if os.path.isfile(fp_movie):
|
if os.path.isfile(fp_movie):
|
||||||
os.remove(fp_movie)
|
os.remove(fp_movie)
|
||||||
ms = MovieSaver(fp_movie, fps=fps)
|
ms = MovieSaver(fp_movie, fps=fps)
|
||||||
|
|
|
@ -27,22 +27,19 @@ from diffusers import StableDiffusionInpaintPipeline
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
import torch
|
||||||
from movie_man import MovieSaver
|
from movie_util import MovieSaver
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
from latent_blending import LatentBlending, add_frames_linear_interp
|
from latent_blending import LatentBlending, add_frames_linear_interp
|
||||||
|
from stable_diffusion_holder import StableDiffusionHolder
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
#%% First let us spawn a diffusers pipe using DDIMScheduler
|
#%% First let us spawn a stable diffusion holder
|
||||||
device = "cuda:0"
|
device = "cuda:0"
|
||||||
model_path = "../stable_diffusion_models/stable-diffusion-inpainting"
|
num_inference_steps = 20 # Number of diffusion interations
|
||||||
|
fp_ckpt= "../stable_diffusion_models/ckpt/512-inpainting-ema.ckpt"
|
||||||
|
fp_config = '../stablediffusion/configs//stable-diffusion/v2-inpainting-inference.yaml'
|
||||||
|
|
||||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
sdh = StableDiffusionHolder(fp_ckpt, fp_config, device, num_inference_steps=num_inference_steps)
|
||||||
model_path,
|
|
||||||
revision="fp16",
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
safety_checker=None
|
|
||||||
)
|
|
||||||
pipe = pipe.to(device)
|
|
||||||
|
|
||||||
|
|
||||||
#%% Let's make a source image and mask.
|
#%% Let's make a source image and mask.
|
||||||
|
@ -52,7 +49,7 @@ num_inference_steps = 30
|
||||||
guidance_scale = 5
|
guidance_scale = 5
|
||||||
fixed_seeds = [629575320, 670154945]
|
fixed_seeds = [629575320, 670154945]
|
||||||
|
|
||||||
lb = LatentBlending(pipe, device, height, width, num_inference_steps, guidance_scale)
|
lb = LatentBlending(sdh, num_inference_steps, guidance_scale)
|
||||||
prompt1 = "photo of a futuristic alien temple in a desert, mystic, glowing, organic, intricate, sci-fi movie, mesmerizing, scary"
|
prompt1 = "photo of a futuristic alien temple in a desert, mystic, glowing, organic, intricate, sci-fi movie, mesmerizing, scary"
|
||||||
lb.set_prompt1(prompt1)
|
lb.set_prompt1(prompt1)
|
||||||
lb.init_inpainting(init_empty=True)
|
lb.init_inpainting(init_empty=True)
|
||||||
|
@ -77,7 +74,6 @@ height = 512
|
||||||
guidance_scale = 5
|
guidance_scale = 5
|
||||||
fixed_seeds = [993621550, 280335986]
|
fixed_seeds = [993621550, 280335986]
|
||||||
|
|
||||||
lb = LatentBlending(pipe, device, height, width, num_inference_steps, guidance_scale)
|
|
||||||
prompt1 = "photo of a futuristic alien temple in a desert, mystic, glowing, organic, intricate, sci-fi movie, mesmerizing, scary"
|
prompt1 = "photo of a futuristic alien temple in a desert, mystic, glowing, organic, intricate, sci-fi movie, mesmerizing, scary"
|
||||||
prompt2 = "aerial photo of a futuristic alien temple in a coastal area, waves clashing"
|
prompt2 = "aerial photo of a futuristic alien temple in a coastal area, waves clashing"
|
||||||
lb.set_prompt1(prompt1)
|
lb.set_prompt1(prompt1)
|
||||||
|
@ -92,12 +88,11 @@ fps = 60
|
||||||
imgs_transition_ext = add_frames_linear_interp(imgs_transition, duration_transition, fps)
|
imgs_transition_ext = add_frames_linear_interp(imgs_transition, duration_transition, fps)
|
||||||
|
|
||||||
# movie saving
|
# movie saving
|
||||||
fp_movie = f"/home/lugo/tmp/latentblending/bobo_incoming.mp4"
|
fp_movie = "/home/lugo/tmp/latentblending/bobo_incoming.mp4"
|
||||||
if os.path.isfile(fp_movie):
|
if os.path.isfile(fp_movie):
|
||||||
os.remove(fp_movie)
|
os.remove(fp_movie)
|
||||||
ms = MovieSaver(fp_movie, fps=fps, profile='save')
|
ms = MovieSaver(fp_movie, fps=fps, shape_hw=[lb.height, lb.width])
|
||||||
for img in tqdm(imgs_transition_ext):
|
for img in tqdm(imgs_transition_ext):
|
||||||
ms.write_frame(img)
|
ms.write_frame(img)
|
||||||
ms.finalize()
|
ms.finalize()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -38,15 +38,14 @@ from typing import Callable, List, Optional, Union
|
||||||
import inspect
|
import inspect
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from torch import autocast
|
||||||
|
from contextlib import nullcontext
|
||||||
#%%
|
#%%
|
||||||
class LatentBlending():
|
class LatentBlending():
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
pipe: Union[StableDiffusionInpaintPipeline, StableDiffusionPipeline],
|
sdh: None,
|
||||||
device: str,
|
|
||||||
height: int = 512,
|
|
||||||
width: int = 512,
|
|
||||||
num_inference_steps: int = 30,
|
num_inference_steps: int = 30,
|
||||||
guidance_scale: float = 7.5,
|
guidance_scale: float = 7.5,
|
||||||
seed: int = 420,
|
seed: int = 420,
|
||||||
|
@ -54,8 +53,7 @@ class LatentBlending():
|
||||||
r"""
|
r"""
|
||||||
Initializes the latent blending class.
|
Initializes the latent blending class.
|
||||||
Args:
|
Args:
|
||||||
device: str
|
FIXME XXX
|
||||||
Compute device, e.g. cuda:0
|
|
||||||
height: int
|
height: int
|
||||||
Height of the desired output image. The model was trained on 512.
|
Height of the desired output image. The model was trained on 512.
|
||||||
width: int
|
width: int
|
||||||
|
@ -72,19 +70,15 @@ class LatentBlending():
|
||||||
Random seed.
|
Random seed.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
self.sdh = sdh
|
||||||
self.pipe = pipe
|
|
||||||
self.device = device
|
|
||||||
self.guidance_scale = guidance_scale
|
|
||||||
self.num_inference_steps = num_inference_steps
|
self.num_inference_steps = num_inference_steps
|
||||||
self.width = width
|
self.sdh.num_inference_steps = num_inference_steps
|
||||||
self.height = height
|
self.device = self.sdh.device
|
||||||
|
self.guidance_scale = guidance_scale
|
||||||
|
self.width = self.sdh.width
|
||||||
|
self.height = self.sdh.height
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|
||||||
# Inits
|
|
||||||
self.check_asserts()
|
|
||||||
self.init_mode()
|
|
||||||
|
|
||||||
# Initialize vars
|
# Initialize vars
|
||||||
self.prompt1 = ""
|
self.prompt1 = ""
|
||||||
self.prompt2 = ""
|
self.prompt2 = ""
|
||||||
|
@ -99,60 +93,22 @@ class LatentBlending():
|
||||||
self.stop_diffusion = False
|
self.stop_diffusion = False
|
||||||
self.negative_prompt = None
|
self.negative_prompt = None
|
||||||
|
|
||||||
|
self.init_mode()
|
||||||
def check_asserts(self):
|
|
||||||
r"""
|
|
||||||
Runs Minimal set of sanity checks.
|
|
||||||
"""
|
|
||||||
assert self.pipe.scheduler._class_name == 'DDIMScheduler', 'Currently only the DDIMScheduler is supported.'
|
|
||||||
|
|
||||||
|
|
||||||
def init_mode(self):
|
def init_mode(self, mode='standard'):
|
||||||
r"""
|
r"""
|
||||||
Automatically sets the mode of this class, depending on the supplied pipeline.
|
Automatically sets the mode of this class, depending on the supplied pipeline.
|
||||||
|
FIXME XXX
|
||||||
"""
|
"""
|
||||||
if self.pipe._class_name == 'StableDiffusionInpaintPipeline':
|
if mode == 'inpaint':
|
||||||
self.mask_empty = Image.fromarray(255*np.ones([self.width, self.height], dtype=np.uint8))
|
self.sdh.image_source = None
|
||||||
self.image_empty = Image.fromarray(np.zeros([self.width, self.height, 3], dtype=np.uint8))
|
self.sdh.mask_image = None
|
||||||
self.image_source = None
|
|
||||||
self.mask_image = None
|
|
||||||
self.mode = 'inpaint'
|
self.mode = 'inpaint'
|
||||||
else:
|
else:
|
||||||
self.mode = 'standard'
|
self.mode = 'standard'
|
||||||
|
|
||||||
|
|
||||||
def init_inpainting(
|
|
||||||
self,
|
|
||||||
image_source: Union[Image.Image, np.ndarray] = None,
|
|
||||||
mask_image: Union[Image.Image, np.ndarray] = None,
|
|
||||||
init_empty: Optional[bool] = False,
|
|
||||||
):
|
|
||||||
r"""
|
|
||||||
Initializes inpainting with a source and maks image.
|
|
||||||
Args:
|
|
||||||
image_source: Union[Image.Image, np.ndarray]
|
|
||||||
Source image onto which the mask will be applied.
|
|
||||||
mask_image: Union[Image.Image, np.ndarray]
|
|
||||||
Mask image, value = 0 will stay untouched, value = 255 subjet to diffusion
|
|
||||||
init_empty: Optional[bool]:
|
|
||||||
Initialize inpainting with an empty image and mask, effectively disabling inpainting.
|
|
||||||
"""
|
|
||||||
assert self.mode == 'inpaint', 'Initialize class with an inpainting pipeline!'
|
|
||||||
if not init_empty:
|
|
||||||
assert image_source is not None, "init_inpainting: you need to provide image_source"
|
|
||||||
assert mask_image is not None, "init_inpainting: you need to provide mask_image"
|
|
||||||
if type(image_source) == np.ndarray:
|
|
||||||
image_source = Image.fromarray(image_source)
|
|
||||||
self.image_source = image_source
|
|
||||||
|
|
||||||
if type(mask_image) == np.ndarray:
|
|
||||||
mask_image = Image.fromarray(mask_image)
|
|
||||||
self.mask_image = mask_image
|
|
||||||
else:
|
|
||||||
self.mask_image = self.mask_empty
|
|
||||||
self.image_source = self.image_empty
|
|
||||||
|
|
||||||
|
|
||||||
def set_prompt1(self, prompt: str):
|
def set_prompt1(self, prompt: str):
|
||||||
r"""
|
r"""
|
||||||
Sets the first prompt (for the first keyframe) including text embeddings.
|
Sets the first prompt (for the first keyframe) including text embeddings.
|
||||||
|
@ -238,6 +194,9 @@ class LatentBlending():
|
||||||
# Process interruption variable
|
# Process interruption variable
|
||||||
self.stop_diffusion = False
|
self.stop_diffusion = False
|
||||||
|
|
||||||
|
# Ensure correct num_inference_steps in holder
|
||||||
|
self.sdh.num_inference_steps = self.num_inference_steps
|
||||||
|
|
||||||
# Recycling? There are requirements
|
# Recycling? There are requirements
|
||||||
if recycle_img1 or recycle_img2:
|
if recycle_img1 or recycle_img2:
|
||||||
if self.list_nmb_branches_prev == []:
|
if self.list_nmb_branches_prev == []:
|
||||||
|
@ -291,11 +250,11 @@ class LatentBlending():
|
||||||
self.tree_status[t_block][idx_branch] = 'untouched'
|
self.tree_status[t_block][idx_branch] = 'untouched'
|
||||||
if recycle_img1:
|
if recycle_img1:
|
||||||
self.tree_status[t_block][0] = 'computed'
|
self.tree_status[t_block][0] = 'computed'
|
||||||
self.tree_final_imgs[0] = self.latent2image(self.tree_latents[-1][0][-1])
|
self.tree_final_imgs[0] = self.sdh.latent2image(self.tree_latents[-1][0][-1])
|
||||||
self.tree_final_imgs_timing[0] = 0
|
self.tree_final_imgs_timing[0] = 0
|
||||||
if recycle_img2:
|
if recycle_img2:
|
||||||
self.tree_status[t_block][-1] = 'computed'
|
self.tree_status[t_block][-1] = 'computed'
|
||||||
self.tree_final_imgs[-1] = self.latent2image(self.tree_latents[-1][-1][-1])
|
self.tree_final_imgs[-1] = self.sdh.latent2image(self.tree_latents[-1][-1][-1])
|
||||||
self.tree_final_imgs_timing[-1] = 0
|
self.tree_final_imgs_timing[-1] = 0
|
||||||
|
|
||||||
# setup compute order: goal: try to get last branch computed asap.
|
# setup compute order: goal: try to get last branch computed asap.
|
||||||
|
@ -365,7 +324,7 @@ class LatentBlending():
|
||||||
|
|
||||||
# Convert latents to image directly for the last t_block
|
# Convert latents to image directly for the last t_block
|
||||||
if t_block == nmb_blocks_time-1:
|
if t_block == nmb_blocks_time-1:
|
||||||
self.tree_final_imgs[idx_branch] = self.latent2image(list_latents[-1])
|
self.tree_final_imgs[idx_branch] = self.sdh.latent2image(list_latents[-1])
|
||||||
self.tree_final_imgs_timing[idx_branch] = time.time() - time_start
|
self.tree_final_imgs_timing[idx_branch] = time.time() - time_start
|
||||||
|
|
||||||
return self.tree_final_imgs
|
return self.tree_final_imgs
|
||||||
|
@ -406,6 +365,8 @@ class LatentBlending():
|
||||||
The duration of your movie will be duration_single_trans * len(list_prompts)
|
The duration of your movie will be duration_single_trans * len(list_prompts)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Ensure correct
|
||||||
if list_seeds is None:
|
if list_seeds is None:
|
||||||
list_seeds = list(np.random.randint(0, 10e10, len(list_prompts)))
|
list_seeds = list(np.random.randint(0, 10e10, len(list_prompts)))
|
||||||
|
|
||||||
|
@ -424,7 +385,7 @@ class LatentBlending():
|
||||||
recycle_img1 = True
|
recycle_img1 = True
|
||||||
|
|
||||||
local_seeds = [list_seeds[i], list_seeds[i+1]]
|
local_seeds = [list_seeds[i], list_seeds[i+1]]
|
||||||
list_imgs = lb.run_transition(list_nmb_branches, list_injection_strength=list_injection_strength, list_injection_idx=list_injection_idx, recycle_img1=recycle_img1, fixed_seeds=local_seeds)
|
list_imgs = self.run_transition(list_nmb_branches, list_injection_strength=list_injection_strength, list_injection_idx=list_injection_idx, recycle_img1=recycle_img1, fixed_seeds=local_seeds)
|
||||||
list_imgs_interp = add_frames_linear_interp(list_imgs, fps, duration_single_trans)
|
list_imgs_interp = add_frames_linear_interp(list_imgs, fps, duration_single_trans)
|
||||||
|
|
||||||
# Save movie frame
|
# Save movie frame
|
||||||
|
@ -462,257 +423,37 @@ class LatentBlending():
|
||||||
Optionally return image directly
|
Optionally return image directly
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Ensure correct num_inference_steps in Holder
|
||||||
|
self.sdh.num_inference_steps = self.num_inference_steps
|
||||||
|
|
||||||
if self.mode == 'standard':
|
if self.mode == 'standard':
|
||||||
return self.run_diffusion_standard(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
|
return self.sdh.run_diffusion_standard(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
|
||||||
|
|
||||||
elif self.mode == 'inpaint':
|
elif self.mode == 'inpaint':
|
||||||
assert self.image_source is not None, "image_source is None. Please run init_inpainting first."
|
assert self.sdh.image_source is not None, "image_source is None. Please run init_inpainting first."
|
||||||
assert self.mask_image is not None, "image_source is None. Please run init_inpainting first."
|
assert self.sdh.mask_image is not None, "image_source is None. Please run init_inpainting first."
|
||||||
return self.run_diffusion_inpaint(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
|
return self.sdh.run_diffusion_inpaint(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
|
||||||
|
|
||||||
|
def init_inpainting(
|
||||||
@torch.no_grad()
|
|
||||||
def run_diffusion_standard(
|
|
||||||
self,
|
self,
|
||||||
text_embeddings: torch.FloatTensor,
|
image_source: Union[Image.Image, np.ndarray] = None,
|
||||||
latents_for_injection: torch.FloatTensor = None,
|
mask_image: Union[Image.Image, np.ndarray] = None,
|
||||||
idx_start: int = -1,
|
init_empty: Optional[bool] = False,
|
||||||
idx_stop: int = -1,
|
|
||||||
return_image: Optional[bool] = False
|
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Runs regular diffusion. Returns a list of latents that were computed.
|
Initializes inpainting with a source and maks image.
|
||||||
Adaptations allow to supply
|
|
||||||
a) starting index for diffusion
|
|
||||||
b) stopping index for diffusion
|
|
||||||
c) latent representations that are injected at the starting index
|
|
||||||
Furthermore the intermittent latents are collected and returned.
|
|
||||||
Adapted from diffusers (https://github.com/huggingface/diffusers)
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text_embeddings: torch.FloatTensor
|
image_source: Union[Image.Image, np.ndarray]
|
||||||
Text embeddings used for diffusion
|
Source image onto which the mask will be applied.
|
||||||
latents_for_injection: torch.FloatTensor
|
mask_image: Union[Image.Image, np.ndarray]
|
||||||
Latents that are used for injection
|
Mask image, value = 0 will stay untouched, value = 255 subjet to diffusion
|
||||||
idx_start: int
|
init_empty: Optional[bool]:
|
||||||
Index of the diffusion process start and where the latents_for_injection are injected
|
Initialize inpainting with an empty image and mask, effectively disabling inpainting,
|
||||||
idx_stop: int
|
useful for generating a first image for transitions using diffusion.
|
||||||
Index of the diffusion process end.
|
|
||||||
return_image: Optional[bool]
|
|
||||||
Optionally return image directly
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if latents_for_injection is None:
|
self.init_mode('inpaint')
|
||||||
do_inject_latents = False
|
self.sdh.init_inpainting(image_source, mask_image, init_empty)
|
||||||
else:
|
|
||||||
do_inject_latents = True
|
|
||||||
|
|
||||||
generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
|
|
||||||
batch_size = 1
|
|
||||||
height = self.height
|
|
||||||
width = self.width
|
|
||||||
num_inference_steps = self.num_inference_steps
|
|
||||||
num_images_per_prompt = 1
|
|
||||||
do_classifier_free_guidance = True
|
|
||||||
|
|
||||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
|
||||||
bs_embed, seq_len, _ = text_embeddings.shape
|
|
||||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
|
||||||
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
|
||||||
|
|
||||||
# set timesteps
|
|
||||||
self.pipe.scheduler.set_timesteps(num_inference_steps)
|
|
||||||
|
|
||||||
# Some schedulers like PNDM have timesteps as arrays
|
|
||||||
# It's more optimized to move all timesteps to correct device beforehand
|
|
||||||
timesteps_tensor = self.pipe.scheduler.timesteps.to(self.pipe.device)
|
|
||||||
|
|
||||||
if not do_inject_latents:
|
|
||||||
# get the initial random noise unless the user supplied it
|
|
||||||
latents_shape = (batch_size * num_images_per_prompt, self.pipe.unet.in_channels, height // 8, width // 8)
|
|
||||||
latents_dtype = text_embeddings.dtype
|
|
||||||
latents = torch.randn(latents_shape, generator=generator, device=self.pipe.device, dtype=latents_dtype)
|
|
||||||
|
|
||||||
# scale the initial noise by the standard deviation required by the scheduler
|
|
||||||
latents = latents * self.pipe.scheduler.init_noise_sigma
|
|
||||||
extra_step_kwargs = {}
|
|
||||||
|
|
||||||
# collect latents
|
|
||||||
list_latents_out = []
|
|
||||||
for i, t in enumerate(timesteps_tensor):
|
|
||||||
|
|
||||||
|
|
||||||
if do_inject_latents:
|
|
||||||
# Inject latent at right place
|
|
||||||
if i < idx_start:
|
|
||||||
continue
|
|
||||||
elif i == idx_start:
|
|
||||||
latents = latents_for_injection.clone()
|
|
||||||
|
|
||||||
if i == idx_stop:
|
|
||||||
return list_latents_out
|
|
||||||
|
|
||||||
# expand the latents if we are doing classifier free guidance
|
|
||||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
|
||||||
latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t)
|
|
||||||
|
|
||||||
# predict the noise residual
|
|
||||||
noise_pred = self.pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
|
||||||
|
|
||||||
# perform guidance
|
|
||||||
if do_classifier_free_guidance:
|
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
|
||||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
|
||||||
latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
|
||||||
|
|
||||||
list_latents_out.append(latents.clone())
|
|
||||||
|
|
||||||
if return_image:
|
|
||||||
return self.latent2image(latents)
|
|
||||||
else:
|
|
||||||
return list_latents_out
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def run_diffusion_inpaint(
|
|
||||||
self,
|
|
||||||
text_embeddings: torch.FloatTensor,
|
|
||||||
latents_for_injection: torch.FloatTensor = None,
|
|
||||||
idx_start: int = -1,
|
|
||||||
idx_stop: int = -1,
|
|
||||||
return_image: Optional[bool] = False
|
|
||||||
):
|
|
||||||
r"""
|
|
||||||
Runs inpaint-based diffusion. Returns a list of latents that were computed.
|
|
||||||
Adaptations allow to supply
|
|
||||||
a) starting index for diffusion
|
|
||||||
b) stopping index for diffusion
|
|
||||||
c) latent representations that are injected at the starting index
|
|
||||||
Furthermore the intermittent latents are collected and returned.
|
|
||||||
|
|
||||||
Adapted from diffusers (https://github.com/huggingface/diffusers)
|
|
||||||
Args:
|
|
||||||
text_embeddings: torch.FloatTensor
|
|
||||||
Text embeddings used for diffusion
|
|
||||||
latents_for_injection: torch.FloatTensor
|
|
||||||
Latents that are used for injection
|
|
||||||
idx_start: int
|
|
||||||
Index of the diffusion process start and where the latents_for_injection are injected
|
|
||||||
idx_stop: int
|
|
||||||
Index of the diffusion process end.
|
|
||||||
return_image: Optional[bool]
|
|
||||||
Optionally return image directly
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
if latents_for_injection is None:
|
|
||||||
do_inject_latents = False
|
|
||||||
else:
|
|
||||||
do_inject_latents = True
|
|
||||||
|
|
||||||
generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
|
|
||||||
batch_size = 1
|
|
||||||
height = self.height
|
|
||||||
width = self.width
|
|
||||||
num_inference_steps = self.num_inference_steps
|
|
||||||
num_images_per_prompt = 1
|
|
||||||
do_classifier_free_guidance = True
|
|
||||||
|
|
||||||
# prepare mask and masked_image
|
|
||||||
mask, masked_image = self.prepare_mask_and_masked_image(self.image_source, self.mask_image)
|
|
||||||
mask = mask.to(device=self.pipe.device, dtype=text_embeddings.dtype)
|
|
||||||
masked_image = masked_image.to(device=self.pipe.device, dtype=text_embeddings.dtype)
|
|
||||||
|
|
||||||
# resize the mask to latents shape as we concatenate the mask to the latents
|
|
||||||
mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
|
|
||||||
|
|
||||||
# encode the mask image into latents space so we can concatenate it to the latents
|
|
||||||
masked_image_latents = self.pipe.vae.encode(masked_image).latent_dist.sample(generator=generator)
|
|
||||||
masked_image_latents = 0.18215 * masked_image_latents
|
|
||||||
|
|
||||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
|
||||||
mask = mask.repeat(num_images_per_prompt, 1, 1, 1)
|
|
||||||
masked_image_latents = masked_image_latents.repeat(num_images_per_prompt, 1, 1, 1)
|
|
||||||
|
|
||||||
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
|
|
||||||
masked_image_latents = (
|
|
||||||
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
|
|
||||||
)
|
|
||||||
|
|
||||||
num_channels_mask = mask.shape[1]
|
|
||||||
num_channels_masked_image = masked_image_latents.shape[1]
|
|
||||||
|
|
||||||
num_channels_latents = self.pipe.vae.config.latent_channels
|
|
||||||
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
|
|
||||||
latents_dtype = text_embeddings.dtype
|
|
||||||
latents = torch.randn(latents_shape, generator=generator, device=self.pipe.device, dtype=latents_dtype)
|
|
||||||
latents = latents.to(self.pipe.device)
|
|
||||||
# set timesteps
|
|
||||||
self.pipe.scheduler.set_timesteps(num_inference_steps)
|
|
||||||
timesteps_tensor = self.pipe.scheduler.timesteps.to(self.pipe.device)
|
|
||||||
latents = latents * self.pipe.scheduler.init_noise_sigma
|
|
||||||
extra_step_kwargs = {}
|
|
||||||
# collect latents
|
|
||||||
list_latents_out = []
|
|
||||||
|
|
||||||
for i, t in enumerate(timesteps_tensor):
|
|
||||||
if do_inject_latents:
|
|
||||||
# Inject latent at right place
|
|
||||||
if i < idx_start:
|
|
||||||
continue
|
|
||||||
elif i == idx_start:
|
|
||||||
latents = latents_for_injection.clone()
|
|
||||||
|
|
||||||
if i == idx_stop:
|
|
||||||
return list_latents_out
|
|
||||||
|
|
||||||
# expand the latents if we are doing classifier free guidance
|
|
||||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
|
||||||
# concat latents, mask, masked_image_latents in the channel dimension
|
|
||||||
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
|
|
||||||
|
|
||||||
latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t)
|
|
||||||
|
|
||||||
# predict the noise residual
|
|
||||||
noise_pred = self.pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
|
||||||
|
|
||||||
# perform guidance
|
|
||||||
if do_classifier_free_guidance:
|
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
|
||||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
|
||||||
latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
|
||||||
|
|
||||||
list_latents_out.append(latents.clone())
|
|
||||||
|
|
||||||
if return_image:
|
|
||||||
return self.latent2image(latents)
|
|
||||||
else:
|
|
||||||
return list_latents_out
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def latent2image(
|
|
||||||
self,
|
|
||||||
latents: torch.FloatTensor
|
|
||||||
):
|
|
||||||
r"""
|
|
||||||
Returns an image provided a latent representation from diffusion.
|
|
||||||
Args:
|
|
||||||
latents: torch.FloatTensor
|
|
||||||
Result of the diffusion process.
|
|
||||||
"""
|
|
||||||
|
|
||||||
latents = 1 / 0.18215 * latents
|
|
||||||
image = self.pipe.vae.decode(latents).sample
|
|
||||||
image = (image / 2 + 0.5).clamp(0, 1)
|
|
||||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
|
||||||
image = (image[0,:,:,:] * 255).astype(np.uint8)
|
|
||||||
|
|
||||||
return image
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_text_embeddings(
|
def get_text_embeddings(
|
||||||
|
@ -721,82 +462,14 @@ class LatentBlending():
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Computes the text embeddings provided a string with a prompts.
|
Computes the text embeddings provided a string with a prompts.
|
||||||
Adapted from diffusers (https://github.com/huggingface/diffusers)
|
Adapted from stable diffusion repo
|
||||||
Args:
|
Args:
|
||||||
prompt: str
|
prompt: str
|
||||||
ABC trending on artstation painted by Old Greg.
|
ABC trending on artstation painted by Old Greg.
|
||||||
"""
|
"""
|
||||||
if self.negative_prompt is None:
|
|
||||||
uncond_tokens = [""]
|
|
||||||
else:
|
|
||||||
if isinstance(self.negative_prompt, str):
|
|
||||||
uncond_tokens = [self.negative_prompt]
|
|
||||||
|
|
||||||
batch_size = 1
|
return self.sdh.get_text_embedding(prompt)
|
||||||
num_images_per_prompt = 1
|
|
||||||
do_classifier_free_guidance = True
|
|
||||||
# get prompt text embeddings
|
|
||||||
text_inputs = self.pipe.tokenizer(
|
|
||||||
prompt,
|
|
||||||
padding="max_length",
|
|
||||||
max_length=self.pipe.tokenizer.model_max_length,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
text_input_ids = text_inputs.input_ids
|
|
||||||
|
|
||||||
# if text_input_ids.shape[-1] > self.pipe.tokenizer.modeLatentBlendingl_max_length:
|
|
||||||
# removed_text = self.pipe.tokenizer.batch_decode(text_input_ids[:, self.pipe.tokenizer.model_max_length :])
|
|
||||||
# text_input_ids = text_input_ids[:, : self.pipe.tokenizer.model_max_length]
|
|
||||||
text_embeddings = self.pipe.text_encoder(text_input_ids.to(self.pipe.device))[0]
|
|
||||||
|
|
||||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
|
||||||
bs_embed, seq_len, _ = text_embeddings.shape
|
|
||||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
|
||||||
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
|
||||||
|
|
||||||
# get unconditional embeddings for classifier free guidance
|
|
||||||
if do_classifier_free_guidance:
|
|
||||||
max_length = text_input_ids.shape[-1]
|
|
||||||
uncond_input = self.pipe.tokenizer(
|
|
||||||
uncond_tokens,
|
|
||||||
padding="max_length",
|
|
||||||
max_length=max_length,
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
uncond_embeddings = self.pipe.text_encoder(uncond_input.input_ids.to(self.pipe.device))[0]
|
|
||||||
|
|
||||||
seq_len = uncond_embeddings.shape[1]
|
|
||||||
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
|
|
||||||
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
|
||||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
|
||||||
return text_embeddings
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_mask_and_masked_image(self, image, mask):
|
|
||||||
r"""
|
|
||||||
Mask and image preparation for inpainting.
|
|
||||||
Adapted from diffusers (https://github.com/huggingface/diffusers)
|
|
||||||
Args:
|
|
||||||
image:
|
|
||||||
Source image
|
|
||||||
mask:
|
|
||||||
Mask image
|
|
||||||
"""
|
|
||||||
image = np.array(image.convert("RGB"))
|
|
||||||
image = image[None].transpose(0, 3, 1, 2)
|
|
||||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
|
||||||
|
|
||||||
mask = np.array(mask.convert("L"))
|
|
||||||
mask = mask.astype(np.float32) / 255.0
|
|
||||||
mask = mask[None, None]
|
|
||||||
mask[mask < 0.5] = 0
|
|
||||||
mask[mask >= 0.5] = 1
|
|
||||||
mask = torch.from_numpy(mask)
|
|
||||||
|
|
||||||
masked_image = image * (mask < 0.5)
|
|
||||||
|
|
||||||
return mask, masked_image
|
|
||||||
|
|
||||||
def randomize_seed(self):
|
def randomize_seed(self):
|
||||||
r"""
|
r"""
|
||||||
|
@ -858,7 +531,7 @@ def interpolate_spherical(p0, p1, fract_mixing: float):
|
||||||
r"""
|
r"""
|
||||||
Helper function to correctly mix two random variables using spherical interpolation.
|
Helper function to correctly mix two random variables using spherical interpolation.
|
||||||
See https://en.wikipedia.org/wiki/Slerp
|
See https://en.wikipedia.org/wiki/Slerp
|
||||||
The function will always cast up to float64 for sake of extra precision.
|
The function will always cast up to float64 for sake of extra 4.
|
||||||
Args:
|
Args:
|
||||||
p0:
|
p0:
|
||||||
First tensor for interpolation
|
First tensor for interpolation
|
||||||
|
@ -992,7 +665,7 @@ def add_frames_linear_interp(
|
||||||
|
|
||||||
nmb_frames_to_insert = nmb_frames_to_insert.astype(np.int32)
|
nmb_frames_to_insert = nmb_frames_to_insert.astype(np.int32)
|
||||||
list_imgs_interp = []
|
list_imgs_interp = []
|
||||||
for i in tqdm(range(len(list_imgs_float)-1), desc="STAGE linear interp"):
|
for i in range(len(list_imgs_float)-1):#, desc="STAGE linear interp"):
|
||||||
img0 = list_imgs_float[i]
|
img0 = list_imgs_float[i]
|
||||||
img1 = list_imgs_float[i+1]
|
img1 = list_imgs_float[i+1]
|
||||||
list_imgs_interp.append(img0.astype(np.uint8))
|
list_imgs_interp.append(img0.astype(np.uint8))
|
||||||
|
@ -1051,8 +724,7 @@ def get_branching(
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
#%%
|
||||||
|
|
||||||
if quality == 'lowest':
|
if quality == 'lowest':
|
||||||
num_inference_steps = 12
|
num_inference_steps = 12
|
||||||
nmb_branches_final = 5
|
nmb_branches_final = 5
|
||||||
|
@ -1095,12 +767,49 @@ def get_branching(
|
||||||
print(f"list_injection_idx: {list_injection_idx_clean}")
|
print(f"list_injection_idx: {list_injection_idx_clean}")
|
||||||
print(f"list_nmb_branches: {list_nmb_branches_clean}")
|
print(f"list_nmb_branches: {list_nmb_branches_clean}")
|
||||||
|
|
||||||
return num_inference_steps, list_injection_idx_clean, list_nmb_branches_clean
|
# return num_inference_steps, list_injection_idx_clean, list_nmb_branches_clean
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#%% le main
|
#%% le main
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
sys.path.append('../stablediffusion/ldm')
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
|
||||||
|
|
||||||
|
num_inference_steps = 20 # Number of diffusion interations
|
||||||
|
sdh = StableDiffusionHolder(num_inference_steps)
|
||||||
|
# fp_ckpt = "../stable_diffusion_models/ckpt/768-v-ema.ckpt"
|
||||||
|
# fp_config = '../stablediffusion/configs/stable-diffusion/v2-inference-v.yaml'
|
||||||
|
|
||||||
|
fp_ckpt= "../stable_diffusion_models/ckpt/512-base-ema.ckpt"
|
||||||
|
fp_config = '../stablediffusion/configs//stable-diffusion/v2-inference.yaml'
|
||||||
|
|
||||||
|
sdh.init_model(fp_ckpt, fp_config)
|
||||||
|
|
||||||
|
#%%
|
||||||
|
list_nmb_branches = [2, 3, 10, 24] # Branching structure: how many branches
|
||||||
|
list_injection_strength = [0.0, 0.6, 0.8, 0.9] # Branching structure: how deep is the blending
|
||||||
|
width = 512
|
||||||
|
height = 512
|
||||||
|
guidance_scale = 5
|
||||||
|
fixed_seeds = [993621550, 280335986]
|
||||||
|
device = "cuda:0"
|
||||||
|
lb = LatentBlending(sdh, device, height, width, num_inference_steps, guidance_scale)
|
||||||
|
prompt1 = "photo of a forest covered in white flowers, ambient light, very detailed, magic"
|
||||||
|
prompt2 = "photo of an eerie statue surrounded by ferns and vines, analog photograph kodak portra, mystical ambience, incredible detail"
|
||||||
|
lb.set_prompt1(prompt1)
|
||||||
|
lb.set_prompt2(prompt2)
|
||||||
|
|
||||||
|
lx = lb.run_transition(list_nmb_branches, list_injection_strength)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#%%
|
||||||
|
xxx
|
||||||
device = "cuda:0"
|
device = "cuda:0"
|
||||||
model_path = "../stable_diffusion_models/stable-diffusion-v1-5"
|
model_path = "../stable_diffusion_models/stable-diffusion-v1-5"
|
||||||
|
|
||||||
|
@ -1110,7 +819,7 @@ if __name__ == "__main__":
|
||||||
clip_sample=False,
|
clip_sample=False,
|
||||||
set_alpha_to_one=False)
|
set_alpha_to_one=False)
|
||||||
|
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
pipe = StableDiffusionPipeline.from_Union[StableDiffusionInpaintPipeline, StableDiffusionPipeline],pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
revision="fp16",
|
revision="fp16",
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
|
|
Loading…
Reference in New Issue