From ddd6fdee212968137dcfae52b23a2527c68dc50b Mon Sep 17 00:00:00 2001
From: Johannes Stelzer <jsdmail@gmail.com>
Date: Thu, 16 Nov 2023 15:37:02 +0100
Subject: [PATCH] cleanup

---
 diffusers_holder.py    | 77 ++++++++++--------------------------------
 example1_standard.py   | 57 +++++++++++++++++++++++++++++++
 example2_multitrans.py | 24 ++++++-------
 latent_blending.py     | 57 ++++++-------------------------
 4 files changed, 98 insertions(+), 117 deletions(-)

diff --git a/diffusers_holder.py b/diffusers_holder.py
index e6dee73..59a2824 100644
--- a/diffusers_holder.py
+++ b/diffusers_holder.py
@@ -13,20 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import os
 import torch
-torch.backends.cudnn.benchmark = False
-torch.set_grad_enabled(False)
 import numpy as np
 import warnings
-warnings.filterwarnings('ignore')
-import warnings
-import torch
-from PIL import Image
-import torch
+
 from typing import Optional
-from torch import autocast
-from contextlib import nullcontext
 from utils import interpolate_spherical
 from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel
 from diffusers.models.attention_processor import (
@@ -35,6 +26,9 @@ from diffusers.models.attention_processor import (
     LoRAXFormersAttnProcessor,
     XFormersAttnProcessor,
 )
+warnings.filterwarnings('ignore')
+torch.backends.cudnn.benchmark = False
+torch.set_grad_enabled(False)
 
 
 class DiffusersHolder():
@@ -71,13 +65,11 @@ class DiffusersHolder():
 
     def set_dimensions(self, size_output):
         s = self.pipe.vae_scale_factor
-
         if size_output is None:
             width = self.pipe.unet.config.sample_size
             height = self.pipe.unet.config.sample_size
         else:
             width, height = size_output
-        
         self.width_img = int(round(width / s) * s)
         self.width_latent = int(self.width_img / s)
         self.height_img = int(round(height / s) * s)
@@ -95,7 +87,6 @@ class DiffusersHolder():
         if len(self.negative_prompt) > 1:
             self.negative_prompt = [self.negative_prompt[0]]
 
-
     def get_text_embedding(self, prompt, do_classifier_free_guidance=True):
         if self.use_sd_xl:
             pr_encoder = self.pipe.encode_prompt
@@ -114,7 +105,7 @@ class DiffusersHolder():
         )
         return prompt_embeds
 
-    def get_noise(self, seed=420, mode=None):
+    def get_noise(self, seed=420):
         H = self.height_latent
         W = self.width_latent
         C = self.pipe.unet.config.in_channels
@@ -164,7 +155,6 @@ class DiffusersHolder():
             return np.asarray(image)
         else:
             return image
-            
 
     def prepare_mixing(self, mixing_coeffs, list_latents_mixing):
         if type(mixing_coeffs) == float:
@@ -265,10 +255,10 @@ class DiffusersHolder():
             list_latents_mixing=None,
             mixing_coeffs=0.0,
             return_image: Optional[bool] = False):
-        
+
         # 0. Default height and width to unet
-        original_size = (self.width_img, self.height_img)  # FIXME
-        crops_coords_top_left = (0, 0) # FIXME
+        original_size = (self.width_img, self.height_img)
+        crops_coords_top_left = (0, 0)
         target_size = original_size
         batch_size = 1
         eta = 0.0
@@ -276,10 +266,10 @@ class DiffusersHolder():
         cross_attention_kwargs = None
         generator = torch.Generator(device=self.device)  # dummy generator
         do_classifier_free_guidance = self.guidance_scale > 1.0
-        
+
         # 1. Check inputs. Raise error if not correct & 2. Define call parameters
         list_mixing_coeffs = self.prepare_mixing(mixing_coeffs, list_latents_mixing)
-        
+
         # 3. Encode input prompt (already encoded outside bc of mixing, just split here)
         prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = text_embeddings
 
@@ -294,28 +284,13 @@ class DiffusersHolder():
         # 6. Prepare extra step kwargs. usedummy generator
         extra_step_kwargs = self.pipe.prepare_extra_step_kwargs(generator, eta)  # dummy
 
