This commit is contained in:
Johannes Stelzer 2023-02-22 10:15:03 +01:00
parent 3ed876e0ee
commit 297bb9abe6
9 changed files with 906 additions and 1322 deletions

View File

@ -13,39 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os, sys
import torch
torch.backends.cudnn.benchmark = False
import numpy as np
torch.set_grad_enabled(False)
import warnings
warnings.filterwarnings('ignore')
import warnings
import torch
from tqdm.auto import tqdm
from PIL import Image
# import matplotlib.pyplot as plt
import torch
from movie_util import MovieSaver
from typing import Callable, List, Optional, Union
from latent_blending import LatentBlending, add_frames_linear_interp
from latent_blending import LatentBlending
from stable_diffusion_holder import StableDiffusionHolder
torch.set_grad_enabled(False)
from huggingface_hub import hf_hub_download
#%% First let us spawn a stable diffusion holder
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
# %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
# fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.ckpt")
sdh = StableDiffusionHolder(fp_ckpt)
#%% Next let's set up all parameters
depth_strength = 0.65 # Specifies how deep (in terms of diffusion iterations the first branching happens)
t_compute_max_allowed = 15 # Determines the quality of the transition in terms of compute time you grant it
# %% Next let's set up all parameters
depth_strength = 0.65 # Specifies how deep (in terms of diffusion iterations the first branching happens)
t_compute_max_allowed = 15 # Determines the quality of the transition in terms of compute time you grant it
fixed_seeds = [69731932, 504430820]
prompt1 = "photo of a beautiful cherry forest covered in white flowers, ambient light, very detailed, magic"
prompt2 = "photo of an golden statue with a funny hat, surrounded by ferns and vines, grainy analog photograph, mystical ambience, incredible detail"
fp_movie = 'movie_example1.mp4'
duration_transition = 12 # In seconds
duration_transition = 12 # In seconds
# Spawn latent blending
lb = LatentBlending(sdh)
@ -54,10 +46,9 @@ lb.set_prompt2(prompt2)
# Run latent blending
lb.run_transition(
depth_strength = depth_strength,
t_compute_max_allowed = t_compute_max_allowed,
fixed_seeds = fixed_seeds
)
depth_strength=depth_strength,
t_compute_max_allowed=t_compute_max_allowed,
fixed_seeds=fixed_seeds)
# Save movie
lb.write_movie_transition(fp_movie, duration_transition)
lb.write_movie_transition(fp_movie, duration_transition)

View File

@ -13,33 +13,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os, sys
import torch
torch.backends.cudnn.benchmark = False
import numpy as np
torch.set_grad_enabled(False)
import warnings
warnings.filterwarnings('ignore')
import warnings
import torch
from tqdm.auto import tqdm
from PIL import Image
import torch
from movie_util import MovieSaver, concatenate_movies
from typing import Callable, List, Optional, Union
from latent_blending import LatentBlending, add_frames_linear_interp
from latent_blending import LatentBlending
from stable_diffusion_holder import StableDiffusionHolder
torch.set_grad_enabled(False)
from movie_util import concatenate_movies
from huggingface_hub import hf_hub_download
#%% First let us spawn a stable diffusion holder
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
# fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
# %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
# fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.ckpt")
sdh = StableDiffusionHolder(fp_ckpt)
#%% Let's setup the multi transition
# %% Let's setup the multi transition
fps = 30
duration_single_trans = 6
depth_strength = 0.55 #Specifies how deep (in terms of diffusion iterations the first branching happens)
depth_strength = 0.55 # Specifies how deep (in terms of diffusion iterations the first branching happens)
# Specify a list of prompts below
list_prompts = []
@ -52,36 +45,33 @@ list_prompts.append("statue of an ancient cybernetic messenger annoucing good ne
# You can optionally specify the seeds
list_seeds = [954375479, 332539350, 956051013, 408831845, 250009012, 675588737]
t_compute_max_allowed = 12 # per segment
t_compute_max_allowed = 12 # per segment
fp_movie = 'movie_example2.mp4'
lb = LatentBlending(sdh)
list_movie_parts = [] #
for i in range(len(list_prompts)-1):
list_movie_parts = []
for i in range(len(list_prompts) - 1):
# For a multi transition we can save some computation time and recycle the latents
if i==0:
if i == 0:
lb.set_prompt1(list_prompts[i])
lb.set_prompt2(list_prompts[i+1])
lb.set_prompt2(list_prompts[i + 1])
recycle_img1 = False
else:
lb.swap_forward()
lb.set_prompt2(list_prompts[i+1])
recycle_img1 = True
lb.set_prompt2(list_prompts[i + 1])
recycle_img1 = True
fp_movie_part = f"tmp_part_{str(i).zfill(3)}.mp4"
fixed_seeds = list_seeds[i:i+2]
fixed_seeds = list_seeds[i:i + 2]
# Run latent blending
lb.run_transition(
depth_strength = depth_strength,
t_compute_max_allowed = t_compute_max_allowed,
fixed_seeds = fixed_seeds
)
depth_strength=depth_strength,
t_compute_max_allowed=t_compute_max_allowed,
fixed_seeds=fixed_seeds)
# Save movie
lb.write_movie_transition(fp_movie_part, duration_single_trans)
list_movie_parts.append(fp_movie_part)
# Finally, concatente the result
concatenate_movies(fp_movie, list_movie_parts)
concatenate_movies(fp_movie, list_movie_parts)

View File

@ -13,25 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os, sys
import torch
torch.backends.cudnn.benchmark = False
import numpy as np
torch.set_grad_enabled(False)
import warnings
warnings.filterwarnings('ignore')
import warnings
import torch
from tqdm.auto import tqdm
from PIL import Image
# import matplotlib.pyplot as plt
import torch
from movie_util import MovieSaver
from typing import Callable, List, Optional, Union
from latent_blending import LatentBlending, add_frames_linear_interp
from latent_blending import LatentBlending
from stable_diffusion_holder import StableDiffusionHolder
torch.set_grad_enabled(False)
from huggingface_hub import hf_hub_download
#%% Define vars for low-resoltion pass
# %% Define vars for low-resoltion pass
prompt1 = "photo of mount vesuvius erupting a terrifying pyroclastic ash cloud"
prompt2 = "photo of a inside a building full of ash, fire, death, destruction, explosions"
fixed_seeds = [5054613, 1168652]
@ -41,21 +33,18 @@ height = 384
num_inference_steps_lores = 40
nmb_max_branches_lores = 10
depth_strength_lores = 0.5
fp_ckpt_lores = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
fp_ckpt_lores = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
#%% Define vars for high-resoltion pass
fp_ckpt_hires = "../stable_diffusion_models/ckpt/x4-upscaler-ema.ckpt"
# %% Define vars for high-resoltion pass
fp_ckpt_hires = hf_hub_download(repo_id="stabilityai/stable-diffusion-x4-upscaler", filename="x4-upscaler-ema.ckpt")
depth_strength_hires = 0.65
num_inference_steps_hires = 100
nmb_branches_final_hires = 6
dp_imgs = "tmp_transition" # folder for results and intermediate steps
dp_imgs = "tmp_transition" # Folder for results and intermediate steps
#%% Run low-res pass
# %% Run low-res pass
sdh = StableDiffusionHolder(fp_ckpt_lores)
#%%
lb = LatentBlending(sdh)
lb.set_prompt1(prompt1)
lb.set_prompt2(prompt2)
@ -64,14 +53,13 @@ lb.set_height(height)
# Run latent blending
lb.run_transition(
depth_strength = depth_strength_lores,
nmb_max_branches = nmb_max_branches_lores,
fixed_seeds = fixed_seeds
)
depth_strength=depth_strength_lores,
nmb_max_branches=nmb_max_branches_lores,
fixed_seeds=fixed_seeds)
lb.write_imgs_transition(dp_imgs)
#%% Run high-res pass
# %% Run high-res pass
sdh = StableDiffusionHolder(fp_ckpt_hires)
lb = LatentBlending(sdh)
lb = LatentBlending(sdh)
lb.run_upscaling(dp_imgs, depth_strength_hires, num_inference_steps_hires, nmb_branches_final_hires)

View File

