This commit is contained in:
Johannes Stelzer 2023-02-22 10:15:03 +01:00
parent 3ed876e0ee
commit 297bb9abe6
9 changed files with 906 additions and 1322 deletions

View File

@ -13,30 +13,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os, sys
import torch import torch
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
import numpy as np torch.set_grad_enabled(False)
import warnings import warnings
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
import warnings import warnings
import torch from latent_blending import LatentBlending
from tqdm.auto import tqdm
from PIL import Image
# import matplotlib.pyplot as plt
import torch
from movie_util import MovieSaver
from typing import Callable, List, Optional, Union
from latent_blending import LatentBlending, add_frames_linear_interp
from stable_diffusion_holder import StableDiffusionHolder from stable_diffusion_holder import StableDiffusionHolder
torch.set_grad_enabled(False) from huggingface_hub import hf_hub_download
#%% First let us spawn a stable diffusion holder # %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt" # fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.ckpt")
sdh = StableDiffusionHolder(fp_ckpt) sdh = StableDiffusionHolder(fp_ckpt)
# %% Next let's set up all parameters
#%% Next let's set up all parameters
depth_strength = 0.65 # Specifies how deep (in terms of diffusion iterations the first branching happens) depth_strength = 0.65 # Specifies how deep (in terms of diffusion iterations the first branching happens)
t_compute_max_allowed = 15 # Determines the quality of the transition in terms of compute time you grant it t_compute_max_allowed = 15 # Determines the quality of the transition in terms of compute time you grant it
fixed_seeds = [69731932, 504430820] fixed_seeds = [69731932, 504430820]
@ -54,10 +46,9 @@ lb.set_prompt2(prompt2)
# Run latent blending # Run latent blending
lb.run_transition( lb.run_transition(
depth_strength = depth_strength, depth_strength=depth_strength,
t_compute_max_allowed = t_compute_max_allowed, t_compute_max_allowed=t_compute_max_allowed,
fixed_seeds = fixed_seeds fixed_seeds=fixed_seeds)
)
# Save movie # Save movie
lb.write_movie_transition(fp_movie, duration_transition) lb.write_movie_transition(fp_movie, duration_transition)

View File

@ -13,33 +13,26 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os, sys
import torch import torch
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
import numpy as np torch.set_grad_enabled(False)
import warnings import warnings
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
import warnings import warnings
import torch from latent_blending import LatentBlending
from tqdm.auto import tqdm
from PIL import Image
import torch
from movie_util import MovieSaver, concatenate_movies
from typing import Callable, List, Optional, Union
from latent_blending import LatentBlending, add_frames_linear_interp
from stable_diffusion_holder import StableDiffusionHolder from stable_diffusion_holder import StableDiffusionHolder
torch.set_grad_enabled(False) from movie_util import concatenate_movies
from huggingface_hub import hf_hub_download
#%% First let us spawn a stable diffusion holder # %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt" # fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
# fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt" fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.ckpt")
sdh = StableDiffusionHolder(fp_ckpt) sdh = StableDiffusionHolder(fp_ckpt)
# %% Let's setup the multi transition
#%% Let's setup the multi transition
fps = 30 fps = 30
duration_single_trans = 6 duration_single_trans = 6
depth_strength = 0.55 #Specifies how deep (in terms of diffusion iterations the first branching happens) depth_strength = 0.55 # Specifies how deep (in terms of diffusion iterations the first branching happens)
# Specify a list of prompts below # Specify a list of prompts below
list_prompts = [] list_prompts = []
@ -56,28 +49,25 @@ t_compute_max_allowed = 12 # per segment
fp_movie = 'movie_example2.mp4' fp_movie = 'movie_example2.mp4'
lb = LatentBlending(sdh) lb = LatentBlending(sdh)
list_movie_parts = [] # list_movie_parts = []
for i in range(len(list_prompts)-1): for i in range(len(list_prompts) - 1):
# For a multi transition we can save some computation time and recycle the latents # For a multi transition we can save some computation time and recycle the latents
if i==0: if i == 0:
lb.set_prompt1(list_prompts[i]) lb.set_prompt1(list_prompts[i])
lb.set_prompt2(list_prompts[i+1]) lb.set_prompt2(list_prompts[i + 1])
recycle_img1 = False recycle_img1 = False
else: else:
lb.swap_forward() lb.swap_forward()
lb.set_prompt2(list_prompts[i+1]) lb.set_prompt2(list_prompts[i + 1])
recycle_img1 = True recycle_img1 = True
fp_movie_part = f"tmp_part_{str(i).zfill(3)}.mp4" fp_movie_part = f"tmp_part_{str(i).zfill(3)}.mp4"
fixed_seeds = list_seeds[i:i + 2]
fixed_seeds = list_seeds[i:i+2]
# Run latent blending # Run latent blending
lb.run_transition( lb.run_transition(
depth_strength = depth_strength, depth_strength=depth_strength,
t_compute_max_allowed = t_compute_max_allowed, t_compute_max_allowed=t_compute_max_allowed,
fixed_seeds = fixed_seeds fixed_seeds=fixed_seeds)
)
# Save movie # Save movie
lb.write_movie_transition(fp_movie_part, duration_single_trans) lb.write_movie_transition(fp_movie_part, duration_single_trans)

View File

@ -13,25 +13,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os, sys
import torch import torch
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
import numpy as np torch.set_grad_enabled(False)
import warnings import warnings
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
import warnings import warnings
import torch from latent_blending import LatentBlending
from tqdm.auto import tqdm
from PIL import Image
# import matplotlib.pyplot as plt
import torch
from movie_util import MovieSaver
from typing import Callable, List, Optional, Union
from latent_blending import LatentBlending, add_frames_linear_interp
from stable_diffusion_holder import StableDiffusionHolder from stable_diffusion_holder import StableDiffusionHolder
torch.set_grad_enabled(False) from huggingface_hub import hf_hub_download
#%% Define vars for low-resoltion pass # %% Define vars for low-resoltion pass
prompt1 = "photo of mount vesuvius erupting a terrifying pyroclastic ash cloud" prompt1 = "photo of mount vesuvius erupting a terrifying pyroclastic ash cloud"
prompt2 = "photo of a inside a building full of ash, fire, death, destruction, explosions" prompt2 = "photo of a inside a building full of ash, fire, death, destruction, explosions"
fixed_seeds = [5054613, 1168652] fixed_seeds = [5054613, 1168652]
@ -41,21 +33,18 @@ height = 384
num_inference_steps_lores = 40 num_inference_steps_lores = 40
nmb_max_branches_lores = 10 nmb_max_branches_lores = 10
depth_strength_lores = 0.5 depth_strength_lores = 0.5
fp_ckpt_lores = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
fp_ckpt_lores = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt" # %% Define vars for high-resoltion pass
fp_ckpt_hires = hf_hub_download(repo_id="stabilityai/stable-diffusion-x4-upscaler", filename="x4-upscaler-ema.ckpt")
#%% Define vars for high-resoltion pass
fp_ckpt_hires = "../stable_diffusion_models/ckpt/x4-upscaler-ema.ckpt"
depth_strength_hires = 0.65 depth_strength_hires = 0.65
num_inference_steps_hires = 100 num_inference_steps_hires = 100
nmb_branches_final_hires = 6 nmb_branches_final_hires = 6
dp_imgs = "tmp_transition" # folder for results and intermediate steps dp_imgs = "tmp_transition" # Folder for results and intermediate steps
#%% Run low-res pass # %% Run low-res pass
sdh = StableDiffusionHolder(fp_ckpt_lores) sdh = StableDiffusionHolder(fp_ckpt_lores)
#%%
lb = LatentBlending(sdh) lb = LatentBlending(sdh)
lb.set_prompt1(prompt1) lb.set_prompt1(prompt1)
lb.set_prompt2(prompt2) lb.set_prompt2(prompt2)
@ -64,14 +53,13 @@ lb.set_height(height)
# Run latent blending # Run latent blending
lb.run_transition( lb.run_transition(
depth_strength = depth_strength_lores, depth_strength=depth_strength_lores,
nmb_max_branches = nmb_max_branches_lores, nmb_max_branches=nmb_max_branches_lores,
fixed_seeds = fixed_seeds fixed_seeds=fixed_seeds)
)
lb.write_imgs_transition(dp_imgs) lb.write_imgs_transition(dp_imgs)
#%% Run high-res pass # %% Run high-res pass
sdh = StableDiffusionHolder(fp_ckpt_hires) sdh = StableDiffusionHolder(fp_ckpt_hires)
lb = LatentBlending(sdh) lb = LatentBlending(sdh)
lb.run_upscaling(dp_imgs, depth_strength_hires, num_inference_steps_hires, nmb_branches_final_hires) lb.run_upscaling(dp_imgs, depth_strength_hires, num_inference_steps_hires, nmb_branches_final_hires)

View File

@ -13,25 +13,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os, sys import os
import torch import torch
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
import numpy as np torch.set_grad_enabled(False)
import warnings import warnings
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
import warnings import warnings
import torch from latent_blending import LatentBlending
from tqdm.auto import tqdm
from PIL import Image
# import matplotlib.pyplot as plt
import torch
from movie_util import MovieSaver, concatenate_movies
from typing import Callable, List, Optional, Union
from latent_blending import LatentBlending, add_frames_linear_interp
from stable_diffusion_holder import StableDiffusionHolder from stable_diffusion_holder import StableDiffusionHolder
torch.set_grad_enabled(False) from movie_util import concatenate_movies
from huggingface_hub import hf_hub_download
#%% Define vars for low-resoltion pass # %% Define vars for low-resoltion pass
list_prompts = [] list_prompts = []
list_prompts.append("surrealistic statue made of glitter and dirt, standing in a lake, atmospheric light, strange glow") list_prompts.append("surrealistic statue made of glitter and dirt, standing in a lake, atmospheric light, strange glow")
list_prompts.append("statue of a mix between a tree and human, made of marble, incredibly detailed") list_prompts.append("statue of a mix between a tree and human, made of marble, incredibly detailed")
@ -50,56 +44,54 @@ num_inference_steps_lores = 40
nmb_max_branches_lores = 10 nmb_max_branches_lores = 10
depth_strength_lores = 0.5 depth_strength_lores = 0.5
fp_ckpt_lores = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt" fp_ckpt_lores = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
#%% Define vars for high-resoltion pass # %% Define vars for high-resoltion pass
fp_ckpt_hires = "../stable_diffusion_models/ckpt/x4-upscaler-ema.ckpt" fp_ckpt_hires = hf_hub_download(repo_id="stabilityai/stable-diffusion-x4-upscaler", filename="x4-upscaler-ema.ckpt")
depth_strength_hires = 0.65 depth_strength_hires = 0.65
num_inference_steps_hires = 100 num_inference_steps_hires = 100
nmb_branches_final_hires = 6 nmb_branches_final_hires = 6
#%% Run low-res pass
# %% Run low-res pass
sdh = StableDiffusionHolder(fp_ckpt_lores) sdh = StableDiffusionHolder(fp_ckpt_lores)
t_compute_max_allowed = 12 # per segment t_compute_max_allowed = 12 # Per segment
lb = LatentBlending(sdh) lb = LatentBlending(sdh)
list_movie_dirs = [] # list_movie_dirs = []
for i in range(len(list_prompts)-1): for i in range(len(list_prompts) - 1):
# For a multi transition we can save some computation time and recycle the latents # For a multi transition we can save some computation time and recycle the latents
if i==0: if i == 0:
lb.set_prompt1(list_prompts[i]) lb.set_prompt1(list_prompts[i])
lb.set_prompt2(list_prompts[i+1]) lb.set_prompt2(list_prompts[i + 1])
recycle_img1 = False recycle_img1 = False
else: else:
lb.swap_forward() lb.swap_forward()
lb.set_prompt2(list_prompts[i+1]) lb.set_prompt2(list_prompts[i + 1])
recycle_img1 = True recycle_img1 = True
dp_movie_part = f"tmp_part_{str(i).zfill(3)}" dp_movie_part = f"tmp_part_{str(i).zfill(3)}"
fp_movie_part = os.path.join(dp_movie_part, "movie_lowres.mp4") fp_movie_part = os.path.join(dp_movie_part, "movie_lowres.mp4")
os.makedirs(dp_movie_part, exist_ok=True) os.makedirs(dp_movie_part, exist_ok=True)
fixed_seeds = list_seeds[i:i+2] fixed_seeds = list_seeds[i:i + 2]
# Run latent blending # Run latent blending
lb.run_transition( lb.run_transition(
depth_strength = depth_strength_lores, depth_strength=depth_strength_lores,
nmb_max_branches = nmb_max_branches_lores, nmb_max_branches=nmb_max_branches_lores,
fixed_seeds = fixed_seeds fixed_seeds=fixed_seeds)
)
# Save movie and images (needed for upscaling!) # Save movie and images (needed for upscaling!)
lb.write_movie_transition(fp_movie_part, duration_single_trans) lb.write_movie_transition(fp_movie_part, duration_single_trans)
lb.write_imgs_transition(dp_movie_part) lb.write_imgs_transition(dp_movie_part)
list_movie_dirs.append(dp_movie_part) list_movie_dirs.append(dp_movie_part)
# %% Run high-res pass on each segment
#%% Run high-res pass on each segment
sdh = StableDiffusionHolder(fp_ckpt_hires) sdh = StableDiffusionHolder(fp_ckpt_hires)
lb = LatentBlending(sdh) lb = LatentBlending(sdh)
for dp_part in list_movie_dirs: for dp_part in list_movie_dirs:
lb.run_upscaling(dp_part, depth_strength_hires, num_inference_steps_hires, nmb_branches_final_hires) lb.run_upscaling(dp_part, depth_strength_hires, num_inference_steps_hires, nmb_branches_final_hires)
#%% concatenate into one long movie # %% concatenate into one long movie
list_fp_movies = [] list_fp_movies = []
for dp_part in list_movie_dirs: for dp_part in list_movie_dirs:
fp_movie = os.path.join(dp_part, "movie_highres.mp4") fp_movie = os.path.join(dp_part, "movie_highres.mp4")

View File