-        # 7. Prepare added time ids & embeddings
-        # add_text_embeds = pooled_prompt_embeds
-        # add_time_ids = self.pipe._get_add_time_ids(
-        #     original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
-        # )
-
-        # if do_classifier_free_guidance:
-        #     prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
-        #     add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
-        #     add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
-
-        # prompt_embeds = prompt_embeds.to(self.device)
-        # add_text_embeds = add_text_embeds.to(self.device)
-        # add_time_ids = add_time_ids.to(self.device).repeat(batch_size * num_images_per_prompt, 1)
-        
         # 7. Prepare added time ids & embeddings
         add_text_embeds = pooled_prompt_embeds
         if self.pipe.text_encoder_2 is None:
             text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
         else:
             text_encoder_projection_dim = self.pipe.text_encoder_2.config.projection_dim
-        
+
         add_time_ids = self.pipe._get_add_time_ids(
             original_size,
             crops_coords_top_left,
@@ -323,26 +298,16 @@ class DiffusersHolder():
             dtype=prompt_embeds.dtype,
             text_encoder_projection_dim=text_encoder_projection_dim,
         )
-        # if negative_original_size is not None and negative_target_size is not None:
-        #     negative_add_time_ids = self.pipe._get_add_time_ids(
-        #         negative_original_size,
-        #         negative_crops_coords_top_left,
-        #         negative_target_size,
-        #         dtype=prompt_embeds.dtype,
-        #         text_encoder_projection_dim=text_encoder_projection_dim,
-        #     )
-        # else:
+
         negative_add_time_ids = add_time_ids
-        
+
         prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
         add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
         add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
-        
+
         prompt_embeds = prompt_embeds.to(self.device)
         add_text_embeds = add_text_embeds.to(self.device)
         add_time_ids = add_time_ids.to(self.device).repeat(batch_size * num_images_per_prompt, 1)
-        
-        
 
         # 8. Denoising loop
         for i, t in enumerate(timesteps):
@@ -358,7 +323,6 @@ class DiffusersHolder():
                 latents_mixtarget = list_latents_mixing[i - 1].clone()
                 latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
 
-
             # expand the latents if we are doing classifier free guidance
             latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
             # Always scale latents
@@ -380,14 +344,12 @@ class DiffusersHolder():
                 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                 noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
 
-            # FIXME guidance_rescale disabled
-
             # compute the previous noisy sample x_t -> x_t-1
             latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
 
             # Append latents
             list_latents_out.append(latents.clone())
-        
+
         if return_image:
             return self.latent2image(latents)
         else:
@@ -415,7 +377,7 @@ class DiffusersHolder():
         batch_size = 1
         eta = 0.0
         controlnet_conditioning_scale = 1.0
-        
+       
         # align format for control guidance
         if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
             control_guidance_start = len(control_guidance_end) * [control_guidance_start]
@@ -527,19 +489,16 @@ class DiffusersHolder():
 
             # Append latents
             list_latents_out.append(latents.clone())
-        
+
         if return_image:
             return self.latent2image(latents)
         else:
             return list_latents_out
-    
 
 
 #%%
-
 if __name__ == "__main__":
-    
-    
+    from PIL import Image
     #%% 
     pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
     pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16)
diff --git a/example1_standard.py b/example1_standard.py
index e69de29..125ce61 100644
--- a/example1_standard.py
+++ b/example1_standard.py
@@ -0,0 +1,57 @@
+# Copyright 2022 Lunar Ring. All rights reserved.
+# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import warnings
+from latent_blending import LatentBlending
+from diffusers_holder import DiffusersHolder
+from diffusers import DiffusionPipeline
+warnings.filterwarnings('ignore')
+torch.set_grad_enabled(False)
+torch.backends.cudnn.benchmark = False
+
+# %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
+pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
+pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16)
+pipe.to('cuda')
+dh = DiffusersHolder(pipe)
+# %% Next let's set up all parameters
+depth_strength = 0.55  # Specifies how deep (in terms of diffusion iterations the first branching happens)
+t_compute_max_allowed = 60  # Determines the quality of the transition in terms of compute time you grant it
+num_inference_steps = 30
+size_output = (1024, 1024)
+
+prompt1 = "underwater landscape, fish, und the sea, incredible detail, high resolution"
+prompt2 = "rendering of an alien planet, strange plants, strange creatures, surreal"
+negative_prompt = "blurry, ugly, pale"  # Optional
+
+fp_movie = 'movie_example1.mp4'
+duration_transition = 12  # In seconds
+
+# Spawn latent blending
+lb = LatentBlending(dh)
+lb.set_prompt1(prompt1)
+lb.set_prompt2(prompt2)
+lb.set_dimensions(size_output)
+lb.set_negative_prompt(negative_prompt)
+
+# Run latent blending
+lb.run_transition(
+    depth_strength=depth_strength,
+    num_inference_steps=num_inference_steps,
+    t_compute_max_allowed=t_compute_max_allowed)
+
+# Save movie
+lb.write_movie_transition(fp_movie, duration_transition)
diff --git a/example2_multitrans.py b/example2_multitrans.py
index ce985cc..7409137 100644
--- a/example2_multitrans.py
+++ b/example2_multitrans.py
@@ -14,16 +14,14 @@
 # limitations under the License.
 
 import torch