@ -13,25 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os, sys
import os
import torch
torch.backends.cudnn.benchmark = False
import numpy as np
torch.set_grad_enabled(False)
import warnings
warnings.filterwarnings('ignore')
import warnings
import torch
from tqdm.auto import tqdm
from PIL import Image
# import matplotlib.pyplot as plt
import torch
from movie_util import MovieSaver, concatenate_movies
from typing import Callable, List, Optional, Union
from latent_blending import LatentBlending, add_frames_linear_interp
from latent_blending import LatentBlending
from stable_diffusion_holder import StableDiffusionHolder
torch.set_grad_enabled(False)
from movie_util import concatenate_movies
from huggingface_hub import hf_hub_download
#%% Define vars for low-resoltion pass
# %% Define vars for low-resoltion pass
list_prompts = []
list_prompts.append("surrealistic statue made of glitter and dirt, standing in a lake, atmospheric light, strange glow")
list_prompts.append("statue of a mix between a tree and human, made of marble, incredibly detailed")
@ -50,61 +44,59 @@ num_inference_steps_lores = 40
nmb_max_branches_lores = 10
depth_strength_lores = 0.5
fp_ckpt_lores = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
fp_ckpt_lores = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
#%% Define vars for high-resoltion pass
fp_ckpt_hires = "../stable_diffusion_models/ckpt/x4-upscaler-ema.ckpt"
# %% Define vars for high-resoltion pass
fp_ckpt_hires = hf_hub_download(repo_id="stabilityai/stable-diffusion-x4-upscaler", filename="x4-upscaler-ema.ckpt")
depth_strength_hires = 0.65
num_inference_steps_hires = 100
nmb_branches_final_hires = 6
#%% Run low-res pass
# %% Run low-res pass
sdh = StableDiffusionHolder(fp_ckpt_lores)
t_compute_max_allowed = 12 # per segment
t_compute_max_allowed = 12 # Per segment
lb = LatentBlending(sdh)
list_movie_dirs = [] #
for i in range(len(list_prompts)-1):
list_movie_dirs = []
for i in range(len(list_prompts) - 1):
# For a multi transition we can save some computation time and recycle the latents
if i==0:
if i == 0:
lb.set_prompt1(list_prompts[i])
lb.set_prompt2(list_prompts[i+1])
lb.set_prompt2(list_prompts[i + 1])
recycle_img1 = False
else:
lb.swap_forward()
lb.set_prompt2(list_prompts[i+1])
recycle_img1 = True
lb.set_prompt2(list_prompts[i + 1])
recycle_img1 = True
dp_movie_part = f"tmp_part_{str(i).zfill(3)}"
fp_movie_part = os.path.join(dp_movie_part, "movie_lowres.mp4")
os.makedirs(dp_movie_part, exist_ok=True)
fixed_seeds = list_seeds[i:i+2]
fixed_seeds = list_seeds[i:i + 2]
# Run latent blending
lb.run_transition(
depth_strength = depth_strength_lores,
nmb_max_branches = nmb_max_branches_lores,
fixed_seeds = fixed_seeds
)
depth_strength=depth_strength_lores,
nmb_max_branches=nmb_max_branches_lores,
fixed_seeds=fixed_seeds)
# Save movie and images (needed for upscaling!)
lb.write_movie_transition(fp_movie_part, duration_single_trans)
lb.write_imgs_transition(dp_movie_part)
list_movie_dirs.append(dp_movie_part)
#%% Run high-res pass on each segment
# %% Run high-res pass on each segment
sdh = StableDiffusionHolder(fp_ckpt_hires)
lb = LatentBlending(sdh)
lb = LatentBlending(sdh)
for dp_part in list_movie_dirs:
lb.run_upscaling(dp_part, depth_strength_hires, num_inference_steps_hires, nmb_branches_final_hires)
#%% concatenate into one long movie
# %% concatenate into one long movie
list_fp_movies = []
for dp_part in list_movie_dirs:
fp_movie = os.path.join(dp_part, "movie_highres.mp4")
assert os.path.isfile(fp_movie)
list_fp_movies.append(fp_movie)
fp_final = "example4.mp4"
concatenate_movies(fp_final, list_fp_movies)
concatenate_movies(fp_final, list_fp_movies)

View File

