diff --git a/cherry_picknick.py b/cherry_picknick.py index f59b049..0d6813a 100644 --- a/cherry_picknick.py +++ b/cherry_picknick.py @@ -21,83 +21,58 @@ warnings.filterwarnings('ignore') import warnings import torch from tqdm.auto import tqdm -from diffusers import StableDiffusionPipeline -from diffusers.schedulers import DDIMScheduler from PIL import Image import matplotlib.pyplot as plt import torch from movie_util import MovieSaver from typing import Callable, List, Optional, Union from latent_blending import LatentBlending, add_frames_linear_interp +from stable_diffusion_holder import StableDiffusionHolder torch.set_grad_enabled(False) -#%% First let us spawn a diffusers pipe using DDIMScheduler -device = "cuda:0" -model_path = "../stable_diffusion_models/stable-diffusion-v1-5" -scheduler = DDIMScheduler(beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - clip_sample=False, - set_alpha_to_one=False) - -pipe = StableDiffusionPipeline.from_pretrained( - model_path, - revision="fp16", - torch_dtype=torch.float16, - scheduler=scheduler, - use_auth_token=True -) -pipe = pipe.to(device) +#%% First let us spawn a stable diffusion holder +device = "cuda:0" +num_inference_steps = 20 # Number of diffusion interations +fp_ckpt = "../stable_diffusion_models/ckpt/768-v-ema.ckpt" +fp_config = '../stablediffusion/configs/stable-diffusion/v2-inference-v.yaml' + +sdh = StableDiffusionHolder(fp_ckpt, fp_config, device, num_inference_steps=num_inference_steps) #%% Next let's set up all parameters num_inference_steps = 30 # Number of diffusion interations list_nmb_branches = [2, 3, 10, 24]#, 50] # Branching structure: how many branches list_injection_strength = [0.0, 0.6, 0.8, 0.9]#, 0.95] # Branching structure: how deep is the blending -width = 512 -height = 512 guidance_scale = 5 fps = 30 duration_target = 10 width = 512 height = 512 -lb = LatentBlending(pipe, device, height, width, num_inference_steps, guidance_scale) - +lb = LatentBlending(sdh, num_inference_steps, guidance_scale) list_prompts = [] -list_prompts.append("surrealistic statue made of glitter and dirt, standing in a lake, atmospheric light, strange glow") -list_prompts.append("weird statue of a frog monkey, many colors, standing next to the ruins of an ancient city") -list_prompts.append("statue of a mix between a tree and human, made of marble, incredibly detailed") -list_prompts.append("statue made of hot metal, bizzarre, dark clouds in the sky") -list_prompts.append("statue of a spider that looked like a human") -list_prompts.append("statue of a bird that looked like a scorpion") -list_prompts.append("statue of an ancient cybernetic messenger annoucing good news, golden, futuristic") +list_prompts.append("photo of a beautiful forest covered in white flowers, ambient light, very detailed, magic") +list_prompts.append("photo of an golden statue with a funny hat, surrounded by ferns and vines, grainy analog photograph, mystical ambience, incredible detail") -k = 6 -prompt = list_prompts[k] -for i in range(4): - lb.set_prompt1(prompt) - - seed = np.random.randint(999999999) - lb.set_seed(seed) - plt.imshow(lb.run_diffusion(lb.text_embedding1, return_image=True)) - plt.title(f"{i} seed {seed}") - plt.show() - print(f"prompt {k} seed {seed} trial {i}") +for k, prompt in enumerate(list_prompts): + # k = 6 + # prompt = list_prompts[k] + for i in range(10): + lb.set_prompt1(prompt) + + seed = np.random.randint(999999999) + lb.set_seed(seed) + plt.imshow(lb.run_diffusion(lb.text_embedding1, return_image=True)) + plt.title(f"prompt {k}, seed {i} {seed}") + plt.show() + print(f"prompt {k} seed {seed} trial {i}") + #%% """ - -prompt 3 seed 28652396 trial 2 -prompt 4 seed 783279867 trial 3 -prompt 5 seed 831049796 trial 3 - -prompt 6 seed 798876383 trial 2 -prompt 6 seed 750494819 trial 2 -prompt 6 seed 416472011 trial 1 - +69731932, 504430820 """ \ No newline at end of file diff --git a/example1_standard.py b/example1_standard.py index 0652659..eb9bd64 100644 --- a/example1_standard.py +++ b/example1_standard.py @@ -32,32 +32,25 @@ torch.set_grad_enabled(False) #%% First let us spawn a stable diffusion holder device = "cuda:0" -num_inference_steps = 20 # Number of diffusion interations fp_ckpt = "../stable_diffusion_models/ckpt/768-v-ema.ckpt" fp_config = '../stablediffusion/configs/stable-diffusion/v2-inference-v.yaml' -sdh = StableDiffusionHolder(fp_ckpt, fp_config, device, num_inference_steps=num_inference_steps) +sdh = StableDiffusionHolder(fp_ckpt, fp_config, device) #%% Next let's set up all parameters -# FIXME below fix numbers -# We want 20 diffusion steps in total, begin with 2 branches, have 3 branches at step 12 (=0.6*20) -# 10 branches at step 16 (=0.8*20) and 24 branches at step 18 (=0.9*20) -# Furthermore we want seed 993621550 for keyframeA and seed 54878562 for keyframeB () -list_nmb_branches = [2, 3, 10, 24] # Branching structure: how many branches -list_injection_strength = [0.0, 0.6, 0.8, 0.9] # Branching structure: how deep is the blending -width = 768 -height = 768 guidance_scale = 5 -fixed_seeds = [993621550, 280335986] +quality = 'high' +fixed_seeds = [69731932, 504430820] -lb = LatentBlending(sdh, num_inference_steps, guidance_scale) +lb = LatentBlending(sdh, guidance_scale) prompt1 = "photo of a beautiful forest covered in white flowers, ambient light, very detailed, magic" -prompt2 = "photo of an golden statue with a funny hat, surrounded by ferns and vines, grainy analog photograph,, mystical ambience, incredible detail" +prompt2 = "photo of an golden statue with a funny hat, surrounded by ferns and vines, grainy analog photograph, mystical ambience, incredible detail" lb.set_prompt1(prompt1) lb.set_prompt2(prompt2) +lb.autosetup_branching(quality=quality) -imgs_transition = lb.run_transition(list_nmb_branches, list_injection_strength, fixed_seeds=fixed_seeds) +imgs_transition = lb.run_transition(fixed_seeds=fixed_seeds) # let's get more cheap frames via linear interpolation duration_transition = 12 @@ -65,10 +58,10 @@ fps = 60 imgs_transition_ext = add_frames_linear_interp(imgs_transition, duration_transition, fps) # movie saving -fp_movie = "movie_example1.mp4" +fp_movie = f"movie_example1_{quality}.mp4" if os.path.isfile(fp_movie): os.remove(fp_movie) -ms = MovieSaver(fp_movie, fps=fps) +ms = MovieSaver(fp_movie, fps=fps, shape_hw=[sdh.height, sdh.width]) for img in tqdm(imgs_transition_ext): ms.write_frame(img) ms.finalize() diff --git a/latent_blending.py b/latent_blending.py index 6661331..275c624 100644 --- a/latent_blending.py +++ b/latent_blending.py @@ -47,9 +47,7 @@ class LatentBlending(): def __init__( self, sdh: None, - num_inference_steps: int = 30, guidance_scale: float = 7.5, - seed: int = 420, ): r""" Initializes the latent blending class. @@ -59,8 +57,6 @@ class LatentBlending(): Height of the desired output image. The model was trained on 512. width: int Width of the desired output image. The model was trained on 512. - num_inference_steps: int - Number of diffusion steps. Larger values will take more compute time. guidance_scale: float Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen @@ -72,13 +68,11 @@ class LatentBlending(): """ self.sdh = sdh - self.num_inference_steps = num_inference_steps - self.sdh.num_inference_steps = num_inference_steps self.device = self.sdh.device self.guidance_scale = guidance_scale self.width = self.sdh.width self.height = self.sdh.height - self.seed = seed + self.seed = 420 #use self.set_seed or fixed_seeds argument in run_transition # Initialize vars self.prompt1 = "" @@ -93,6 +87,9 @@ class LatentBlending(): self.text_embedding2 = None self.stop_diffusion = False self.negative_prompt = None + self.num_inference_steps = -1 + self.list_injection_idx = None + self.list_nmb_branches = None self.init_mode() @@ -133,19 +130,92 @@ class LatentBlending(): self.prompt2 = prompt self.text_embedding2 = self.get_text_embeddings(self.prompt2) - - def run_transition( + def autosetup_branching( self, - list_nmb_branches: List[int], - list_injection_strength: List[float] = None, - list_injection_idx: List[int] = None, - recycle_img1: Optional[bool] = False, - recycle_img2: Optional[bool] = False, - fixed_seeds: Optional[List[int]] = None, + quality: str = 'medium', + deepth_strength: float = 0.65, + nmb_frames: int = 360, + nmb_mindist: int = 3, ): r""" - Returns a list of transition images using spherical latent blending. + Helper function to set up the branching structure automatically. + Args: + quality: str + Determines how many diffusion steps are being made + how many branches in total. + Tradeoff between quality and speed of computation. + Choose: lowest, low, medium, high, ultra + deepth_strength: float = 0.65, + Determines how deep the first injection will happen. + Deeper injections will cause (unwanted) formation of new structures, + more shallow values will go into alpha-blendy land. + nmb_frames: int = 360, + total number of frames + nmb_mindist: int = 3 + minimum distance in terms of diffusion iteratinos between subsequent injections + """ + + if quality == 'lowest': + num_inference_steps = 12 + nmb_branches_final = 5 + elif quality == 'low': + num_inference_steps = 15 + nmb_branches_final = nmb_frames//16 + elif quality == 'medium': + num_inference_steps = 30 + nmb_branches_final = nmb_frames//8 + elif quality == 'high': + num_inference_steps = 60 + nmb_branches_final = nmb_frames//4 + elif quality == 'ultra': + num_inference_steps = 100 + nmb_branches_final = nmb_frames//2 + else: + raise ValueError("quality = '{quality}' not supported") + + idx_injection_first = int(np.round(num_inference_steps*deepth_strength)) + idx_injection_last = num_inference_steps - 3 + nmb_injections = int(np.floor(num_inference_steps/5)) - 1 + + list_injection_idx = [0] + list_injection_idx.extend(np.linspace(idx_injection_first, idx_injection_last, nmb_injections).astype(int)) + list_nmb_branches = np.round(np.logspace(np.log10(2), np.log10(nmb_branches_final), nmb_injections+1)).astype(int) + + # Cleanup. There should be at least 3 diffusion steps between each injection + list_injection_idx_clean = [list_injection_idx[0]] + list_nmb_branches_clean = [list_nmb_branches[0]] + idx_last_check = 0 + for i in range(len(list_injection_idx)-1): + if list_injection_idx[i+1] - list_injection_idx_clean[idx_last_check] >= nmb_mindist: + list_injection_idx_clean.append(list_injection_idx[i+1]) + list_nmb_branches_clean.append(list_nmb_branches[i+1]) + idx_last_check +=1 + list_injection_idx_clean = [int(l) for l in list_injection_idx_clean] + list_nmb_branches_clean = [int(l) for l in list_nmb_branches_clean] + + list_injection_idx = list_injection_idx_clean + list_nmb_branches = list_nmb_branches_clean + + print(f"num_inference_steps: {num_inference_steps}") + print(f"list_injection_idx: {list_injection_idx}") + print(f"list_nmb_branches: {list_nmb_branches}") + + self.num_inference_steps = num_inference_steps + self.list_injection_idx = list_injection_idx + self.list_nmb_branches = list_nmb_branches + + + def setup_branching(self, + num_inference_steps: int =30, + list_nmb_branches: List[int] = None, + list_injection_strength: List[float] = None, + list_injection_idx: List[int] = None, + guidance_downscale: float = 1.0, + ): + r""" + Sets the branching structure for making transitions. + num_inference_steps: int + Number of diffusion steps. Larger values will take more compute time. list_nmb_branches: List[int]: list of the number of branches for each injection. list_injection_strength: List[float]: @@ -154,6 +224,51 @@ class LatentBlending(): list_injection_idx: List[int]: list of injection strengths within interval [0, 1), values need to be increasing. Alternatively you can specify the list_injection_strength. + guidance_downscale: float = 1.0 + reduces the guidance scale towards the middle of the transition + + + """ + # Assert + assert guidance_downscale>0 and guidance_downscale<=1.0, "guidance_downscale neees to be in interval (0,1]" + assert not((list_injection_strength is not None) and (list_injection_idx is not None)), "suppyl either list_injection_strength or list_injection_idx" + + if list_injection_strength is None: + assert list_injection_idx is not None, "Supply either list_injection_idx or list_injection_strength" + assert isinstance(list_injection_idx[0], int) or isinstance(list_injection_idx[0], np.int) , "Need to supply integers for list_injection_idx" + + if list_injection_idx is None: + assert list_injection_strength is not None, "Supply either list_injection_idx or list_injection_strength" + # Create the injection indexes + list_injection_idx = [int(round(x*num_inference_steps)) for x in list_injection_strength] + assert min(np.diff(list_injection_idx)) > 0, 'Injection idx needs to be increasing' + if min(np.diff(list_injection_idx)) < 2: + print("Warning: your injection spacing is very tight. consider increasing the distances") + assert isinstance(list_injection_strength[1], np.floating) or isinstance(list_injection_strength[1], float), "Need to supply floats for list_injection_strength" + # we are checking element 1 in list_injection_strength because "0" is an int... [0, 0.5] + + assert max(list_injection_idx) < num_inference_steps, "Decrease the injection index or strength" + assert len(list_injection_idx) == len(list_nmb_branches), "Need to have same length" + assert max(list_injection_idx) < num_inference_steps,"Injection index cannot happen after last diffusion step! Decrease list_injection_idx or list_injection_strength[-1]" + + + # Set attributes + self.num_inference_steps = num_inference_steps + self.sdh.num_inference_steps = num_inference_steps + self.list_nmb_branches = list_nmb_branches + self.list_injection_idx = list_injection_idx + + + + def run_transition( + self, + recycle_img1: Optional[bool] = False, + recycle_img2: Optional[bool] = False, + fixed_seeds: Optional[List[int]] = None, + ): + r""" + Returns a list of transition images using spherical latent blending. + Args: recycle_img1: Optional[bool]: Don't recompute the latents for the first keyframe (purely prompt1). Saves compute. recycle_img2: Optional[bool]: @@ -166,25 +281,7 @@ class LatentBlending(): # Sanity checks first assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before' assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before' - assert not((list_injection_strength is not None) and (list_injection_idx is not None)), "suppyl either list_injection_strength or list_injection_idx" - - if list_injection_strength is None: - assert list_injection_idx is not None, "Supply either list_injection_idx or list_injection_strength" - assert isinstance(list_injection_idx[0], int) or isinstance(list_injection_idx[0], np.int) , "Need to supply integers for list_injection_idx" - - if list_injection_idx is None: - assert list_injection_strength is not None, "Supply either list_injection_idx or list_injection_strength" - # Create the injection indexes - list_injection_idx = [int(round(x*self.num_inference_steps)) for x in list_injection_strength] - assert min(np.diff(list_injection_idx)) > 0, 'Injection idx needs to be increasing' - if min(np.diff(list_injection_idx)) < 2: - print("Warning: your injection spacing is very tight. consider increasing the distances") - assert isinstance(list_injection_strength[1], np.floating) or isinstance(list_injection_strength[1], float), "Need to supply floats for list_injection_strength" - # we are checking element 1 in list_injection_strength because "0" is an int... [0, 0.5] - - assert max(list_injection_idx) < self.num_inference_steps, "Decrease the injection index or strength" - assert len(list_injection_idx) == len(list_nmb_branches), "Need to have same length" - assert max(list_injection_idx) < self.num_inference_steps,"Injection index cannot happen after last diffusion step! Decrease list_injection_idx or list_injection_strength[-1]" + assert self.list_injection_idx is not None, 'Set the branching structure before, by calling autosetup_branching or setup_branching' if fixed_seeds is not None: if fixed_seeds == 'randomize': @@ -204,21 +301,22 @@ class LatentBlending(): print("Warning. You want to recycle but there is nothing here. Disabling recycling.") recycle_img1 = False recycle_img2 = False - elif self.list_nmb_branches_prev != list_nmb_branches: + elif self.list_nmb_branches_prev != self.list_nmb_branches: print("Warning. Cannot change list_nmb_branches if recycling latent. Disabling recycling.") recycle_img1 = False recycle_img2 = False - elif self.list_injection_idx_prev != list_injection_idx: + elif self.list_injection_idx_prev != self.list_injection_idx: print("Warning. Cannot change list_nmb_branches if recycling latent. Disabling recycling.") recycle_img1 = False recycle_img2 = False # Make a backup for future reference - self.list_nmb_branches_prev = list_nmb_branches - self.list_injection_idx_prev = list_injection_idx + self.list_nmb_branches_prev = self.list_nmb_branches[:] + self.list_injection_idx_prev = self.list_injection_idx[:] # Auto inits - list_injection_idx_ext = list_injection_idx[:] + list_injection_idx_ext = self.list_injection_idx[:] + list_nmb_branches = self.list_nmb_branches[:] list_injection_idx_ext.append(self.num_inference_steps) # If injection at depth 0 not specified, we will start out with 2 branches @@ -291,7 +389,7 @@ class LatentBlending(): # Diffusion computations start here time_start = time.time() - for t_block, idx_branch in tqdm(list_compute, desc="computing transition"): + for t_block, idx_branch in tqdm(list_compute, desc="computing transition", smoothing=-1): if self.stop_diffusion: print("run_transition: process interrupted") return self.tree_final_imgs @@ -484,6 +582,7 @@ class LatentBlending(): Set a the seed for a fresh start. """ self.seed = seed + self.sdh.seed = seed def swap_forward(self): @@ -703,76 +802,6 @@ def get_time(resolution=None): raise ValueError("bad resolution provided: %s" %resolution) return t -def get_branching( - quality: str = 'medium', - deepth_strength: float = 0.65, - nmb_frames: int = 360, - nmb_mindist: int = 3, - ): - r""" - Helper function to set up the branching structure automatically. - - Args: - quality: str - Determines how many diffusion steps are being made + how many branches in total. - Choose: fast, medium, high, ultra - deepth_strength: float = 0.65, - Determines how deep the first injection will happen. - Deeper injections will cause (unwanted) formation of new structures, - more shallow values will go into alpha-blendy land. - nmb_frames: int = 360, - total number of frames - nmb_mindist: int = 3 - minimum distance in terms of diffusion iteratinos between subsequent injections - - """ -#%% - if quality == 'lowest': - num_inference_steps = 12 - nmb_branches_final = 5 - elif quality == 'low': - num_inference_steps = 15 - nmb_branches_final = nmb_frames//16 - elif quality == 'medium': - num_inference_steps = 30 - nmb_branches_final = nmb_frames//8 - elif quality == 'high': - num_inference_steps = 60 - nmb_branches_final = nmb_frames//4 - elif quality == 'ultra': - num_inference_steps = 100 - nmb_branches_final = nmb_frames//2 - else: - raise ValueError("quality = '{quality}' not supported") - - idx_injection_first = int(np.round(num_inference_steps*deepth_strength)) - idx_injection_last = num_inference_steps - 3 - nmb_injections = int(np.floor(num_inference_steps/5)) - 1 - - list_injection_idx = [0] - list_injection_idx.extend(np.linspace(idx_injection_first, idx_injection_last, nmb_injections).astype(int)) - list_nmb_branches = np.round(np.logspace(np.log10(2), np.log10(nmb_branches_final), nmb_injections+1)).astype(int) - - # Cleanup. There should be at least 3 diffusion steps between each injection - list_injection_idx_clean = [list_injection_idx[0]] - list_nmb_branches_clean = [list_nmb_branches[0]] - idx_last_check = 0 - for i in range(len(list_injection_idx)-1): - if list_injection_idx[i+1] - list_injection_idx_clean[idx_last_check] >= nmb_mindist: - list_injection_idx_clean.append(list_injection_idx[i+1]) - list_nmb_branches_clean.append(list_nmb_branches[i+1]) - idx_last_check +=1 - list_injection_idx_clean = [int(l) for l in list_injection_idx_clean] - list_nmb_branches_clean = [int(l) for l in list_nmb_branches_clean] - - list_injection_idx = list_injection_idx_clean - list_nmb_branches = list_nmb_branches_clean - - print(f"num_inference_steps: {num_inference_steps}") - print(f"list_injection_idx: {list_injection_idx}") - print(f"list_nmb_branches: {list_nmb_branches}") - - return num_inference_steps, list_injection_idx, list_nmb_branches @@ -786,6 +815,7 @@ if __name__ == "__main__": TODO Coding: RUNNING WITHOUT PROMPT! save value ranges, can it be trashed? + in the middle: have more branches + lower guidance scale TODO Other: github