-torch.backends.cudnn.benchmark = False
-torch.set_grad_enabled(False)
-import warnings
-warnings.filterwarnings('ignore')
 import warnings
 from latent_blending import LatentBlending
 from diffusers_holder import DiffusersHolder
 from diffusers import DiffusionPipeline
 from movie_util import concatenate_movies
-from huggingface_hub import hf_hub_download
+torch.set_grad_enabled(False)
+torch.backends.cudnn.benchmark = False
+warnings.filterwarnings('ignore')
 
 # %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
 pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -35,21 +33,23 @@ dh = DiffusersHolder(pipe)
 fps = 30
 duration_single_trans = 20
 depth_strength = 0.25  # Specifies how deep (in terms of diffusion iterations the first branching happens)
+size_output = (1280, 768)
+num_inference_steps = 30
 
 # Specify a list of prompts below
 list_prompts = []
-list_prompts.append("A panoramic photo of a sentient mirror maze amidst a neon-lit forest, where bioluminescent mushrooms glow eerily, reflecting off the mirrors, and cybernetic crows, with silver wings and ruby eyes, perch ominously, David Lynch, Gaspar NoƩ, Photograph.")
-list_prompts.append("An unsettling tableau of spectral butterflies with clockwork wings, swirling around an antique typewriter perched precariously atop a floating, gnarled tree trunk, a stormy twilight sky, David Lynch's dreamscape, meticulously crafted.")
-# list_prompts.append("A haunting tableau of an antique dollhouse swallowed by a giant venus flytrap under the neon glow of an alien moon, its uncanny light reflecting from shattered porcelain faces and marbles, in a quiet, abandoned amusement park.")
+list_prompts.append("A beautiful astronomic photo of a nebula, with intricate microscopic structures, mitochondria")
+list_prompts.append("Microscope fluorescence photo, cell filaments, intricate galaxy, astronomic nebula")
+list_prompts.append("telescope photo starry sky, nebula, cell core, dna, stunning")
 
 
 # You can optionally specify the seeds
-list_seeds = [95437579, 33259350, 956051013, 408831845, 250009012, 675588737]
+list_seeds = [95437579, 33259350, 956051013]
 t_compute_max_allowed = 20  # per segment
 fp_movie = 'movie_example2.mp4'
 lb = LatentBlending(dh)
-lb.dh.set_dimensions(1024, 704)
-lb.dh.set_num_inference_steps(40)
+lb.set_dimensions(size_output)
+lb.dh.set_num_inference_steps(num_inference_steps)
 
 
 list_movie_parts = []
@@ -68,7 +68,7 @@ for i in range(len(list_prompts) - 1):
     fixed_seeds = list_seeds[i:i + 2]
     # Run latent blending
     lb.run_transition(
-        recycle_img1 = recycle_img1,
+        recycle_img1=recycle_img1,
         depth_strength=depth_strength,
         t_compute_max_allowed=t_compute_max_allowed,
         fixed_seeds=fixed_seeds)
diff --git a/latent_blending.py b/latent_blending.py
index 57b16a4..21846e1 100644
--- a/latent_blending.py
+++ b/latent_blending.py
@@ -15,19 +15,18 @@
 
 import os
 import torch
-torch.backends.cudnn.benchmark = False
-torch.set_grad_enabled(False)
 import numpy as np
 import warnings
-warnings.filterwarnings('ignore')
 import time
-import warnings
 from tqdm.auto import tqdm
 from PIL import Image
 from movie_util import MovieSaver
 from typing import List, Optional
 import lpips
 from utils import interpolate_spherical, interpolate_linear, add_frames_linear_interp, yml_load, yml_save