@ -13,83 +13,90 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os, sys
import os
import torch
torch.backends.cudnn.benchmark = False
torch.set_grad_enabled(False)
import numpy as np
import warnings
warnings.filterwarnings('ignore')
import warnings
import torch
from tqdm.auto import tqdm
from PIL import Image
import torch
from movie_util import MovieSaver, concatenate_movies
from typing import Callable, List, Optional, Union
from latent_blending import get_time, yml_save, LatentBlending, add_frames_linear_interp, compare_dicts
from latent_blending import LatentBlending
from stable_diffusion_holder import StableDiffusionHolder
torch.set_grad_enabled(False)
import gradio as gr
import copy
from dotenv import find_dotenv, load_dotenv
import shutil
import random
import time
from utils import get_time, add_frames_linear_interp
from huggingface_hub import hf_hub_download
#%%
class BlendingFrontend():
def __init__(self, sdh=None):
def __init__(
self,
sdh,
share=False):
r"""
Gradio Helper Class to collect UI data and start latent blending.
Args:
sdh:
StableDiffusionHolder
share: bool
Set true to get a shareable gradio link (e.g. for running a remote server)
"""
self.share = share
# UI Defaults
self.num_inference_steps = 30
if sdh is None:
self.use_debug = True
self.height = 768
self.width = 768
else:
self.use_debug = False
self.lb = LatentBlending(sdh)
self.lb.sdh.num_inference_steps = self.num_inference_steps
self.height = self.lb.sdh.height
self.width = self.lb.sdh.width
self.init_save_dir()
self.save_empty_image()
self.share = False
self.transition_can_be_computed = False
self.depth_strength = 0.25
self.seed1 = 420
self.seed2 = 420
self.guidance_scale = 4.0
self.guidance_scale_mid_damper = 0.5
self.mid_compression_scaler = 1.2
self.prompt1 = ""
self.prompt2 = ""
self.negative_prompt = ""
self.state_current = {}
self.fps = 30
self.duration_video = 8
self.t_compute_max_allowed = 10
self.lb = LatentBlending(sdh)
self.lb.sdh.num_inference_steps = self.num_inference_steps
self.init_parameters_from_lb()
self.init_save_dir()
# Vars
self.list_fp_imgs_current = []
self.recycle_img1 = False
self.recycle_img2 = False
self.list_all_segments = []
self.dp_session = ""
self.user_id = None
def init_parameters_from_lb(self):
r"""
Automatically init parameters from latentblending instance
"""
self.height = self.lb.sdh.height
self.width = self.lb.sdh.width
self.guidance_scale = self.lb.guidance_scale
self.guidance_scale_mid_damper = self.lb.guidance_scale_mid_damper
self.mid_compression_scaler = self.lb.mid_compression_scaler
self.branch1_crossfeed_power = self.lb.branch1_crossfeed_power
self.branch1_crossfeed_range = self.lb.branch1_crossfeed_range
self.branch1_crossfeed_decay = self.lb.branch1_crossfeed_decay
self.parental_crossfeed_power = self.lb.parental_crossfeed_power
self.parental_crossfeed_range = self.lb.parental_crossfeed_range
self.parental_crossfeed_power_decay = self.lb.parental_crossfeed_power_decay
self.fps = 30
self.duration_video = 10
self.t_compute_max_allowed = 10
self.list_fp_imgs_current = []
self.current_timestamp = None
self.recycle_img1 = False
self.recycle_img2 = False
self.multi_idx_current = -1
self.list_imgs_shown_last = 5*[self.fp_img_empty]
self.list_all_segments = []
self.dp_session = ""
self.user_id = None
self.block_transition = False
def init_save_dir(self):
load_dotenv(find_dotenv(), verbose=False)
r"""
Initializes the directory where stuff is being saved.
You can specify this directory in a ".env" file in your latentblending root, setting
DIR_OUT='/path/to/saving'
"""
load_dotenv(find_dotenv(), verbose=False)
self.dp_out = os.getenv("DIR_OUT")
if self.dp_out is None:
self.dp_out = ""
@ -97,151 +104,151 @@ class BlendingFrontend():
os.makedirs(self.dp_imgs, exist_ok=True)
self.dp_movies = os.path.join(self.dp_out, "movies")
os.makedirs(self.dp_movies, exist_ok=True)
# make dummy image
self.save_empty_image()
def save_empty_image(self):
r"""
Saves an empty/black dummy image.
"""
self.fp_img_empty = os.path.join(self.dp_imgs, 'empty.jpg')
Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5)
def randomize_seed1(self):
# Dont randomize seed if we are in a multi concat mode. we don't want to change this one otherwise the movie breaks
r"""
Randomizes the first seed
"""
seed = np.random.randint(0, 10000000)
self.seed1 = int(seed)
print(f"randomize_seed1: new seed = {self.seed1}")
return seed
def randomize_seed2(self):
r"""
Randomizes the second seed
"""
seed = np.random.randint(0, 10000000)
self.seed2 = int(seed)
print(f"randomize_seed2: new seed = {self.seed2}")
return seed
def setup_lb(self, list_ui_elem):
def setup_lb(self, list_ui_vals):
r"""
Sets all parameters from the UI. Since gradio does not support to pass dictionaries,
we have to instead pass keys (list_ui_keys, global) and values (list_ui_vals)
"""
# Collect latent blending variables
self.state_current = self.get_state_dict()
self.lb.set_width(list_ui_elem[list_ui_keys.index('width')])
self.lb.set_height(list_ui_elem[list_ui_keys.index('height')])
self.lb.set_prompt1(list_ui_elem[list_ui_keys.index('prompt1')])
self.lb.set_prompt2(list_ui_elem[list_ui_keys.index('prompt2')])
self.lb.set_negative_prompt(list_ui_elem[list_ui_keys.index('negative_prompt')])
self.lb.guidance_scale = list_ui_elem[list_ui_keys.index('guidance_scale')]
self.lb.guidance_scale_mid_damper = list_ui_elem[list_ui_keys.index('guidance_scale_mid_damper')]
self.t_compute_max_allowed = list_ui_elem[list_ui_keys.index('duration_compute')]
self.lb.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
self.lb.sdh.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
self.duration_video = list_ui_elem[list_ui_keys.index('duration_video')]
self.lb.seed1 = list_ui_elem[list_ui_keys.index('seed1')] #seed
self.lb.seed2 = list_ui_elem[list_ui_keys.index('seed2')]
self.lb.branch1_crossfeed_power = list_ui_elem[list_ui_keys.index('branch1_crossfeed_power')]
self.lb.branch1_crossfeed_range = list_ui_elem[list_ui_keys.index('branch1_crossfeed_range')]
self.lb.branch1_crossfeed_decay = list_ui_elem[list_ui_keys.index('branch1_crossfeed_decay')]
self.lb.parental_crossfeed_power = list_ui_elem[list_ui_keys.index('parental_crossfeed_power')]
self.lb.parental_crossfeed_range = list_ui_elem[list_ui_keys.index('parental_crossfeed_range')]
self.lb.parental_crossfeed_power_decay = list_ui_elem[list_ui_keys.index('parental_crossfeed_power_decay')]
self.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
self.depth_strength = list_ui_elem[list_ui_keys.index('depth_strength')]
if len(list_ui_elem[list_ui_keys.index('user_id')]) > 1:
self.user_id = list_ui_elem[list_ui_keys.index('user_id')]
self.lb.set_width(list_ui_vals[list_ui_keys.index('width')])
self.lb.set_height(list_ui_vals[list_ui_keys.index('height')])
self.lb.set_prompt1(list_ui_vals[list_ui_keys.index('prompt1')])
self.lb.set_prompt2(list_ui_vals[list_ui_keys.index('prompt2')])
self.lb.set_negative_prompt(list_ui_vals[list_ui_keys.index('negative_prompt')])
self.lb.guidance_scale = list_ui_vals[list_ui_keys.index('guidance_scale')]
self.lb.guidance_scale_mid_damper = list_ui_vals[list_ui_keys.index('guidance_scale_mid_damper')]
self.t_compute_max_allowed = list_ui_vals[list_ui_keys.index('duration_compute')]
self.lb.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
self.lb.sdh.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
self.duration_video = list_ui_vals[list_ui_keys.index('duration_video')]
self.lb.seed1 = list_ui_vals[list_ui_keys.index('seed1')]
self.lb.seed2 = list_ui_vals[list_ui_keys.index('seed2')]
self.lb.branch1_crossfeed_power = list_ui_vals[list_ui_keys.index('branch1_crossfeed_power')]
self.lb.branch1_crossfeed_range = list_ui_vals[list_ui_keys.index('branch1_crossfeed_range')]
self.lb.branch1_crossfeed_decay = list_ui_vals[list_ui_keys.index('branch1_crossfeed_decay')]
self.lb.parental_crossfeed_power = list_ui_vals[list_ui_keys.index('parental_crossfeed_power')]
self.lb.parental_crossfeed_range = list_ui_vals[list_ui_keys.index('parental_crossfeed_range')]
self.lb.parental_crossfeed_power_decay = list_ui_vals[list_ui_keys.index('parental_crossfeed_power_decay')]
self.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
self.depth_strength = list_ui_vals[list_ui_keys.index('depth_strength')]
if len(list_ui_vals[list_ui_keys.index('user_id')]) > 1:
self.user_id = list_ui_vals[list_ui_keys.index('user_id')]
else:
# generate new user id
self.user_id = ''.join((random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ') for i in range(8)))
print(f"made new user_id: {self.user_id}")
print(f"made new user_id: {self.user_id} at {get_time('second')}")
def save_latents(self, fp_latents, list_latents):
r"""
Saves a latent trajectory on disk, in npy format.
"""
list_latents_cpu = [l.cpu().numpy() for l in list_latents]
np.save(fp_latents, list_latents_cpu)
def load_latents(self, fp_latents):
r"""
Loads a latent trajectory from disk, converts to torch tensor.
"""
list_latents_cpu = np.load(fp_latents)
list_latents = [torch.from_numpy(l).to(self.lb.device) for l in list_latents_cpu]
return list_latents
def compute_img1(self, *args):
list_ui_elem = args
self.setup_lb(list_ui_elem)
r"""
Computes the first transition image and returns it for display.
Sets all other transition images and last image to empty (as they are obsolete with this operation)
"""
list_ui_vals = args
self.setup_lb(list_ui_vals)
fp_img1 = os.path.join(self.dp_imgs, f"img1_{self.user_id}")
img1 = Image.fromarray(self.lb.compute_latents1(return_image=True))
img1.save(fp_img1+".jpg")
self.save_latents(fp_img1+".npy", self.lb.tree_latents[0])
img1.save(fp_img1 + ".jpg")
self.save_latents(fp_img1 + ".npy", self.lb.tree_latents[0])
self.recycle_img1 = True
self.recycle_img2 = False
# fixme save seeds. change filenames?
return [fp_img1+".jpg", self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
return [fp_img1 + ".jpg", self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
def compute_img2(self, *args):
if not os.path.isfile(os.path.join(self.dp_imgs, f"img1_{self.user_id}.jpg")): # don't do anything
r"""
Computes the last transition image and returns it for display.
Sets all other transition images to empty (as they are obsolete with this operation)
"""
if not os.path.isfile(os.path.join(self.dp_imgs, f"img1_{self.user_id}.jpg")): # don't do anything
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
list_ui_elem = args
self.setup_lb(list_ui_elem)
list_ui_vals = args
self.setup_lb(list_ui_vals)
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
fp_img2 = os.path.join(self.dp_imgs, f"img2_{self.user_id}")
img2 = Image.fromarray(self.lb.compute_latents2(return_image=True))
img2.save(fp_img2+'.jpg')
self.save_latents(fp_img2+".npy", self.lb.tree_latents[-1])
img2.save(fp_img2 + '.jpg')
self.save_latents(fp_img2 + ".npy", self.lb.tree_latents[-1])
self.recycle_img2 = True
self.transition_can_be_computed = True
# fixme save seeds. change filenames?
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2+".jpg", self.user_id]
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2 + ".jpg", self.user_id]
def compute_transition(self, *args):
if not self.transition_can_be_computed:
list_return = [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
return list_return
list_ui_elem = args
self.setup_lb(list_ui_elem)
r"""
Computes transition images and movie.
"""
list_ui_vals = args
self.setup_lb(list_ui_vals)
print("STARTING TRANSITION...")
fixed_seeds = [self.seed1, self.seed2]
# Run Latent Blending
# Check if another user is blocking this... otherwise everything will become mixed.
# t_now = time.time()
# if self.block_transition:
# while True:
# time.sleep(1)
# if not self.block_transition:
# break
# if time.time() - t_now > 1000:
# return
self.block_transition = True
# Inject loaded latents (other user interference)
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
self.lb.tree_latents[-1] = self.load_latents(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"))
imgs_transition = self.lb.run_transition(
recycle_img1=self.recycle_img1,
recycle_img2=self.recycle_img2,
num_inference_steps=self.num_inference_steps,
depth_strength=self.depth_strength,
recycle_img1=self.recycle_img1,
recycle_img2=self.recycle_img2,
num_inference_steps=self.num_inference_steps,
depth_strength=self.depth_strength,
t_compute_max_allowed=self.t_compute_max_allowed,
fixed_seeds=fixed_seeds
)
print(f"Latent Blending pass finished. Resulted in {len(imgs_transition)} images")
fixed_seeds=fixed_seeds)
print(f"Latent Blending pass finished ({get_time('second')}). Resulted in {len(imgs_transition)} images")
# Subselect three preview images
idx_img_prev = np.round(np.linspace(0, len(imgs_transition)-1, 5)[1:-1]).astype(np.int32)
idx_img_prev = np.round(np.linspace(0, len(imgs_transition) - 1, 5)[1:-1]).astype(np.int32)
list_imgs_preview = []
for j in idx_img_prev:
list_imgs_preview.append(Image.fromarray(imgs_transition[j]))
# Save the preview imgs as jpgs on disk so we are not sending umcompressed data around
self.current_timestamp = get_time('second')
current_timestamp = get_time('second')
self.list_fp_imgs_current = []
for i in range(len(list_imgs_preview)):
fp_img = os.path.join(self.dp_imgs, f"img_preview_{i}_{self.current_timestamp}.jpg")
fp_img = os.path.join(self.dp_imgs, f"img_preview_{i}_{current_timestamp}.jpg")
list_imgs_preview[i].save(fp_img)
self.list_fp_imgs_current.append(fp_img)
self.block_transition = False
# Insert cheap frames for the movie
imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration_video, self.fps)
@ -254,44 +261,43 @@ class BlendingFrontend():
ms.write_frame(img)
ms.finalize()
print("DONE SAVING MOVIE! SENDING BACK...")
# Assemble Output, updating the preview images and le movie
list_return = self.list_fp_imgs_current + [self.fp_movie]
return list_return
def stack_forward(self, prompt2, seed2):
r"""
Allows to generate multi-segment movies. Sets last image -> first image with all
relevant parameters.
"""
# Save preview images, prompts and seeds into dictionary for stacking
if len(self.list_all_segments) == 0:
timestamp_session = get_time('second')
self.dp_session = os.path.join(self.dp_out, f"session_{timestamp_session}")
os.makedirs(self.dp_session)
self.transition_can_be_computed = False
idx_segment = len(self.list_all_segments)
idx_segment = len(self.list_all_segments)
dp_segment = os.path.join(self.dp_session, f"segment_{str(idx_segment).zfill(3)}")
self.list_all_segments.append(dp_segment)
self.lb.write_imgs_transition(dp_segment)
fp_movie_last = self.get_fp_video_last()
fp_movie_next = self.get_fp_video_next()
shutil.copyfile(fp_movie_last, fp_movie_next)
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
self.lb.tree_latents[-1] = self.load_latents(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"))
self.lb.swap_forward()
shutil.copyfile(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"), os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
fp_multi = self.multi_concat()
list_out = [fp_multi]
list_out.extend([os.path.join(self.dp_imgs, f"img2_{self.user_id}.jpg")])
list_out.extend([self.fp_img_empty]*4)
list_out.extend([self.fp_img_empty] * 4)
list_out.append(gr.update(interactive=False, value=prompt2))
list_out.append(gr.update(interactive=False, value=seed2))
list_out.append("")
@ -299,25 +305,31 @@ class BlendingFrontend():
print(f"stack_forward: fp_multi {fp_multi}")
return list_out
def multi_concat(self):
r"""
Concatentates all stacked segments into one long movie.
"""
list_fp_movies = self.get_fp_video_all()
# Concatenate movies and save
fp_final = os.path.join(self.dp_session, f"concat_{self.user_id}.mp4")
concatenate_movies(fp_final, list_fp_movies)
return fp_final
def get_fp_video_all(self):
r"""
Collects all stacked movie segments.
"""
list_all = os.listdir(self.dp_movies)
str_beg = f"movie_{self.user_id}_"
list_user = [l for l in list_all if str_beg in l]
list_user.sort()
list_user = [os.path.join(self.dp_movies, l) for l in list_user]
return list_user
def get_fp_video_next(self):
r"""
Gets the filepath of the next movie segment.
"""
list_videos = self.get_fp_video_all()
if len(list_videos) == 0:
idx_next = 0
@ -325,93 +337,81 @@ class BlendingFrontend():
idx_next = len(list_videos)
fp_video_next = os.path.join(self.dp_movies, f"movie_{self.user_id}_{str(idx_next).zfill(3)}.mp4")
return fp_video_next
def get_fp_video_last(self):
r"""
Gets the current video that was saved.
"""
fp_video_last = os.path.join(self.dp_movies, f"last_{self.user_id}.mp4")
return fp_video_last
def get_state_dict(self):
state_dict = {}
grab_vars = ['prompt1', 'prompt2', 'seed1', 'seed2', 'height', 'width',
'num_inference_steps', 'depth_strength', 'guidance_scale',
'guidance_scale_mid_damper', 'mid_compression_scaler']
for v in grab_vars:
state_dict[v] = getattr(self, v)
return state_dict
if __name__ == "__main__":
fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
# fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.ckpt")
bf = BlendingFrontend(StableDiffusionHolder(fp_ckpt))
# self = BlendingFrontend(None)
if __name__ == "__main__":
# fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
bf = BlendingFrontend(StableDiffusionHolder(fp_ckpt))
# self = BlendingFrontend(None)
with gr.Blocks() as demo:
with gr.Row():
prompt1 = gr.Textbox(label="prompt 1")
prompt2 = gr.Textbox(label="prompt 2")
with gr.Row():
duration_compute = gr.Slider(5, 200, bf.t_compute_max_allowed, step=1, label='compute budget', interactive=True)
duration_video = gr.Slider(1, 100, bf.duration_video, step=0.1, label='video duration', interactive=True)
duration_compute = gr.Slider(5, 200, bf.t_compute_max_allowed, step=1, label='compute budget', interactive=True)
duration_video = gr.Slider(1, 100, bf.duration_video, step=0.1, label='video duration', interactive=True)
height = gr.Slider(256, 2048, bf.height, step=128, label='height', interactive=True)
width = gr.Slider(256, 2048, bf.width, step=128, label='width', interactive=True)
width = gr.Slider(256, 2048, bf.width, step=128, label='width', interactive=True)
with gr.Accordion("Advanced Settings (click to expand)", open=False):
with gr.Accordion("Diffusion settings", open=True):
with gr.Row():
num_inference_steps = gr.Slider(5, 100, bf.num_inference_steps, step=1, label='num_inference_steps', interactive=True)
guidance_scale = gr.Slider(1, 25, bf.guidance_scale, step=0.1, label='guidance_scale', interactive=True)
negative_prompt = gr.Textbox(label="negative prompt")
guidance_scale = gr.Slider(1, 25, bf.guidance_scale, step=0.1, label='guidance_scale', interactive=True)
negative_prompt = gr.Textbox(label="negative prompt")
with gr.Accordion("Seed control: adjust seeds for first and last images", open=True):
with gr.Row():
b_newseed1 = gr.Button("randomize seed 1", variant='secondary')
seed1 = gr.Number(bf.seed1, label="seed 1", interactive=True)
seed2 = gr.Number(bf.seed2, label="seed 2", interactive=True)
b_newseed2 = gr.Button("randomize seed 2", variant='secondary')
with gr.Accordion("Last image crossfeeding.", open=True):
with gr.Row():
branch1_crossfeed_power = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_power, step=0.01, label='branch1 crossfeed power', interactive=True)
branch1_crossfeed_range = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_range, step=0.01, label='branch1 crossfeed range', interactive=True)
branch1_crossfeed_decay = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_decay, step=0.01, label='branch1 crossfeed decay', interactive=True)
branch1_crossfeed_power = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_power, step=0.01, label='branch1 crossfeed power', interactive=True)
branch1_crossfeed_range = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_range, step=0.01, label='branch1 crossfeed range', interactive=True)
branch1_crossfeed_decay = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_decay, step=0.01, label='branch1 crossfeed decay', interactive=True)
with gr.Accordion("Transition settings", open=True):
with gr.Row():
parental_crossfeed_power = gr.Slider(0.0, 1.0, bf.parental_crossfeed_power, step=0.01, label='parental crossfeed power', interactive=True)
parental_crossfeed_range = gr.Slider(0.0, 1.0, bf.parental_crossfeed_range, step=0.01, label='parental crossfeed range', interactive=True)
parental_crossfeed_power_decay = gr.Slider(0.0, 1.0, bf.parental_crossfeed_power_decay, step=0.01, label='parental crossfeed decay', interactive=True)
parental_crossfeed_power = gr.Slider(0.0, 1.0, bf.parental_crossfeed_power, step=0.01, label='parental crossfeed power', interactive=True)
parental_crossfeed_range = gr.Slider(0.0, 1.0, bf.parental_crossfeed_range, step=0.01, label='parental crossfeed range', interactive=True)
parental_crossfeed_power_decay = gr.Slider(0.0, 1.0, bf.parental_crossfeed_power_decay, step=0.01, label='parental crossfeed decay', interactive=True)
with gr.Row():
depth_strength = gr.Slider(0.01, 0.99, bf.depth_strength, step=0.01, label='depth_strength', interactive=True)
guidance_scale_mid_damper = gr.Slider(0.01, 2.0, bf.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
depth_strength = gr.Slider(0.01, 0.99, bf.depth_strength, step=0.01, label='depth_strength', interactive=True)
guidance_scale_mid_damper = gr.Slider(0.01, 2.0, bf.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
with gr.Row():
b_compute1 = gr.Button('compute first image', variant='primary')
b_compute_transition = gr.Button('compute transition', variant='primary')
b_compute2 = gr.Button('compute last image', variant='primary')
with gr.Row():
img1 = gr.Image(label="1/5")
img2 = gr.Image(label="2/5", show_progress=False)
img3 = gr.Image(label="3/5", show_progress=False)
img4 = gr.Image(label="4/5", show_progress=False)
img5 = gr.Image(label="5/5")
with gr.Row():
vid_single = gr.Video(label="single trans")
vid_multi = gr.Video(label="multi trans")
vid_single = gr.Video(label="current single trans")
vid_multi = gr.Video(label="concatented multi trans")
with gr.Row():
# b_restart = gr.Button("RESTART EVERYTHING")
b_stackforward = gr.Button('append last movie segment (left) to multi movie (right)', variant='primary')
with gr.Row():
gr.Markdown(
"""
@ -420,75 +420,73 @@ if __name__ == "__main__":
- compute budget: set your waiting time for the transition. high values = better quality
- video duration: seconds per segment
- height/width: in pixels
## Diffusion settings
- num_inference_steps: number of diffusion steps
- guidance_scale: latent blending seems to prefer lower values here
- negative prompt: enter negative prompt here, applied for all images
## Last image crossfeeding
- branch1_crossfeed_power: Controls the level of cross-feeding between the first and last image branch. For preserving structures.
- branch1_crossfeed_range: Sets the duration of active crossfeed during development. High values enforce strong structural similarity.
- branch1_crossfeed_decay: Sets decay for branch1_crossfeed_power. Lower values make the decay stronger across the range.
## Transition settings
- parental_crossfeed_power: Similar to branch1_crossfeed_power, however applied for the images withinin the transition.
- parental_crossfeed_range: Similar to branch1_crossfeed_range, however applied for the images withinin the transition.
- parental_crossfeed_power_decay: Similar to branch1_crossfeed_decay, however applied for the images withinin the transition.
- depth_strength: Determines when the blending process will begin in terms of diffusion steps. Low values more inventive but can cause motion.
- guidance_scale_mid_damper: Decreases the guidance scale in the middle of a transition.
"""
)
""")
with gr.Row():
user_id = gr.Textbox(label="user id", interactive=False)
# Collect all UI elemts in list to easily pass as inputs in gradio
dict_ui_elem = {}
dict_ui_elem["prompt1"] = prompt1
dict_ui_elem["negative_prompt"] = negative_prompt
dict_ui_elem["prompt2"] = prompt2
dict_ui_elem["duration_compute"] = duration_compute
dict_ui_elem["duration_video"] = duration_video
dict_ui_elem["height"] = height
dict_ui_elem["width"] = width
dict_ui_elem["depth_strength"] = depth_strength
dict_ui_elem["branch1_crossfeed_power"] = branch1_crossfeed_power
dict_ui_elem["branch1_crossfeed_range"] = branch1_crossfeed_range
dict_ui_elem["branch1_crossfeed_decay"] = branch1_crossfeed_decay
dict_ui_elem["num_inference_steps"] = num_inference_steps
dict_ui_elem["guidance_scale"] = guidance_scale
dict_ui_elem["guidance_scale_mid_damper"] = guidance_scale_mid_damper
dict_ui_elem["seed1"] = seed1
dict_ui_elem["seed2"] = seed2
dict_ui_elem["parental_crossfeed_range"] = parental_crossfeed_range
dict_ui_elem["parental_crossfeed_power"] = parental_crossfeed_power
dict_ui_elem["parental_crossfeed_power_decay"] = parental_crossfeed_power_decay
dict_ui_elem["user_id"] = user_id
# Convert to list, as gradio doesn't seem to accept dicts
list_ui_elem = []
list_ui_vals = []
list_ui_keys = []
for k in dict_ui_elem.keys():
list_ui_elem.append(dict_ui_elem[k])
list_ui_vals.append(dict_ui_elem[k])
list_ui_keys.append(k)
bf.list_ui_keys = list_ui_keys
b_newseed1.click(bf.randomize_seed1, outputs=seed1)
b_newseed2.click(bf.randomize_seed2, outputs=seed2)
b_compute1.click(bf.compute_img1, inputs=list_ui_elem, outputs=[img1, img2, img3, img4, img5, user_id])
b_compute2.click(bf.compute_img2, inputs=list_ui_elem, outputs=[img2, img3, img4, img5, user_id])
b_compute_transition.click(bf.compute_transition,
inputs=list_ui_elem,
outputs=[img2, img3, img4, vid_single])
b_stackforward.click(bf.stack_forward,
inputs=[prompt2, seed2],
outputs=[vid_multi, img1, img2, img3, img4, img5, prompt1, seed1, prompt2])
b_compute1.click(bf.compute_img1, inputs=list_ui_vals, outputs=[img1, img2, img3, img4, img5, user_id])
b_compute2.click(bf.compute_img2, inputs=list_ui_vals, outputs=[img2, img3, img4, img5, user_id])
b_compute_transition.click(bf.compute_transition,
inputs=list_ui_vals,
outputs=[img2, img3, img4, vid_single])
b_stackforward.click(bf.stack_forward,
inputs=[prompt2, seed2],
outputs=[vid_multi, img1, img2, img3, img4, img5, prompt1, seed1, prompt2])
demo.launch(share=bf.share, inbrowser=True, inline=False)

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,6 @@
# Copyright 2022 Lunar Ring. All rights reserved.
#
# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
@ -17,26 +18,24 @@ import os
import numpy as np
from tqdm import tqdm
import cv2
from typing import Callable, List, Optional, Union
import ffmpeg # pip install ffmpeg-python. if error with broken pipe: conda update ffmpeg
from typing import List
import ffmpeg # pip install ffmpeg-python. if error with broken pipe: conda update ffmpeg
#%%
class MovieSaver():
def __init__(
self,
fp_out: str,
fps: int = 24,
self,
fp_out: str,
fps: int = 24,
shape_hw: List[int] = None,
crf: int = 24,
codec: str = 'libx264',
preset: str ='fast',
pix_fmt: str = 'yuv420p',
silent_ffmpeg: bool = True
):
preset: str = 'fast',
pix_fmt: str = 'yuv420p',
silent_ffmpeg: bool = True):
r"""
Initializes movie saver class - a human friendly ffmpeg wrapper.
After you init the class, you can dump numpy arrays x into moviesaver.write_frame(x).
After you init the class, you can dump numpy arrays x into moviesaver.write_frame(x).
Don't forget toi finalize movie file with moviesaver.finalize().
Args:
fp_out: str
@ -47,22 +46,22 @@ class MovieSaver():
Output shape, optional argument. Can be initialized automatically when first frame is written.
crf: int
ffmpeg doc: the range of the CRF scale is 051, where 0 is lossless
(for 8 bit only, for 10 bit use -qp 0), 23 is the default, and 51 is worst quality possible.
A lower value generally leads to higher quality, and a subjectively sane range is 1728.
Consider 17 or 18 to be visually lossless or nearly so;
it should look the same or nearly the same as the input but it isn't technically lossless.
The range is exponential, so increasing the CRF value +6 results in
roughly half the bitrate / file size, while -6 leads to roughly twice the bitrate.
(for 8 bit only, for 10 bit use -qp 0), 23 is the default, and 51 is worst quality possible.
A lower value generally leads to higher quality, and a subjectively sane range is 1728.
Consider 17 or 18 to be visually lossless or nearly so;
it should look the same or nearly the same as the input but it isn't technically lossless.
The range is exponential, so increasing the CRF value +6 results in
roughly half the bitrate / file size, while -6 leads to roughly twice the bitrate.
codec: int
Number of diffusion steps. Larger values will take more compute time.
preset: str
Choose between ultrafast, superfast, veryfast, faster, fast, medium, slow, slower, veryslow.
ffmpeg doc: A preset is a collection of options that will provide a certain encoding speed
to compression ratio. A slower preset will provide better compression
(compression is quality per filesize).
This means that, for example, if you target a certain file size or constant bit rate,
ffmpeg doc: A preset is a collection of options that will provide a certain encoding speed
to compression ratio. A slower preset will provide better compression
(compression is quality per filesize).
This means that, for example, if you target a certain file size or constant bit rate,
you will achieve better quality with a slower preset. Similarly, for constant quality encoding,
you will simply save bitrate by choosing a slower preset.
you will simply save bitrate by choosing a slower preset.
pix_fmt: str
Pixel format. Run 'ffmpeg -pix_fmts' in your shell to see all options.
silent_ffmpeg: bool
@ -70,7 +69,7 @@ class MovieSaver():
"""
if len(os.path.split(fp_out)[0]) > 0:
assert os.path.isdir(os.path.split(fp_out)[0]), "Directory does not exist!"
self.fp_out = fp_out
self.fps = fps
self.crf = crf
@ -78,10 +77,10 @@ class MovieSaver():
self.codec = codec
self.preset = preset
self.silent_ffmpeg = silent_ffmpeg
if os.path.isfile(fp_out):
os.remove(fp_out)
self.init_done = False
self.nmb_frames = 0
if shape_hw is None:
@ -91,11 +90,9 @@ class MovieSaver():
shape_hw.append(3)
self.shape_hw = shape_hw
self.initialize()
print(f"MovieSaver initialized. fps={fps} crf={crf} pix_fmt={pix_fmt} codec={codec} preset={preset}")
def initialize(self):
args = (
ffmpeg
@ -111,8 +108,7 @@ class MovieSaver():
self.init_done = True
self.shape_hw = tuple(self.shape_hw)
print(f"Initialization done. Movie shape: {self.shape_hw}")
def write_frame(self, out_frame: np.ndarray):
r"""
Function to dump a numpy array as frame of a movie.
@ -123,18 +119,17 @@ class MovieSaver():
Dim 1: x
Dim 2: RGB
"""
assert out_frame.dtype == np.uint8, "Convert to np.uint8 before"
assert len(out_frame.shape) == 3, "out_frame needs to be three dimensional, Y X C"
assert out_frame.shape[2] == 3, f"need three color channels, but you provided {out_frame.shape[2]}."
if not self.init_done:
self.shape_hw = out_frame.shape
self.initialize()
assert self.shape_hw == out_frame.shape, f"You cannot change the image size after init. Initialized with {self.shape_hw}, out_frame {out_frame.shape}"
# write frame
# write frame
self.ffmpg_process.stdin.write(
out_frame
.astype(np.uint8)
@ -142,8 +137,7 @@ class MovieSaver():
)
self.nmb_frames += 1
def finalize(self):
r"""
Call this function to finalize the movie. If you forget to call it your movie will be garbage.
@ -157,7 +151,6 @@ class MovieSaver():
print(f"Movie saved, {duration}s playtime, watch here: \n{self.fp_out}")
def concatenate_movies(fp_final: str, list_fp_movies: List[str]):
r"""
Concatenate multiple movie segments into one long movie, using ffmpeg.
@ -167,13 +160,13 @@ def concatenate_movies(fp_final: str, list_fp_movies: List[str]):
fp_final : str
Full path of the final movie file. Should end with .mp4
list_fp_movies : list[str]
List of full paths of movie segments.
List of full paths of movie segments.
"""
assert fp_final[-4] == ".", "fp_final seems to miss file extension: {fp_final}"
for fp in list_fp_movies:
assert os.path.isfile(fp), f"Input movie does not exist: {fp}"
assert os.path.getsize(fp) > 100, f"Input movie seems empty: {fp}"
if os.path.isfile(fp_final):
os.remove(fp_final)
@ -181,32 +174,32 @@ def concatenate_movies(fp_final: str, list_fp_movies: List[str]):
list_concat = []
for fp_part in list_fp_movies:
list_concat.append(f"""file '{fp_part}'""")
# save this list
fp_list = "tmp_move.txt"
with open(fp_list, "w") as fa:
for item in list_concat:
fa.write("%s\n" % item)
cmd = f'ffmpeg -f concat -safe 0 -i {fp_list} -c copy {fp_final}'
dp_movie = os.path.split(fp_final)[0]
subprocess.call(cmd, shell=True)
os.remove(fp_list)
if os.path.isfile(fp_final):
print(f"concatenate_movies: success! Watch here: {fp_final}")
class MovieReader():
r"""
Class to read in a movie.
"""
def __init__(self, fp_movie):
self.video_player_object = cv2.VideoCapture(fp_movie)
self.nmb_frames = int(self.video_player_object.get(cv2.CAP_PROP_FRAME_COUNT))
self.fps_movie = int(self.video_player_object.get(cv2.CAP_PROP_FPS))
self.shape = [100,100,3]
self.shape = [100, 100, 3]
self.shape_is_set = False
def get_next_frame(self):
success, image = self.video_player_object.read()
if success:
@ -217,19 +210,18 @@ class MovieReader():
else:
return np.zeros(self.shape)
#%%
if __name__ == "__main__":
fps=2
if __name__ == "__main__":
fps = 2
list_fp_movies = []
for k in range(4):
fp_movie = f"/tmp/my_random_movie_{k}.mp4"
list_fp_movies.append(fp_movie)
ms = MovieSaver(fp_movie, fps=fps)
for fn in tqdm(range(30)):
img = (np.random.rand(512, 1024, 3)*255).astype(np.uint8)
img = (np.random.rand(512, 1024, 3) * 255).astype(np.uint8)
ms.write_frame(img)
ms.finalize()
fp_final = "/tmp/my_concatenated_movie.mp4"
concatenate_movies(fp_final, list_fp_movies)

View File

@ -13,36 +13,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os, sys
dp_git = "/home/lugo/git/"
sys.path.append(os.path.join(dp_git,'garden4'))
sys.path.append('util')
import os
import torch
torch.backends.cudnn.benchmark = False
torch.set_grad_enabled(False)
import numpy as np
import warnings
warnings.filterwarnings('ignore')
import time
import subprocess
import warnings
import torch
from tqdm.auto import tqdm
from PIL import Image
# import matplotlib.pyplot as plt
import torch
from movie_util import MovieSaver
import datetime
from typing import Callable, List, Optional, Union
import inspect
from threading import Thread
torch.set_grad_enabled(False)
from typing import Optional
from omegaconf import OmegaConf
from torch import autocast
from contextlib import nullcontext
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from einops import repeat, rearrange
#%%
from utils import interpolate_spherical
def pad_image(input_image):
@ -53,41 +42,11 @@ def pad_image(input_image):
return im_padded
def make_batch_inpaint(
image,
mask,
txt,
device,
num_samples=1):
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
batch = {
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
"txt": num_samples * [txt],
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
}
return batch
def make_batch_superres(
image,
txt,
device,
num_samples=1,
):
num_samples=1):
image = np.array(image.convert("RGB"))
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
batch = {
@ -107,14 +66,14 @@ def make_noise_augmentation(model, batch, noise_level=None):
class StableDiffusionHolder:
def __init__(self,
fp_ckpt: str = None,
def __init__(self,
fp_ckpt: str = None,
fp_config: str = None,
num_inference_steps: int = 30,
num_inference_steps: int = 30,
height: Optional[int] = None,
width: Optional[int] = None,
device: str = None,
precision: str='autocast',
precision: str = 'autocast',
):
r"""
Initializes the stable diffusion holder, which contains the models and sampler.
@ -122,26 +81,26 @@ class StableDiffusionHolder:
fp_ckpt: File pointer to the .ckpt model file
fp_config: File pointer to the .yaml config file
num_inference_steps: Number of diffusion iterations. Will be overwritten by latent blending.
height: Height of the resulting image.
width: Width of the resulting image.
height: Height of the resulting image.
width: Width of the resulting image.
device: Device to run the model on.
precision: Precision to run the model on.
"""
self.seed = 42
self.guidance_scale = 5.0
if device is None:
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
else:
self.device = device
self.precision = precision
self.init_model(fp_ckpt, fp_config)
self.f = 8 #downsampling factor, most often 8 or 16",
self.f = 8 # downsampling factor, most often 8 or 16"
self.C = 4
self.ddim_eta = 0
self.num_inference_steps = num_inference_steps
if height is None and width is None:
self.init_auto_res()
else:
@ -149,53 +108,44 @@ class StableDiffusionHolder:
assert width is not None, "specify both width and height"
self.height = height
self.width = width
# Inpainting inits
self.mask_empty = Image.fromarray(255*np.ones([self.width, self.height], dtype=np.uint8))
self.image_empty = Image.fromarray(np.zeros([self.width, self.height, 3], dtype=np.uint8))
self.negative_prompt = [""]
def init_model(self, fp_ckpt, fp_config):
r"""Loads the models and sampler.
"""
assert os.path.isfile(fp_ckpt), f"Your model checkpoint file does not exist: {fp_ckpt}"
self.fp_ckpt = fp_ckpt
# Auto init the config?
if fp_config is None:
fn_ckpt = os.path.basename(fp_ckpt)
if 'depth' in fn_ckpt:
fp_config = 'configs/v2-midas-inference.yaml'
elif 'inpain' in fn_ckpt:
fp_config = 'configs/v2-inpainting-inference.yaml'
elif 'upscaler' in fn_ckpt:
fp_config = 'configs/x4-upscaling.yaml'
fp_config = 'configs/x4-upscaling.yaml'
elif '512' in fn_ckpt:
fp_config = 'configs/v2-inference.yaml'
elif '768'in fn_ckpt:
fp_config = 'configs/v2-inference-v.yaml'
fp_config = 'configs/v2-inference.yaml'
elif '768' in fn_ckpt:
fp_config = 'configs/v2-inference-v.yaml'
elif 'v1-5' in fn_ckpt:
fp_config = 'configs/v1-inference.yaml'
fp_config = 'configs/v1-inference.yaml'
else:
raise ValueError("auto detect of config failed. please specify fp_config manually!")
assert os.path.isfile(fp_config), "Auto-init of the config file failed. Please specify manually."
assert os.path.isfile(fp_config), f"Your config file does not exist: {fp_config}"
config = OmegaConf.load(fp_config)
self.model = instantiate_from_config(config.model)
self.model.load_state_dict(torch.load(fp_ckpt)["state_dict"], strict=False)
self.model = self.model.to(self.device)
self.sampler = DDIMSampler(self.model)
def init_auto_res(self):
r"""Automatically set the resolution to the one used in training.
"""
@ -205,7 +155,7 @@ class StableDiffusionHolder:
else:
self.height = 512
self.width = 512
def set_negative_prompt(self, negative_prompt):
r"""Set the negative prompt. Currenty only one negative prompt is supported
"""
@ -214,51 +164,46 @@ class StableDiffusionHolder:
self.negative_prompt = [negative_prompt]
else:
self.negative_prompt = negative_prompt
if len(self.negative_prompt) > 1:
self.negative_prompt = [self.negative_prompt[0]]
def get_text_embedding(self, prompt):
c = self.model.get_learned_conditioning(prompt)
return c
@torch.no_grad()
def get_cond_upscaling(self, image, text_embedding, noise_level):
r"""
Initializes the conditioning for the x4 upscaling model.
"""
image = pad_image(image) # resize to integer multiple of 32
w, h = image.size
noise_level = torch.Tensor(1 * [noise_level]).to(self.sampler.model.device).long()
batch = make_batch_superres(image, txt="placeholder", device=self.device, num_samples=1)
x_augment, noise_level = make_noise_augmentation(self.model, batch, noise_level)
cond = {"c_concat": [x_augment], "c_crossattn": [text_embedding], "c_adm": noise_level}
# uncond cond
uc_cross = self.model.get_unconditional_conditioning(1, "")
uc_full = {"c_concat": [x_augment], "c_crossattn": [uc_cross], "c_adm": noise_level}
return cond, uc_full
@torch.no_grad()
def run_diffusion_standard(
self,
text_embeddings: torch.FloatTensor,
self,
text_embeddings: torch.FloatTensor,
latents_start: torch.FloatTensor,
idx_start: int = 0,
list_latents_mixing = None,
mixing_coeffs = 0.0,
spatial_mask = None,
return_image: Optional[bool] = False,
):
idx_start: int = 0,
list_latents_mixing=None,
mixing_coeffs=0.0,
spatial_mask=None,
return_image: Optional[bool] = False):
r"""
Diffusion standard version.
Diffusion standard version.
Args:
text_embeddings: torch.FloatTensor
text_embeddings: torch.FloatTensor
Text embeddings used for diffusion
latents_for_injection: torch.FloatTensor or list
Latents that are used for injection
@ -270,41 +215,32 @@ class StableDiffusionHolder:
experimental feature for enforcing pixels from list_latents_mixing
return_image: Optional[bool]
Optionally return image directly
"""
# Asserts
if type(mixing_coeffs) == float:
list_mixing_coeffs = self.num_inference_steps*[mixing_coeffs]
list_mixing_coeffs = self.num_inference_steps * [mixing_coeffs]
elif type(mixing_coeffs) == list:
assert len(mixing_coeffs) == self.num_inference_steps
list_mixing_coeffs = mixing_coeffs
else:
raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps")
if np.sum(list_mixing_coeffs) > 0:
assert len(list_latents_mixing) == self.num_inference_steps
precision_scope = autocast if self.precision == "autocast" else nullcontext
with precision_scope("cuda"):
with self.model.ema_scope():
if self.guidance_scale != 1.0:
uc = self.model.get_learned_conditioning(self.negative_prompt)
else:
uc = None
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps - 1, ddim_eta=self.ddim_eta, verbose=False)
latents = latents_start.clone()
timesteps = self.sampler.ddim_timesteps
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
# collect latents
# Collect latents
list_latents_out = []
for i, step in enumerate(time_range):
# Set the right starting latents
@ -313,83 +249,71 @@ class StableDiffusionHolder:
continue
elif i == idx_start:
latents = latents_start.clone()
# Mix the latents.
if i > 0 and list_mixing_coeffs[i]>0:
latents_mixtarget = list_latents_mixing[i-1].clone()
# Mix latents
if i > 0 and list_mixing_coeffs[i] > 0:
latents_mixtarget = list_latents_mixing[i - 1].clone()
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
if spatial_mask is not None and list_latents_mixing is not None:
latents = interpolate_spherical(latents, list_latents_mixing[i-1], 1-spatial_mask)
# latents[:,:,-15:,:] = latents_mixtarget[:,:,-15:,:]
latents = interpolate_spherical(latents, list_latents_mixing[i - 1], 1 - spatial_mask)
index = total_steps - i - 1
ts = torch.full((1,), step, device=self.device, dtype=torch.long)
outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False,
quantize_denoised=False, temperature=1.0,
noise_dropout=0.0, score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=self.guidance_scale,
unconditional_conditioning=uc,
dynamic_threshold=None)
quantize_denoised=False, temperature=1.0,
noise_dropout=0.0, score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=self.guidance_scale,
unconditional_conditioning=uc,
dynamic_threshold=None)
latents, pred_x0 = outs
list_latents_out.append(latents.clone())
if return_image:
if return_image:
return self.latent2image(latents)
else:
return list_latents_out
@torch.no_grad()
def run_diffusion_upscaling(
self,
self,
cond,
uc_full,
latents_start: torch.FloatTensor,
idx_start: int = -1,
list_latents_mixing = None,
mixing_coeffs = 0.0,
return_image: Optional[bool] = False
):
latents_start: torch.FloatTensor,
idx_start: int = -1,
list_latents_mixing: list = None,
mixing_coeffs: float = 0.0,
return_image: Optional[bool] = False):
r"""
Diffusion upscaling version.
Diffusion upscaling version.
"""
# Asserts
if type(mixing_coeffs) == float:
list_mixing_coeffs = self.num_inference_steps*[mixing_coeffs]
list_mixing_coeffs = self.num_inference_steps * [mixing_coeffs]
elif type(mixing_coeffs) == list:
assert len(mixing_coeffs) == self.num_inference_steps
list_mixing_coeffs = mixing_coeffs
else:
raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps")
if np.sum(list_mixing_coeffs) > 0:
assert len(list_latents_mixing) == self.num_inference_steps
precision_scope = autocast if self.precision == "autocast" else nullcontext
h = uc_full['c_concat'][0].shape[2]
w = uc_full['c_concat'][0].shape[3]
h = uc_full['c_concat'][0].shape[2]
w = uc_full['c_concat'][0].shape[3]
with precision_scope("cuda"):
with self.model.ema_scope():
shape_latents = [self.model.channels, h, w]
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps - 1, ddim_eta=self.ddim_eta, verbose=False)
C, H, W = shape_latents
size = (1, C, H, W)
b = size[0]
latents = latents_start.clone()
timesteps = self.sampler.ddim_timesteps
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
# collect latents
list_latents_out = []
for i, step in enumerate(time_range):
@ -399,232 +323,40 @@ class StableDiffusionHolder:
continue
elif i == idx_start:
latents = latents_start.clone()
# Mix the latents.
if i > 0 and list_mixing_coeffs[i]>0:
latents_mixtarget = list_latents_mixing[i-1].clone()
# Mix the latents.
if i > 0 and list_mixing_coeffs[i] > 0:
latents_mixtarget = list_latents_mixing[i - 1].clone()
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
# print(f"diffusion iter {i}")
index = total_steps - i - 1
ts = torch.full((b,), step, device=self.device, dtype=torch.long)
outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
quantize_denoised=False, temperature=1.0,
noise_dropout=0.0, score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=self.guidance_scale,
unconditional_conditioning=uc_full,
dynamic_threshold=None)
quantize_denoised=False, temperature=1.0,
noise_dropout=0.0, score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=self.guidance_scale,
unconditional_conditioning=uc_full,
dynamic_threshold=None)
latents, pred_x0 = outs
list_latents_out.append(latents.clone())
if return_image:
return self.latent2image(latents)
else:
return list_latents_out
@torch.no_grad()
def run_diffusion_inpaint(
self,
text_embeddings: torch.FloatTensor,
latents_for_injection: torch.FloatTensor = None,
idx_start: int = -1,
idx_stop: int = -1,
return_image: Optional[bool] = False
):
r"""
Runs inpaint-based diffusion. Returns a list of latents that were computed.
Adaptations allow to supply
a) starting index for diffusion
b) stopping index for diffusion
c) latent representations that are injected at the starting index
Furthermore the intermittent latents are collected and returned.
Adapted from diffusers (https://github.com/huggingface/diffusers)
Args:
text_embeddings: torch.FloatTensor
Text embeddings used for diffusion
latents_for_injection: torch.FloatTensor
Latents that are used for injection
idx_start: int
Index of the diffusion process start and where the latents_for_injection are injected
idx_stop: int
Index of the diffusion process end.
return_image: Optional[bool]
Optionally return image directly
"""
if latents_for_injection is None:
do_inject_latents = False
else:
do_inject_latents = True
precision_scope = autocast if self.precision == "autocast" else nullcontext
generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
with precision_scope("cuda"):
with self.model.ema_scope():
batch = make_batch_inpaint(self.image_source, self.mask_image, txt="willbereplaced", device=self.device, num_samples=1)
c = text_embeddings
c_cat = list()
for ck in self.model.concat_keys:
cc = batch[ck].float()
if ck != self.model.masked_image_key:
bchw = [1, 4, self.height // 8, self.width // 8]
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else:
cc = self.model.get_first_stage_encoding(self.model.encode_first_stage(cc))
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
# cond
cond = {"c_concat": [c_cat], "c_crossattn": [c]}
# uncond cond
uc_cross = self.model.get_unconditional_conditioning(1, "")
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
shape_latents = [self.model.channels, self.height // 8, self.width // 8]
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=0., verbose=False)
# sampling
C, H, W = shape_latents
size = (1, C, H, W)
device = self.model.betas.device
b = size[0]
latents = torch.randn(size, generator=generator, device=device)
timesteps = self.sampler.ddim_timesteps
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
# collect latents
list_latents_out = []
for i, step in enumerate(time_range):
if do_inject_latents:
# Inject latent at right place
if i < idx_start:
continue
elif i == idx_start:
latents = latents_for_injection.clone()
if i == idx_stop:
return list_latents_out
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
quantize_denoised=False, temperature=1.0,
noise_dropout=0.0, score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=self.guidance_scale,
unconditional_conditioning=uc_full,
dynamic_threshold=None)
latents, pred_x0 = outs
list_latents_out.append(latents.clone())
if return_image:
if return_image:
return self.latent2image(latents)
else:
return list_latents_out
@torch.no_grad()
def latent2image(
self,
latents: torch.FloatTensor
):
self,
latents: torch.FloatTensor):
r"""
Returns an image provided a latent representation from diffusion.
Args:
latents: torch.FloatTensor
Result of the diffusion process.
Result of the diffusion process.
"""
x_sample = self.model.decode_first_stage(latents)
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255 * x_sample[0,:,:].permute([1,2,0]).cpu().numpy()
x_sample = 255 * x_sample[0, :, :].permute([1, 2, 0]).cpu().numpy()
image = x_sample.astype(np.uint8)
return image
@torch.no_grad()
def interpolate_spherical(p0, p1, fract_mixing: float):
r"""
Helper function to correctly mix two random variables using spherical interpolation.
See https://en.wikipedia.org/wiki/Slerp
The function will always cast up to float64 for sake of extra 4.
Args:
p0:
First tensor for interpolation
p1:
Second tensor for interpolation
fract_mixing: float
Mixing coefficient of interval [0, 1].
0 will return in p0
1 will return in p1
0.x will return a mix between both preserving angular velocity.
"""
if p0.dtype == torch.float16:
recast_to = 'fp16'
else:
recast_to = 'fp32'
p0 = p0.double()
p1 = p1.double()
norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
epsilon = 1e-7
dot = torch.sum(p0 * p1) / norm
dot = dot.clamp(-1+epsilon, 1-epsilon)
theta_0 = torch.arccos(dot)
sin_theta_0 = torch.sin(theta_0)
theta_t = theta_0 * fract_mixing
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
s1 = torch.sin(theta_t) / sin_theta_0
interp = p0*s0 + p1*s1
if recast_to == 'fp16':
interp = interp.half()
elif recast_to == 'fp32':
interp = interp.float()
return interp
if __name__ == "__main__":
num_inference_steps = 20 # Number of diffusion interations
# fp_ckpt = "../stable_diffusion_models/ckpt/768-v-ema.ckpt"
# fp_config = '../stablediffusion/configs/stable-diffusion/v2-inference-v.yaml'
# fp_ckpt= "../stable_diffusion_models/ckpt/512-inpainting-ema.ckpt"
# fp_config = '../stablediffusion/configs//stable-diffusion/v2-inpainting-inference.yaml'
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
# fp_config = 'configs/v2-inference-v.yaml'
self = StableDiffusionHolder(fp_ckpt, num_inference_steps=num_inference_steps)
xxx
#%%
self.width = 1536
self.height = 768
prompt = "360 degree equirectangular, a huge rocky hill full of pianos and keyboards, musical instruments, cinematic, masterpiece 8 k, artstation"
self.set_negative_prompt("out of frame, faces, rendering, blurry")
te = self.get_text_embedding(prompt)
img = self.run_diffusion_standard(te, return_image=True)
Image.fromarray(img).show()