@ -13,82 +13,89 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os, sys import os
import torch import torch
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.set_grad_enabled(False)
import numpy as np import numpy as np
import warnings import warnings
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
import warnings import warnings
import torch
from tqdm.auto import tqdm from tqdm.auto import tqdm
from PIL import Image from PIL import Image
import torch
from movie_util import MovieSaver, concatenate_movies from movie_util import MovieSaver, concatenate_movies
from typing import Callable, List, Optional, Union from latent_blending import LatentBlending
from latent_blending import get_time, yml_save, LatentBlending, add_frames_linear_interp, compare_dicts
from stable_diffusion_holder import StableDiffusionHolder from stable_diffusion_holder import StableDiffusionHolder
torch.set_grad_enabled(False)
import gradio as gr import gradio as gr
import copy
from dotenv import find_dotenv, load_dotenv from dotenv import find_dotenv, load_dotenv
import shutil import shutil
import random import random
import time from utils import get_time, add_frames_linear_interp
from huggingface_hub import hf_hub_download
#%%
class BlendingFrontend(): class BlendingFrontend():
def __init__(self, sdh=None): def __init__(
self.num_inference_steps = 30 self,
if sdh is None: sdh,
self.use_debug = True share=False):
self.height = 768 r"""
self.width = 768 Gradio Helper Class to collect UI data and start latent blending.
else: Args:
self.use_debug = False sdh:
self.lb = LatentBlending(sdh) StableDiffusionHolder
self.lb.sdh.num_inference_steps = self.num_inference_steps share: bool
self.height = self.lb.sdh.height Set true to get a shareable gradio link (e.g. for running a remote server)
self.width = self.lb.sdh.width """
self.share = share
self.init_save_dir() # UI Defaults
self.save_empty_image() self.num_inference_steps = 30
self.share = False
self.transition_can_be_computed = False
self.depth_strength = 0.25 self.depth_strength = 0.25
self.seed1 = 420 self.seed1 = 420
self.seed2 = 420 self.seed2 = 420
self.guidance_scale = 4.0
self.guidance_scale_mid_damper = 0.5
self.mid_compression_scaler = 1.2
self.prompt1 = "" self.prompt1 = ""
self.prompt2 = "" self.prompt2 = ""
self.negative_prompt = "" self.negative_prompt = ""
self.state_current = {} self.fps = 30
self.duration_video = 8
self.t_compute_max_allowed = 10
self.lb = LatentBlending(sdh)
self.lb.sdh.num_inference_steps = self.num_inference_steps
self.init_parameters_from_lb()
self.init_save_dir()
# Vars
self.list_fp_imgs_current = []
self.recycle_img1 = False
self.recycle_img2 = False
self.list_all_segments = []
self.dp_session = ""
self.user_id = None
def init_parameters_from_lb(self):
r"""
Automatically init parameters from latentblending instance
"""
self.height = self.lb.sdh.height
self.width = self.lb.sdh.width
self.guidance_scale = self.lb.guidance_scale
self.guidance_scale_mid_damper = self.lb.guidance_scale_mid_damper
self.mid_compression_scaler = self.lb.mid_compression_scaler
self.branch1_crossfeed_power = self.lb.branch1_crossfeed_power self.branch1_crossfeed_power = self.lb.branch1_crossfeed_power
self.branch1_crossfeed_range = self.lb.branch1_crossfeed_range self.branch1_crossfeed_range = self.lb.branch1_crossfeed_range
self.branch1_crossfeed_decay = self.lb.branch1_crossfeed_decay self.branch1_crossfeed_decay = self.lb.branch1_crossfeed_decay
self.parental_crossfeed_power = self.lb.parental_crossfeed_power self.parental_crossfeed_power = self.lb.parental_crossfeed_power
self.parental_crossfeed_range = self.lb.parental_crossfeed_range self.parental_crossfeed_range = self.lb.parental_crossfeed_range
self.parental_crossfeed_power_decay = self.lb.parental_crossfeed_power_decay self.parental_crossfeed_power_decay = self.lb.parental_crossfeed_power_decay
self.fps = 30
self.duration_video = 10
self.t_compute_max_allowed = 10
self.list_fp_imgs_current = []
self.current_timestamp = None
self.recycle_img1 = False
self.recycle_img2 = False
self.multi_idx_current = -1
self.list_imgs_shown_last = 5*[self.fp_img_empty]
self.list_all_segments = []
self.dp_session = ""
self.user_id = None
self.block_transition = False
def init_save_dir(self): def init_save_dir(self):
r"""
Initializes the directory where stuff is being saved.
You can specify this directory in a ".env" file in your latentblending root, setting
DIR_OUT='/path/to/saving'
"""
load_dotenv(find_dotenv(), verbose=False) load_dotenv(find_dotenv(), verbose=False)
self.dp_out = os.getenv("DIR_OUT") self.dp_out = os.getenv("DIR_OUT")
if self.dp_out is None: if self.dp_out is None:
@ -97,124 +104,125 @@ class BlendingFrontend():
os.makedirs(self.dp_imgs, exist_ok=True) os.makedirs(self.dp_imgs, exist_ok=True)
self.dp_movies = os.path.join(self.dp_out, "movies") self.dp_movies = os.path.join(self.dp_out, "movies")
os.makedirs(self.dp_movies, exist_ok=True) os.makedirs(self.dp_movies, exist_ok=True)
self.save_empty_image()
# make dummy image
def save_empty_image(self): def save_empty_image(self):
r"""
Saves an empty/black dummy image.
"""
self.fp_img_empty = os.path.join(self.dp_imgs, 'empty.jpg') self.fp_img_empty = os.path.join(self.dp_imgs, 'empty.jpg')
Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5) Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5)
def randomize_seed1(self): def randomize_seed1(self):
# Dont randomize seed if we are in a multi concat mode. we don't want to change this one otherwise the movie breaks r"""
Randomizes the first seed
"""
seed = np.random.randint(0, 10000000) seed = np.random.randint(0, 10000000)
self.seed1 = int(seed) self.seed1 = int(seed)
print(f"randomize_seed1: new seed = {self.seed1}") print(f"randomize_seed1: new seed = {self.seed1}")
return seed return seed
def randomize_seed2(self): def randomize_seed2(self):
r"""
Randomizes the second seed
"""
seed = np.random.randint(0, 10000000) seed = np.random.randint(0, 10000000)
self.seed2 = int(seed) self.seed2 = int(seed)
print(f"randomize_seed2: new seed = {self.seed2}") print(f"randomize_seed2: new seed = {self.seed2}")
return seed return seed
def setup_lb(self, list_ui_vals):
def setup_lb(self, list_ui_elem): r"""
Sets all parameters from the UI. Since gradio does not support to pass dictionaries,
we have to instead pass keys (list_ui_keys, global) and values (list_ui_vals)
"""
# Collect latent blending variables # Collect latent blending variables
self.state_current = self.get_state_dict() self.lb.set_width(list_ui_vals[list_ui_keys.index('width')])
self.lb.set_width(list_ui_elem[list_ui_keys.index('width')]) self.lb.set_height(list_ui_vals[list_ui_keys.index('height')])
self.lb.set_height(list_ui_elem[list_ui_keys.index('height')]) self.lb.set_prompt1(list_ui_vals[list_ui_keys.index('prompt1')])
self.lb.set_prompt1(list_ui_elem[list_ui_keys.index('prompt1')]) self.lb.set_prompt2(list_ui_vals[list_ui_keys.index('prompt2')])
self.lb.set_prompt2(list_ui_elem[list_ui_keys.index('prompt2')]) self.lb.set_negative_prompt(list_ui_vals[list_ui_keys.index('negative_prompt')])
self.lb.set_negative_prompt(list_ui_elem[list_ui_keys.index('negative_prompt')]) self.lb.guidance_scale = list_ui_vals[list_ui_keys.index('guidance_scale')]
self.lb.guidance_scale = list_ui_elem[list_ui_keys.index('guidance_scale')] self.lb.guidance_scale_mid_damper = list_ui_vals[list_ui_keys.index('guidance_scale_mid_damper')]
self.lb.guidance_scale_mid_damper = list_ui_elem[list_ui_keys.index('guidance_scale_mid_damper')] self.t_compute_max_allowed = list_ui_vals[list_ui_keys.index('duration_compute')]
self.t_compute_max_allowed = list_ui_elem[list_ui_keys.index('duration_compute')] self.lb.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
self.lb.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')] self.lb.sdh.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
self.lb.sdh.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')] self.duration_video = list_ui_vals[list_ui_keys.index('duration_video')]
self.duration_video = list_ui_elem[list_ui_keys.index('duration_video')] self.lb.seed1 = list_ui_vals[list_ui_keys.index('seed1')]
self.lb.seed1 = list_ui_elem[list_ui_keys.index('seed1')] #seed self.lb.seed2 = list_ui_vals[list_ui_keys.index('seed2')]
self.lb.seed2 = list_ui_elem[list_ui_keys.index('seed2')] self.lb.branch1_crossfeed_power = list_ui_vals[list_ui_keys.index('branch1_crossfeed_power')]
self.lb.branch1_crossfeed_range = list_ui_vals[list_ui_keys.index('branch1_crossfeed_range')]
self.lb.branch1_crossfeed_decay = list_ui_vals[list_ui_keys.index('branch1_crossfeed_decay')]
self.lb.parental_crossfeed_power = list_ui_vals[list_ui_keys.index('parental_crossfeed_power')]
self.lb.parental_crossfeed_range = list_ui_vals[list_ui_keys.index('parental_crossfeed_range')]
self.lb.parental_crossfeed_power_decay = list_ui_vals[list_ui_keys.index('parental_crossfeed_power_decay')]
self.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
self.depth_strength = list_ui_vals[list_ui_keys.index('depth_strength')]
self.lb.branch1_crossfeed_power = list_ui_elem[list_ui_keys.index('branch1_crossfeed_power')] if len(list_ui_vals[list_ui_keys.index('user_id')]) > 1:
self.lb.branch1_crossfeed_range = list_ui_elem[list_ui_keys.index('branch1_crossfeed_range')] self.user_id = list_ui_vals[list_ui_keys.index('user_id')]
self.lb.branch1_crossfeed_decay = list_ui_elem[list_ui_keys.index('branch1_crossfeed_decay')]
self.lb.parental_crossfeed_power = list_ui_elem[list_ui_keys.index('parental_crossfeed_power')]
self.lb.parental_crossfeed_range = list_ui_elem[list_ui_keys.index('parental_crossfeed_range')]
self.lb.parental_crossfeed_power_decay = list_ui_elem[list_ui_keys.index('parental_crossfeed_power_decay')]
self.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
self.depth_strength = list_ui_elem[list_ui_keys.index('depth_strength')]
if len(list_ui_elem[list_ui_keys.index('user_id')]) > 1:
self.user_id = list_ui_elem[list_ui_keys.index('user_id')]
else: else:
# generate new user id # generate new user id
self.user_id = ''.join((random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ') for i in range(8))) self.user_id = ''.join((random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ') for i in range(8)))
print(f"made new user_id: {self.user_id}") print(f"made new user_id: {self.user_id} at {get_time('second')}")
def save_latents(self, fp_latents, list_latents): def save_latents(self, fp_latents, list_latents):
r"""
Saves a latent trajectory on disk, in npy format.
"""
list_latents_cpu = [l.cpu().numpy() for l in list_latents] list_latents_cpu = [l.cpu().numpy() for l in list_latents]
np.save(fp_latents, list_latents_cpu) np.save(fp_latents, list_latents_cpu)
def load_latents(self, fp_latents): def load_latents(self, fp_latents):
r"""
Loads a latent trajectory from disk, converts to torch tensor.
"""
list_latents_cpu = np.load(fp_latents) list_latents_cpu = np.load(fp_latents)
list_latents = [torch.from_numpy(l).to(self.lb.device) for l in list_latents_cpu] list_latents = [torch.from_numpy(l).to(self.lb.device) for l in list_latents_cpu]
return list_latents return list_latents
def compute_img1(self, *args): def compute_img1(self, *args):
list_ui_elem = args r"""
self.setup_lb(list_ui_elem) Computes the first transition image and returns it for display.
Sets all other transition images and last image to empty (as they are obsolete with this operation)
"""
list_ui_vals = args
self.setup_lb(list_ui_vals)
fp_img1 = os.path.join(self.dp_imgs, f"img1_{self.user_id}") fp_img1 = os.path.join(self.dp_imgs, f"img1_{self.user_id}")
img1 = Image.fromarray(self.lb.compute_latents1(return_image=True)) img1 = Image.fromarray(self.lb.compute_latents1(return_image=True))
img1.save(fp_img1+".jpg") img1.save(fp_img1 + ".jpg")
self.save_latents(fp_img1+".npy", self.lb.tree_latents[0]) self.save_latents(fp_img1 + ".npy", self.lb.tree_latents[0])
self.recycle_img1 = True self.recycle_img1 = True
self.recycle_img2 = False self.recycle_img2 = False
# fixme save seeds. change filenames? return [fp_img1 + ".jpg", self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
return [fp_img1+".jpg", self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
def compute_img2(self, *args): def compute_img2(self, *args):
r"""
Computes the last transition image and returns it for display.
Sets all other transition images to empty (as they are obsolete with this operation)
"""
if not os.path.isfile(os.path.join(self.dp_imgs, f"img1_{self.user_id}.jpg")): # don't do anything if not os.path.isfile(os.path.join(self.dp_imgs, f"img1_{self.user_id}.jpg")): # don't do anything
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id] return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
list_ui_elem = args list_ui_vals = args
self.setup_lb(list_ui_elem) self.setup_lb(list_ui_vals)
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy")) self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
fp_img2 = os.path.join(self.dp_imgs, f"img2_{self.user_id}") fp_img2 = os.path.join(self.dp_imgs, f"img2_{self.user_id}")
img2 = Image.fromarray(self.lb.compute_latents2(return_image=True)) img2 = Image.fromarray(self.lb.compute_latents2(return_image=True))
img2.save(fp_img2+'.jpg') img2.save(fp_img2 + '.jpg')
self.save_latents(fp_img2+".npy", self.lb.tree_latents[-1]) self.save_latents(fp_img2 + ".npy", self.lb.tree_latents[-1])
self.recycle_img2 = True self.recycle_img2 = True
self.transition_can_be_computed = True
# fixme save seeds. change filenames? # fixme save seeds. change filenames?
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2+".jpg", self.user_id] return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2 + ".jpg", self.user_id]
def compute_transition(self, *args): def compute_transition(self, *args):
if not self.transition_can_be_computed: r"""
list_return = [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id] Computes transition images and movie.
return list_return """
list_ui_vals = args
list_ui_elem = args self.setup_lb(list_ui_vals)
self.setup_lb(list_ui_elem)
print("STARTING TRANSITION...") print("STARTING TRANSITION...")
fixed_seeds = [self.seed1, self.seed2] fixed_seeds = [self.seed1, self.seed2]
# Run Latent Blending
# Check if another user is blocking this... otherwise everything will become mixed.
# t_now = time.time()
# if self.block_transition:
# while True:
# time.sleep(1)
# if not self.block_transition:
# break
# if time.time() - t_now > 1000:
# return
self.block_transition = True
# Inject loaded latents (other user interference) # Inject loaded latents (other user interference)
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy")) self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
self.lb.tree_latents[-1] = self.load_latents(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy")) self.lb.tree_latents[-1] = self.load_latents(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"))
@ -224,24 +232,23 @@ class BlendingFrontend():
num_inference_steps=self.num_inference_steps, num_inference_steps=self.num_inference_steps,
depth_strength=self.depth_strength, depth_strength=self.depth_strength,
t_compute_max_allowed=self.t_compute_max_allowed, t_compute_max_allowed=self.t_compute_max_allowed,
fixed_seeds=fixed_seeds fixed_seeds=fixed_seeds)
) print(f"Latent Blending pass finished ({get_time('second')}). Resulted in {len(imgs_transition)} images")
print(f"Latent Blending pass finished. Resulted in {len(imgs_transition)} images")
# Subselect three preview images # Subselect three preview images
idx_img_prev = np.round(np.linspace(0, len(imgs_transition)-1, 5)[1:-1]).astype(np.int32) idx_img_prev = np.round(np.linspace(0, len(imgs_transition) - 1, 5)[1:-1]).astype(np.int32)
list_imgs_preview = [] list_imgs_preview = []
for j in idx_img_prev: for j in idx_img_prev:
list_imgs_preview.append(Image.fromarray(imgs_transition[j])) list_imgs_preview.append(Image.fromarray(imgs_transition[j]))
# Save the preview imgs as jpgs on disk so we are not sending umcompressed data around # Save the preview imgs as jpgs on disk so we are not sending umcompressed data around
self.current_timestamp = get_time('second') current_timestamp = get_time('second')
self.list_fp_imgs_current = [] self.list_fp_imgs_current = []
for i in range(len(list_imgs_preview)): for i in range(len(list_imgs_preview)):
fp_img = os.path.join(self.dp_imgs, f"img_preview_{i}_{self.current_timestamp}.jpg") fp_img = os.path.join(self.dp_imgs, f"img_preview_{i}_{current_timestamp}.jpg")
list_imgs_preview[i].save(fp_img) list_imgs_preview[i].save(fp_img)
self.list_fp_imgs_current.append(fp_img) self.list_fp_imgs_current.append(fp_img)
self.block_transition = False
# Insert cheap frames for the movie # Insert cheap frames for the movie
imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration_video, self.fps) imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration_video, self.fps)
@ -259,16 +266,17 @@ class BlendingFrontend():
list_return = self.list_fp_imgs_current + [self.fp_movie] list_return = self.list_fp_imgs_current + [self.fp_movie]
return list_return return list_return
def stack_forward(self, prompt2, seed2): def stack_forward(self, prompt2, seed2):
r"""
Allows to generate multi-segment movies. Sets last image -> first image with all
relevant parameters.
"""
# Save preview images, prompts and seeds into dictionary for stacking # Save preview images, prompts and seeds into dictionary for stacking
if len(self.list_all_segments) == 0: if len(self.list_all_segments) == 0:
timestamp_session = get_time('second') timestamp_session = get_time('second')
self.dp_session = os.path.join(self.dp_out, f"session_{timestamp_session}") self.dp_session = os.path.join(self.dp_out, f"session_{timestamp_session}")
os.makedirs(self.dp_session) os.makedirs(self.dp_session)
self.transition_can_be_computed = False
idx_segment = len(self.list_all_segments) idx_segment = len(self.list_all_segments)
dp_segment = os.path.join(self.dp_session, f"segment_{str(idx_segment).zfill(3)}") dp_segment = os.path.join(self.dp_session, f"segment_{str(idx_segment).zfill(3)}")
@ -285,13 +293,11 @@ class BlendingFrontend():
self.lb.swap_forward() self.lb.swap_forward()
shutil.copyfile(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"), os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy")) shutil.copyfile(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"), os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
fp_multi = self.multi_concat() fp_multi = self.multi_concat()
list_out = [fp_multi] list_out = [fp_multi]
list_out.extend([os.path.join(self.dp_imgs, f"img2_{self.user_id}.jpg")]) list_out.extend([os.path.join(self.dp_imgs, f"img2_{self.user_id}.jpg")])
list_out.extend([self.fp_img_empty]*4) list_out.extend([self.fp_img_empty] * 4)
list_out.append(gr.update(interactive=False, value=prompt2)) list_out.append(gr.update(interactive=False, value=prompt2))
list_out.append(gr.update(interactive=False, value=seed2)) list_out.append(gr.update(interactive=False, value=seed2))
list_out.append("") list_out.append("")
@ -299,16 +305,20 @@ class BlendingFrontend():
print(f"stack_forward: fp_multi {fp_multi}") print(f"stack_forward: fp_multi {fp_multi}")
return list_out return list_out
def multi_concat(self): def multi_concat(self):
r"""
Concatentates all stacked segments into one long movie.
"""
list_fp_movies = self.get_fp_video_all() list_fp_movies = self.get_fp_video_all()
# Concatenate movies and save # Concatenate movies and save
fp_final = os.path.join(self.dp_session, f"concat_{self.user_id}.mp4") fp_final = os.path.join(self.dp_session, f"concat_{self.user_id}.mp4")
concatenate_movies(fp_final, list_fp_movies) concatenate_movies(fp_final, list_fp_movies)
return fp_final return fp_final
def get_fp_video_all(self): def get_fp_video_all(self):
r"""
Collects all stacked movie segments.
"""
list_all = os.listdir(self.dp_movies) list_all = os.listdir(self.dp_movies)
str_beg = f"movie_{self.user_id}_" str_beg = f"movie_{self.user_id}_"
list_user = [l for l in list_all if str_beg in l] list_user = [l for l in list_all if str_beg in l]
@ -316,8 +326,10 @@ class BlendingFrontend():
list_user = [os.path.join(self.dp_movies, l) for l in list_user] list_user = [os.path.join(self.dp_movies, l) for l in list_user]
return list_user return list_user
def get_fp_video_next(self): def get_fp_video_next(self):
r"""
Gets the filepath of the next movie segment.
"""
list_videos = self.get_fp_video_all() list_videos = self.get_fp_video_all()
if len(list_videos) == 0: if len(list_videos) == 0:
idx_next = 0 idx_next = 0
@ -327,26 +339,16 @@ class BlendingFrontend():
return fp_video_next return fp_video_next
def get_fp_video_last(self): def get_fp_video_last(self):
r"""
Gets the current video that was saved.
"""
fp_video_last = os.path.join(self.dp_movies, f"last_{self.user_id}.mp4") fp_video_last = os.path.join(self.dp_movies, f"last_{self.user_id}.mp4")
return fp_video_last return fp_video_last
def get_state_dict(self):
state_dict = {}
grab_vars = ['prompt1', 'prompt2', 'seed1', 'seed2', 'height', 'width',
'num_inference_steps', 'depth_strength', 'guidance_scale',
'guidance_scale_mid_damper', 'mid_compression_scaler']
for v in grab_vars:
state_dict[v] = getattr(self, v)
return state_dict
if __name__ == "__main__": if __name__ == "__main__":
fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
# fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt" # fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.ckpt")
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
bf = BlendingFrontend(StableDiffusionHolder(fp_ckpt)) bf = BlendingFrontend(StableDiffusionHolder(fp_ckpt))
# self = BlendingFrontend(None) # self = BlendingFrontend(None)
@ -391,7 +393,6 @@ if __name__ == "__main__":
depth_strength = gr.Slider(0.01, 0.99, bf.depth_strength, step=0.01, label='depth_strength', interactive=True) depth_strength = gr.Slider(0.01, 0.99, bf.depth_strength, step=0.01, label='depth_strength', interactive=True)
guidance_scale_mid_damper = gr.Slider(0.01, 2.0, bf.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True) guidance_scale_mid_damper = gr.Slider(0.01, 2.0, bf.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
with gr.Row(): with gr.Row():
b_compute1 = gr.Button('compute first image', variant='primary') b_compute1 = gr.Button('compute first image', variant='primary')
b_compute_transition = gr.Button('compute transition', variant='primary') b_compute_transition = gr.Button('compute transition', variant='primary')
@ -405,11 +406,10 @@ if __name__ == "__main__":
img5 = gr.Image(label="5/5") img5 = gr.Image(label="5/5")
with gr.Row(): with gr.Row():
vid_single = gr.Video(label="single trans") vid_single = gr.Video(label="current single trans")
vid_multi = gr.Video(label="multi trans") vid_multi = gr.Video(label="concatented multi trans")
with gr.Row(): with gr.Row():
# b_restart = gr.Button("RESTART EVERYTHING")
b_stackforward = gr.Button('append last movie segment (left) to multi movie (right)', variant='primary') b_stackforward = gr.Button('append last movie segment (left) to multi movie (right)', variant='primary')
with gr.Row(): with gr.Row():
@ -437,8 +437,7 @@ if __name__ == "__main__":
- parental_crossfeed_power_decay: Similar to branch1_crossfeed_decay, however applied for the images withinin the transition. - parental_crossfeed_power_decay: Similar to branch1_crossfeed_decay, however applied for the images withinin the transition.
- depth_strength: Determines when the blending process will begin in terms of diffusion steps. Low values more inventive but can cause motion. - depth_strength: Determines when the blending process will begin in terms of diffusion steps. Low values more inventive but can cause motion.
- guidance_scale_mid_damper: Decreases the guidance scale in the middle of a transition. - guidance_scale_mid_damper: Decreases the guidance scale in the middle of a transition.
""" """)
)
with gr.Row(): with gr.Row():
user_id = gr.Textbox(label="user id", interactive=False) user_id = gr.Textbox(label="user id", interactive=False)
@ -471,24 +470,23 @@ if __name__ == "__main__":
dict_ui_elem["user_id"] = user_id dict_ui_elem["user_id"] = user_id
# Convert to list, as gradio doesn't seem to accept dicts # Convert to list, as gradio doesn't seem to accept dicts
list_ui_elem = [] list_ui_vals = []
list_ui_keys = [] list_ui_keys = []
for k in dict_ui_elem.keys(): for k in dict_ui_elem.keys():
list_ui_elem.append(dict_ui_elem[k]) list_ui_vals.append(dict_ui_elem[k])
list_ui_keys.append(k) list_ui_keys.append(k)
bf.list_ui_keys = list_ui_keys bf.list_ui_keys = list_ui_keys
b_newseed1.click(bf.randomize_seed1, outputs=seed1) b_newseed1.click(bf.randomize_seed1, outputs=seed1)
b_newseed2.click(bf.randomize_seed2, outputs=seed2) b_newseed2.click(bf.randomize_seed2, outputs=seed2)
b_compute1.click(bf.compute_img1, inputs=list_ui_elem, outputs=[img1, img2, img3, img4, img5, user_id]) b_compute1.click(bf.compute_img1, inputs=list_ui_vals, outputs=[img1, img2, img3, img4, img5, user_id])
b_compute2.click(bf.compute_img2, inputs=list_ui_elem, outputs=[img2, img3, img4, img5, user_id]) b_compute2.click(bf.compute_img2, inputs=list_ui_vals, outputs=[img2, img3, img4, img5, user_id])
b_compute_transition.click(bf.compute_transition, b_compute_transition.click(bf.compute_transition,
inputs=list_ui_elem, inputs=list_ui_vals,
outputs=[img2, img3, img4, vid_single]) outputs=[img2, img3, img4, vid_single])
b_stackforward.click(bf.stack_forward, b_stackforward.click(bf.stack_forward,
inputs=[prompt2, seed2], inputs=[prompt2, seed2],
outputs=[vid_multi, img1, img2, img3, img4, img5, prompt1, seed1, prompt2]) outputs=[vid_multi, img1, img2, img3, img4, img5, prompt1, seed1, prompt2])
demo.launch(share=bf.share, inbrowser=True, inline=False) demo.launch(share=bf.share, inbrowser=True, inline=False)

View File

@ -13,41 +13,31 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os, sys import os
import torch import torch
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.set_grad_enabled(False)
import numpy as np import numpy as np
import warnings import warnings
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
import time import time
import subprocess
import warnings import warnings
from tqdm.auto import tqdm from tqdm.auto import tqdm
from PIL import Image from PIL import Image
# import matplotlib.pyplot as plt
from movie_util import MovieSaver from movie_util import MovieSaver
import datetime from typing import List, Optional
from typing import Callable, List, Optional, Union
import inspect
from threading import Thread
torch.set_grad_enabled(False)
from contextlib import nullcontext
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentInpaintDiffusion from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentInpaintDiffusion
from stable_diffusion_holder import StableDiffusionHolder
import yaml
import lpips import lpips
#%% from utils import interpolate_spherical, interpolate_linear, add_frames_linear_interp, yml_load, yml_save
class LatentBlending(): class LatentBlending():
def __init__( def __init__(
self, self,
sdh: None, sdh: None,
guidance_scale: float = 4, guidance_scale: float = 4,
guidance_scale_mid_damper: float = 0.5, guidance_scale_mid_damper: float = 0.5,
mid_compression_scaler: float = 1.2, mid_compression_scaler: float = 1.2):
):
r""" r"""
Initializes the latent blending class. Initializes the latent blending class.
Args: Args:
@ -64,9 +54,10 @@ class LatentBlending():
Increases the sampling density in the middle (where most changes happen). Higher value Increases the sampling density in the middle (where most changes happen). Higher value
imply more values in the middle. However the inflection point can occur outside the middle, imply more values in the middle. However the inflection point can occur outside the middle,
thus high values can give rough transitions. Values around 2 should be fine. thus high values can give rough transitions. Values around 2 should be fine.
""" """
assert guidance_scale_mid_damper>0 and guidance_scale_mid_damper<=1.0, f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}" assert guidance_scale_mid_damper > 0 \
and guidance_scale_mid_damper <= 1.0, \
f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}"
self.sdh = sdh self.sdh = sdh
self.device = self.sdh.device self.device = self.sdh.device
@ -115,10 +106,8 @@ class LatentBlending():
self.multi_transition_img_last = None self.multi_transition_img_last = None
self.dt_per_diff = 0 self.dt_per_diff = 0
self.spatial_mask = None self.spatial_mask = None
self.lpips = lpips.LPIPS(net='alex').cuda(self.device) self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
def init_mode(self): def init_mode(self):
r""" r"""
Sets the operational mode. Currently supported are standard, inpainting and x4 upscaling. Sets the operational mode. Currently supported are standard, inpainting and x4 upscaling.
@ -151,13 +140,12 @@ class LatentBlending():
Tunes the guidance scale down as a linear function of fract_mixing, Tunes the guidance scale down as a linear function of fract_mixing,
towards 0.5 the minimum will be reached. towards 0.5 the minimum will be reached.
""" """
mid_factor = 1 - np.abs(fract_mixing - 0.5)/ 0.5 mid_factor = 1 - np.abs(fract_mixing - 0.5) / 0.5
max_guidance_reduction = self.guidance_scale_base * (1-self.guidance_scale_mid_damper) - 1 max_guidance_reduction = self.guidance_scale_base * (1 - self.guidance_scale_mid_damper) - 1
guidance_scale_effective = self.guidance_scale_base - max_guidance_reduction*mid_factor guidance_scale_effective = self.guidance_scale_base - max_guidance_reduction * mid_factor
self.guidance_scale = guidance_scale_effective self.guidance_scale = guidance_scale_effective
self.sdh.guidance_scale = guidance_scale_effective self.sdh.guidance_scale = guidance_scale_effective
def set_branch1_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay): def set_branch1_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay):
r""" r"""
Sets the crossfeed parameters for the first branch to the last branch. Sets the crossfeed parameters for the first branch to the last branch.
@ -173,7 +161,6 @@ class LatentBlending():
self.branch1_crossfeed_range = np.clip(crossfeed_range, 0, 1) self.branch1_crossfeed_range = np.clip(crossfeed_range, 0, 1)
self.branch1_crossfeed_decay = np.clip(crossfeed_decay, 0, 1) self.branch1_crossfeed_decay = np.clip(crossfeed_decay, 0, 1)
def set_parental_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay): def set_parental_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay):
r""" r"""
Sets the crossfeed parameters for all transition images (within the first and last branch). Sets the crossfeed parameters for all transition images (within the first and last branch).
@ -189,7 +176,6 @@ class LatentBlending():
self.parental_crossfeed_range = np.clip(crossfeed_range, 0, 1) self.parental_crossfeed_range = np.clip(crossfeed_range, 0, 1)
self.parental_crossfeed_power_decay = np.clip(crossfeed_decay, 0, 1) self.parental_crossfeed_power_decay = np.clip(crossfeed_decay, 0, 1)
def set_prompt1(self, prompt: str): def set_prompt1(self, prompt: str):
r""" r"""
Sets the first prompt (for the first keyframe) including text embeddings. Sets the first prompt (for the first keyframe) including text embeddings.
@ -201,7 +187,6 @@ class LatentBlending():
self.prompt1 = prompt self.prompt1 = prompt
self.text_embedding1 = self.get_text_embeddings(self.prompt1) self.text_embedding1 = self.get_text_embeddings(self.prompt1)
def set_prompt2(self, prompt: str): def set_prompt2(self, prompt: str):
r""" r"""
Sets the second prompt (for the second keyframe) including text embeddings. Sets the second prompt (for the second keyframe) including text embeddings.
@ -237,8 +222,7 @@ class LatentBlending():
depth_strength: Optional[float] = 0.3, depth_strength: Optional[float] = 0.3,
t_compute_max_allowed: Optional[float] = None, t_compute_max_allowed: Optional[float] = None,
nmb_max_branches: Optional[int] = None, nmb_max_branches: Optional[int] = None,
fixed_seeds: Optional[List[int]] = None, fixed_seeds: Optional[List[int]] = None):
):
r""" r"""
Function for computing transitions. Function for computing transitions.
Returns a list of transition images using spherical latent blending. Returns a list of transition images using spherical latent blending.
@ -263,7 +247,6 @@ class LatentBlending():
fixed_seeds: Optional[List[int)]: fixed_seeds: Optional[List[int)]:
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2). You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
Otherwise random seeds will be taken. Otherwise random seeds will be taken.
""" """
# Sanity checks first # Sanity checks first
@ -275,7 +258,7 @@ class LatentBlending():
if fixed_seeds == 'randomize': if fixed_seeds == 'randomize':
fixed_seeds = list(np.random.randint(0, 1000000, 2).astype(np.int32)) fixed_seeds = list(np.random.randint(0, 1000000, 2).astype(np.int32))
else: else:
assert len(fixed_seeds)==2, "Supply a list with len = 2" assert len(fixed_seeds) == 2, "Supply a list with len = 2"
self.seed1 = fixed_seeds[0] self.seed1 = fixed_seeds[0]
self.seed2 = fixed_seeds[1] self.seed2 = fixed_seeds[1]
@ -323,7 +306,6 @@ class LatentBlending():
return self.tree_final_imgs return self.tree_final_imgs
def compute_latents1(self, return_image=False): def compute_latents1(self, return_image=False):
r""" r"""
Runs a diffusion trajectory for the first image Runs a diffusion trajectory for the first image
@ -337,11 +319,10 @@ class LatentBlending():
latents_start = self.get_noise(self.seed1) latents_start = self.get_noise(self.seed1)
list_latents1 = self.run_diffusion( list_latents1 = self.run_diffusion(
list_conditionings, list_conditionings,
latents_start = latents_start, latents_start=latents_start,
idx_start = 0 idx_start=0)
)
t1 = time.time() t1 = time.time()
self.dt_per_diff = (t1-t0) / self.num_inference_steps self.dt_per_diff = (t1 - t0) / self.num_inference_steps
self.tree_latents[0] = list_latents1 self.tree_latents[0] = list_latents1
if return_image: if return_image:
return self.sdh.latent2image(list_latents1[-1]) return self.sdh.latent2image(list_latents1[-1])
@ -361,17 +342,16 @@ class LatentBlending():
# Influence from branch1 # Influence from branch1
if self.branch1_crossfeed_power > 0.0: if self.branch1_crossfeed_power > 0.0:
# Set up the mixing_coeffs # Set up the mixing_coeffs
idx_mixing_stop = int(round(self.num_inference_steps*self.branch1_crossfeed_range)) idx_mixing_stop = int(round(self.num_inference_steps * self.branch1_crossfeed_range))
mixing_coeffs = list(np.linspace(self.branch1_crossfeed_power, self.branch1_crossfeed_power*self.branch1_crossfeed_decay, idx_mixing_stop)) mixing_coeffs = list(np.linspace(self.branch1_crossfeed_power, self.branch1_crossfeed_power * self.branch1_crossfeed_decay, idx_mixing_stop))
mixing_coeffs.extend((self.num_inference_steps-idx_mixing_stop)*[0]) mixing_coeffs.extend((self.num_inference_steps - idx_mixing_stop) * [0])
list_latents_mixing = self.tree_latents[0] list_latents_mixing = self.tree_latents[0]
list_latents2 = self.run_diffusion( list_latents2 = self.run_diffusion(
list_conditionings, list_conditionings,
latents_start = latents_start, latents_start=latents_start,
idx_start = 0, idx_start=0,
list_latents_mixing = list_latents_mixing, list_latents_mixing=list_latents_mixing,
mixing_coeffs = mixing_coeffs mixing_coeffs=mixing_coeffs)
)
else: else:
list_latents2 = self.run_diffusion(list_conditionings, latents_start) list_latents2 = self.run_diffusion(list_conditionings, latents_start)
self.tree_latents[-1] = list_latents2 self.tree_latents[-1] = list_latents2
@ -381,7 +361,6 @@ class LatentBlending():
else: else:
return list_latents2 return list_latents2
def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection): def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection):
r""" r"""
Runs a diffusion trajectory, using the latents from the respective parents Runs a diffusion trajectory, using the latents from the respective parents
@ -409,22 +388,19 @@ class LatentBlending():
latents_parental = interpolate_spherical(latents_p1, latents_p2, fract_mixing_parental) latents_parental = interpolate_spherical(latents_p1, latents_p2, fract_mixing_parental)
list_latents_parental_mix.append(latents_parental) list_latents_parental_mix.append(latents_parental)
idx_mixing_stop = int(round(self.num_inference_steps*self.parental_crossfeed_range)) idx_mixing_stop = int(round(self.num_inference_steps * self.parental_crossfeed_range))
mixing_coeffs = idx_injection*[self.parental_crossfeed_power] mixing_coeffs = idx_injection * [self.parental_crossfeed_power]
nmb_mixing = idx_mixing_stop - idx_injection nmb_mixing = idx_mixing_stop - idx_injection
if nmb_mixing > 0: if nmb_mixing > 0:
mixing_coeffs.extend(list(np.linspace(self.parental_crossfeed_power, self.parental_crossfeed_power*self.parental_crossfeed_power_decay, nmb_mixing))) mixing_coeffs.extend(list(np.linspace(self.parental_crossfeed_power, self.parental_crossfeed_power * self.parental_crossfeed_power_decay, nmb_mixing)))
mixing_coeffs.extend((self.num_inference_steps-len(mixing_coeffs))*[0]) mixing_coeffs.extend((self.num_inference_steps - len(mixing_coeffs)) * [0])
latents_start = list_latents_parental_mix[idx_injection - 1]
latents_start = list_latents_parental_mix[idx_injection-1]
list_latents = self.run_diffusion( list_latents = self.run_diffusion(
list_conditionings, list_conditionings,
latents_start = latents_start, latents_start=latents_start,
idx_start = idx_injection, idx_start=idx_injection,
list_latents_mixing = list_latents_parental_mix, list_latents_mixing=list_latents_parental_mix,
mixing_coeffs = mixing_coeffs mixing_coeffs=mixing_coeffs)
)
return list_latents return list_latents
def get_time_based_branching(self, depth_strength, t_compute_max_allowed=None, nmb_max_branches=None): def get_time_based_branching(self, depth_strength, t_compute_max_allowed=None, nmb_max_branches=None):
@ -445,8 +421,8 @@ class LatentBlending():
results. Use this if you want to have controllable results independent results. Use this if you want to have controllable results independent
of your computer. of your computer.
""" """
idx_injection_base = int(round(self.num_inference_steps*depth_strength)) idx_injection_base = int(round(self.num_inference_steps * depth_strength))
list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps-1, 3) list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps - 1, 3)
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32) list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
t_compute = 0 t_compute = 0
@ -456,20 +432,18 @@ class LatentBlending():
elif t_compute_max_allowed is None: elif t_compute_max_allowed is None:
assert nmb_max_branches is not None, "Either specify t_compute_max_allowed or nmb_max_branches" assert nmb_max_branches is not None, "Either specify t_compute_max_allowed or nmb_max_branches"
stop_criterion = "nmb_max_branches" stop_criterion = "nmb_max_branches"
nmb_max_branches -= 2 # discounting the outer frames nmb_max_branches -= 2 # Discounting the outer frames
else: else:
raise ValueError("Either specify t_compute_max_allowed or nmb_max_branches") raise ValueError("Either specify t_compute_max_allowed or nmb_max_branches")
stop_criterion_reached = False stop_criterion_reached = False
is_first_iteration = True is_first_iteration = True
while not stop_criterion_reached: while not stop_criterion_reached:
list_compute_steps = self.num_inference_steps - list_idx_injection list_compute_steps = self.num_inference_steps - list_idx_injection
list_compute_steps *= list_nmb_stems list_compute_steps *= list_nmb_stems
t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15*np.sum(list_nmb_stems) t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15 * np.sum(list_nmb_stems)
increase_done = False increase_done = False
for s_idx in range(len(list_nmb_stems)-1): for s_idx in range(len(list_nmb_stems) - 1):
if list_nmb_stems[s_idx+1] / list_nmb_stems[s_idx] >= 2: if list_nmb_stems[s_idx + 1] / list_nmb_stems[s_idx] >= 2:
list_nmb_stems[s_idx] += 1 list_nmb_stems[s_idx] += 1
increase_done = True increase_done = True
break break
@ -501,10 +475,10 @@ class LatentBlending():
""" """
# get_lpips_similarity # get_lpips_similarity
similarities = [] similarities = []
for i in range(len(self.tree_final_imgs)-1): for i in range(len(self.tree_final_imgs) - 1):
similarities.append(self.get_lpips_similarity(self.tree_final_imgs[i], self.tree_final_imgs[i+1])) similarities.append(self.get_lpips_similarity(self.tree_final_imgs[i], self.tree_final_imgs[i + 1]))
b_closest1 = np.argmax(similarities) b_closest1 = np.argmax(similarities)
b_closest2 = b_closest1+1 b_closest2 = b_closest1 + 1
fract_closest1 = self.tree_fracts[b_closest1] fract_closest1 = self.tree_fracts[b_closest1]
fract_closest2 = self.tree_fracts[b_closest2] fract_closest2 = self.tree_fracts[b_closest2]
@ -515,23 +489,15 @@ class LatentBlending():
break break
else: else:
b_parent1 -= 1 b_parent1 -= 1
b_parent2 = b_closest2 b_parent2 = b_closest2
while True: while True:
if self.tree_idx_injection[b_parent2] < idx_injection: if self.tree_idx_injection[b_parent2] < idx_injection:
break break
else: else:
b_parent2 += 1 b_parent2 += 1
fract_mixing = (fract_closest1 + fract_closest2) / 2
# print(f"\n\nb_closest: {b_closest1} {b_closest2} fract_closest1 {fract_closest1} fract_closest2 {fract_closest2}")
# print(f"b_parent: {b_parent1} {b_parent2}")
# print(f"similarities {similarities}")
# print(f"idx_injection {idx_injection} tree_idx_injection {self.tree_idx_injection}")
fract_mixing = (fract_closest1 + fract_closest2) /2
return fract_mixing, b_parent1, b_parent2 return fract_mixing, b_parent1, b_parent2
def insert_into_tree(self, fract_mixing, idx_injection, list_latents): def insert_into_tree(self, fract_mixing, idx_injection, list_latents):
r""" r"""
Inserts all necessary parameters into the trajectory tree. Inserts all necessary parameters into the trajectory tree.
@ -543,12 +509,11 @@ class LatentBlending():
list_latents: list list_latents: list
list of the latents to be inserted list of the latents to be inserted
""" """
b_parent1, b_parent2 = get_closest_idx(fract_mixing, self.tree_fracts) b_parent1, b_parent2 = self.get_closest_idx(fract_mixing)
self.tree_latents.insert(b_parent1+1, list_latents) self.tree_latents.insert(b_parent1 + 1, list_latents)
self.tree_final_imgs.insert(b_parent1+1, self.sdh.latent2image(list_latents[-1])) self.tree_final_imgs.insert(b_parent1 + 1, self.sdh.latent2image(list_latents[-1]))
self.tree_fracts.insert(b_parent1+1, fract_mixing) self.tree_fracts.insert(b_parent1 + 1, fract_mixing)
self.tree_idx_injection.insert(b_parent1+1, idx_injection) self.tree_idx_injection.insert(b_parent1 + 1, idx_injection)
def get_spatial_mask_template(self): def get_spatial_mask_template(self):
r""" r"""
@ -565,9 +530,7 @@ class LatentBlending():
Args: Args:
img_mask: img_mask:
mask image [0,1]. You can get a template using get_spatial_mask_template mask image [0,1]. You can get a template using get_spatial_mask_template
""" """
shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f] shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
C, H, W = shape_latents C, H, W = shape_latents
img_mask = np.asarray(img_mask) img_mask = np.asarray(img_mask)
@ -577,18 +540,15 @@ class LatentBlending():
assert img_mask.shape[1] == W, f"Your mask needs to be of dimension {H} x {W}" assert img_mask.shape[1] == W, f"Your mask needs to be of dimension {H} x {W}"
spatial_mask = torch.from_numpy(img_mask).to(device=self.device) spatial_mask = torch.from_numpy(img_mask).to(device=self.device)
spatial_mask = torch.unsqueeze(spatial_mask, 0) spatial_mask = torch.unsqueeze(spatial_mask, 0)
spatial_mask = spatial_mask.repeat((C,1,1)) spatial_mask = spatial_mask.repeat((C, 1, 1))
spatial_mask = torch.unsqueeze(spatial_mask, 0) spatial_mask = torch.unsqueeze(spatial_mask, 0)
self.spatial_mask = spatial_mask self.spatial_mask = spatial_mask
def get_noise(self, seed): def get_noise(self, seed):
r""" r"""
Helper function to get noise given seed. Helper function to get noise given seed.
Args: Args:
seed: int seed: int
""" """
generator = torch.Generator(device=self.sdh.device).manual_seed(int(seed)) generator = torch.Generator(device=self.sdh.device).manual_seed(int(seed))
if self.mode == 'standard': if self.mode == 'standard':
@ -599,21 +559,17 @@ class LatentBlending():
h = self.image1_lowres.size[1] h = self.image1_lowres.size[1]
shape_latents = [self.sdh.model.channels, h, w] shape_latents = [self.sdh.model.channels, h, w]
C, H, W = shape_latents C, H, W = shape_latents
return torch.randn((1, C, H, W), generator=generator, device=self.sdh.device) return torch.randn((1, C, H, W), generator=generator, device=self.sdh.device)
@torch.no_grad() @torch.no_grad()
def run_diffusion( def run_diffusion(
self, self,
list_conditionings, list_conditionings,
latents_start: torch.FloatTensor = None, latents_start: torch.FloatTensor = None,
idx_start: int = 0, idx_start: int = 0,
list_latents_mixing = None, list_latents_mixing=None,
mixing_coeffs = 0.0, mixing_coeffs=0.0,
return_image: Optional[bool] = False return_image: Optional[bool] = False):
):
r""" r"""
Wrapper function for diffusion runners. Wrapper function for diffusion runners.
Depending on the mode, the correct one will be executed. Depending on the mode, the correct one will be executed.
@ -640,14 +596,13 @@ class LatentBlending():
if self.mode == 'standard': if self.mode == 'standard':
text_embeddings = list_conditionings[0] text_embeddings = list_conditionings[0]
return self.sdh.run_diffusion_standard( return self.sdh.run_diffusion_standard(
text_embeddings = text_embeddings, text_embeddings=text_embeddings,
latents_start = latents_start, latents_start=latents_start,
idx_start = idx_start, idx_start=idx_start,
list_latents_mixing = list_latents_mixing, list_latents_mixing=list_latents_mixing,
mixing_coeffs = mixing_coeffs, mixing_coeffs=mixing_coeffs,
spatial_mask = self.spatial_mask, spatial_mask=self.spatial_mask,
return_image = return_image, return_image=return_image)
)
elif self.mode == 'upscale': elif self.mode == 'upscale':
cond = list_conditionings[0] cond = list_conditionings[0]
@ -657,11 +612,10 @@ class LatentBlending():
uc_full, uc_full,
latents_start=latents_start, latents_start=latents_start,
idx_start=idx_start, idx_start=idx_start,
list_latents_mixing = list_latents_mixing, list_latents_mixing=list_latents_mixing,
mixing_coeffs = mixing_coeffs, mixing_coeffs=mixing_coeffs,
return_image=return_image) return_image=return_image)
def run_upscaling( def run_upscaling(
self, self,
dp_img: str, dp_img: str,
@ -669,9 +623,9 @@ class LatentBlending():
num_inference_steps: int = 100, num_inference_steps: int = 100,
nmb_max_branches_highres: int = 5, nmb_max_branches_highres: int = 5,
nmb_max_branches_lowres: int = 6, nmb_max_branches_lowres: int = 6,
duration_single_segment = 3, duration_single_segment=3,
fixed_seeds: Optional[List[int]] = None, fps=24,
): fixed_seeds: Optional[List[int]] = None):
r""" r"""
Runs upscaling with the x4 model. Requires that you run a transition before with a low-res model and save the results using write_imgs_transition. Runs upscaling with the x4 model. Requires that you run a transition before with a low-res model and save the results using write_imgs_transition.
@ -692,13 +646,14 @@ class LatentBlending():
Setting this number lower (e.g. 6) will decrease the compute time but not affect the results too much. Setting this number lower (e.g. 6) will decrease the compute time but not affect the results too much.
duration_single_segment: float duration_single_segment: float
The duration of each high-res movie segment. You will have nmb_max_branches_lowres-1 segments in total. The duration of each high-res movie segment. You will have nmb_max_branches_lowres-1 segments in total.
fps: float
frames per second of movie
fixed_seeds: Optional[List[int)]: fixed_seeds: Optional[List[int)]:
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2). You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
Otherwise random seeds will be taken. Otherwise random seeds will be taken.
""" """
fp_yml = os.path.join(dp_img, "lowres.yaml") fp_yml = os.path.join(dp_img, "lowres.yaml")
fp_movie = os.path.join(dp_img, "movie_highres.mp4") fp_movie = os.path.join(dp_img, "movie_highres.mp4")
fps = 24
ms = MovieSaver(fp_movie, fps=fps) ms = MovieSaver(fp_movie, fps=fps)
assert os.path.isfile(fp_yml), "lowres.yaml does not exist. did you forget run_upscaling_step1?" assert os.path.isfile(fp_yml), "lowres.yaml does not exist. did you forget run_upscaling_step1?"
dict_stuff = yml_load(fp_yml) dict_stuff = yml_load(fp_yml)
@ -707,53 +662,43 @@ class LatentBlending():
nmb_images_lowres = dict_stuff['nmb_images'] nmb_images_lowres = dict_stuff['nmb_images']
prompt1 = dict_stuff['prompt1'] prompt1 = dict_stuff['prompt1']
prompt2 = dict_stuff['prompt2'] prompt2 = dict_stuff['prompt2']
idx_img_lowres = np.round(np.linspace(0, nmb_images_lowres-1, nmb_max_branches_lowres)).astype(np.int32) idx_img_lowres = np.round(np.linspace(0, nmb_images_lowres - 1, nmb_max_branches_lowres)).astype(np.int32)
imgs_lowres = [] imgs_lowres = []
for i in idx_img_lowres: for i in idx_img_lowres:
fp_img_lowres = os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg") fp_img_lowres = os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg")
assert os.path.isfile(fp_img_lowres), f"{fp_img_lowres} does not exist. did you forget run_upscaling_step1?" assert os.path.isfile(fp_img_lowres), f"{fp_img_lowres} does not exist. did you forget run_upscaling_step1?"
imgs_lowres.append(Image.open(fp_img_lowres)) imgs_lowres.append(Image.open(fp_img_lowres))
# set up upscaling # set up upscaling
text_embeddingA = self.sdh.get_text_embedding(prompt1) text_embeddingA = self.sdh.get_text_embedding(prompt1)
text_embeddingB = self.sdh.get_text_embedding(prompt2) text_embeddingB = self.sdh.get_text_embedding(prompt2)
list_fract_mixing = np.linspace(0, 1, nmb_max_branches_lowres - 1)
list_fract_mixing = np.linspace(0, 1, nmb_max_branches_lowres-1) for i in range(nmb_max_branches_lowres - 1):
for i in range(nmb_max_branches_lowres-1):
print(f"Starting movie segment {i+1}/{nmb_max_branches_lowres-1}") print(f"Starting movie segment {i+1}/{nmb_max_branches_lowres-1}")
self.text_embedding1 = interpolate_linear(text_embeddingA, text_embeddingB, list_fract_mixing[i]) self.text_embedding1 = interpolate_linear(text_embeddingA, text_embeddingB, list_fract_mixing[i])
self.text_embedding2 = interpolate_linear(text_embeddingA, text_embeddingB, 1-list_fract_mixing[i]) self.text_embedding2 = interpolate_linear(text_embeddingA, text_embeddingB, 1 - list_fract_mixing[i])
if i == 0:
if i==0:
recycle_img1 = False recycle_img1 = False
else: else:
self.swap_forward() self.swap_forward()
recycle_img1 = True recycle_img1 = True
self.set_image1(imgs_lowres[i]) self.set_image1(imgs_lowres[i])
self.set_image2(imgs_lowres[i+1]) self.set_image2(imgs_lowres[i + 1])
list_imgs = self.run_transition( list_imgs = self.run_transition(
recycle_img1 = recycle_img1, recycle_img1=recycle_img1,
recycle_img2 = False, recycle_img2=False,
num_inference_steps = num_inference_steps, num_inference_steps=num_inference_steps,
depth_strength = depth_strength, depth_strength=depth_strength,
nmb_max_branches = nmb_max_branches_highres, nmb_max_branches=nmb_max_branches_highres)
)
list_imgs_interp = add_frames_linear_interp(list_imgs, fps, duration_single_segment) list_imgs_interp = add_frames_linear_interp(list_imgs, fps, duration_single_segment)
# Save movie frame # Save movie frame
for img in list_imgs_interp: for img in list_imgs_interp:
ms.write_frame(img) ms.write_frame(img)
ms.finalize() ms.finalize()
@torch.no_grad() @torch.no_grad()
def get_mixed_conditioning(self, fract_mixing): def get_mixed_conditioning(self, fract_mixing):
if self.mode == 'standard': if self.mode == 'standard':
@ -776,8 +721,7 @@ class LatentBlending():
@torch.no_grad() @torch.no_grad()
def get_text_embeddings( def get_text_embeddings(
self, self,
prompt: str prompt: str):
):
r""" r"""
Computes the text embeddings provided a string with a prompts. Computes the text embeddings provided a string with a prompts.
Adapted from stable diffusion repo Adapted from stable diffusion repo
@ -785,10 +729,8 @@ class LatentBlending():
prompt: str prompt: str
ABC trending on artstation painted by Old Greg. ABC trending on artstation painted by Old Greg.
""" """
return self.sdh.get_text_embedding(prompt) return self.sdh.get_text_embedding(prompt)
def write_imgs_transition(self, dp_img): def write_imgs_transition(self, dp_img):
r""" r"""
Writes the transition images into the folder dp_img. Writes the transition images into the folder dp_img.
@ -802,7 +744,6 @@ class LatentBlending():
for i, img in enumerate(imgs_transition): for i, img in enumerate(imgs_transition):
img_leaf = Image.fromarray(img) img_leaf = Image.fromarray(img)
img_leaf.save(os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg")) img_leaf.save(os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg"))
fp_yml = os.path.join(dp_img, "lowres.yaml") fp_yml = os.path.join(dp_img, "lowres.yaml")
self.save_statedict(fp_yml) self.save_statedict(fp_yml)
@ -817,7 +758,6 @@ class LatentBlending():
duration of the movie in seonds duration of the movie in seonds
fps: int fps: int
fps of the movie fps of the movie
""" """
# Let's get more cheap frames via linear interpolation (duration_transition*fps frames) # Let's get more cheap frames via linear interpolation (duration_transition*fps frames)
@ -831,8 +771,6 @@ class LatentBlending():
ms.write_frame(img) ms.write_frame(img)
ms.finalize() ms.finalize()
def save_statedict(self, fp_yml): def save_statedict(self, fp_yml):
# Dump everything relevant into yaml # Dump everything relevant into yaml
imgs_transition = self.tree_final_imgs imgs_transition = self.tree_final_imgs
@ -857,9 +795,8 @@ class LatentBlending():
else: else:
try: try:
state_dict[v] = getattr(self, v) state_dict[v] = getattr(self, v)
except Exception as e: except Exception:
pass pass
return state_dict return state_dict
def randomize_seed(self): def randomize_seed(self):
@ -892,7 +829,6 @@ class LatentBlending():
self.height = height self.height = height
self.sdh.height = height self.sdh.height = height
def swap_forward(self): def swap_forward(self):
r""" r"""
Moves over keyframe two -> keyframe one. Useful for making a sequence of transitions Moves over keyframe two -> keyframe one. Useful for making a sequence of transitions
@ -900,15 +836,12 @@ class LatentBlending():
""" """
# Move over all latents # Move over all latents
self.tree_latents[0] = self.tree_latents[-1] self.tree_latents[0] = self.tree_latents[-1]
# Move over prompts and text embeddings # Move over prompts and text embeddings
self.prompt1 = self.prompt2 self.prompt1 = self.prompt2
self.text_embedding1 = self.text_embedding2 self.text_embedding1 = self.text_embedding2
# Final cleanup for extra sanity # Final cleanup for extra sanity
self.tree_final_imgs = [] self.tree_final_imgs = []
def get_lpips_similarity(self, imgA, imgB): def get_lpips_similarity(self, imgA, imgB):
r""" r"""
Computes the image similarity between two images imgA and imgB. Computes the image similarity between two images imgA and imgB.
@ -916,36 +849,32 @@ class LatentBlending():
High values indicate low similarity. High values indicate low similarity.
""" """
tensorA = torch.from_numpy(imgA).float().cuda(self.device) tensorA = torch.from_numpy(imgA).float().cuda(self.device)
tensorA = 2*tensorA/255.0 - 1 tensorA = 2 * tensorA / 255.0 - 1
tensorA = tensorA.permute([2,0,1]).unsqueeze(0) tensorA = tensorA.permute([2, 0, 1]).unsqueeze(0)
tensorB = torch.from_numpy(imgB).float().cuda(self.device) tensorB = torch.from_numpy(imgB).float().cuda(self.device)
tensorB = 2*tensorB/255.0 - 1 tensorB = 2 * tensorB / 255.0 - 1
tensorB = tensorB.permute([2,0,1]).unsqueeze(0) tensorB = tensorB.permute([2, 0, 1]).unsqueeze(0)
lploss = self.lpips(tensorA, tensorB) lploss = self.lpips(tensorA, tensorB)
lploss = float(lploss[0][0][0][0]) lploss = float(lploss[0][0][0][0])
return lploss return lploss
# Auxiliary functions
# Auxiliary functions def get_closest_idx(
def get_closest_idx( self,
fract_mixing: float, fract_mixing: float):
list_fract_mixing_prev: List[float],
):
r""" r"""
Helper function to retrieve the parents for any given mixing. Helper function to retrieve the parents for any given mixing.
Example: fract_mixing = 0.4 and list_fract_mixing_prev = [0, 0.3, 0.6, 1.0] Example: fract_mixing = 0.4 and self.tree_fracts = [0, 0.3, 0.6, 1.0]
Will return the two closest values from list_fract_mixing_prev, i.e. [1, 2] Will return the two closest values here, i.e. [1, 2]
""" """
pdist = fract_mixing - np.asarray(list_fract_mixing_prev) pdist = fract_mixing - np.asarray(self.tree_fracts)
pdist_pos = pdist.copy() pdist_pos = pdist.copy()
pdist_pos[pdist_pos<0] = np.inf pdist_pos[pdist_pos < 0] = np.inf
b_parent1 = np.argmin(pdist_pos) b_parent1 = np.argmin(pdist_pos)
pdist_neg = -pdist.copy() pdist_neg = -pdist.copy()
pdist_neg[pdist_neg<=0] = np.inf pdist_neg[pdist_neg <= 0] = np.inf
b_parent2= np.argmin(pdist_neg) b_parent2 = np.argmin(pdist_neg)
if b_parent1 > b_parent2: if b_parent1 > b_parent2:
tmp = b_parent2 tmp = b_parent2
@ -953,291 +882,3 @@ def get_closest_idx(
b_parent1 = tmp b_parent1 = tmp
return b_parent1, b_parent2 return b_parent1, b_parent2
@torch.no_grad()
def interpolate_spherical(p0, p1, fract_mixing: float):
r"""
Helper function to correctly mix two random variables using spherical interpolation.
See https://en.wikipedia.org/wiki/Slerp
The function will always cast up to float64 for sake of extra 4.
Args:
p0:
First tensor for interpolation
p1:
Second tensor for interpolation
fract_mixing: float
Mixing coefficient of interval [0, 1].
0 will return in p0
1 will return in p1
0.x will return a mix between both preserving angular velocity.
"""
if p0.dtype == torch.float16:
recast_to = 'fp16'
else:
recast_to = 'fp32'
p0 = p0.double()
p1 = p1.double()
norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
epsilon = 1e-7
dot = torch.sum(p0 * p1) / norm
dot = dot.clamp(-1+epsilon, 1-epsilon)
theta_0 = torch.arccos(dot)
sin_theta_0 = torch.sin(theta_0)
theta_t = theta_0 * fract_mixing
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
s1 = torch.sin(theta_t) / sin_theta_0
interp = p0*s0 + p1*s1
if recast_to == 'fp16':
interp = interp.half()
elif recast_to == 'fp32':
interp = interp.float()
return interp
def interpolate_linear(p0, p1, fract_mixing):
r"""
Helper function to mix two variables using standard linear interpolation.
Args:
p0:
First tensor / np.ndarray for interpolation
p1:
Second tensor / np.ndarray for interpolation
fract_mixing: float
Mixing coefficient of interval [0, 1].
0 will return in p0
1 will return in p1
0.x will return a linear mix between both.
"""
reconvert_uint8 = False
if type(p0) is np.ndarray and p0.dtype == 'uint8':
reconvert_uint8 = True
p0 = p0.astype(np.float64)
if type(p1) is np.ndarray and p1.dtype == 'uint8':
reconvert_uint8 = True
p1 = p1.astype(np.float64)
interp = (1-fract_mixing) * p0 + fract_mixing * p1
if reconvert_uint8:
interp = np.clip(interp, 0, 255).astype(np.uint8)
return interp
def add_frames_linear_interp(
list_imgs: List[np.ndarray],
fps_target: Union[float, int] = None,
duration_target: Union[float, int] = None,
nmb_frames_target: int=None,
):
r"""
Helper function to cheaply increase the number of frames given a list of images,
by virtue of standard linear interpolation.
The number of inserted frames will be automatically adjusted so that the total of number
of frames can be fixed precisely, using a random shuffling technique.
The function allows 1:1 comparisons between transitions as videos.
Args:
list_imgs: List[np.ndarray)
List of images, between each image new frames will be inserted via linear interpolation.
fps_target:
OptionA: specify here the desired frames per second.
duration_target:
OptionA: specify here the desired duration of the transition in seconds.
nmb_frames_target:
OptionB: directly fix the total number of frames of the output.
"""
# Sanity
if nmb_frames_target is not None and fps_target is not None:
raise ValueError("You cannot specify both fps_target and nmb_frames_target")
if fps_target is None:
assert nmb_frames_target is not None, "Either specify nmb_frames_target or nmb_frames_target"
if nmb_frames_target is None:
assert fps_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
assert duration_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
nmb_frames_target = fps_target*duration_target
# Get number of frames that are missing
nmb_frames_diff = len(list_imgs)-1
nmb_frames_missing = nmb_frames_target - nmb_frames_diff - 1
if nmb_frames_missing < 1:
return list_imgs
list_imgs_float = [img.astype(np.float32) for img in list_imgs]
# Distribute missing frames, append nmb_frames_to_insert(i) frames for each frame
mean_nmb_frames_insert = nmb_frames_missing/nmb_frames_diff
constfact = np.floor(mean_nmb_frames_insert)
remainder_x = 1-(mean_nmb_frames_insert - constfact)
nmb_iter = 0
while True:
nmb_frames_to_insert = np.random.rand(nmb_frames_diff)
nmb_frames_to_insert[nmb_frames_to_insert<=remainder_x] = 0
nmb_frames_to_insert[nmb_frames_to_insert>remainder_x] = 1
nmb_frames_to_insert += constfact
if np.sum(nmb_frames_to_insert) == nmb_frames_missing:
break
nmb_iter += 1
if nmb_iter > 100000:
print("add_frames_linear_interp: issue with inserting the right number of frames")
break
nmb_frames_to_insert = nmb_frames_to_insert.astype(np.int32)
list_imgs_interp = []
for i in range(len(list_imgs_float)-1):#, desc="STAGE linear interp"):
img0 = list_imgs_float[i]
img1 = list_imgs_float[i+1]
list_imgs_interp.append(img0.astype(np.uint8))
list_fracts_linblend = np.linspace(0, 1, nmb_frames_to_insert[i]+2)[1:-1]
for fract_linblend in list_fracts_linblend:
img_blend = interpolate_linear(img0, img1, fract_linblend).astype(np.uint8)
list_imgs_interp.append(img_blend.astype(np.uint8))
if i==len(list_imgs_float)-2:
list_imgs_interp.append(img1.astype(np.uint8))
return list_imgs_interp
def get_spacing(nmb_points: int, scaling: float):
"""
Helper function for getting nonlinear spacing between 0 and 1, symmetric around 0.5
Args:
nmb_points: int
Number of points between [0, 1]
scaling: float
Higher values will return higher sampling density around 0.5
"""
if scaling < 1.7:
return np.linspace(0, 1, nmb_points)
nmb_points_per_side = nmb_points//2 + 1
if np.mod(nmb_points, 2) != 0: # uneven case
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)
right_side = 1-left_side[::-1][1:]
else:
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)[0:-1]
right_side = 1-left_side[::-1]
all_fracts = np.hstack([left_side, right_side])
return all_fracts
def get_time(resolution=None):
"""
Helper function returning an nicely formatted time string, e.g. 221117_1620
"""
if resolution==None:
resolution="second"
if resolution == "day":
t = time.strftime('%y%m%d', time.localtime())
elif resolution == "minute":
t = time.strftime('%y%m%d_%H%M', time.localtime())
elif resolution == "second":
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
elif resolution == "millisecond":
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
t += "_"
t += str("{:03d}".format(int(int(datetime.utcnow().strftime('%f'))/1000)))
else:
raise ValueError("bad resolution provided: %s" %resolution)
return t
def compare_dicts(a, b):
"""
Compares two dictionaries a and b and returns a dictionary c, with all
keys,values that have shared keys in a and b but same values in a and b.
The values of a and b are stacked together in the output.
Example:
a = {}; a['bobo'] = 4
b = {}; b['bobo'] = 5
c = dict_compare(a,b)
c = {"bobo",[4,5]}
"""
c = {}
for key in a.keys():
if key in b.keys():
val_a = a[key]
val_b = b[key]
if val_a != val_b:
c[key] = [val_a, val_b]
return c
def yml_load(fp_yml, print_fields=False):
"""
Helper function for loading yaml files
"""
with open(fp_yml) as f:
data = yaml.load(f, Loader=yaml.loader.SafeLoader)
dict_data = dict(data)
print("load: loaded {}".format(fp_yml))
return dict_data
def yml_save(fp_yml, dict_stuff):
"""
Helper function for saving yaml files
"""
with open(fp_yml, 'w') as f:
data = yaml.dump(dict_stuff, f, sort_keys=False, default_flow_style=False)
print("yml_save: saved {}".format(fp_yml))
#%% le main
if __name__ == "__main__":
# xxxx
#%% First let us spawn a stable diffusion holder
device = "cuda"
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
sdh = StableDiffusionHolder(fp_ckpt)
xxx
#%% Next let's set up all parameters
depth_strength = 0.3 # Specifies how deep (in terms of diffusion iterations the first branching happens)
fixed_seeds = [697164, 430214]
prompt1 = "photo of a desert and a sky"
prompt2 = "photo of a tree with a lake"
duration_transition = 12 # In seconds
fps = 30
# Spawn latent blending
self = LatentBlending(sdh)
self.set_prompt1(prompt1)
self.set_prompt2(prompt2)
# Run latent blending
self.branch1_crossfeed_power = 0.3
self.branch1_crossfeed_range = 0.4
# self.run_transition(depth_strength=depth_strength, fixed_seeds=fixed_seeds)
self.seed1=21312
img1 =self.compute_latents1(True)
#%
self.seed2=1234121
self.branch1_crossfeed_power = 0.7
self.branch1_crossfeed_range = 0.3
self.branch1_crossfeed_decay = 0.3
img2 =self.compute_latents2(True)
# Image.fromarray(np.concatenate((img1, img2), axis=1))
#%%
t0 = time.time()
self.t_compute_max_allowed = 30
self.parental_crossfeed_range = 1.0
self.parental_crossfeed_power = 0.0
self.parental_crossfeed_power_decay = 1.0
imgs_transition = self.run_transition(recycle_img1=True, recycle_img2=True)
t1 = time.time()
print(f"took: {t1-t0}s")

View File

@ -1,5 +1,6 @@
# Copyright 2022 Lunar Ring. All rights reserved. # Copyright 2022 Lunar Ring. All rights reserved.
# # Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
@ -17,10 +18,9 @@ import os
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
import cv2 import cv2
from typing import Callable, List, Optional, Union from typing import List
import ffmpeg # pip install ffmpeg-python. if error with broken pipe: conda update ffmpeg import ffmpeg # pip install ffmpeg-python. if error with broken pipe: conda update ffmpeg
#%%
class MovieSaver(): class MovieSaver():
def __init__( def __init__(
@ -30,10 +30,9 @@ class MovieSaver():
shape_hw: List[int] = None, shape_hw: List[int] = None,
crf: int = 24, crf: int = 24,
codec: str = 'libx264', codec: str = 'libx264',
preset: str ='fast', preset: str = 'fast',
pix_fmt: str = 'yuv420p', pix_fmt: str = 'yuv420p',
silent_ffmpeg: bool = True silent_ffmpeg: bool = True):
):
r""" r"""
Initializes movie saver class - a human friendly ffmpeg wrapper. Initializes movie saver class - a human friendly ffmpeg wrapper.
After you init the class, you can dump numpy arrays x into moviesaver.write_frame(x). After you init the class, you can dump numpy arrays x into moviesaver.write_frame(x).
@ -92,10 +91,8 @@ class MovieSaver():
self.shape_hw = shape_hw self.shape_hw = shape_hw
self.initialize() self.initialize()
print(f"MovieSaver initialized. fps={fps} crf={crf} pix_fmt={pix_fmt} codec={codec} preset={preset}") print(f"MovieSaver initialized. fps={fps} crf={crf} pix_fmt={pix_fmt} codec={codec} preset={preset}")
def initialize(self): def initialize(self):
args = ( args = (
ffmpeg ffmpeg
@ -112,7 +109,6 @@ class MovieSaver():
self.shape_hw = tuple(self.shape_hw) self.shape_hw = tuple(self.shape_hw)
print(f"Initialization done. Movie shape: {self.shape_hw}") print(f"Initialization done. Movie shape: {self.shape_hw}")
def write_frame(self, out_frame: np.ndarray): def write_frame(self, out_frame: np.ndarray):
r""" r"""
Function to dump a numpy array as frame of a movie. Function to dump a numpy array as frame of a movie.
@ -123,7 +119,6 @@ class MovieSaver():
Dim 1: x Dim 1: x
Dim 2: RGB Dim 2: RGB
""" """
assert out_frame.dtype == np.uint8, "Convert to np.uint8 before" assert out_frame.dtype == np.uint8, "Convert to np.uint8 before"
assert len(out_frame.shape) == 3, "out_frame needs to be three dimensional, Y X C" assert len(out_frame.shape) == 3, "out_frame needs to be three dimensional, Y X C"
assert out_frame.shape[2] == 3, f"need three color channels, but you provided {out_frame.shape[2]}." assert out_frame.shape[2] == 3, f"need three color channels, but you provided {out_frame.shape[2]}."
@ -143,7 +138,6 @@ class MovieSaver():
self.nmb_frames += 1 self.nmb_frames += 1
def finalize(self): def finalize(self):
r""" r"""
Call this function to finalize the movie. If you forget to call it your movie will be garbage. Call this function to finalize the movie. If you forget to call it your movie will be garbage.
@ -157,7 +151,6 @@ class MovieSaver():
print(f"Movie saved, {duration}s playtime, watch here: \n{self.fp_out}") print(f"Movie saved, {duration}s playtime, watch here: \n{self.fp_out}")
def concatenate_movies(fp_final: str, list_fp_movies: List[str]): def concatenate_movies(fp_final: str, list_fp_movies: List[str]):
r""" r"""
Concatenate multiple movie segments into one long movie, using ffmpeg. Concatenate multiple movie segments into one long movie, using ffmpeg.
@ -189,7 +182,6 @@ def concatenate_movies(fp_final: str, list_fp_movies: List[str]):
fa.write("%s\n" % item) fa.write("%s\n" % item)
cmd = f'ffmpeg -f concat -safe 0 -i {fp_list} -c copy {fp_final}' cmd = f'ffmpeg -f concat -safe 0 -i {fp_list} -c copy {fp_final}'
dp_movie = os.path.split(fp_final)[0]
subprocess.call(cmd, shell=True) subprocess.call(cmd, shell=True)
os.remove(fp_list) os.remove(fp_list)
if os.path.isfile(fp_final): if os.path.isfile(fp_final):
@ -200,11 +192,12 @@ class MovieReader():
r""" r"""
Class to read in a movie. Class to read in a movie.
""" """
def __init__(self, fp_movie): def __init__(self, fp_movie):
self.video_player_object = cv2.VideoCapture(fp_movie) self.video_player_object = cv2.VideoCapture(fp_movie)
self.nmb_frames = int(self.video_player_object.get(cv2.CAP_PROP_FRAME_COUNT)) self.nmb_frames = int(self.video_player_object.get(cv2.CAP_PROP_FRAME_COUNT))
self.fps_movie = int(self.video_player_object.get(cv2.CAP_PROP_FPS)) self.fps_movie = int(self.video_player_object.get(cv2.CAP_PROP_FPS))
self.shape = [100,100,3] self.shape = [100, 100, 3]
self.shape_is_set = False self.shape_is_set = False
def get_next_frame(self): def get_next_frame(self):
@ -217,19 +210,18 @@ class MovieReader():
else: else:
return np.zeros(self.shape) return np.zeros(self.shape)
#%%
if __name__ == "__main__": if __name__ == "__main__":
fps=2 fps = 2
list_fp_movies = [] list_fp_movies = []
for k in range(4): for k in range(4):
fp_movie = f"/tmp/my_random_movie_{k}.mp4" fp_movie = f"/tmp/my_random_movie_{k}.mp4"
list_fp_movies.append(fp_movie) list_fp_movies.append(fp_movie)
ms = MovieSaver(fp_movie, fps=fps) ms = MovieSaver(fp_movie, fps=fps)
for fn in tqdm(range(30)): for fn in tqdm(range(30)):
img = (np.random.rand(512, 1024, 3)*255).astype(np.uint8) img = (np.random.rand(512, 1024, 3) * 255).astype(np.uint8)
ms.write_frame(img) ms.write_frame(img)
ms.finalize() ms.finalize()
fp_final = "/tmp/my_concatenated_movie.mp4" fp_final = "/tmp/my_concatenated_movie.mp4"
concatenate_movies(fp_final, list_fp_movies) concatenate_movies(fp_final, list_fp_movies)

View File

@ -13,36 +13,25 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os, sys import os
dp_git = "/home/lugo/git/"
sys.path.append(os.path.join(dp_git,'garden4'))
sys.path.append('util')
import torch import torch
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.set_grad_enabled(False)
import numpy as np import numpy as np
import warnings import warnings
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
import time
import subprocess
import warnings import warnings
import torch import torch
from tqdm.auto import tqdm
from PIL import Image from PIL import Image
# import matplotlib.pyplot as plt
import torch import torch
from movie_util import MovieSaver from typing import Optional
import datetime
from typing import Callable, List, Optional, Union
import inspect
from threading import Thread
torch.set_grad_enabled(False)
from omegaconf import OmegaConf from omegaconf import OmegaConf
from torch import autocast from torch import autocast
from contextlib import nullcontext from contextlib import nullcontext
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from einops import repeat, rearrange from einops import repeat, rearrange
#%% from utils import interpolate_spherical
def pad_image(input_image): def pad_image(input_image):
@ -53,41 +42,11 @@ def pad_image(input_image):
return im_padded return im_padded
def make_batch_inpaint(
image,
mask,
txt,
device,
num_samples=1):
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
batch = {
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
"txt": num_samples * [txt],
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
}
return batch
def make_batch_superres( def make_batch_superres(
image, image,
txt, txt,
device, device,
num_samples=1, num_samples=1):
):
image = np.array(image.convert("RGB")) image = np.array(image.convert("RGB"))
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
batch = { batch = {
@ -114,7 +73,7 @@ class StableDiffusionHolder:
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
device: str = None, device: str = None,
precision: str='autocast', precision: str = 'autocast',
): ):
r""" r"""
Initializes the stable diffusion holder, which contains the models and sampler. Initializes the stable diffusion holder, which contains the models and sampler.
@ -137,7 +96,7 @@ class StableDiffusionHolder:
self.precision = precision self.precision = precision
self.init_model(fp_ckpt, fp_config) self.init_model(fp_ckpt, fp_config)
self.f = 8 #downsampling factor, most often 8 or 16", self.f = 8 # downsampling factor, most often 8 or 16"
self.C = 4 self.C = 4
self.ddim_eta = 0 self.ddim_eta = 0
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
@ -150,13 +109,8 @@ class StableDiffusionHolder:
self.height = height self.height = height
self.width = width self.width = width
# Inpainting inits
self.mask_empty = Image.fromarray(255*np.ones([self.width, self.height], dtype=np.uint8))
self.image_empty = Image.fromarray(np.zeros([self.width, self.height, 3], dtype=np.uint8))
self.negative_prompt = [""] self.negative_prompt = [""]
def init_model(self, fp_ckpt, fp_config): def init_model(self, fp_ckpt, fp_config):
r"""Loads the models and sampler. r"""Loads the models and sampler.
""" """
@ -169,13 +123,11 @@ class StableDiffusionHolder:
fn_ckpt = os.path.basename(fp_ckpt) fn_ckpt = os.path.basename(fp_ckpt)
if 'depth' in fn_ckpt: if 'depth' in fn_ckpt:
fp_config = 'configs/v2-midas-inference.yaml' fp_config = 'configs/v2-midas-inference.yaml'
elif 'inpain' in fn_ckpt:
fp_config = 'configs/v2-inpainting-inference.yaml'
elif 'upscaler' in fn_ckpt: elif 'upscaler' in fn_ckpt:
fp_config = 'configs/x4-upscaling.yaml' fp_config = 'configs/x4-upscaling.yaml'
elif '512' in fn_ckpt: elif '512' in fn_ckpt:
fp_config = 'configs/v2-inference.yaml' fp_config = 'configs/v2-inference.yaml'
elif '768'in fn_ckpt: elif '768' in fn_ckpt:
fp_config = 'configs/v2-inference-v.yaml' fp_config = 'configs/v2-inference-v.yaml'
elif 'v1-5' in fn_ckpt: elif 'v1-5' in fn_ckpt:
fp_config = 'configs/v1-inference.yaml' fp_config = 'configs/v1-inference.yaml'
@ -186,7 +138,6 @@ class StableDiffusionHolder:
assert os.path.isfile(fp_config), f"Your config file does not exist: {fp_config}" assert os.path.isfile(fp_config), f"Your config file does not exist: {fp_config}"
config = OmegaConf.load(fp_config) config = OmegaConf.load(fp_config)
self.model = instantiate_from_config(config.model) self.model = instantiate_from_config(config.model)
@ -195,7 +146,6 @@ class StableDiffusionHolder:
self.model = self.model.to(self.device) self.model = self.model.to(self.device)
self.sampler = DDIMSampler(self.model) self.sampler = DDIMSampler(self.model)
def init_auto_res(self): def init_auto_res(self):
r"""Automatically set the resolution to the one used in training. r"""Automatically set the resolution to the one used in training.
""" """
@ -218,7 +168,6 @@ class StableDiffusionHolder:
if len(self.negative_prompt) > 1: if len(self.negative_prompt) > 1:
self.negative_prompt = [self.negative_prompt[0]] self.negative_prompt = [self.negative_prompt[0]]
def get_text_embedding(self, prompt): def get_text_embedding(self, prompt):
c = self.model.get_learned_conditioning(prompt) c = self.model.get_learned_conditioning(prompt)
return c return c
@ -228,7 +177,6 @@ class StableDiffusionHolder:
r""" r"""
Initializes the conditioning for the x4 upscaling model. Initializes the conditioning for the x4 upscaling model.
""" """
image = pad_image(image) # resize to integer multiple of 32 image = pad_image(image) # resize to integer multiple of 32
w, h = image.size w, h = image.size
noise_level = torch.Tensor(1 * [noise_level]).to(self.sampler.model.device).long() noise_level = torch.Tensor(1 * [noise_level]).to(self.sampler.model.device).long()
@ -240,7 +188,6 @@ class StableDiffusionHolder:
# uncond cond # uncond cond
uc_cross = self.model.get_unconditional_conditioning(1, "") uc_cross = self.model.get_unconditional_conditioning(1, "")
uc_full = {"c_concat": [x_augment], "c_crossattn": [uc_cross], "c_adm": noise_level} uc_full = {"c_concat": [x_augment], "c_crossattn": [uc_cross], "c_adm": noise_level}
return cond, uc_full return cond, uc_full
@torch.no_grad() @torch.no_grad()
@ -249,14 +196,12 @@ class StableDiffusionHolder:
text_embeddings: torch.FloatTensor, text_embeddings: torch.FloatTensor,
latents_start: torch.FloatTensor, latents_start: torch.FloatTensor,
idx_start: int = 0, idx_start: int = 0,
list_latents_mixing = None, list_latents_mixing=None,
mixing_coeffs = 0.0, mixing_coeffs=0.0,
spatial_mask = None, spatial_mask=None,
return_image: Optional[bool] = False, return_image: Optional[bool] = False):
):
r""" r"""
Diffusion standard version. Diffusion standard version.
Args: Args:
text_embeddings: torch.FloatTensor text_embeddings: torch.FloatTensor
Text embeddings used for diffusion Text embeddings used for diffusion
@ -270,12 +215,10 @@ class StableDiffusionHolder:
experimental feature for enforcing pixels from list_latents_mixing experimental feature for enforcing pixels from list_latents_mixing
return_image: Optional[bool] return_image: Optional[bool]
Optionally return image directly Optionally return image directly
""" """
# Asserts # Asserts
if type(mixing_coeffs) == float: if type(mixing_coeffs) == float:
list_mixing_coeffs = self.num_inference_steps*[mixing_coeffs] list_mixing_coeffs = self.num_inference_steps * [mixing_coeffs]
elif type(mixing_coeffs) == list: elif type(mixing_coeffs) == list:
assert len(mixing_coeffs) == self.num_inference_steps assert len(mixing_coeffs) == self.num_inference_steps
list_mixing_coeffs = mixing_coeffs list_mixing_coeffs = mixing_coeffs
@ -285,26 +228,19 @@ class StableDiffusionHolder:
if np.sum(list_mixing_coeffs) > 0: if np.sum(list_mixing_coeffs) > 0:
assert len(list_latents_mixing) == self.num_inference_steps assert len(list_latents_mixing) == self.num_inference_steps
precision_scope = autocast if self.precision == "autocast" else nullcontext precision_scope = autocast if self.precision == "autocast" else nullcontext
with precision_scope("cuda"): with precision_scope("cuda"):
with self.model.ema_scope(): with self.model.ema_scope():
if self.guidance_scale != 1.0: if self.guidance_scale != 1.0:
uc = self.model.get_learned_conditioning(self.negative_prompt) uc = self.model.get_learned_conditioning(self.negative_prompt)
else: else:
uc = None uc = None
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps - 1, ddim_eta=self.ddim_eta, verbose=False)
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
latents = latents_start.clone() latents = latents_start.clone()
timesteps = self.sampler.ddim_timesteps timesteps = self.sampler.ddim_timesteps
time_range = np.flip(timesteps) time_range = np.flip(timesteps)
total_steps = timesteps.shape[0] total_steps = timesteps.shape[0]
# Collect latents
# collect latents
list_latents_out = [] list_latents_out = []
for i, step in enumerate(time_range): for i, step in enumerate(time_range):
# Set the right starting latents # Set the right starting latents
@ -313,15 +249,13 @@ class StableDiffusionHolder:
continue continue
elif i == idx_start: elif i == idx_start:
latents = latents_start.clone() latents = latents_start.clone()
# Mix latents
# Mix the latents. if i > 0 and list_mixing_coeffs[i] > 0:
if i > 0 and list_mixing_coeffs[i]>0: latents_mixtarget = list_latents_mixing[i - 1].clone()
latents_mixtarget = list_latents_mixing[i-1].clone()
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i]) latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
if spatial_mask is not None and list_latents_mixing is not None: if spatial_mask is not None and list_latents_mixing is not None:
latents = interpolate_spherical(latents, list_latents_mixing[i-1], 1-spatial_mask) latents = interpolate_spherical(latents, list_latents_mixing[i - 1], 1 - spatial_mask)
# latents[:,:,-15:,:] = latents_mixtarget[:,:,-15:,:]
index = total_steps - i - 1 index = total_steps - i - 1
ts = torch.full((1,), step, device=self.device, dtype=torch.long) ts = torch.full((1,), step, device=self.device, dtype=torch.long)
@ -334,13 +268,11 @@ class StableDiffusionHolder:
dynamic_threshold=None) dynamic_threshold=None)
latents, pred_x0 = outs latents, pred_x0 = outs
list_latents_out.append(latents.clone()) list_latents_out.append(latents.clone())
if return_image: if return_image:
return self.latent2image(latents) return self.latent2image(latents)
else: else:
return list_latents_out return list_latents_out
@torch.no_grad() @torch.no_grad()
def run_diffusion_upscaling( def run_diffusion_upscaling(
self, self,
@ -348,17 +280,16 @@ class StableDiffusionHolder:
uc_full, uc_full,
latents_start: torch.FloatTensor, latents_start: torch.FloatTensor,
idx_start: int = -1, idx_start: int = -1,
list_latents_mixing = None, list_latents_mixing: list = None,
mixing_coeffs = 0.0, mixing_coeffs: float = 0.0,
return_image: Optional[bool] = False return_image: Optional[bool] = False):
):
r""" r"""
Diffusion upscaling version. Diffusion upscaling version.
""" """
# Asserts # Asserts
if type(mixing_coeffs) == float: if type(mixing_coeffs) == float:
list_mixing_coeffs = self.num_inference_steps*[mixing_coeffs] list_mixing_coeffs = self.num_inference_steps * [mixing_coeffs]
elif type(mixing_coeffs) == list: elif type(mixing_coeffs) == list:
assert len(mixing_coeffs) == self.num_inference_steps assert len(mixing_coeffs) == self.num_inference_steps
list_mixing_coeffs = mixing_coeffs list_mixing_coeffs = mixing_coeffs
@ -369,27 +300,20 @@ class StableDiffusionHolder:
assert len(list_latents_mixing) == self.num_inference_steps assert len(list_latents_mixing) == self.num_inference_steps
precision_scope = autocast if self.precision == "autocast" else nullcontext precision_scope = autocast if self.precision == "autocast" else nullcontext
h = uc_full['c_concat'][0].shape[2] h = uc_full['c_concat'][0].shape[2]
w = uc_full['c_concat'][0].shape[3] w = uc_full['c_concat'][0].shape[3]
with precision_scope("cuda"): with precision_scope("cuda"):
with self.model.ema_scope(): with self.model.ema_scope():
shape_latents = [self.model.channels, h, w] shape_latents = [self.model.channels, h, w]
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps - 1, ddim_eta=self.ddim_eta, verbose=False)
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
C, H, W = shape_latents C, H, W = shape_latents
size = (1, C, H, W) size = (1, C, H, W)
b = size[0] b = size[0]
latents = latents_start.clone() latents = latents_start.clone()
timesteps = self.sampler.ddim_timesteps timesteps = self.sampler.ddim_timesteps
time_range = np.flip(timesteps) time_range = np.flip(timesteps)
total_steps = timesteps.shape[0] total_steps = timesteps.shape[0]
# collect latents # collect latents
list_latents_out = [] list_latents_out = []
for i, step in enumerate(time_range): for i, step in enumerate(time_range):
@ -399,12 +323,10 @@ class StableDiffusionHolder:
continue continue
elif i == idx_start: elif i == idx_start:
latents = latents_start.clone() latents = latents_start.clone()
# Mix the latents. # Mix the latents.
if i > 0 and list_mixing_coeffs[i]>0: if i > 0 and list_mixing_coeffs[i] > 0:
latents_mixtarget = list_latents_mixing[i-1].clone() latents_mixtarget = list_latents_mixing[i - 1].clone()
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i]) latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
# print(f"diffusion iter {i}") # print(f"diffusion iter {i}")
index = total_steps - i - 1 index = total_steps - i - 1
ts = torch.full((b,), step, device=self.device, dtype=torch.long) ts = torch.full((b,), step, device=self.device, dtype=torch.long)
@ -423,121 +345,10 @@ class StableDiffusionHolder:
else: else:
return list_latents_out return list_latents_out
@torch.no_grad()
def run_diffusion_inpaint(
self,
text_embeddings: torch.FloatTensor,
latents_for_injection: torch.FloatTensor = None,
idx_start: int = -1,
idx_stop: int = -1,
return_image: Optional[bool] = False
):
r"""
Runs inpaint-based diffusion. Returns a list of latents that were computed.
Adaptations allow to supply
a) starting index for diffusion
b) stopping index for diffusion
c) latent representations that are injected at the starting index
Furthermore the intermittent latents are collected and returned.
Adapted from diffusers (https://github.com/huggingface/diffusers)
Args:
text_embeddings: torch.FloatTensor
Text embeddings used for diffusion
latents_for_injection: torch.FloatTensor
Latents that are used for injection
idx_start: int
Index of the diffusion process start and where the latents_for_injection are injected
idx_stop: int
Index of the diffusion process end.
return_image: Optional[bool]
Optionally return image directly
"""
if latents_for_injection is None:
do_inject_latents = False
else:
do_inject_latents = True
precision_scope = autocast if self.precision == "autocast" else nullcontext
generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
with precision_scope("cuda"):
with self.model.ema_scope():
batch = make_batch_inpaint(self.image_source, self.mask_image, txt="willbereplaced", device=self.device, num_samples=1)
c = text_embeddings
c_cat = list()
for ck in self.model.concat_keys:
cc = batch[ck].float()
if ck != self.model.masked_image_key:
bchw = [1, 4, self.height // 8, self.width // 8]
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else:
cc = self.model.get_first_stage_encoding(self.model.encode_first_stage(cc))
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
# cond
cond = {"c_concat": [c_cat], "c_crossattn": [c]}
# uncond cond
uc_cross = self.model.get_unconditional_conditioning(1, "")
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
shape_latents = [self.model.channels, self.height // 8, self.width // 8]
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=0., verbose=False)
# sampling
C, H, W = shape_latents
size = (1, C, H, W)
device = self.model.betas.device
b = size[0]
latents = torch.randn(size, generator=generator, device=device)
timesteps = self.sampler.ddim_timesteps
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
# collect latents
list_latents_out = []
for i, step in enumerate(time_range):
if do_inject_latents:
# Inject latent at right place
if i < idx_start:
continue
elif i == idx_start:
latents = latents_for_injection.clone()
if i == idx_stop:
return list_latents_out
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
quantize_denoised=False, temperature=1.0,
noise_dropout=0.0, score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=self.guidance_scale,
unconditional_conditioning=uc_full,
dynamic_threshold=None)
latents, pred_x0 = outs
list_latents_out.append(latents.clone())
if return_image:
return self.latent2image(latents)
else:
return list_latents_out
@torch.no_grad() @torch.no_grad()
def latent2image( def latent2image(
self, self,
latents: torch.FloatTensor latents: torch.FloatTensor):
):
r""" r"""
Returns an image provided a latent representation from diffusion. Returns an image provided a latent representation from diffusion.
Args: Args:
@ -546,85 +357,6 @@ class StableDiffusionHolder:
""" """
x_sample = self.model.decode_first_stage(latents) x_sample = self.model.decode_first_stage(latents)
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255 * x_sample[0,:,:].permute([1,2,0]).cpu().numpy() x_sample = 255 * x_sample[0, :, :].permute([1, 2, 0]).cpu().numpy()
image = x_sample.astype(np.uint8) image = x_sample.astype(np.uint8)
return image return image
@torch.no_grad()
def interpolate_spherical(p0, p1, fract_mixing: float):
r"""
Helper function to correctly mix two random variables using spherical interpolation.
See https://en.wikipedia.org/wiki/Slerp
The function will always cast up to float64 for sake of extra 4.
Args:
p0:
First tensor for interpolation
p1:
Second tensor for interpolation
fract_mixing: float
Mixing coefficient of interval [0, 1].
0 will return in p0
1 will return in p1
0.x will return a mix between both preserving angular velocity.
"""
if p0.dtype == torch.float16:
recast_to = 'fp16'
else:
recast_to = 'fp32'
p0 = p0.double()
p1 = p1.double()
norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
epsilon = 1e-7
dot = torch.sum(p0 * p1) / norm
dot = dot.clamp(-1+epsilon, 1-epsilon)
theta_0 = torch.arccos(dot)
sin_theta_0 = torch.sin(theta_0)
theta_t = theta_0 * fract_mixing
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
s1 = torch.sin(theta_t) / sin_theta_0
interp = p0*s0 + p1*s1
if recast_to == 'fp16':
interp = interp.half()
elif recast_to == 'fp32':
interp = interp.float()
return interp
if __name__ == "__main__":
num_inference_steps = 20 # Number of diffusion interations
# fp_ckpt = "../stable_diffusion_models/ckpt/768-v-ema.ckpt"
# fp_config = '../stablediffusion/configs/stable-diffusion/v2-inference-v.yaml'
# fp_ckpt= "../stable_diffusion_models/ckpt/512-inpainting-ema.ckpt"
# fp_config = '../stablediffusion/configs//stable-diffusion/v2-inpainting-inference.yaml'
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
# fp_config = 'configs/v2-inference-v.yaml'
self = StableDiffusionHolder(fp_ckpt, num_inference_steps=num_inference_steps)
xxx
#%%
self.width = 1536
self.height = 768
prompt = "360 degree equirectangular, a huge rocky hill full of pianos and keyboards, musical instruments, cinematic, masterpiece 8 k, artstation"
self.set_negative_prompt("out of frame, faces, rendering, blurry")
te = self.get_text_embedding(prompt)
img = self.run_diffusion_standard(te, return_image=True)
Image.fromarray(img).show()

260
utils.py Normal file
View File

@ -0,0 +1,260 @@
# Copyright 2022 Lunar Ring. All rights reserved.
# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
torch.backends.cudnn.benchmark = False
import numpy as np
import warnings
warnings.filterwarnings('ignore')
import time
import warnings
import datetime
from typing import List, Union
torch.set_grad_enabled(False)
import yaml
@torch.no_grad()
def interpolate_spherical(p0, p1, fract_mixing: float):
r"""
Helper function to correctly mix two random variables using spherical interpolation.
See https://en.wikipedia.org/wiki/Slerp
The function will always cast up to float64 for sake of extra 4.
Args:
p0:
First tensor for interpolation
p1:
Second tensor for interpolation
fract_mixing: float
Mixing coefficient of interval [0, 1].
0 will return in p0
1 will return in p1
0.x will return a mix between both preserving angular velocity.
"""
if p0.dtype == torch.float16:
recast_to = 'fp16'
else:
recast_to = 'fp32'
p0 = p0.double()
p1 = p1.double()
norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
epsilon = 1e-7
dot = torch.sum(p0 * p1) / norm
dot = dot.clamp(-1 + epsilon, 1 - epsilon)
theta_0 = torch.arccos(dot)
sin_theta_0 = torch.sin(theta_0)
theta_t = theta_0 * fract_mixing
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
s1 = torch.sin(theta_t) / sin_theta_0
interp = p0 * s0 + p1 * s1
if recast_to == 'fp16':
interp = interp.half()
elif recast_to == 'fp32':
interp = interp.float()
return interp
def interpolate_linear(p0, p1, fract_mixing):
r"""
Helper function to mix two variables using standard linear interpolation.
Args:
p0:
First tensor / np.ndarray for interpolation
p1:
Second tensor / np.ndarray for interpolation
fract_mixing: float
Mixing coefficient of interval [0, 1].
0 will return in p0
1 will return in p1
0.x will return a linear mix between both.
"""
reconvert_uint8 = False
if type(p0) is np.ndarray and p0.dtype == 'uint8':
reconvert_uint8 = True
p0 = p0.astype(np.float64)
if type(p1) is np.ndarray and p1.dtype == 'uint8':
reconvert_uint8 = True
p1 = p1.astype(np.float64)
interp = (1 - fract_mixing) * p0 + fract_mixing * p1
if reconvert_uint8:
interp = np.clip(interp, 0, 255).astype(np.uint8)
return interp
def add_frames_linear_interp(
list_imgs: List[np.ndarray],
fps_target: Union[float, int] = None,
duration_target: Union[float, int] = None,
nmb_frames_target: int = None):
r"""
Helper function to cheaply increase the number of frames given a list of images,
by virtue of standard linear interpolation.
The number of inserted frames will be automatically adjusted so that the total of number
of frames can be fixed precisely, using a random shuffling technique.
The function allows 1:1 comparisons between transitions as videos.
Args:
list_imgs: List[np.ndarray)
List of images, between each image new frames will be inserted via linear interpolation.
fps_target:
OptionA: specify here the desired frames per second.
duration_target:
OptionA: specify here the desired duration of the transition in seconds.
nmb_frames_target:
OptionB: directly fix the total number of frames of the output.
"""
# Sanity
if nmb_frames_target is not None and fps_target is not None:
raise ValueError("You cannot specify both fps_target and nmb_frames_target")
if fps_target is None:
assert nmb_frames_target is not None, "Either specify nmb_frames_target or nmb_frames_target"
if nmb_frames_target is None:
assert fps_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
assert duration_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
nmb_frames_target = fps_target * duration_target
# Get number of frames that are missing
nmb_frames_diff = len(list_imgs) - 1
nmb_frames_missing = nmb_frames_target - nmb_frames_diff - 1
if nmb_frames_missing < 1:
return list_imgs
list_imgs_float = [img.astype(np.float32) for img in list_imgs]
# Distribute missing frames, append nmb_frames_to_insert(i) frames for each frame
mean_nmb_frames_insert = nmb_frames_missing / nmb_frames_diff
constfact = np.floor(mean_nmb_frames_insert)
remainder_x = 1 - (mean_nmb_frames_insert - constfact)
nmb_iter = 0
while True:
nmb_frames_to_insert = np.random.rand(nmb_frames_diff)
nmb_frames_to_insert[nmb_frames_to_insert <= remainder_x] = 0
nmb_frames_to_insert[nmb_frames_to_insert > remainder_x] = 1
nmb_frames_to_insert += constfact
if np.sum(nmb_frames_to_insert) == nmb_frames_missing:
break
nmb_iter += 1
if nmb_iter > 100000:
print("add_frames_linear_interp: issue with inserting the right number of frames")
break
nmb_frames_to_insert = nmb_frames_to_insert.astype(np.int32)
list_imgs_interp = []
for i in range(len(list_imgs_float) - 1):
img0 = list_imgs_float[i]
img1 = list_imgs_float[i + 1]
list_imgs_interp.append(img0.astype(np.uint8))
list_fracts_linblend = np.linspace(0, 1, nmb_frames_to_insert[i] + 2)[1:-1]
for fract_linblend in list_fracts_linblend:
img_blend = interpolate_linear(img0, img1, fract_linblend).astype(np.uint8)
list_imgs_interp.append(img_blend.astype(np.uint8))
if i == len(list_imgs_float) - 2:
list_imgs_interp.append(img1.astype(np.uint8))
return list_imgs_interp
def get_spacing(nmb_points: int, scaling: float):
"""
Helper function for getting nonlinear spacing between 0 and 1, symmetric around 0.5
Args:
nmb_points: int
Number of points between [0, 1]
scaling: float
Higher values will return higher sampling density around 0.5
"""
if scaling < 1.7:
return np.linspace(0, 1, nmb_points)
nmb_points_per_side = nmb_points // 2 + 1
if np.mod(nmb_points, 2) != 0: # Uneven case
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)
right_side = 1 - left_side[::-1][1:]
else:
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)[0:-1]
right_side = 1 - left_side[::-1]
all_fracts = np.hstack([left_side, right_side])
return all_fracts
def get_time(resolution=None):
"""
Helper function returning an nicely formatted time string, e.g. 221117_1620
"""
if resolution is None:
resolution = "second"
if resolution == "day":
t = time.strftime('%y%m%d', time.localtime())
elif resolution == "minute":
t = time.strftime('%y%m%d_%H%M', time.localtime())
elif resolution == "second":
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
elif resolution == "millisecond":
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
t += "_"
t += str("{:03d}".format(int(int(datetime.utcnow().strftime('%f')) / 1000)))
else:
raise ValueError("bad resolution provided: %s" % resolution)
return t
def compare_dicts(a, b):
"""
Compares two dictionaries a and b and returns a dictionary c, with all
keys,values that have shared keys in a and b but same values in a and b.
The values of a and b are stacked together in the output.
Example:
a = {}; a['bobo'] = 4
b = {}; b['bobo'] = 5
c = dict_compare(a,b)
c = {"bobo",[4,5]}
"""
c = {}
for key in a.keys():
if key in b.keys():
val_a = a[key]
val_b = b[key]
if val_a != val_b:
c[key] = [val_a, val_b]
return c
def yml_load(fp_yml, print_fields=False):
"""
Helper function for loading yaml files
"""
with open(fp_yml) as f:
data = yaml.load(f, Loader=yaml.loader.SafeLoader)
dict_data = dict(data)
print("load: loaded {}".format(fp_yml))
return dict_data
def yml_save(fp_yml, dict_stuff):
"""
Helper function for saving yaml files
"""
with open(fp_yml, 'w') as f:
yaml.dump(dict_stuff, f, sort_keys=False, default_flow_style=False)
print("yml_save: saved {}".format(fp_yml))