+warnings.filterwarnings('ignore')
+torch.backends.cudnn.benchmark = False
+torch.set_grad_enabled(False)
 
 
 class LatentBlending():
@@ -70,7 +69,6 @@ class LatentBlending():
         # Initialize vars
         self.prompt1 = ""
         self.prompt2 = ""
-        self.negative_prompt = ""
 
         self.tree_latents = [None, None]
         self.tree_fracts = None
@@ -91,17 +89,15 @@ class LatentBlending():
         self.list_nmb_branches = None
 
         # Mixing parameters
-        self.branch1_crossfeed_power = 0.05
-        self.branch1_crossfeed_range = 0.4
-        self.branch1_crossfeed_decay = 0.9
+        self.branch1_crossfeed_power = 0.3
+        self.branch1_crossfeed_range = 0.3
+        self.branch1_crossfeed_decay = 0.99
 
-        self.parental_crossfeed_power = 0.1
-        self.parental_crossfeed_range = 0.8
-        self.parental_crossfeed_power_decay = 0.8
+        self.parental_crossfeed_power = 0.3
+        self.parental_crossfeed_range = 0.6
+        self.parental_crossfeed_power_decay = 0.9
 
         self.set_guidance_scale(guidance_scale)
-        self.mode = 'standard'
-        # self.init_mode()
         self.multi_transition_img_first = None
         self.multi_transition_img_last = None
         self.dt_per_diff = 0
@@ -441,7 +437,7 @@ class LatentBlending():
             list_compute_steps = self.num_inference_steps - list_idx_injection
             list_compute_steps *= list_nmb_stems
             t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15 * np.sum(list_nmb_stems)
-            t_compute += 2*self.num_inference_steps*self.dt_per_diff # outer branches
+            t_compute += 2 * self.num_inference_steps * self.dt_per_diff  # outer branches
             increase_done = False
             for s_idx in range(len(list_nmb_stems) - 1):
                 if list_nmb_stems[s_idx + 1] / list_nmb_stems[s_idx] >= 2:
@@ -522,7 +518,7 @@ class LatentBlending():
         Args:
             seed: int
         """
-        return self.dh.get_noise(seed, self.mode)
+        return self.dh.get_noise(seed)
 
     @torch.no_grad()
     def run_diffusion(
@@ -576,18 +572,6 @@ class LatentBlending():
                 mixing_coeffs=mixing_coeffs,
                 return_image=return_image)
 
-        # elif self.mode == 'upscale':
-        #     cond = list_conditionings[0]
-        #     uc_full = list_conditionings[1]
-        #     return self.dh.run_diffusion_upscaling(
-        #         cond,
-        #         uc_full,
-        #         latents_start=latents_start,
-        #         idx_start=idx_start,
-        #         list_latents_mixing=list_latents_mixing,
-        #         mixing_coeffs=mixing_coeffs,
-        #         return_image=return_image)
-
     def run_upscaling(
             self,
             dp_img: str,
@@ -683,25 +667,6 @@ class LatentBlending():
             list_conditionings = [text_embeddings_mix]
         return list_conditionings
 
-    # @torch.no_grad()
-    # def get_mixed_conditioning(self, fract_mixing):
-    #     if self.mode == 'standard':
-    #         text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing)
-    #         list_conditionings = [text_embeddings_mix]
-    #     elif self.mode == 'inpaint':
-    #         text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing)
-    #         list_conditionings = [text_embeddings_mix]
-    #     elif self.mode == 'upscale':
-    #         text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing)
-    #         cond, uc_full = self.dh.get_cond_upscaling(self.image1_lowres, text_embeddings_mix, self.noise_level_upscaling)
-    #         condB, uc_fullB = self.dh.get_cond_upscaling(self.image2_lowres, text_embeddings_mix, self.noise_level_upscaling)
-    #         cond['c_concat'][0] = interpolate_spherical(cond['c_concat'][0], condB['c_concat'][0], fract_mixing)
-    #         uc_full['c_concat'][0] = interpolate_spherical(uc_full['c_concat'][0], uc_fullB['c_concat'][0], fract_mixing)
-    #         list_conditionings = [cond, uc_full]
-    #     else:
-    #         raise ValueError(f"mix_conditioning: unknown mode {self.mode}")
-    #     return list_conditionings
-
     @torch.no_grad()
     def get_text_embeddings(
             self,