260
utils.py Normal file
View File

@ -0,0 +1,260 @@
# Copyright 2022 Lunar Ring. All rights reserved.
# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
torch.backends.cudnn.benchmark = False
import numpy as np
import warnings
warnings.filterwarnings('ignore')
import time
import warnings
import datetime
from typing import List, Union
torch.set_grad_enabled(False)
import yaml
@torch.no_grad()
def interpolate_spherical(p0, p1, fract_mixing: float):
r"""
Helper function to correctly mix two random variables using spherical interpolation.
See https://en.wikipedia.org/wiki/Slerp
The function will always cast up to float64 for sake of extra 4.
Args:
p0:
First tensor for interpolation
p1:
Second tensor for interpolation
fract_mixing: float
Mixing coefficient of interval [0, 1].
0 will return in p0
1 will return in p1
0.x will return a mix between both preserving angular velocity.
"""
if p0.dtype == torch.float16:
recast_to = 'fp16'
else:
recast_to = 'fp32'
p0 = p0.double()
p1 = p1.double()
norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
epsilon = 1e-7
dot = torch.sum(p0 * p1) / norm
dot = dot.clamp(-1 + epsilon, 1 - epsilon)
theta_0 = torch.arccos(dot)
sin_theta_0 = torch.sin(theta_0)
theta_t = theta_0 * fract_mixing
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
s1 = torch.sin(theta_t) / sin_theta_0
interp = p0 * s0 + p1 * s1
if recast_to == 'fp16':
interp = interp.half()
elif recast_to == 'fp32':
interp = interp.float()
return interp
def interpolate_linear(p0, p1, fract_mixing):
r"""
Helper function to mix two variables using standard linear interpolation.
Args:
p0:
First tensor / np.ndarray for interpolation
p1:
Second tensor / np.ndarray for interpolation
fract_mixing: float
Mixing coefficient of interval [0, 1].
0 will return in p0
1 will return in p1
0.x will return a linear mix between both.
"""
reconvert_uint8 = False
if type(p0) is np.ndarray and p0.dtype == 'uint8':
reconvert_uint8 = True
p0 = p0.astype(np.float64)
if type(p1) is np.ndarray and p1.dtype == 'uint8':
reconvert_uint8 = True
p1 = p1.astype(np.float64)
interp = (1 - fract_mixing) * p0 + fract_mixing * p1
if reconvert_uint8:
interp = np.clip(interp, 0, 255).astype(np.uint8)
return interp
def add_frames_linear_interp(
list_imgs: List[np.ndarray],
fps_target: Union[float, int] = None,
duration_target: Union[float, int] = None,
nmb_frames_target: int = None):
r"""
Helper function to cheaply increase the number of frames given a list of images,
by virtue of standard linear interpolation.
The number of inserted frames will be automatically adjusted so that the total of number
of frames can be fixed precisely, using a random shuffling technique.
The function allows 1:1 comparisons between transitions as videos.
Args:
list_imgs: List[np.ndarray)
List of images, between each image new frames will be inserted via linear interpolation.
fps_target:
OptionA: specify here the desired frames per second.
duration_target:
OptionA: specify here the desired duration of the transition in seconds.
nmb_frames_target:
OptionB: directly fix the total number of frames of the output.
"""
# Sanity
if nmb_frames_target is not None and fps_target is not None:
raise ValueError("You cannot specify both fps_target and nmb_frames_target")
if fps_target is None:
assert nmb_frames_target is not None, "Either specify nmb_frames_target or nmb_frames_target"
if nmb_frames_target is None:
assert fps_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
assert duration_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
nmb_frames_target = fps_target * duration_target
# Get number of frames that are missing
nmb_frames_diff = len(list_imgs) - 1
nmb_frames_missing = nmb_frames_target - nmb_frames_diff - 1
if nmb_frames_missing < 1:
return list_imgs
list_imgs_float = [img.astype(np.float32) for img in list_imgs]
# Distribute missing frames, append nmb_frames_to_insert(i) frames for each frame
mean_nmb_frames_insert = nmb_frames_missing / nmb_frames_diff
constfact = np.floor(mean_nmb_frames_insert)
remainder_x = 1 - (mean_nmb_frames_insert - constfact)
nmb_iter = 0
while True:
nmb_frames_to_insert = np.random.rand(nmb_frames_diff)
nmb_frames_to_insert[nmb_frames_to_insert <= remainder_x] = 0
nmb_frames_to_insert[nmb_frames_to_insert > remainder_x] = 1
nmb_frames_to_insert += constfact
if np.sum(nmb_frames_to_insert) == nmb_frames_missing:
break
nmb_iter += 1
if nmb_iter > 100000:
print("add_frames_linear_interp: issue with inserting the right number of frames")
break
nmb_frames_to_insert = nmb_frames_to_insert.astype(np.int32)
list_imgs_interp = []
for i in range(len(list_imgs_float) - 1):
img0 = list_imgs_float[i]
img1 = list_imgs_float[i + 1]
list_imgs_interp.append(img0.astype(np.uint8))
list_fracts_linblend = np.linspace(0, 1, nmb_frames_to_insert[i] + 2)[1:-1]
for fract_linblend in list_fracts_linblend:
img_blend = interpolate_linear(img0, img1, fract_linblend).astype(np.uint8)
list_imgs_interp.append(img_blend.astype(np.uint8))
if i == len(list_imgs_float) - 2:
list_imgs_interp.append(img1.astype(np.uint8))
return list_imgs_interp
def get_spacing(nmb_points: int, scaling: float):
"""
Helper function for getting nonlinear spacing between 0 and 1, symmetric around 0.5
Args:
nmb_points: int
Number of points between [0, 1]
scaling: float
Higher values will return higher sampling density around 0.5
"""
if scaling < 1.7:
return np.linspace(0, 1, nmb_points)
nmb_points_per_side = nmb_points // 2 + 1
if np.mod(nmb_points, 2) != 0: # Uneven case
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)
right_side = 1 - left_side[::-1][1:]
else:
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)[0:-1]
right_side = 1 - left_side[::-1]
all_fracts = np.hstack([left_side, right_side])
return all_fracts
def get_time(resolution=None):
"""
Helper function returning an nicely formatted time string, e.g. 221117_1620
"""
if resolution is None:
resolution = "second"
if resolution == "day":
t = time.strftime('%y%m%d', time.localtime())
elif resolution == "minute":
t = time.strftime('%y%m%d_%H%M', time.localtime())
elif resolution == "second":
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
elif resolution == "millisecond":
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
t += "_"
t += str("{:03d}".format(int(int(datetime.utcnow().strftime('%f')) / 1000)))
else:
raise ValueError("bad resolution provided: %s" % resolution)
return t
def compare_dicts(a, b):
"""
Compares two dictionaries a and b and returns a dictionary c, with all
keys,values that have shared keys in a and b but same values in a and b.
The values of a and b are stacked together in the output.
Example:
a = {}; a['bobo'] = 4
b = {}; b['bobo'] = 5
c = dict_compare(a,b)
c = {"bobo",[4,5]}
"""
c = {}
for key in a.keys():
if key in b.keys():
val_a = a[key]
val_b = b[key]
if val_a != val_b:
c[key] = [val_a, val_b]
return c
def yml_load(fp_yml, print_fields=False):
"""
Helper function for loading yaml files
"""
with open(fp_yml) as f:
data = yaml.load(f, Loader=yaml.loader.SafeLoader)
dict_data = dict(data)
print("load: loaded {}".format(fp_yml))
return dict_data
def yml_save(fp_yml, dict_stuff):
"""
Helper function for saving yaml files
"""
with open(fp_yml, 'w') as f:
yaml.dump(dict_stuff, f, sort_keys=False, default_flow_style=False)
print("yml_save: saved {}".format(fp_yml))