cleanup
This commit is contained in:
parent
3ed876e0ee
commit
297bb9abe6
|
@ -13,30 +13,22 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os, sys
|
||||
import torch
|
||||
torch.backends.cudnn.benchmark = False
|
||||
import numpy as np
|
||||
torch.set_grad_enabled(False)
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
import warnings
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
from PIL import Image
|
||||
# import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from movie_util import MovieSaver
|
||||
from typing import Callable, List, Optional, Union
|
||||
from latent_blending import LatentBlending, add_frames_linear_interp
|
||||
from latent_blending import LatentBlending
|
||||
from stable_diffusion_holder import StableDiffusionHolder
|
||||
torch.set_grad_enabled(False)
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
||||
#%% First let us spawn a stable diffusion holder
|
||||
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
|
||||
# %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
|
||||
# fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
|
||||
fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.ckpt")
|
||||
sdh = StableDiffusionHolder(fp_ckpt)
|
||||
|
||||
#%% Next let's set up all parameters
|
||||
# %% Next let's set up all parameters
|
||||
depth_strength = 0.65 # Specifies how deep (in terms of diffusion iterations the first branching happens)
|
||||
t_compute_max_allowed = 15 # Determines the quality of the transition in terms of compute time you grant it
|
||||
fixed_seeds = [69731932, 504430820]
|
||||
|
@ -54,10 +46,9 @@ lb.set_prompt2(prompt2)
|
|||
|
||||
# Run latent blending
|
||||
lb.run_transition(
|
||||
depth_strength = depth_strength,
|
||||
t_compute_max_allowed = t_compute_max_allowed,
|
||||
fixed_seeds = fixed_seeds
|
||||
)
|
||||
depth_strength=depth_strength,
|
||||
t_compute_max_allowed=t_compute_max_allowed,
|
||||
fixed_seeds=fixed_seeds)
|
||||
|
||||
# Save movie
|
||||
lb.write_movie_transition(fp_movie, duration_transition)
|
|
@ -13,33 +13,26 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os, sys
|
||||
import torch
|
||||
torch.backends.cudnn.benchmark = False
|
||||
import numpy as np
|
||||
torch.set_grad_enabled(False)
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
import warnings
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
from PIL import Image
|
||||
import torch
|
||||
from movie_util import MovieSaver, concatenate_movies
|
||||
from typing import Callable, List, Optional, Union
|
||||
from latent_blending import LatentBlending, add_frames_linear_interp
|
||||
from latent_blending import LatentBlending
|
||||
from stable_diffusion_holder import StableDiffusionHolder
|
||||
torch.set_grad_enabled(False)
|
||||
from movie_util import concatenate_movies
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
#%% First let us spawn a stable diffusion holder
|
||||
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
|
||||
# fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
|
||||
# %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
|
||||
# fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
|
||||
fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.ckpt")
|
||||
sdh = StableDiffusionHolder(fp_ckpt)
|
||||
|
||||
|
||||
#%% Let's setup the multi transition
|
||||
# %% Let's setup the multi transition
|
||||
fps = 30
|
||||
duration_single_trans = 6
|
||||
depth_strength = 0.55 #Specifies how deep (in terms of diffusion iterations the first branching happens)
|
||||
depth_strength = 0.55 # Specifies how deep (in terms of diffusion iterations the first branching happens)
|
||||
|
||||
# Specify a list of prompts below
|
||||
list_prompts = []
|
||||
|
@ -56,28 +49,25 @@ t_compute_max_allowed = 12 # per segment
|
|||
fp_movie = 'movie_example2.mp4'
|
||||
lb = LatentBlending(sdh)
|
||||
|
||||
list_movie_parts = [] #
|
||||
for i in range(len(list_prompts)-1):
|
||||
list_movie_parts = []
|
||||
for i in range(len(list_prompts) - 1):
|
||||
# For a multi transition we can save some computation time and recycle the latents
|
||||
if i==0:
|
||||
if i == 0:
|
||||
lb.set_prompt1(list_prompts[i])
|
||||
lb.set_prompt2(list_prompts[i+1])
|
||||
lb.set_prompt2(list_prompts[i + 1])
|
||||
recycle_img1 = False
|
||||
else:
|
||||
lb.swap_forward()
|
||||
lb.set_prompt2(list_prompts[i+1])
|
||||
lb.set_prompt2(list_prompts[i + 1])
|
||||
recycle_img1 = True
|
||||
|
||||
fp_movie_part = f"tmp_part_{str(i).zfill(3)}.mp4"
|
||||
|
||||
fixed_seeds = list_seeds[i:i+2]
|
||||
|
||||
fixed_seeds = list_seeds[i:i + 2]
|
||||
# Run latent blending
|
||||
lb.run_transition(
|
||||
depth_strength = depth_strength,
|
||||
t_compute_max_allowed = t_compute_max_allowed,
|
||||
fixed_seeds = fixed_seeds
|
||||
)
|
||||
depth_strength=depth_strength,
|
||||
t_compute_max_allowed=t_compute_max_allowed,
|
||||
fixed_seeds=fixed_seeds)
|
||||
|
||||
# Save movie
|
||||
lb.write_movie_transition(fp_movie_part, duration_single_trans)
|
||||
|
|
|
@ -13,25 +13,17 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os, sys
|
||||
import torch
|
||||
torch.backends.cudnn.benchmark = False
|
||||
import numpy as np
|
||||
torch.set_grad_enabled(False)
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
import warnings
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
from PIL import Image
|
||||
# import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from movie_util import MovieSaver
|
||||
from typing import Callable, List, Optional, Union
|
||||
from latent_blending import LatentBlending, add_frames_linear_interp
|
||||
from latent_blending import LatentBlending
|
||||
from stable_diffusion_holder import StableDiffusionHolder
|
||||
torch.set_grad_enabled(False)
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
#%% Define vars for low-resoltion pass
|
||||
# %% Define vars for low-resoltion pass
|
||||
prompt1 = "photo of mount vesuvius erupting a terrifying pyroclastic ash cloud"
|
||||
prompt2 = "photo of a inside a building full of ash, fire, death, destruction, explosions"
|
||||
fixed_seeds = [5054613, 1168652]
|
||||
|
@ -41,21 +33,18 @@ height = 384
|
|||
num_inference_steps_lores = 40
|
||||
nmb_max_branches_lores = 10
|
||||
depth_strength_lores = 0.5
|
||||
fp_ckpt_lores = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
|
||||
|
||||
fp_ckpt_lores = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
|
||||
|
||||
#%% Define vars for high-resoltion pass
|
||||
fp_ckpt_hires = "../stable_diffusion_models/ckpt/x4-upscaler-ema.ckpt"
|
||||
# %% Define vars for high-resoltion pass
|
||||
fp_ckpt_hires = hf_hub_download(repo_id="stabilityai/stable-diffusion-x4-upscaler", filename="x4-upscaler-ema.ckpt")
|
||||
depth_strength_hires = 0.65
|
||||
num_inference_steps_hires = 100
|
||||
nmb_branches_final_hires = 6
|
||||
dp_imgs = "tmp_transition" # folder for results and intermediate steps
|
||||
dp_imgs = "tmp_transition" # Folder for results and intermediate steps
|
||||
|
||||
|
||||
#%% Run low-res pass
|
||||
# %% Run low-res pass
|
||||
sdh = StableDiffusionHolder(fp_ckpt_lores)
|
||||
|
||||
#%%
|
||||
lb = LatentBlending(sdh)
|
||||
lb.set_prompt1(prompt1)
|
||||
lb.set_prompt2(prompt2)
|
||||
|
@ -64,14 +53,13 @@ lb.set_height(height)
|
|||
|
||||
# Run latent blending
|
||||
lb.run_transition(
|
||||
depth_strength = depth_strength_lores,
|
||||
nmb_max_branches = nmb_max_branches_lores,
|
||||
fixed_seeds = fixed_seeds
|
||||
)
|
||||
depth_strength=depth_strength_lores,
|
||||
nmb_max_branches=nmb_max_branches_lores,
|
||||
fixed_seeds=fixed_seeds)
|
||||
|
||||
lb.write_imgs_transition(dp_imgs)
|
||||
|
||||
#%% Run high-res pass
|
||||
# %% Run high-res pass
|
||||
sdh = StableDiffusionHolder(fp_ckpt_hires)
|
||||
lb = LatentBlending(sdh)
|
||||
lb.run_upscaling(dp_imgs, depth_strength_hires, num_inference_steps_hires, nmb_branches_final_hires)
|
||||
|
|
|
@ -13,25 +13,19 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os, sys
|
||||
import os
|
||||
import torch
|
||||
torch.backends.cudnn.benchmark = False
|
||||
import numpy as np
|
||||
torch.set_grad_enabled(False)
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
import warnings
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
from PIL import Image
|
||||
# import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from movie_util import MovieSaver, concatenate_movies
|
||||
from typing import Callable, List, Optional, Union
|
||||
from latent_blending import LatentBlending, add_frames_linear_interp
|
||||
from latent_blending import LatentBlending
|
||||
from stable_diffusion_holder import StableDiffusionHolder
|
||||
torch.set_grad_enabled(False)
|
||||
from movie_util import concatenate_movies
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
#%% Define vars for low-resoltion pass
|
||||
# %% Define vars for low-resoltion pass
|
||||
list_prompts = []
|
||||
list_prompts.append("surrealistic statue made of glitter and dirt, standing in a lake, atmospheric light, strange glow")
|
||||
list_prompts.append("statue of a mix between a tree and human, made of marble, incredibly detailed")
|
||||
|
@ -50,56 +44,54 @@ num_inference_steps_lores = 40
|
|||
nmb_max_branches_lores = 10
|
||||
depth_strength_lores = 0.5
|
||||
|
||||
fp_ckpt_lores = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
|
||||
fp_ckpt_lores = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
|
||||
|
||||
#%% Define vars for high-resoltion pass
|
||||
fp_ckpt_hires = "../stable_diffusion_models/ckpt/x4-upscaler-ema.ckpt"
|
||||
# %% Define vars for high-resoltion pass
|
||||
fp_ckpt_hires = hf_hub_download(repo_id="stabilityai/stable-diffusion-x4-upscaler", filename="x4-upscaler-ema.ckpt")
|
||||
depth_strength_hires = 0.65
|
||||
num_inference_steps_hires = 100
|
||||
nmb_branches_final_hires = 6
|
||||
#%% Run low-res pass
|
||||
|
||||
# %% Run low-res pass
|
||||
sdh = StableDiffusionHolder(fp_ckpt_lores)
|
||||
t_compute_max_allowed = 12 # per segment
|
||||
t_compute_max_allowed = 12 # Per segment
|
||||
lb = LatentBlending(sdh)
|
||||
|
||||
list_movie_dirs = [] #
|
||||
for i in range(len(list_prompts)-1):
|
||||
list_movie_dirs = []
|
||||
for i in range(len(list_prompts) - 1):
|
||||
# For a multi transition we can save some computation time and recycle the latents
|
||||
if i==0:
|
||||
if i == 0:
|
||||
lb.set_prompt1(list_prompts[i])
|
||||
lb.set_prompt2(list_prompts[i+1])
|
||||
lb.set_prompt2(list_prompts[i + 1])
|
||||
recycle_img1 = False
|
||||
else:
|
||||
lb.swap_forward()
|
||||
lb.set_prompt2(list_prompts[i+1])
|
||||
lb.set_prompt2(list_prompts[i + 1])
|
||||
recycle_img1 = True
|
||||
|
||||
dp_movie_part = f"tmp_part_{str(i).zfill(3)}"
|
||||
fp_movie_part = os.path.join(dp_movie_part, "movie_lowres.mp4")
|
||||
os.makedirs(dp_movie_part, exist_ok=True)
|
||||
fixed_seeds = list_seeds[i:i+2]
|
||||
fixed_seeds = list_seeds[i:i + 2]
|
||||
|
||||
# Run latent blending
|
||||
lb.run_transition(
|
||||
depth_strength = depth_strength_lores,
|
||||
nmb_max_branches = nmb_max_branches_lores,
|
||||
fixed_seeds = fixed_seeds
|
||||
)
|
||||
depth_strength=depth_strength_lores,
|
||||
nmb_max_branches=nmb_max_branches_lores,
|
||||
fixed_seeds=fixed_seeds)
|
||||
|
||||
# Save movie and images (needed for upscaling!)
|
||||
lb.write_movie_transition(fp_movie_part, duration_single_trans)
|
||||
lb.write_imgs_transition(dp_movie_part)
|
||||
list_movie_dirs.append(dp_movie_part)
|
||||
|
||||
|
||||
|
||||
#%% Run high-res pass on each segment
|
||||
# %% Run high-res pass on each segment
|
||||
sdh = StableDiffusionHolder(fp_ckpt_hires)
|
||||
lb = LatentBlending(sdh)
|
||||
for dp_part in list_movie_dirs:
|
||||
lb.run_upscaling(dp_part, depth_strength_hires, num_inference_steps_hires, nmb_branches_final_hires)
|
||||
|
||||
#%% concatenate into one long movie
|
||||
# %% concatenate into one long movie
|
||||
list_fp_movies = []
|
||||
for dp_part in list_movie_dirs:
|
||||
fp_movie = os.path.join(dp_part, "movie_highres.mp4")
|
||||
|
|
306
gradio_ui.py
306
gradio_ui.py
|
@ -13,82 +13,89 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os, sys
|
||||
import os
|
||||
import torch
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.set_grad_enabled(False)
|
||||
import numpy as np
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
import warnings
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
from PIL import Image
|
||||
import torch
|
||||
from movie_util import MovieSaver, concatenate_movies
|
||||
from typing import Callable, List, Optional, Union
|
||||
from latent_blending import get_time, yml_save, LatentBlending, add_frames_linear_interp, compare_dicts
|
||||
from latent_blending import LatentBlending
|
||||
from stable_diffusion_holder import StableDiffusionHolder
|
||||
torch.set_grad_enabled(False)
|
||||
import gradio as gr
|
||||
import copy
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
import shutil
|
||||
import random
|
||||
import time
|
||||
from utils import get_time, add_frames_linear_interp
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
||||
#%%
|
||||
|
||||
class BlendingFrontend():
|
||||
def __init__(self, sdh=None):
|
||||
self.num_inference_steps = 30
|
||||
if sdh is None:
|
||||
self.use_debug = True
|
||||
self.height = 768
|
||||
self.width = 768
|
||||
else:
|
||||
self.use_debug = False
|
||||
self.lb = LatentBlending(sdh)
|
||||
self.lb.sdh.num_inference_steps = self.num_inference_steps
|
||||
self.height = self.lb.sdh.height
|
||||
self.width = self.lb.sdh.width
|
||||
def __init__(
|
||||
self,
|
||||
sdh,
|
||||
share=False):
|
||||
r"""
|
||||
Gradio Helper Class to collect UI data and start latent blending.
|
||||
Args:
|
||||
sdh:
|
||||
StableDiffusionHolder
|
||||
share: bool
|
||||
Set true to get a shareable gradio link (e.g. for running a remote server)
|
||||
"""
|
||||
self.share = share
|
||||
|
||||
self.init_save_dir()
|
||||
self.save_empty_image()
|
||||
self.share = False
|
||||
self.transition_can_be_computed = False
|
||||
# UI Defaults
|
||||
self.num_inference_steps = 30
|
||||
self.depth_strength = 0.25
|
||||
self.seed1 = 420
|
||||
self.seed2 = 420
|
||||
self.guidance_scale = 4.0
|
||||
self.guidance_scale_mid_damper = 0.5
|
||||
self.mid_compression_scaler = 1.2
|
||||
self.prompt1 = ""
|
||||
self.prompt2 = ""
|
||||
self.negative_prompt = ""
|
||||
self.state_current = {}
|
||||
self.fps = 30
|
||||
self.duration_video = 8
|
||||
self.t_compute_max_allowed = 10
|
||||
|
||||
self.lb = LatentBlending(sdh)
|
||||
self.lb.sdh.num_inference_steps = self.num_inference_steps
|
||||
self.init_parameters_from_lb()
|
||||
self.init_save_dir()
|
||||
|
||||
# Vars
|
||||
self.list_fp_imgs_current = []
|
||||
self.recycle_img1 = False
|
||||
self.recycle_img2 = False
|
||||
self.list_all_segments = []
|
||||
self.dp_session = ""
|
||||
self.user_id = None
|
||||
|
||||
def init_parameters_from_lb(self):
|
||||
r"""
|
||||
Automatically init parameters from latentblending instance
|
||||
"""
|
||||
self.height = self.lb.sdh.height
|
||||
self.width = self.lb.sdh.width
|
||||
self.guidance_scale = self.lb.guidance_scale
|
||||
self.guidance_scale_mid_damper = self.lb.guidance_scale_mid_damper
|
||||
self.mid_compression_scaler = self.lb.mid_compression_scaler
|
||||
self.branch1_crossfeed_power = self.lb.branch1_crossfeed_power
|
||||
self.branch1_crossfeed_range = self.lb.branch1_crossfeed_range
|
||||
self.branch1_crossfeed_decay = self.lb.branch1_crossfeed_decay
|
||||
self.parental_crossfeed_power = self.lb.parental_crossfeed_power
|
||||
self.parental_crossfeed_range = self.lb.parental_crossfeed_range
|
||||
self.parental_crossfeed_power_decay = self.lb.parental_crossfeed_power_decay
|
||||
self.fps = 30
|
||||
self.duration_video = 10
|
||||
self.t_compute_max_allowed = 10
|
||||
self.list_fp_imgs_current = []
|
||||
self.current_timestamp = None
|
||||
self.recycle_img1 = False
|
||||
self.recycle_img2 = False
|
||||
self.multi_idx_current = -1
|
||||
self.list_imgs_shown_last = 5*[self.fp_img_empty]
|
||||
self.list_all_segments = []
|
||||
self.dp_session = ""
|
||||
self.user_id = None
|
||||
self.block_transition = False
|
||||
|
||||
|
||||
def init_save_dir(self):
|
||||
r"""
|
||||
Initializes the directory where stuff is being saved.
|
||||
You can specify this directory in a ".env" file in your latentblending root, setting
|
||||
DIR_OUT='/path/to/saving'
|
||||
"""
|
||||
load_dotenv(find_dotenv(), verbose=False)
|
||||
self.dp_out = os.getenv("DIR_OUT")
|
||||
if self.dp_out is None:
|
||||
|
@ -97,124 +104,125 @@ class BlendingFrontend():
|
|||
os.makedirs(self.dp_imgs, exist_ok=True)
|
||||
self.dp_movies = os.path.join(self.dp_out, "movies")
|
||||
os.makedirs(self.dp_movies, exist_ok=True)
|
||||
self.save_empty_image()
|
||||
|
||||
|
||||
# make dummy image
|
||||
def save_empty_image(self):
|
||||
r"""
|
||||
Saves an empty/black dummy image.
|
||||
"""
|
||||
self.fp_img_empty = os.path.join(self.dp_imgs, 'empty.jpg')
|
||||
Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5)
|
||||
|
||||
|
||||
def randomize_seed1(self):
|
||||
# Dont randomize seed if we are in a multi concat mode. we don't want to change this one otherwise the movie breaks
|
||||
r"""
|
||||
Randomizes the first seed
|
||||
"""
|
||||
seed = np.random.randint(0, 10000000)
|
||||
self.seed1 = int(seed)
|
||||
print(f"randomize_seed1: new seed = {self.seed1}")
|
||||
return seed
|
||||
|
||||
def randomize_seed2(self):
|
||||
r"""
|
||||
Randomizes the second seed
|
||||
"""
|
||||
seed = np.random.randint(0, 10000000)
|
||||
self.seed2 = int(seed)
|
||||
print(f"randomize_seed2: new seed = {self.seed2}")
|
||||
return seed
|
||||
|
||||
|
||||
def setup_lb(self, list_ui_elem):
|
||||
def setup_lb(self, list_ui_vals):
|
||||
r"""
|
||||
Sets all parameters from the UI. Since gradio does not support to pass dictionaries,
|
||||
we have to instead pass keys (list_ui_keys, global) and values (list_ui_vals)
|
||||
"""
|
||||
# Collect latent blending variables
|
||||
self.state_current = self.get_state_dict()
|
||||
self.lb.set_width(list_ui_elem[list_ui_keys.index('width')])
|
||||
self.lb.set_height(list_ui_elem[list_ui_keys.index('height')])
|
||||
self.lb.set_prompt1(list_ui_elem[list_ui_keys.index('prompt1')])
|
||||
self.lb.set_prompt2(list_ui_elem[list_ui_keys.index('prompt2')])
|
||||
self.lb.set_negative_prompt(list_ui_elem[list_ui_keys.index('negative_prompt')])
|
||||
self.lb.guidance_scale = list_ui_elem[list_ui_keys.index('guidance_scale')]
|
||||
self.lb.guidance_scale_mid_damper = list_ui_elem[list_ui_keys.index('guidance_scale_mid_damper')]
|
||||
self.t_compute_max_allowed = list_ui_elem[list_ui_keys.index('duration_compute')]
|
||||
self.lb.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
|
||||
self.lb.sdh.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
|
||||
self.duration_video = list_ui_elem[list_ui_keys.index('duration_video')]
|
||||
self.lb.seed1 = list_ui_elem[list_ui_keys.index('seed1')] #seed
|
||||
self.lb.seed2 = list_ui_elem[list_ui_keys.index('seed2')]
|
||||
self.lb.set_width(list_ui_vals[list_ui_keys.index('width')])
|
||||
self.lb.set_height(list_ui_vals[list_ui_keys.index('height')])
|
||||
self.lb.set_prompt1(list_ui_vals[list_ui_keys.index('prompt1')])
|
||||
self.lb.set_prompt2(list_ui_vals[list_ui_keys.index('prompt2')])
|
||||
self.lb.set_negative_prompt(list_ui_vals[list_ui_keys.index('negative_prompt')])
|
||||
self.lb.guidance_scale = list_ui_vals[list_ui_keys.index('guidance_scale')]
|
||||
self.lb.guidance_scale_mid_damper = list_ui_vals[list_ui_keys.index('guidance_scale_mid_damper')]
|
||||
self.t_compute_max_allowed = list_ui_vals[list_ui_keys.index('duration_compute')]
|
||||
self.lb.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
|
||||
self.lb.sdh.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
|
||||
self.duration_video = list_ui_vals[list_ui_keys.index('duration_video')]
|
||||
self.lb.seed1 = list_ui_vals[list_ui_keys.index('seed1')]
|
||||
self.lb.seed2 = list_ui_vals[list_ui_keys.index('seed2')]
|
||||
self.lb.branch1_crossfeed_power = list_ui_vals[list_ui_keys.index('branch1_crossfeed_power')]
|
||||
self.lb.branch1_crossfeed_range = list_ui_vals[list_ui_keys.index('branch1_crossfeed_range')]
|
||||
self.lb.branch1_crossfeed_decay = list_ui_vals[list_ui_keys.index('branch1_crossfeed_decay')]
|
||||
self.lb.parental_crossfeed_power = list_ui_vals[list_ui_keys.index('parental_crossfeed_power')]
|
||||
self.lb.parental_crossfeed_range = list_ui_vals[list_ui_keys.index('parental_crossfeed_range')]
|
||||
self.lb.parental_crossfeed_power_decay = list_ui_vals[list_ui_keys.index('parental_crossfeed_power_decay')]
|
||||
self.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
|
||||
self.depth_strength = list_ui_vals[list_ui_keys.index('depth_strength')]
|
||||
|
||||
self.lb.branch1_crossfeed_power = list_ui_elem[list_ui_keys.index('branch1_crossfeed_power')]
|
||||
self.lb.branch1_crossfeed_range = list_ui_elem[list_ui_keys.index('branch1_crossfeed_range')]
|
||||
self.lb.branch1_crossfeed_decay = list_ui_elem[list_ui_keys.index('branch1_crossfeed_decay')]
|
||||
self.lb.parental_crossfeed_power = list_ui_elem[list_ui_keys.index('parental_crossfeed_power')]
|
||||
self.lb.parental_crossfeed_range = list_ui_elem[list_ui_keys.index('parental_crossfeed_range')]
|
||||
self.lb.parental_crossfeed_power_decay = list_ui_elem[list_ui_keys.index('parental_crossfeed_power_decay')]
|
||||
self.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
|
||||
self.depth_strength = list_ui_elem[list_ui_keys.index('depth_strength')]
|
||||
|
||||
if len(list_ui_elem[list_ui_keys.index('user_id')]) > 1:
|
||||
self.user_id = list_ui_elem[list_ui_keys.index('user_id')]
|
||||
if len(list_ui_vals[list_ui_keys.index('user_id')]) > 1:
|
||||
self.user_id = list_ui_vals[list_ui_keys.index('user_id')]
|
||||
else:
|
||||
# generate new user id
|
||||
self.user_id = ''.join((random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ') for i in range(8)))
|
||||
print(f"made new user_id: {self.user_id}")
|
||||
print(f"made new user_id: {self.user_id} at {get_time('second')}")
|
||||
|
||||
def save_latents(self, fp_latents, list_latents):
|
||||
r"""
|
||||
Saves a latent trajectory on disk, in npy format.
|
||||
"""
|
||||
list_latents_cpu = [l.cpu().numpy() for l in list_latents]
|
||||
np.save(fp_latents, list_latents_cpu)
|
||||
|
||||
|
||||
def load_latents(self, fp_latents):
|
||||
r"""
|
||||
Loads a latent trajectory from disk, converts to torch tensor.
|
||||
"""
|
||||
list_latents_cpu = np.load(fp_latents)
|
||||
list_latents = [torch.from_numpy(l).to(self.lb.device) for l in list_latents_cpu]
|
||||
return list_latents
|
||||
|
||||
def compute_img1(self, *args):
|
||||
list_ui_elem = args
|
||||
self.setup_lb(list_ui_elem)
|
||||
r"""
|
||||
Computes the first transition image and returns it for display.
|
||||
Sets all other transition images and last image to empty (as they are obsolete with this operation)
|
||||
"""
|
||||
list_ui_vals = args
|
||||
self.setup_lb(list_ui_vals)
|
||||
fp_img1 = os.path.join(self.dp_imgs, f"img1_{self.user_id}")
|
||||
img1 = Image.fromarray(self.lb.compute_latents1(return_image=True))
|
||||
img1.save(fp_img1+".jpg")
|
||||
self.save_latents(fp_img1+".npy", self.lb.tree_latents[0])
|
||||
|
||||
img1.save(fp_img1 + ".jpg")
|
||||
self.save_latents(fp_img1 + ".npy", self.lb.tree_latents[0])
|
||||
self.recycle_img1 = True
|
||||
self.recycle_img2 = False
|
||||
# fixme save seeds. change filenames?
|
||||
return [fp_img1+".jpg", self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
|
||||
return [fp_img1 + ".jpg", self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
|
||||
|
||||
def compute_img2(self, *args):
|
||||
r"""
|
||||
Computes the last transition image and returns it for display.
|
||||
Sets all other transition images to empty (as they are obsolete with this operation)
|
||||
"""
|
||||
if not os.path.isfile(os.path.join(self.dp_imgs, f"img1_{self.user_id}.jpg")): # don't do anything
|
||||
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
|
||||
list_ui_elem = args
|
||||
self.setup_lb(list_ui_elem)
|
||||
list_ui_vals = args
|
||||
self.setup_lb(list_ui_vals)
|
||||
|
||||
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
|
||||
fp_img2 = os.path.join(self.dp_imgs, f"img2_{self.user_id}")
|
||||
img2 = Image.fromarray(self.lb.compute_latents2(return_image=True))
|
||||
img2.save(fp_img2+'.jpg')
|
||||
self.save_latents(fp_img2+".npy", self.lb.tree_latents[-1])
|
||||
img2.save(fp_img2 + '.jpg')
|
||||
self.save_latents(fp_img2 + ".npy", self.lb.tree_latents[-1])
|
||||
self.recycle_img2 = True
|
||||
self.transition_can_be_computed = True
|
||||
# fixme save seeds. change filenames?
|
||||
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2+".jpg", self.user_id]
|
||||
|
||||
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2 + ".jpg", self.user_id]
|
||||
|
||||
def compute_transition(self, *args):
|
||||
if not self.transition_can_be_computed:
|
||||
list_return = [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
|
||||
return list_return
|
||||
|
||||
list_ui_elem = args
|
||||
self.setup_lb(list_ui_elem)
|
||||
r"""
|
||||
Computes transition images and movie.
|
||||
"""
|
||||
list_ui_vals = args
|
||||
self.setup_lb(list_ui_vals)
|
||||
print("STARTING TRANSITION...")
|
||||
|
||||
fixed_seeds = [self.seed1, self.seed2]
|
||||
|
||||
# Run Latent Blending
|
||||
# Check if another user is blocking this... otherwise everything will become mixed.
|
||||
# t_now = time.time()
|
||||
# if self.block_transition:
|
||||
# while True:
|
||||
# time.sleep(1)
|
||||
# if not self.block_transition:
|
||||
# break
|
||||
# if time.time() - t_now > 1000:
|
||||
# return
|
||||
|
||||
self.block_transition = True
|
||||
# Inject loaded latents (other user interference)
|
||||
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
|
||||
self.lb.tree_latents[-1] = self.load_latents(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"))
|
||||
|
@ -224,24 +232,23 @@ class BlendingFrontend():
|
|||
num_inference_steps=self.num_inference_steps,
|
||||
depth_strength=self.depth_strength,
|
||||
t_compute_max_allowed=self.t_compute_max_allowed,
|
||||
fixed_seeds=fixed_seeds
|
||||
)
|
||||
print(f"Latent Blending pass finished. Resulted in {len(imgs_transition)} images")
|
||||
fixed_seeds=fixed_seeds)
|
||||
print(f"Latent Blending pass finished ({get_time('second')}). Resulted in {len(imgs_transition)} images")
|
||||
|
||||
# Subselect three preview images
|
||||
idx_img_prev = np.round(np.linspace(0, len(imgs_transition)-1, 5)[1:-1]).astype(np.int32)
|
||||
idx_img_prev = np.round(np.linspace(0, len(imgs_transition) - 1, 5)[1:-1]).astype(np.int32)
|
||||
|
||||
list_imgs_preview = []
|
||||
for j in idx_img_prev:
|
||||
list_imgs_preview.append(Image.fromarray(imgs_transition[j]))
|
||||
|
||||
# Save the preview imgs as jpgs on disk so we are not sending umcompressed data around
|
||||
self.current_timestamp = get_time('second')
|
||||
current_timestamp = get_time('second')
|
||||
self.list_fp_imgs_current = []
|
||||
for i in range(len(list_imgs_preview)):
|
||||
fp_img = os.path.join(self.dp_imgs, f"img_preview_{i}_{self.current_timestamp}.jpg")
|
||||
fp_img = os.path.join(self.dp_imgs, f"img_preview_{i}_{current_timestamp}.jpg")
|
||||
list_imgs_preview[i].save(fp_img)
|
||||
self.list_fp_imgs_current.append(fp_img)
|
||||
self.block_transition = False
|
||||
# Insert cheap frames for the movie
|
||||
imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration_video, self.fps)
|
||||
|
||||
|
@ -259,16 +266,17 @@ class BlendingFrontend():
|
|||
list_return = self.list_fp_imgs_current + [self.fp_movie]
|
||||
return list_return
|
||||
|
||||
|
||||
def stack_forward(self, prompt2, seed2):
|
||||
r"""
|
||||
Allows to generate multi-segment movies. Sets last image -> first image with all
|
||||
relevant parameters.
|
||||
"""
|
||||
# Save preview images, prompts and seeds into dictionary for stacking
|
||||
if len(self.list_all_segments) == 0:
|
||||
timestamp_session = get_time('second')
|
||||
self.dp_session = os.path.join(self.dp_out, f"session_{timestamp_session}")
|
||||
os.makedirs(self.dp_session)
|
||||
|
||||
self.transition_can_be_computed = False
|
||||
|
||||
idx_segment = len(self.list_all_segments)
|
||||
dp_segment = os.path.join(self.dp_session, f"segment_{str(idx_segment).zfill(3)}")
|
||||
|
||||
|
@ -285,13 +293,11 @@ class BlendingFrontend():
|
|||
self.lb.swap_forward()
|
||||
|
||||
shutil.copyfile(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"), os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
|
||||
|
||||
|
||||
fp_multi = self.multi_concat()
|
||||
list_out = [fp_multi]
|
||||
|
||||
list_out.extend([os.path.join(self.dp_imgs, f"img2_{self.user_id}.jpg")])
|
||||
list_out.extend([self.fp_img_empty]*4)
|
||||
list_out.extend([self.fp_img_empty] * 4)
|
||||
list_out.append(gr.update(interactive=False, value=prompt2))
|
||||
list_out.append(gr.update(interactive=False, value=seed2))
|
||||
list_out.append("")
|
||||
|
@ -299,16 +305,20 @@ class BlendingFrontend():
|
|||
print(f"stack_forward: fp_multi {fp_multi}")
|
||||
return list_out
|
||||
|
||||
|
||||
def multi_concat(self):
|
||||
r"""
|
||||
Concatentates all stacked segments into one long movie.
|
||||
"""
|
||||
list_fp_movies = self.get_fp_video_all()
|
||||
# Concatenate movies and save
|
||||
fp_final = os.path.join(self.dp_session, f"concat_{self.user_id}.mp4")
|
||||
concatenate_movies(fp_final, list_fp_movies)
|
||||
return fp_final
|
||||
|
||||
|
||||
def get_fp_video_all(self):
|
||||
r"""
|
||||
Collects all stacked movie segments.
|
||||
"""
|
||||
list_all = os.listdir(self.dp_movies)
|
||||
str_beg = f"movie_{self.user_id}_"
|
||||
list_user = [l for l in list_all if str_beg in l]
|
||||
|
@ -316,8 +326,10 @@ class BlendingFrontend():
|
|||
list_user = [os.path.join(self.dp_movies, l) for l in list_user]
|
||||
return list_user
|
||||
|
||||
|
||||
def get_fp_video_next(self):
|
||||
r"""
|
||||
Gets the filepath of the next movie segment.
|
||||
"""
|
||||
list_videos = self.get_fp_video_all()
|
||||
if len(list_videos) == 0:
|
||||
idx_next = 0
|
||||
|
@ -327,26 +339,16 @@ class BlendingFrontend():
|
|||
return fp_video_next
|
||||
|
||||
def get_fp_video_last(self):
|
||||
r"""
|
||||
Gets the current video that was saved.
|
||||
"""
|
||||
fp_video_last = os.path.join(self.dp_movies, f"last_{self.user_id}.mp4")
|
||||
return fp_video_last
|
||||
|
||||
|
||||
def get_state_dict(self):
|
||||
state_dict = {}
|
||||
grab_vars = ['prompt1', 'prompt2', 'seed1', 'seed2', 'height', 'width',
|
||||
'num_inference_steps', 'depth_strength', 'guidance_scale',
|
||||
'guidance_scale_mid_damper', 'mid_compression_scaler']
|
||||
|
||||
for v in grab_vars:
|
||||
state_dict[v] = getattr(self, v)
|
||||
return state_dict
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
|
||||
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
|
||||
fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
|
||||
# fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.ckpt")
|
||||
bf = BlendingFrontend(StableDiffusionHolder(fp_ckpt))
|
||||
# self = BlendingFrontend(None)
|
||||
|
||||
|
@ -391,7 +393,6 @@ if __name__ == "__main__":
|
|||
depth_strength = gr.Slider(0.01, 0.99, bf.depth_strength, step=0.01, label='depth_strength', interactive=True)
|
||||
guidance_scale_mid_damper = gr.Slider(0.01, 2.0, bf.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
|
||||
|
||||
|
||||
with gr.Row():
|
||||
b_compute1 = gr.Button('compute first image', variant='primary')
|
||||
b_compute_transition = gr.Button('compute transition', variant='primary')
|
||||
|
@ -405,11 +406,10 @@ if __name__ == "__main__":
|
|||
img5 = gr.Image(label="5/5")
|
||||
|
||||
with gr.Row():
|
||||
vid_single = gr.Video(label="single trans")
|
||||
vid_multi = gr.Video(label="multi trans")
|
||||
vid_single = gr.Video(label="current single trans")
|
||||
vid_multi = gr.Video(label="concatented multi trans")
|
||||
|
||||
with gr.Row():
|
||||
# b_restart = gr.Button("RESTART EVERYTHING")
|
||||
b_stackforward = gr.Button('append last movie segment (left) to multi movie (right)', variant='primary')
|
||||
|
||||
with gr.Row():
|
||||
|
@ -437,8 +437,7 @@ if __name__ == "__main__":
|
|||
- parental_crossfeed_power_decay: Similar to branch1_crossfeed_decay, however applied for the images withinin the transition.
|
||||
- depth_strength: Determines when the blending process will begin in terms of diffusion steps. Low values more inventive but can cause motion.
|
||||
- guidance_scale_mid_damper: Decreases the guidance scale in the middle of a transition.
|
||||
"""
|
||||
)
|
||||
""")
|
||||
|
||||
with gr.Row():
|
||||
user_id = gr.Textbox(label="user id", interactive=False)
|
||||
|
@ -471,24 +470,23 @@ if __name__ == "__main__":
|
|||
dict_ui_elem["user_id"] = user_id
|
||||
|
||||
# Convert to list, as gradio doesn't seem to accept dicts
|
||||
list_ui_elem = []
|
||||
list_ui_vals = []
|
||||
list_ui_keys = []
|
||||
for k in dict_ui_elem.keys():
|
||||
list_ui_elem.append(dict_ui_elem[k])
|
||||
list_ui_vals.append(dict_ui_elem[k])
|
||||
list_ui_keys.append(k)
|
||||
bf.list_ui_keys = list_ui_keys
|
||||
|
||||
b_newseed1.click(bf.randomize_seed1, outputs=seed1)
|
||||
b_newseed2.click(bf.randomize_seed2, outputs=seed2)
|
||||
b_compute1.click(bf.compute_img1, inputs=list_ui_elem, outputs=[img1, img2, img3, img4, img5, user_id])
|
||||
b_compute2.click(bf.compute_img2, inputs=list_ui_elem, outputs=[img2, img3, img4, img5, user_id])
|
||||
b_compute1.click(bf.compute_img1, inputs=list_ui_vals, outputs=[img1, img2, img3, img4, img5, user_id])
|
||||
b_compute2.click(bf.compute_img2, inputs=list_ui_vals, outputs=[img2, img3, img4, img5, user_id])
|
||||
b_compute_transition.click(bf.compute_transition,
|
||||
inputs=list_ui_elem,
|
||||
inputs=list_ui_vals,
|
||||
outputs=[img2, img3, img4, vid_single])
|
||||
|
||||
b_stackforward.click(bf.stack_forward,
|
||||
inputs=[prompt2, seed2],
|
||||
outputs=[vid_multi, img1, img2, img3, img4, img5, prompt1, seed1, prompt2])
|
||||
|
||||
|
||||
demo.launch(share=bf.share, inbrowser=True, inline=False)
|
||||
|
|
|
@ -13,41 +13,31 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os, sys
|
||||
import os
|
||||
import torch
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.set_grad_enabled(False)
|
||||
import numpy as np
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
import time
|
||||
import subprocess
|
||||
import warnings
|
||||
from tqdm.auto import tqdm
|
||||
from PIL import Image
|
||||
# import matplotlib.pyplot as plt
|
||||
from movie_util import MovieSaver
|
||||
import datetime
|
||||
from typing import Callable, List, Optional, Union
|
||||
import inspect
|
||||
from threading import Thread
|
||||
torch.set_grad_enabled(False)
|
||||
from contextlib import nullcontext
|
||||
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.util import instantiate_from_config
|
||||
from typing import List, Optional
|
||||
from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentInpaintDiffusion
|
||||
from stable_diffusion_holder import StableDiffusionHolder
|
||||
import yaml
|
||||
import lpips
|
||||
#%%
|
||||
from utils import interpolate_spherical, interpolate_linear, add_frames_linear_interp, yml_load, yml_save
|
||||
|
||||
|
||||
class LatentBlending():
|
||||
def __init__(
|
||||
self,
|
||||
sdh: None,
|
||||
guidance_scale: float = 4,
|
||||
guidance_scale_mid_damper: float = 0.5,
|
||||
mid_compression_scaler: float = 1.2,
|
||||
):
|
||||
mid_compression_scaler: float = 1.2):
|
||||
r"""
|
||||
Initializes the latent blending class.
|
||||
Args:
|
||||
|
@ -64,9 +54,10 @@ class LatentBlending():
|
|||
Increases the sampling density in the middle (where most changes happen). Higher value
|
||||
imply more values in the middle. However the inflection point can occur outside the middle,
|
||||
thus high values can give rough transitions. Values around 2 should be fine.
|
||||
|
||||
"""
|
||||
assert guidance_scale_mid_damper>0 and guidance_scale_mid_damper<=1.0, f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}"
|
||||
assert guidance_scale_mid_damper > 0 \
|
||||
and guidance_scale_mid_damper <= 1.0, \
|
||||
f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}"
|
||||
|
||||
self.sdh = sdh
|
||||
self.device = self.sdh.device
|
||||
|
@ -115,10 +106,8 @@ class LatentBlending():
|
|||
self.multi_transition_img_last = None
|
||||
self.dt_per_diff = 0
|
||||
self.spatial_mask = None
|
||||
|
||||
self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
|
||||
|
||||
|
||||
def init_mode(self):
|
||||
r"""
|
||||
Sets the operational mode. Currently supported are standard, inpainting and x4 upscaling.
|
||||
|
@ -151,13 +140,12 @@ class LatentBlending():
|
|||
Tunes the guidance scale down as a linear function of fract_mixing,
|
||||
towards 0.5 the minimum will be reached.
|
||||
"""
|
||||
mid_factor = 1 - np.abs(fract_mixing - 0.5)/ 0.5
|
||||
max_guidance_reduction = self.guidance_scale_base * (1-self.guidance_scale_mid_damper) - 1
|
||||
guidance_scale_effective = self.guidance_scale_base - max_guidance_reduction*mid_factor
|
||||
mid_factor = 1 - np.abs(fract_mixing - 0.5) / 0.5
|
||||
max_guidance_reduction = self.guidance_scale_base * (1 - self.guidance_scale_mid_damper) - 1
|
||||
guidance_scale_effective = self.guidance_scale_base - max_guidance_reduction * mid_factor
|
||||
self.guidance_scale = guidance_scale_effective
|
||||
self.sdh.guidance_scale = guidance_scale_effective
|
||||
|
||||
|
||||
def set_branch1_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay):
|
||||
r"""
|
||||
Sets the crossfeed parameters for the first branch to the last branch.
|
||||
|
@ -173,7 +161,6 @@ class LatentBlending():
|
|||
self.branch1_crossfeed_range = np.clip(crossfeed_range, 0, 1)
|
||||
self.branch1_crossfeed_decay = np.clip(crossfeed_decay, 0, 1)
|
||||
|
||||
|
||||
def set_parental_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay):
|
||||
r"""
|
||||
Sets the crossfeed parameters for all transition images (within the first and last branch).
|
||||
|
@ -189,7 +176,6 @@ class LatentBlending():
|
|||
self.parental_crossfeed_range = np.clip(crossfeed_range, 0, 1)
|
||||
self.parental_crossfeed_power_decay = np.clip(crossfeed_decay, 0, 1)
|
||||
|
||||
|
||||
def set_prompt1(self, prompt: str):
|
||||
r"""
|
||||
Sets the first prompt (for the first keyframe) including text embeddings.
|
||||
|
@ -201,7 +187,6 @@ class LatentBlending():
|
|||
self.prompt1 = prompt
|
||||
self.text_embedding1 = self.get_text_embeddings(self.prompt1)
|
||||
|
||||
|
||||
def set_prompt2(self, prompt: str):
|
||||
r"""
|
||||
Sets the second prompt (for the second keyframe) including text embeddings.
|
||||
|
@ -237,8 +222,7 @@ class LatentBlending():
|
|||
depth_strength: Optional[float] = 0.3,
|
||||
t_compute_max_allowed: Optional[float] = None,
|
||||
nmb_max_branches: Optional[int] = None,
|
||||
fixed_seeds: Optional[List[int]] = None,
|
||||
):
|
||||
fixed_seeds: Optional[List[int]] = None):
|
||||
r"""
|
||||
Function for computing transitions.
|
||||
Returns a list of transition images using spherical latent blending.
|
||||
|
@ -263,7 +247,6 @@ class LatentBlending():
|
|||
fixed_seeds: Optional[List[int)]:
|
||||
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
|
||||
Otherwise random seeds will be taken.
|
||||
|
||||
"""
|
||||
|
||||
# Sanity checks first
|
||||
|
@ -275,7 +258,7 @@ class LatentBlending():
|
|||
if fixed_seeds == 'randomize':
|
||||
fixed_seeds = list(np.random.randint(0, 1000000, 2).astype(np.int32))
|
||||
else:
|
||||
assert len(fixed_seeds)==2, "Supply a list with len = 2"
|
||||
assert len(fixed_seeds) == 2, "Supply a list with len = 2"
|
||||
|
||||
self.seed1 = fixed_seeds[0]
|
||||
self.seed2 = fixed_seeds[1]
|
||||
|
@ -323,7 +306,6 @@ class LatentBlending():
|
|||
|
||||
return self.tree_final_imgs
|
||||
|
||||
|
||||
def compute_latents1(self, return_image=False):
|
||||
r"""
|
||||
Runs a diffusion trajectory for the first image
|
||||
|
@ -337,11 +319,10 @@ class LatentBlending():
|
|||
latents_start = self.get_noise(self.seed1)
|
||||
list_latents1 = self.run_diffusion(
|
||||
list_conditionings,
|
||||
latents_start = latents_start,
|
||||
idx_start = 0
|
||||
)
|
||||
latents_start=latents_start,
|
||||
idx_start=0)
|
||||
t1 = time.time()
|
||||
self.dt_per_diff = (t1-t0) / self.num_inference_steps
|
||||
self.dt_per_diff = (t1 - t0) / self.num_inference_steps
|
||||
self.tree_latents[0] = list_latents1
|
||||
if return_image:
|
||||
return self.sdh.latent2image(list_latents1[-1])
|
||||
|
@ -361,17 +342,16 @@ class LatentBlending():
|
|||
# Influence from branch1
|
||||
if self.branch1_crossfeed_power > 0.0:
|
||||
# Set up the mixing_coeffs
|
||||
idx_mixing_stop = int(round(self.num_inference_steps*self.branch1_crossfeed_range))
|
||||
mixing_coeffs = list(np.linspace(self.branch1_crossfeed_power, self.branch1_crossfeed_power*self.branch1_crossfeed_decay, idx_mixing_stop))
|
||||
mixing_coeffs.extend((self.num_inference_steps-idx_mixing_stop)*[0])
|
||||
idx_mixing_stop = int(round(self.num_inference_steps * self.branch1_crossfeed_range))
|
||||
mixing_coeffs = list(np.linspace(self.branch1_crossfeed_power, self.branch1_crossfeed_power * self.branch1_crossfeed_decay, idx_mixing_stop))
|
||||
mixing_coeffs.extend((self.num_inference_steps - idx_mixing_stop) * [0])
|
||||
list_latents_mixing = self.tree_latents[0]
|
||||
list_latents2 = self.run_diffusion(
|
||||
list_conditionings,
|
||||
latents_start = latents_start,
|
||||
idx_start = 0,
|
||||
list_latents_mixing = list_latents_mixing,
|
||||
mixing_coeffs = mixing_coeffs
|
||||
)
|
||||
latents_start=latents_start,
|
||||
idx_start=0,
|
||||
list_latents_mixing=list_latents_mixing,
|
||||
mixing_coeffs=mixing_coeffs)
|
||||
else:
|
||||
list_latents2 = self.run_diffusion(list_conditionings, latents_start)
|
||||
self.tree_latents[-1] = list_latents2
|
||||
|
@ -381,7 +361,6 @@ class LatentBlending():
|
|||
else:
|
||||
return list_latents2
|
||||
|
||||
|
||||
def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection):
|
||||
r"""
|
||||
Runs a diffusion trajectory, using the latents from the respective parents
|
||||
|
@ -409,22 +388,19 @@ class LatentBlending():
|
|||
latents_parental = interpolate_spherical(latents_p1, latents_p2, fract_mixing_parental)
|
||||
list_latents_parental_mix.append(latents_parental)
|
||||
|
||||
idx_mixing_stop = int(round(self.num_inference_steps*self.parental_crossfeed_range))
|
||||
mixing_coeffs = idx_injection*[self.parental_crossfeed_power]
|
||||
idx_mixing_stop = int(round(self.num_inference_steps * self.parental_crossfeed_range))
|
||||
mixing_coeffs = idx_injection * [self.parental_crossfeed_power]
|
||||
nmb_mixing = idx_mixing_stop - idx_injection
|
||||
if nmb_mixing > 0:
|
||||
mixing_coeffs.extend(list(np.linspace(self.parental_crossfeed_power, self.parental_crossfeed_power*self.parental_crossfeed_power_decay, nmb_mixing)))
|
||||
mixing_coeffs.extend((self.num_inference_steps-len(mixing_coeffs))*[0])
|
||||
|
||||
latents_start = list_latents_parental_mix[idx_injection-1]
|
||||
mixing_coeffs.extend(list(np.linspace(self.parental_crossfeed_power, self.parental_crossfeed_power * self.parental_crossfeed_power_decay, nmb_mixing)))
|
||||
mixing_coeffs.extend((self.num_inference_steps - len(mixing_coeffs)) * [0])
|
||||
latents_start = list_latents_parental_mix[idx_injection - 1]
|
||||
list_latents = self.run_diffusion(
|
||||
list_conditionings,
|
||||
latents_start = latents_start,
|
||||
idx_start = idx_injection,
|
||||
list_latents_mixing = list_latents_parental_mix,
|
||||
mixing_coeffs = mixing_coeffs
|
||||
)
|
||||
|
||||
latents_start=latents_start,
|
||||
idx_start=idx_injection,
|
||||
list_latents_mixing=list_latents_parental_mix,
|
||||
mixing_coeffs=mixing_coeffs)
|
||||
return list_latents
|
||||
|
||||
def get_time_based_branching(self, depth_strength, t_compute_max_allowed=None, nmb_max_branches=None):
|
||||
|
@ -445,8 +421,8 @@ class LatentBlending():
|
|||
results. Use this if you want to have controllable results independent
|
||||
of your computer.
|
||||
"""
|
||||
idx_injection_base = int(round(self.num_inference_steps*depth_strength))
|
||||
list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps-1, 3)
|
||||
idx_injection_base = int(round(self.num_inference_steps * depth_strength))
|
||||
list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps - 1, 3)
|
||||
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
|
||||
t_compute = 0
|
||||
|
||||
|
@ -456,20 +432,18 @@ class LatentBlending():
|
|||
elif t_compute_max_allowed is None:
|
||||
assert nmb_max_branches is not None, "Either specify t_compute_max_allowed or nmb_max_branches"
|
||||
stop_criterion = "nmb_max_branches"
|
||||
nmb_max_branches -= 2 # discounting the outer frames
|
||||
nmb_max_branches -= 2 # Discounting the outer frames
|
||||
else:
|
||||
raise ValueError("Either specify t_compute_max_allowed or nmb_max_branches")
|
||||
|
||||
stop_criterion_reached = False
|
||||
is_first_iteration = True
|
||||
|
||||
while not stop_criterion_reached:
|
||||
list_compute_steps = self.num_inference_steps - list_idx_injection
|
||||
list_compute_steps *= list_nmb_stems
|
||||
t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15*np.sum(list_nmb_stems)
|
||||
t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15 * np.sum(list_nmb_stems)
|
||||
increase_done = False
|
||||
for s_idx in range(len(list_nmb_stems)-1):
|
||||
if list_nmb_stems[s_idx+1] / list_nmb_stems[s_idx] >= 2:
|
||||
for s_idx in range(len(list_nmb_stems) - 1):
|
||||
if list_nmb_stems[s_idx + 1] / list_nmb_stems[s_idx] >= 2:
|
||||
list_nmb_stems[s_idx] += 1
|
||||
increase_done = True
|
||||
break
|
||||
|
@ -501,10 +475,10 @@ class LatentBlending():
|
|||
"""
|
||||
# get_lpips_similarity
|
||||
similarities = []
|
||||
for i in range(len(self.tree_final_imgs)-1):
|
||||
similarities.append(self.get_lpips_similarity(self.tree_final_imgs[i], self.tree_final_imgs[i+1]))
|
||||
for i in range(len(self.tree_final_imgs) - 1):
|
||||
similarities.append(self.get_lpips_similarity(self.tree_final_imgs[i], self.tree_final_imgs[i + 1]))
|
||||
b_closest1 = np.argmax(similarities)
|
||||
b_closest2 = b_closest1+1
|
||||
b_closest2 = b_closest1 + 1
|
||||
fract_closest1 = self.tree_fracts[b_closest1]
|
||||
fract_closest2 = self.tree_fracts[b_closest2]
|
||||
|
||||
|
@ -515,23 +489,15 @@ class LatentBlending():
|
|||
break
|
||||
else:
|
||||
b_parent1 -= 1
|
||||
|
||||
b_parent2 = b_closest2
|
||||
while True:
|
||||
if self.tree_idx_injection[b_parent2] < idx_injection:
|
||||
break
|
||||
else:
|
||||
b_parent2 += 1
|
||||
|
||||
# print(f"\n\nb_closest: {b_closest1} {b_closest2} fract_closest1 {fract_closest1} fract_closest2 {fract_closest2}")
|
||||
# print(f"b_parent: {b_parent1} {b_parent2}")
|
||||
# print(f"similarities {similarities}")
|
||||
# print(f"idx_injection {idx_injection} tree_idx_injection {self.tree_idx_injection}")
|
||||
|
||||
fract_mixing = (fract_closest1 + fract_closest2) /2
|
||||
fract_mixing = (fract_closest1 + fract_closest2) / 2
|
||||
return fract_mixing, b_parent1, b_parent2
|
||||
|
||||
|
||||
def insert_into_tree(self, fract_mixing, idx_injection, list_latents):
|
||||
r"""
|
||||
Inserts all necessary parameters into the trajectory tree.
|
||||
|
@ -543,12 +509,11 @@ class LatentBlending():
|
|||
list_latents: list
|
||||
list of the latents to be inserted
|
||||
"""
|
||||
b_parent1, b_parent2 = get_closest_idx(fract_mixing, self.tree_fracts)
|
||||
self.tree_latents.insert(b_parent1+1, list_latents)
|
||||
self.tree_final_imgs.insert(b_parent1+1, self.sdh.latent2image(list_latents[-1]))
|
||||
self.tree_fracts.insert(b_parent1+1, fract_mixing)
|
||||
self.tree_idx_injection.insert(b_parent1+1, idx_injection)
|
||||
|
||||
b_parent1, b_parent2 = self.get_closest_idx(fract_mixing)
|
||||
self.tree_latents.insert(b_parent1 + 1, list_latents)
|
||||
self.tree_final_imgs.insert(b_parent1 + 1, self.sdh.latent2image(list_latents[-1]))
|
||||
self.tree_fracts.insert(b_parent1 + 1, fract_mixing)
|
||||
self.tree_idx_injection.insert(b_parent1 + 1, idx_injection)
|
||||
|
||||
def get_spatial_mask_template(self):
|
||||
r"""
|
||||
|
@ -565,9 +530,7 @@ class LatentBlending():
|
|||
Args:
|
||||
img_mask:
|
||||
mask image [0,1]. You can get a template using get_spatial_mask_template
|
||||
|
||||
"""
|
||||
|
||||
shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
|
||||
C, H, W = shape_latents
|
||||
img_mask = np.asarray(img_mask)
|
||||
|
@ -577,18 +540,15 @@ class LatentBlending():
|
|||
assert img_mask.shape[1] == W, f"Your mask needs to be of dimension {H} x {W}"
|
||||
spatial_mask = torch.from_numpy(img_mask).to(device=self.device)
|
||||
spatial_mask = torch.unsqueeze(spatial_mask, 0)
|
||||
spatial_mask = spatial_mask.repeat((C,1,1))
|
||||
spatial_mask = spatial_mask.repeat((C, 1, 1))
|
||||
spatial_mask = torch.unsqueeze(spatial_mask, 0)
|
||||
|
||||
self.spatial_mask = spatial_mask
|
||||
|
||||
|
||||
def get_noise(self, seed):
|
||||
r"""
|
||||
Helper function to get noise given seed.
|
||||
Args:
|
||||
seed: int
|
||||
|
||||
"""
|
||||
generator = torch.Generator(device=self.sdh.device).manual_seed(int(seed))
|
||||
if self.mode == 'standard':
|
||||
|
@ -599,21 +559,17 @@ class LatentBlending():
|
|||
h = self.image1_lowres.size[1]
|
||||
shape_latents = [self.sdh.model.channels, h, w]
|
||||
C, H, W = shape_latents
|
||||
|
||||
return torch.randn((1, C, H, W), generator=generator, device=self.sdh.device)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def run_diffusion(
|
||||
self,
|
||||
list_conditionings,
|
||||
latents_start: torch.FloatTensor = None,
|
||||
idx_start: int = 0,
|
||||
list_latents_mixing = None,
|
||||
mixing_coeffs = 0.0,
|
||||
return_image: Optional[bool] = False
|
||||
):
|
||||
|
||||
list_latents_mixing=None,
|
||||
mixing_coeffs=0.0,
|
||||
return_image: Optional[bool] = False):
|
||||
r"""
|
||||
Wrapper function for diffusion runners.
|
||||
Depending on the mode, the correct one will be executed.
|
||||
|
@ -640,14 +596,13 @@ class LatentBlending():
|
|||
if self.mode == 'standard':
|
||||
text_embeddings = list_conditionings[0]
|
||||
return self.sdh.run_diffusion_standard(
|
||||
text_embeddings = text_embeddings,
|
||||
latents_start = latents_start,
|
||||
idx_start = idx_start,
|
||||
list_latents_mixing = list_latents_mixing,
|
||||
mixing_coeffs = mixing_coeffs,
|
||||
spatial_mask = self.spatial_mask,
|
||||
return_image = return_image,
|
||||
)
|
||||
text_embeddings=text_embeddings,
|
||||
latents_start=latents_start,
|
||||
idx_start=idx_start,
|
||||
list_latents_mixing=list_latents_mixing,
|
||||
mixing_coeffs=mixing_coeffs,
|
||||
spatial_mask=self.spatial_mask,
|
||||
return_image=return_image)
|
||||
|
||||
elif self.mode == 'upscale':
|
||||
cond = list_conditionings[0]
|
||||
|
@ -657,11 +612,10 @@ class LatentBlending():
|
|||
uc_full,
|
||||
latents_start=latents_start,
|
||||
idx_start=idx_start,
|
||||
list_latents_mixing = list_latents_mixing,
|
||||
mixing_coeffs = mixing_coeffs,
|
||||
list_latents_mixing=list_latents_mixing,
|
||||
mixing_coeffs=mixing_coeffs,
|
||||
return_image=return_image)
|
||||
|
||||
|
||||
def run_upscaling(
|
||||
self,
|
||||
dp_img: str,
|
||||
|
@ -669,9 +623,9 @@ class LatentBlending():
|
|||
num_inference_steps: int = 100,
|
||||
nmb_max_branches_highres: int = 5,
|
||||
nmb_max_branches_lowres: int = 6,
|
||||
duration_single_segment = 3,
|
||||
fixed_seeds: Optional[List[int]] = None,
|
||||
):
|
||||
duration_single_segment=3,
|
||||
fps=24,
|
||||
fixed_seeds: Optional[List[int]] = None):
|
||||
r"""
|
||||
Runs upscaling with the x4 model. Requires that you run a transition before with a low-res model and save the results using write_imgs_transition.
|
||||
|
||||
|
@ -692,13 +646,14 @@ class LatentBlending():
|
|||
Setting this number lower (e.g. 6) will decrease the compute time but not affect the results too much.
|
||||
duration_single_segment: float
|
||||
The duration of each high-res movie segment. You will have nmb_max_branches_lowres-1 segments in total.
|
||||
fps: float
|
||||
frames per second of movie
|
||||
fixed_seeds: Optional[List[int)]:
|
||||
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
|
||||
Otherwise random seeds will be taken.
|
||||
"""
|
||||
fp_yml = os.path.join(dp_img, "lowres.yaml")
|
||||
fp_movie = os.path.join(dp_img, "movie_highres.mp4")
|
||||
fps = 24
|
||||
ms = MovieSaver(fp_movie, fps=fps)
|
||||
assert os.path.isfile(fp_yml), "lowres.yaml does not exist. did you forget run_upscaling_step1?"
|
||||
dict_stuff = yml_load(fp_yml)
|
||||
|
@ -707,53 +662,43 @@ class LatentBlending():
|
|||
nmb_images_lowres = dict_stuff['nmb_images']
|
||||
prompt1 = dict_stuff['prompt1']
|
||||
prompt2 = dict_stuff['prompt2']
|
||||
idx_img_lowres = np.round(np.linspace(0, nmb_images_lowres-1, nmb_max_branches_lowres)).astype(np.int32)
|
||||
idx_img_lowres = np.round(np.linspace(0, nmb_images_lowres - 1, nmb_max_branches_lowres)).astype(np.int32)
|
||||
imgs_lowres = []
|
||||
for i in idx_img_lowres:
|
||||
fp_img_lowres = os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg")
|
||||
assert os.path.isfile(fp_img_lowres), f"{fp_img_lowres} does not exist. did you forget run_upscaling_step1?"
|
||||
imgs_lowres.append(Image.open(fp_img_lowres))
|
||||
|
||||
|
||||
# set up upscaling
|
||||
text_embeddingA = self.sdh.get_text_embedding(prompt1)
|
||||
text_embeddingB = self.sdh.get_text_embedding(prompt2)
|
||||
|
||||
list_fract_mixing = np.linspace(0, 1, nmb_max_branches_lowres-1)
|
||||
|
||||
for i in range(nmb_max_branches_lowres-1):
|
||||
list_fract_mixing = np.linspace(0, 1, nmb_max_branches_lowres - 1)
|
||||
for i in range(nmb_max_branches_lowres - 1):
|
||||
print(f"Starting movie segment {i+1}/{nmb_max_branches_lowres-1}")
|
||||
|
||||
self.text_embedding1 = interpolate_linear(text_embeddingA, text_embeddingB, list_fract_mixing[i])
|
||||
self.text_embedding2 = interpolate_linear(text_embeddingA, text_embeddingB, 1-list_fract_mixing[i])
|
||||
|
||||
if i==0:
|
||||
self.text_embedding2 = interpolate_linear(text_embeddingA, text_embeddingB, 1 - list_fract_mixing[i])
|
||||
if i == 0:
|
||||
recycle_img1 = False
|
||||
else:
|
||||
self.swap_forward()
|
||||
recycle_img1 = True
|
||||
|
||||
self.set_image1(imgs_lowres[i])
|
||||
self.set_image2(imgs_lowres[i+1])
|
||||
self.set_image2(imgs_lowres[i + 1])
|
||||
|
||||
list_imgs = self.run_transition(
|
||||
recycle_img1 = recycle_img1,
|
||||
recycle_img2 = False,
|
||||
num_inference_steps = num_inference_steps,
|
||||
depth_strength = depth_strength,
|
||||
nmb_max_branches = nmb_max_branches_highres,
|
||||
)
|
||||
|
||||
recycle_img1=recycle_img1,
|
||||
recycle_img2=False,
|
||||
num_inference_steps=num_inference_steps,
|
||||
depth_strength=depth_strength,
|
||||
nmb_max_branches=nmb_max_branches_highres)
|
||||
list_imgs_interp = add_frames_linear_interp(list_imgs, fps, duration_single_segment)
|
||||
|
||||
# Save movie frame
|
||||
for img in list_imgs_interp:
|
||||
ms.write_frame(img)
|
||||
|
||||
ms.finalize()
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_mixed_conditioning(self, fract_mixing):
|
||||
if self.mode == 'standard':
|
||||
|
@ -776,8 +721,7 @@ class LatentBlending():
|
|||
@torch.no_grad()
|
||||
def get_text_embeddings(
|
||||
self,
|
||||
prompt: str
|
||||
):
|
||||
prompt: str):
|
||||
r"""
|
||||
Computes the text embeddings provided a string with a prompts.
|
||||
Adapted from stable diffusion repo
|
||||
|
@ -785,10 +729,8 @@ class LatentBlending():
|
|||
prompt: str
|
||||
ABC trending on artstation painted by Old Greg.
|
||||
"""
|
||||
|
||||
return self.sdh.get_text_embedding(prompt)
|
||||
|
||||
|
||||
def write_imgs_transition(self, dp_img):
|
||||
r"""
|
||||
Writes the transition images into the folder dp_img.
|
||||
|
@ -802,7 +744,6 @@ class LatentBlending():
|
|||
for i, img in enumerate(imgs_transition):
|
||||
img_leaf = Image.fromarray(img)
|
||||
img_leaf.save(os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg"))
|
||||
|
||||
fp_yml = os.path.join(dp_img, "lowres.yaml")
|
||||
self.save_statedict(fp_yml)
|
||||
|
||||
|
@ -817,7 +758,6 @@ class LatentBlending():
|
|||
duration of the movie in seonds
|
||||
fps: int
|
||||
fps of the movie
|
||||
|
||||
"""
|
||||
|
||||
# Let's get more cheap frames via linear interpolation (duration_transition*fps frames)
|
||||
|
@ -831,8 +771,6 @@ class LatentBlending():
|
|||
ms.write_frame(img)
|
||||
ms.finalize()
|
||||
|
||||
|
||||
|
||||
def save_statedict(self, fp_yml):
|
||||
# Dump everything relevant into yaml
|
||||
imgs_transition = self.tree_final_imgs
|
||||
|
@ -857,9 +795,8 @@ class LatentBlending():
|
|||
else:
|
||||
try:
|
||||
state_dict[v] = getattr(self, v)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return state_dict
|
||||
|
||||
def randomize_seed(self):
|
||||
|
@ -892,7 +829,6 @@ class LatentBlending():
|
|||
self.height = height
|
||||
self.sdh.height = height
|
||||
|
||||
|
||||
def swap_forward(self):
|
||||
r"""
|
||||
Moves over keyframe two -> keyframe one. Useful for making a sequence of transitions
|
||||
|
@ -900,15 +836,12 @@ class LatentBlending():
|
|||
"""
|
||||
# Move over all latents
|
||||
self.tree_latents[0] = self.tree_latents[-1]
|
||||
|
||||
# Move over prompts and text embeddings
|
||||
self.prompt1 = self.prompt2
|
||||
self.text_embedding1 = self.text_embedding2
|
||||
|
||||
# Final cleanup for extra sanity
|
||||
self.tree_final_imgs = []
|
||||
|
||||
|
||||
def get_lpips_similarity(self, imgA, imgB):
|
||||
r"""
|
||||
Computes the image similarity between two images imgA and imgB.
|
||||
|
@ -916,36 +849,32 @@ class LatentBlending():
|
|||
High values indicate low similarity.
|
||||
"""
|
||||
tensorA = torch.from_numpy(imgA).float().cuda(self.device)
|
||||
tensorA = 2*tensorA/255.0 - 1
|
||||
tensorA = tensorA.permute([2,0,1]).unsqueeze(0)
|
||||
|
||||
tensorA = 2 * tensorA / 255.0 - 1
|
||||
tensorA = tensorA.permute([2, 0, 1]).unsqueeze(0)
|
||||
tensorB = torch.from_numpy(imgB).float().cuda(self.device)
|
||||
tensorB = 2*tensorB/255.0 - 1
|
||||
tensorB = tensorB.permute([2,0,1]).unsqueeze(0)
|
||||
tensorB = 2 * tensorB / 255.0 - 1
|
||||
tensorB = tensorB.permute([2, 0, 1]).unsqueeze(0)
|
||||
lploss = self.lpips(tensorA, tensorB)
|
||||
lploss = float(lploss[0][0][0][0])
|
||||
|
||||
return lploss
|
||||
|
||||
|
||||
# Auxiliary functions
|
||||
def get_closest_idx(
|
||||
fract_mixing: float,
|
||||
list_fract_mixing_prev: List[float],
|
||||
):
|
||||
# Auxiliary functions
|
||||
def get_closest_idx(
|
||||
self,
|
||||
fract_mixing: float):
|
||||
r"""
|
||||
Helper function to retrieve the parents for any given mixing.
|
||||
Example: fract_mixing = 0.4 and list_fract_mixing_prev = [0, 0.3, 0.6, 1.0]
|
||||
Will return the two closest values from list_fract_mixing_prev, i.e. [1, 2]
|
||||
Example: fract_mixing = 0.4 and self.tree_fracts = [0, 0.3, 0.6, 1.0]
|
||||
Will return the two closest values here, i.e. [1, 2]
|
||||
"""
|
||||
|
||||
pdist = fract_mixing - np.asarray(list_fract_mixing_prev)
|
||||
pdist = fract_mixing - np.asarray(self.tree_fracts)
|
||||
pdist_pos = pdist.copy()
|
||||
pdist_pos[pdist_pos<0] = np.inf
|
||||
pdist_pos[pdist_pos < 0] = np.inf
|
||||
b_parent1 = np.argmin(pdist_pos)
|
||||
pdist_neg = -pdist.copy()
|
||||
pdist_neg[pdist_neg<=0] = np.inf
|
||||
b_parent2= np.argmin(pdist_neg)
|
||||
pdist_neg[pdist_neg <= 0] = np.inf
|
||||
b_parent2 = np.argmin(pdist_neg)
|
||||
|
||||
if b_parent1 > b_parent2:
|
||||
tmp = b_parent2
|
||||
|
@ -953,291 +882,3 @@ def get_closest_idx(
|
|||
b_parent1 = tmp
|
||||
|
||||
return b_parent1, b_parent2
|
||||
|
||||
@torch.no_grad()
|
||||
def interpolate_spherical(p0, p1, fract_mixing: float):
|
||||
r"""
|
||||
Helper function to correctly mix two random variables using spherical interpolation.
|
||||
See https://en.wikipedia.org/wiki/Slerp
|
||||
The function will always cast up to float64 for sake of extra 4.
|
||||
Args:
|
||||
p0:
|
||||
First tensor for interpolation
|
||||
p1:
|
||||
Second tensor for interpolation
|
||||
fract_mixing: float
|
||||
Mixing coefficient of interval [0, 1].
|
||||
0 will return in p0
|
||||
1 will return in p1
|
||||
0.x will return a mix between both preserving angular velocity.
|
||||
"""
|
||||
|
||||
if p0.dtype == torch.float16:
|
||||
recast_to = 'fp16'
|
||||
else:
|
||||
recast_to = 'fp32'
|
||||
|
||||
p0 = p0.double()
|
||||
p1 = p1.double()
|
||||
norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
|
||||
epsilon = 1e-7
|
||||
dot = torch.sum(p0 * p1) / norm
|
||||
dot = dot.clamp(-1+epsilon, 1-epsilon)
|
||||
|
||||
theta_0 = torch.arccos(dot)
|
||||
sin_theta_0 = torch.sin(theta_0)
|
||||
theta_t = theta_0 * fract_mixing
|
||||
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = torch.sin(theta_t) / sin_theta_0
|
||||
interp = p0*s0 + p1*s1
|
||||
|
||||
if recast_to == 'fp16':
|
||||
interp = interp.half()
|
||||
elif recast_to == 'fp32':
|
||||
interp = interp.float()
|
||||
|
||||
return interp
|
||||
|
||||
|
||||
def interpolate_linear(p0, p1, fract_mixing):
|
||||
r"""
|
||||
Helper function to mix two variables using standard linear interpolation.
|
||||
Args:
|
||||
p0:
|
||||
First tensor / np.ndarray for interpolation
|
||||
p1:
|
||||
Second tensor / np.ndarray for interpolation
|
||||
fract_mixing: float
|
||||
Mixing coefficient of interval [0, 1].
|
||||
0 will return in p0
|
||||
1 will return in p1
|
||||
0.x will return a linear mix between both.
|
||||
"""
|
||||
reconvert_uint8 = False
|
||||
if type(p0) is np.ndarray and p0.dtype == 'uint8':
|
||||
reconvert_uint8 = True
|
||||
p0 = p0.astype(np.float64)
|
||||
|
||||
if type(p1) is np.ndarray and p1.dtype == 'uint8':
|
||||
reconvert_uint8 = True
|
||||
p1 = p1.astype(np.float64)
|
||||
|
||||
interp = (1-fract_mixing) * p0 + fract_mixing * p1
|
||||
|
||||
if reconvert_uint8:
|
||||
interp = np.clip(interp, 0, 255).astype(np.uint8)
|
||||
|
||||
return interp
|
||||
|
||||
|
||||
def add_frames_linear_interp(
|
||||
list_imgs: List[np.ndarray],
|
||||
fps_target: Union[float, int] = None,
|
||||
duration_target: Union[float, int] = None,
|
||||
nmb_frames_target: int=None,
|
||||
):
|
||||
r"""
|
||||
Helper function to cheaply increase the number of frames given a list of images,
|
||||
by virtue of standard linear interpolation.
|
||||
The number of inserted frames will be automatically adjusted so that the total of number
|
||||
of frames can be fixed precisely, using a random shuffling technique.
|
||||
The function allows 1:1 comparisons between transitions as videos.
|
||||
|
||||
Args:
|
||||
list_imgs: List[np.ndarray)
|
||||
List of images, between each image new frames will be inserted via linear interpolation.
|
||||
fps_target:
|
||||
OptionA: specify here the desired frames per second.
|
||||
duration_target:
|
||||
OptionA: specify here the desired duration of the transition in seconds.
|
||||
nmb_frames_target:
|
||||
OptionB: directly fix the total number of frames of the output.
|
||||
"""
|
||||
|
||||
# Sanity
|
||||
if nmb_frames_target is not None and fps_target is not None:
|
||||
raise ValueError("You cannot specify both fps_target and nmb_frames_target")
|
||||
if fps_target is None:
|
||||
assert nmb_frames_target is not None, "Either specify nmb_frames_target or nmb_frames_target"
|
||||
if nmb_frames_target is None:
|
||||
assert fps_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
|
||||
assert duration_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
|
||||
nmb_frames_target = fps_target*duration_target
|
||||
|
||||
# Get number of frames that are missing
|
||||
nmb_frames_diff = len(list_imgs)-1
|
||||
nmb_frames_missing = nmb_frames_target - nmb_frames_diff - 1
|
||||
|
||||
if nmb_frames_missing < 1:
|
||||
return list_imgs
|
||||
|
||||
list_imgs_float = [img.astype(np.float32) for img in list_imgs]
|
||||
# Distribute missing frames, append nmb_frames_to_insert(i) frames for each frame
|
||||
mean_nmb_frames_insert = nmb_frames_missing/nmb_frames_diff
|
||||
constfact = np.floor(mean_nmb_frames_insert)
|
||||
remainder_x = 1-(mean_nmb_frames_insert - constfact)
|
||||
|
||||
nmb_iter = 0
|
||||
while True:
|
||||
nmb_frames_to_insert = np.random.rand(nmb_frames_diff)
|
||||
nmb_frames_to_insert[nmb_frames_to_insert<=remainder_x] = 0
|
||||
nmb_frames_to_insert[nmb_frames_to_insert>remainder_x] = 1
|
||||
nmb_frames_to_insert += constfact
|
||||
if np.sum(nmb_frames_to_insert) == nmb_frames_missing:
|
||||
break
|
||||
nmb_iter += 1
|
||||
if nmb_iter > 100000:
|
||||
print("add_frames_linear_interp: issue with inserting the right number of frames")
|
||||
break
|
||||
|
||||
nmb_frames_to_insert = nmb_frames_to_insert.astype(np.int32)
|
||||
list_imgs_interp = []
|
||||
for i in range(len(list_imgs_float)-1):#, desc="STAGE linear interp"):
|
||||
img0 = list_imgs_float[i]
|
||||
img1 = list_imgs_float[i+1]
|
||||
list_imgs_interp.append(img0.astype(np.uint8))
|
||||
list_fracts_linblend = np.linspace(0, 1, nmb_frames_to_insert[i]+2)[1:-1]
|
||||
for fract_linblend in list_fracts_linblend:
|
||||
img_blend = interpolate_linear(img0, img1, fract_linblend).astype(np.uint8)
|
||||
list_imgs_interp.append(img_blend.astype(np.uint8))
|
||||
|
||||
if i==len(list_imgs_float)-2:
|
||||
list_imgs_interp.append(img1.astype(np.uint8))
|
||||
|
||||
return list_imgs_interp
|
||||
|
||||
|
||||
def get_spacing(nmb_points: int, scaling: float):
|
||||
"""
|
||||
Helper function for getting nonlinear spacing between 0 and 1, symmetric around 0.5
|
||||
Args:
|
||||
nmb_points: int
|
||||
Number of points between [0, 1]
|
||||
scaling: float
|
||||
Higher values will return higher sampling density around 0.5
|
||||
|
||||
"""
|
||||
if scaling < 1.7:
|
||||
return np.linspace(0, 1, nmb_points)
|
||||
nmb_points_per_side = nmb_points//2 + 1
|
||||
if np.mod(nmb_points, 2) != 0: # uneven case
|
||||
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)
|
||||
right_side = 1-left_side[::-1][1:]
|
||||
else:
|
||||
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)[0:-1]
|
||||
right_side = 1-left_side[::-1]
|
||||
all_fracts = np.hstack([left_side, right_side])
|
||||
return all_fracts
|
||||
|
||||
|
||||
def get_time(resolution=None):
|
||||
"""
|
||||
Helper function returning an nicely formatted time string, e.g. 221117_1620
|
||||
"""
|
||||
if resolution==None:
|
||||
resolution="second"
|
||||
if resolution == "day":
|
||||
t = time.strftime('%y%m%d', time.localtime())
|
||||
elif resolution == "minute":
|
||||
t = time.strftime('%y%m%d_%H%M', time.localtime())
|
||||
elif resolution == "second":
|
||||
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
|
||||
elif resolution == "millisecond":
|
||||
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
|
||||
t += "_"
|
||||
t += str("{:03d}".format(int(int(datetime.utcnow().strftime('%f'))/1000)))
|
||||
else:
|
||||
raise ValueError("bad resolution provided: %s" %resolution)
|
||||
return t
|
||||
|
||||
def compare_dicts(a, b):
|
||||
"""
|
||||
Compares two dictionaries a and b and returns a dictionary c, with all
|
||||
keys,values that have shared keys in a and b but same values in a and b.
|
||||
The values of a and b are stacked together in the output.
|
||||
Example:
|
||||
a = {}; a['bobo'] = 4
|
||||
b = {}; b['bobo'] = 5
|
||||
c = dict_compare(a,b)
|
||||
c = {"bobo",[4,5]}
|
||||
"""
|
||||
c = {}
|
||||
for key in a.keys():
|
||||
if key in b.keys():
|
||||
val_a = a[key]
|
||||
val_b = b[key]
|
||||
if val_a != val_b:
|
||||
c[key] = [val_a, val_b]
|
||||
return c
|
||||
|
||||
def yml_load(fp_yml, print_fields=False):
|
||||
"""
|
||||
Helper function for loading yaml files
|
||||
"""
|
||||
with open(fp_yml) as f:
|
||||
data = yaml.load(f, Loader=yaml.loader.SafeLoader)
|
||||
dict_data = dict(data)
|
||||
print("load: loaded {}".format(fp_yml))
|
||||
return dict_data
|
||||
|
||||
def yml_save(fp_yml, dict_stuff):
|
||||
"""
|
||||
Helper function for saving yaml files
|
||||
"""
|
||||
with open(fp_yml, 'w') as f:
|
||||
data = yaml.dump(dict_stuff, f, sort_keys=False, default_flow_style=False)
|
||||
print("yml_save: saved {}".format(fp_yml))
|
||||
|
||||
|
||||
#%% le main
|
||||
if __name__ == "__main__":
|
||||
# xxxx
|
||||
|
||||
#%% First let us spawn a stable diffusion holder
|
||||
device = "cuda"
|
||||
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
|
||||
|
||||
sdh = StableDiffusionHolder(fp_ckpt)
|
||||
|
||||
xxx
|
||||
|
||||
|
||||
#%% Next let's set up all parameters
|
||||
depth_strength = 0.3 # Specifies how deep (in terms of diffusion iterations the first branching happens)
|
||||
fixed_seeds = [697164, 430214]
|
||||
|
||||
prompt1 = "photo of a desert and a sky"
|
||||
prompt2 = "photo of a tree with a lake"
|
||||
|
||||
duration_transition = 12 # In seconds
|
||||
fps = 30
|
||||
|
||||
# Spawn latent blending
|
||||
self = LatentBlending(sdh)
|
||||
|
||||
self.set_prompt1(prompt1)
|
||||
self.set_prompt2(prompt2)
|
||||
|
||||
# Run latent blending
|
||||
self.branch1_crossfeed_power = 0.3
|
||||
self.branch1_crossfeed_range = 0.4
|
||||
# self.run_transition(depth_strength=depth_strength, fixed_seeds=fixed_seeds)
|
||||
self.seed1=21312
|
||||
img1 =self.compute_latents1(True)
|
||||
#%
|
||||
self.seed2=1234121
|
||||
self.branch1_crossfeed_power = 0.7
|
||||
self.branch1_crossfeed_range = 0.3
|
||||
self.branch1_crossfeed_decay = 0.3
|
||||
img2 =self.compute_latents2(True)
|
||||
# Image.fromarray(np.concatenate((img1, img2), axis=1))
|
||||
|
||||
#%%
|
||||
t0 = time.time()
|
||||
self.t_compute_max_allowed = 30
|
||||
self.parental_crossfeed_range = 1.0
|
||||
self.parental_crossfeed_power = 0.0
|
||||
self.parental_crossfeed_power_decay = 1.0
|
||||
imgs_transition = self.run_transition(recycle_img1=True, recycle_img2=True)
|
||||
t1 = time.time()
|
||||
print(f"took: {t1-t0}s")
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright 2022 Lunar Ring. All rights reserved.
|
||||
#
|
||||
# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
@ -17,10 +18,9 @@ import os
|
|||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import cv2
|
||||
from typing import Callable, List, Optional, Union
|
||||
from typing import List
|
||||
import ffmpeg # pip install ffmpeg-python. if error with broken pipe: conda update ffmpeg
|
||||
|
||||
#%%
|
||||
|
||||
class MovieSaver():
|
||||
def __init__(
|
||||
|
@ -30,10 +30,9 @@ class MovieSaver():
|
|||
shape_hw: List[int] = None,
|
||||
crf: int = 24,
|
||||
codec: str = 'libx264',
|
||||
preset: str ='fast',
|
||||
preset: str = 'fast',
|
||||
pix_fmt: str = 'yuv420p',
|
||||
silent_ffmpeg: bool = True
|
||||
):
|
||||
silent_ffmpeg: bool = True):
|
||||
r"""
|
||||
Initializes movie saver class - a human friendly ffmpeg wrapper.
|
||||
After you init the class, you can dump numpy arrays x into moviesaver.write_frame(x).
|
||||
|
@ -92,10 +91,8 @@ class MovieSaver():
|
|||
self.shape_hw = shape_hw
|
||||
self.initialize()
|
||||
|
||||
|
||||
print(f"MovieSaver initialized. fps={fps} crf={crf} pix_fmt={pix_fmt} codec={codec} preset={preset}")
|
||||
|
||||
|
||||
def initialize(self):
|
||||
args = (
|
||||
ffmpeg
|
||||
|
@ -112,7 +109,6 @@ class MovieSaver():
|
|||
self.shape_hw = tuple(self.shape_hw)
|
||||
print(f"Initialization done. Movie shape: {self.shape_hw}")
|
||||
|
||||
|
||||
def write_frame(self, out_frame: np.ndarray):
|
||||
r"""
|
||||
Function to dump a numpy array as frame of a movie.
|
||||
|
@ -123,7 +119,6 @@ class MovieSaver():
|
|||
Dim 1: x
|
||||
Dim 2: RGB
|
||||
"""
|
||||
|
||||
assert out_frame.dtype == np.uint8, "Convert to np.uint8 before"
|
||||
assert len(out_frame.shape) == 3, "out_frame needs to be three dimensional, Y X C"
|
||||
assert out_frame.shape[2] == 3, f"need three color channels, but you provided {out_frame.shape[2]}."
|
||||
|
@ -143,7 +138,6 @@ class MovieSaver():
|
|||
|
||||
self.nmb_frames += 1
|
||||
|
||||
|
||||
def finalize(self):
|
||||
r"""
|
||||
Call this function to finalize the movie. If you forget to call it your movie will be garbage.
|
||||
|
@ -157,7 +151,6 @@ class MovieSaver():
|
|||
print(f"Movie saved, {duration}s playtime, watch here: \n{self.fp_out}")
|
||||
|
||||
|
||||
|
||||
def concatenate_movies(fp_final: str, list_fp_movies: List[str]):
|
||||
r"""
|
||||
Concatenate multiple movie segments into one long movie, using ffmpeg.
|
||||
|
@ -189,7 +182,6 @@ def concatenate_movies(fp_final: str, list_fp_movies: List[str]):
|
|||
fa.write("%s\n" % item)
|
||||
|
||||
cmd = f'ffmpeg -f concat -safe 0 -i {fp_list} -c copy {fp_final}'
|
||||
dp_movie = os.path.split(fp_final)[0]
|
||||
subprocess.call(cmd, shell=True)
|
||||
os.remove(fp_list)
|
||||
if os.path.isfile(fp_final):
|
||||
|
@ -200,11 +192,12 @@ class MovieReader():
|
|||
r"""
|
||||
Class to read in a movie.
|
||||
"""
|
||||
|
||||
def __init__(self, fp_movie):
|
||||
self.video_player_object = cv2.VideoCapture(fp_movie)
|
||||
self.nmb_frames = int(self.video_player_object.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
self.fps_movie = int(self.video_player_object.get(cv2.CAP_PROP_FPS))
|
||||
self.shape = [100,100,3]
|
||||
self.shape = [100, 100, 3]
|
||||
self.shape_is_set = False
|
||||
|
||||
def get_next_frame(self):
|
||||
|
@ -217,19 +210,18 @@ class MovieReader():
|
|||
else:
|
||||
return np.zeros(self.shape)
|
||||
|
||||
#%%
|
||||
|
||||
if __name__ == "__main__":
|
||||
fps=2
|
||||
fps = 2
|
||||
list_fp_movies = []
|
||||
for k in range(4):
|
||||
fp_movie = f"/tmp/my_random_movie_{k}.mp4"
|
||||
list_fp_movies.append(fp_movie)
|
||||
ms = MovieSaver(fp_movie, fps=fps)
|
||||
for fn in tqdm(range(30)):
|
||||
img = (np.random.rand(512, 1024, 3)*255).astype(np.uint8)
|
||||
img = (np.random.rand(512, 1024, 3) * 255).astype(np.uint8)
|
||||
ms.write_frame(img)
|
||||
ms.finalize()
|
||||
|
||||
fp_final = "/tmp/my_concatenated_movie.mp4"
|
||||
concatenate_movies(fp_final, list_fp_movies)
|
||||
|
||||
|
|
|
@ -13,36 +13,25 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os, sys
|
||||
dp_git = "/home/lugo/git/"
|
||||
sys.path.append(os.path.join(dp_git,'garden4'))
|
||||
sys.path.append('util')
|
||||
import os
|
||||
import torch
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.set_grad_enabled(False)
|
||||
import numpy as np
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
import time
|
||||
import subprocess
|
||||
import warnings
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
from PIL import Image
|
||||
# import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from movie_util import MovieSaver
|
||||
import datetime
|
||||
from typing import Callable, List, Optional, Union
|
||||
import inspect
|
||||
from threading import Thread
|
||||
torch.set_grad_enabled(False)
|
||||
from typing import Optional
|
||||
from omegaconf import OmegaConf
|
||||
from torch import autocast
|
||||
from contextlib import nullcontext
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from einops import repeat, rearrange
|
||||
#%%
|
||||
from utils import interpolate_spherical
|
||||
|
||||
|
||||
def pad_image(input_image):
|
||||
|
@ -53,41 +42,11 @@ def pad_image(input_image):
|
|||
return im_padded
|
||||
|
||||
|
||||
|
||||
def make_batch_inpaint(
|
||||
image,
|
||||
mask,
|
||||
txt,
|
||||
device,
|
||||
num_samples=1):
|
||||
image = np.array(image.convert("RGB"))
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
mask = np.array(mask.convert("L"))
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
mask = mask[None, None]
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
masked_image = image * (mask < 0.5)
|
||||
|
||||
batch = {
|
||||
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||
"txt": num_samples * [txt],
|
||||
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
def make_batch_superres(
|
||||
image,
|
||||
txt,
|
||||
device,
|
||||
num_samples=1,
|
||||
):
|
||||
num_samples=1):
|
||||
image = np.array(image.convert("RGB"))
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
batch = {
|
||||
|
@ -114,7 +73,7 @@ class StableDiffusionHolder:
|
|||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
device: str = None,
|
||||
precision: str='autocast',
|
||||
precision: str = 'autocast',
|
||||
):
|
||||
r"""
|
||||
Initializes the stable diffusion holder, which contains the models and sampler.
|
||||
|
@ -137,7 +96,7 @@ class StableDiffusionHolder:
|
|||
self.precision = precision
|
||||
self.init_model(fp_ckpt, fp_config)
|
||||
|
||||
self.f = 8 #downsampling factor, most often 8 or 16",
|
||||
self.f = 8 # downsampling factor, most often 8 or 16"
|
||||
self.C = 4
|
||||
self.ddim_eta = 0
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
@ -150,13 +109,8 @@ class StableDiffusionHolder:
|
|||
self.height = height
|
||||
self.width = width
|
||||
|
||||
# Inpainting inits
|
||||
self.mask_empty = Image.fromarray(255*np.ones([self.width, self.height], dtype=np.uint8))
|
||||
self.image_empty = Image.fromarray(np.zeros([self.width, self.height, 3], dtype=np.uint8))
|
||||
|
||||
self.negative_prompt = [""]
|
||||
|
||||
|
||||
def init_model(self, fp_ckpt, fp_config):
|
||||
r"""Loads the models and sampler.
|
||||
"""
|
||||
|
@ -169,13 +123,11 @@ class StableDiffusionHolder:
|
|||
fn_ckpt = os.path.basename(fp_ckpt)
|
||||
if 'depth' in fn_ckpt:
|
||||
fp_config = 'configs/v2-midas-inference.yaml'
|
||||
elif 'inpain' in fn_ckpt:
|
||||
fp_config = 'configs/v2-inpainting-inference.yaml'
|
||||
elif 'upscaler' in fn_ckpt:
|
||||
fp_config = 'configs/x4-upscaling.yaml'
|
||||
elif '512' in fn_ckpt:
|
||||
fp_config = 'configs/v2-inference.yaml'
|
||||
elif '768'in fn_ckpt:
|
||||
elif '768' in fn_ckpt:
|
||||
fp_config = 'configs/v2-inference-v.yaml'
|
||||
elif 'v1-5' in fn_ckpt:
|
||||
fp_config = 'configs/v1-inference.yaml'
|
||||
|
@ -186,7 +138,6 @@ class StableDiffusionHolder:
|
|||
|
||||
assert os.path.isfile(fp_config), f"Your config file does not exist: {fp_config}"
|
||||
|
||||
|
||||
config = OmegaConf.load(fp_config)
|
||||
|
||||
self.model = instantiate_from_config(config.model)
|
||||
|
@ -195,7 +146,6 @@ class StableDiffusionHolder:
|
|||
self.model = self.model.to(self.device)
|
||||
self.sampler = DDIMSampler(self.model)
|
||||
|
||||
|
||||
def init_auto_res(self):
|
||||
r"""Automatically set the resolution to the one used in training.
|
||||
"""
|
||||
|
@ -218,7 +168,6 @@ class StableDiffusionHolder:
|
|||
if len(self.negative_prompt) > 1:
|
||||
self.negative_prompt = [self.negative_prompt[0]]
|
||||
|
||||
|
||||
def get_text_embedding(self, prompt):
|
||||
c = self.model.get_learned_conditioning(prompt)
|
||||
return c
|
||||
|
@ -228,7 +177,6 @@ class StableDiffusionHolder:
|
|||
r"""
|
||||
Initializes the conditioning for the x4 upscaling model.
|
||||
"""
|
||||
|
||||
image = pad_image(image) # resize to integer multiple of 32
|
||||
w, h = image.size
|
||||
noise_level = torch.Tensor(1 * [noise_level]).to(self.sampler.model.device).long()
|
||||
|
@ -240,7 +188,6 @@ class StableDiffusionHolder:
|
|||
# uncond cond
|
||||
uc_cross = self.model.get_unconditional_conditioning(1, "")
|
||||
uc_full = {"c_concat": [x_augment], "c_crossattn": [uc_cross], "c_adm": noise_level}
|
||||
|
||||
return cond, uc_full
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -249,14 +196,12 @@ class StableDiffusionHolder:
|
|||
text_embeddings: torch.FloatTensor,
|
||||
latents_start: torch.FloatTensor,
|
||||
idx_start: int = 0,
|
||||
list_latents_mixing = None,
|
||||
mixing_coeffs = 0.0,
|
||||
spatial_mask = None,
|
||||
return_image: Optional[bool] = False,
|
||||
):
|
||||
list_latents_mixing=None,
|
||||
mixing_coeffs=0.0,
|
||||
spatial_mask=None,
|
||||
return_image: Optional[bool] = False):
|
||||
r"""
|
||||
Diffusion standard version.
|
||||
|
||||
Args:
|
||||
text_embeddings: torch.FloatTensor
|
||||
Text embeddings used for diffusion
|
||||
|
@ -270,12 +215,10 @@ class StableDiffusionHolder:
|
|||
experimental feature for enforcing pixels from list_latents_mixing
|
||||
return_image: Optional[bool]
|
||||
Optionally return image directly
|
||||
|
||||
"""
|
||||
|
||||
# Asserts
|
||||
if type(mixing_coeffs) == float:
|
||||
list_mixing_coeffs = self.num_inference_steps*[mixing_coeffs]
|
||||
list_mixing_coeffs = self.num_inference_steps * [mixing_coeffs]
|
||||
elif type(mixing_coeffs) == list:
|
||||
assert len(mixing_coeffs) == self.num_inference_steps
|
||||
list_mixing_coeffs = mixing_coeffs
|
||||
|
@ -285,26 +228,19 @@ class StableDiffusionHolder:
|
|||
if np.sum(list_mixing_coeffs) > 0:
|
||||
assert len(list_latents_mixing) == self.num_inference_steps
|
||||
|
||||
|
||||
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
||||
|
||||
with precision_scope("cuda"):
|
||||
with self.model.ema_scope():
|
||||
if self.guidance_scale != 1.0:
|
||||
uc = self.model.get_learned_conditioning(self.negative_prompt)
|
||||
else:
|
||||
uc = None
|
||||
|
||||
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
|
||||
|
||||
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps - 1, ddim_eta=self.ddim_eta, verbose=False)
|
||||
latents = latents_start.clone()
|
||||
|
||||
timesteps = self.sampler.ddim_timesteps
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
|
||||
# collect latents
|
||||
# Collect latents
|
||||
list_latents_out = []
|
||||
for i, step in enumerate(time_range):
|
||||
# Set the right starting latents
|
||||
|
@ -313,15 +249,13 @@ class StableDiffusionHolder:
|
|||
continue
|
||||
elif i == idx_start:
|
||||
latents = latents_start.clone()
|
||||
|
||||
# Mix the latents.
|
||||
if i > 0 and list_mixing_coeffs[i]>0:
|
||||
latents_mixtarget = list_latents_mixing[i-1].clone()
|
||||
# Mix latents
|
||||
if i > 0 and list_mixing_coeffs[i] > 0:
|
||||
latents_mixtarget = list_latents_mixing[i - 1].clone()
|
||||
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
|
||||
|
||||
if spatial_mask is not None and list_latents_mixing is not None:
|
||||
latents = interpolate_spherical(latents, list_latents_mixing[i-1], 1-spatial_mask)
|
||||
# latents[:,:,-15:,:] = latents_mixtarget[:,:,-15:,:]
|
||||
latents = interpolate_spherical(latents, list_latents_mixing[i - 1], 1 - spatial_mask)
|
||||
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((1,), step, device=self.device, dtype=torch.long)
|
||||
|
@ -334,13 +268,11 @@ class StableDiffusionHolder:
|
|||
dynamic_threshold=None)
|
||||
latents, pred_x0 = outs
|
||||
list_latents_out.append(latents.clone())
|
||||
|
||||
if return_image:
|
||||
return self.latent2image(latents)
|
||||
else:
|
||||
return list_latents_out
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def run_diffusion_upscaling(
|
||||
self,
|
||||
|
@ -348,17 +280,16 @@ class StableDiffusionHolder:
|
|||
uc_full,
|
||||
latents_start: torch.FloatTensor,
|
||||
idx_start: int = -1,
|
||||
list_latents_mixing = None,
|
||||
mixing_coeffs = 0.0,
|
||||
return_image: Optional[bool] = False
|
||||
):
|
||||
list_latents_mixing: list = None,
|
||||
mixing_coeffs: float = 0.0,
|
||||
return_image: Optional[bool] = False):
|
||||
r"""
|
||||
Diffusion upscaling version.
|
||||
"""
|
||||
|
||||
# Asserts
|
||||
if type(mixing_coeffs) == float:
|
||||
list_mixing_coeffs = self.num_inference_steps*[mixing_coeffs]
|
||||
list_mixing_coeffs = self.num_inference_steps * [mixing_coeffs]
|
||||
elif type(mixing_coeffs) == list:
|
||||
assert len(mixing_coeffs) == self.num_inference_steps
|
||||
list_mixing_coeffs = mixing_coeffs
|
||||
|
@ -369,27 +300,20 @@ class StableDiffusionHolder:
|
|||
assert len(list_latents_mixing) == self.num_inference_steps
|
||||
|
||||
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
||||
|
||||
h = uc_full['c_concat'][0].shape[2]
|
||||
w = uc_full['c_concat'][0].shape[3]
|
||||
|
||||
with precision_scope("cuda"):
|
||||
with self.model.ema_scope():
|
||||
|
||||
shape_latents = [self.model.channels, h, w]
|
||||
|
||||
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
|
||||
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps - 1, ddim_eta=self.ddim_eta, verbose=False)
|
||||
C, H, W = shape_latents
|
||||
size = (1, C, H, W)
|
||||
b = size[0]
|
||||
|
||||
latents = latents_start.clone()
|
||||
|
||||
timesteps = self.sampler.ddim_timesteps
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
|
||||
# collect latents
|
||||
list_latents_out = []
|
||||
for i, step in enumerate(time_range):
|
||||
|
@ -399,12 +323,10 @@ class StableDiffusionHolder:
|
|||
continue
|
||||
elif i == idx_start:
|
||||
latents = latents_start.clone()
|
||||
|
||||
# Mix the latents.
|
||||
if i > 0 and list_mixing_coeffs[i]>0:
|
||||
latents_mixtarget = list_latents_mixing[i-1].clone()
|
||||
if i > 0 and list_mixing_coeffs[i] > 0:
|
||||
latents_mixtarget = list_latents_mixing[i - 1].clone()
|
||||
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
|
||||
|
||||
# print(f"diffusion iter {i}")
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=self.device, dtype=torch.long)
|
||||
|
@ -423,121 +345,10 @@ class StableDiffusionHolder:
|
|||
else:
|
||||
return list_latents_out
|
||||
|
||||
@torch.no_grad()
|
||||
def run_diffusion_inpaint(
|
||||
self,
|
||||
text_embeddings: torch.FloatTensor,
|
||||
latents_for_injection: torch.FloatTensor = None,
|
||||
idx_start: int = -1,
|
||||
idx_stop: int = -1,
|
||||
return_image: Optional[bool] = False
|
||||
):
|
||||
r"""
|
||||
Runs inpaint-based diffusion. Returns a list of latents that were computed.
|
||||
Adaptations allow to supply
|
||||
a) starting index for diffusion
|
||||
b) stopping index for diffusion
|
||||
c) latent representations that are injected at the starting index
|
||||
Furthermore the intermittent latents are collected and returned.
|
||||
|
||||
Adapted from diffusers (https://github.com/huggingface/diffusers)
|
||||
Args:
|
||||
text_embeddings: torch.FloatTensor
|
||||
Text embeddings used for diffusion
|
||||
latents_for_injection: torch.FloatTensor
|
||||
Latents that are used for injection
|
||||
idx_start: int
|
||||
Index of the diffusion process start and where the latents_for_injection are injected
|
||||
idx_stop: int
|
||||
Index of the diffusion process end.
|
||||
return_image: Optional[bool]
|
||||
Optionally return image directly
|
||||
|
||||
"""
|
||||
|
||||
if latents_for_injection is None:
|
||||
do_inject_latents = False
|
||||
else:
|
||||
do_inject_latents = True
|
||||
|
||||
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
||||
generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
|
||||
|
||||
with precision_scope("cuda"):
|
||||
with self.model.ema_scope():
|
||||
|
||||
batch = make_batch_inpaint(self.image_source, self.mask_image, txt="willbereplaced", device=self.device, num_samples=1)
|
||||
c = text_embeddings
|
||||
c_cat = list()
|
||||
for ck in self.model.concat_keys:
|
||||
cc = batch[ck].float()
|
||||
if ck != self.model.masked_image_key:
|
||||
bchw = [1, 4, self.height // 8, self.width // 8]
|
||||
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
||||
else:
|
||||
cc = self.model.get_first_stage_encoding(self.model.encode_first_stage(cc))
|
||||
c_cat.append(cc)
|
||||
c_cat = torch.cat(c_cat, dim=1)
|
||||
|
||||
# cond
|
||||
cond = {"c_concat": [c_cat], "c_crossattn": [c]}
|
||||
|
||||
# uncond cond
|
||||
uc_cross = self.model.get_unconditional_conditioning(1, "")
|
||||
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
||||
|
||||
shape_latents = [self.model.channels, self.height // 8, self.width // 8]
|
||||
|
||||
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=0., verbose=False)
|
||||
# sampling
|
||||
C, H, W = shape_latents
|
||||
size = (1, C, H, W)
|
||||
|
||||
device = self.model.betas.device
|
||||
b = size[0]
|
||||
latents = torch.randn(size, generator=generator, device=device)
|
||||
|
||||
timesteps = self.sampler.ddim_timesteps
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
|
||||
# collect latents
|
||||
list_latents_out = []
|
||||
for i, step in enumerate(time_range):
|
||||
if do_inject_latents:
|
||||
# Inject latent at right place
|
||||
if i < idx_start:
|
||||
continue
|
||||
elif i == idx_start:
|
||||
latents = latents_for_injection.clone()
|
||||
|
||||
if i == idx_stop:
|
||||
return list_latents_out
|
||||
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
|
||||
outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
|
||||
quantize_denoised=False, temperature=1.0,
|
||||
noise_dropout=0.0, score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=self.guidance_scale,
|
||||
unconditional_conditioning=uc_full,
|
||||
dynamic_threshold=None)
|
||||
latents, pred_x0 = outs
|
||||
list_latents_out.append(latents.clone())
|
||||
|
||||
if return_image:
|
||||
return self.latent2image(latents)
|
||||
else:
|
||||
return list_latents_out
|
||||
|
||||
@torch.no_grad()
|
||||
def latent2image(
|
||||
self,
|
||||
latents: torch.FloatTensor
|
||||
):
|
||||
latents: torch.FloatTensor):
|
||||
r"""
|
||||
Returns an image provided a latent representation from diffusion.
|
||||
Args:
|
||||
|
@ -546,85 +357,6 @@ class StableDiffusionHolder:
|
|||
"""
|
||||
x_sample = self.model.decode_first_stage(latents)
|
||||
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255 * x_sample[0,:,:].permute([1,2,0]).cpu().numpy()
|
||||
x_sample = 255 * x_sample[0, :, :].permute([1, 2, 0]).cpu().numpy()
|
||||
image = x_sample.astype(np.uint8)
|
||||
return image
|
||||
|
||||
@torch.no_grad()
|
||||
def interpolate_spherical(p0, p1, fract_mixing: float):
|
||||
r"""
|
||||
Helper function to correctly mix two random variables using spherical interpolation.
|
||||
See https://en.wikipedia.org/wiki/Slerp
|
||||
The function will always cast up to float64 for sake of extra 4.
|
||||
Args:
|
||||
p0:
|
||||
First tensor for interpolation
|
||||
p1:
|
||||
Second tensor for interpolation
|
||||
fract_mixing: float
|
||||
Mixing coefficient of interval [0, 1].
|
||||
0 will return in p0
|
||||
1 will return in p1
|
||||
0.x will return a mix between both preserving angular velocity.
|
||||
"""
|
||||
|
||||
if p0.dtype == torch.float16:
|
||||
recast_to = 'fp16'
|
||||
else:
|
||||
recast_to = 'fp32'
|
||||
|
||||
p0 = p0.double()
|
||||
p1 = p1.double()
|
||||
norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
|
||||
epsilon = 1e-7
|
||||
dot = torch.sum(p0 * p1) / norm
|
||||
dot = dot.clamp(-1+epsilon, 1-epsilon)
|
||||
|
||||
theta_0 = torch.arccos(dot)
|
||||
sin_theta_0 = torch.sin(theta_0)
|
||||
theta_t = theta_0 * fract_mixing
|
||||
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = torch.sin(theta_t) / sin_theta_0
|
||||
interp = p0*s0 + p1*s1
|
||||
|
||||
if recast_to == 'fp16':
|
||||
interp = interp.half()
|
||||
elif recast_to == 'fp32':
|
||||
interp = interp.float()
|
||||
|
||||
return interp
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
num_inference_steps = 20 # Number of diffusion interations
|
||||
|
||||
# fp_ckpt = "../stable_diffusion_models/ckpt/768-v-ema.ckpt"
|
||||
# fp_config = '../stablediffusion/configs/stable-diffusion/v2-inference-v.yaml'
|
||||
|
||||
# fp_ckpt= "../stable_diffusion_models/ckpt/512-inpainting-ema.ckpt"
|
||||
# fp_config = '../stablediffusion/configs//stable-diffusion/v2-inpainting-inference.yaml'
|
||||
|
||||
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
|
||||
# fp_config = 'configs/v2-inference-v.yaml'
|
||||
|
||||
|
||||
self = StableDiffusionHolder(fp_ckpt, num_inference_steps=num_inference_steps)
|
||||
|
||||
xxx
|
||||
|
||||
#%%
|
||||
self.width = 1536
|
||||
self.height = 768
|
||||
prompt = "360 degree equirectangular, a huge rocky hill full of pianos and keyboards, musical instruments, cinematic, masterpiece 8 k, artstation"
|
||||
self.set_negative_prompt("out of frame, faces, rendering, blurry")
|
||||
te = self.get_text_embedding(prompt)
|
||||
|
||||
img = self.run_diffusion_standard(te, return_image=True)
|
||||
Image.fromarray(img).show()
|
||||
|
||||
|
|
|
@ -0,0 +1,260 @@
|
|||
# Copyright 2022 Lunar Ring. All rights reserved.
|
||||
# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
torch.backends.cudnn.benchmark = False
|
||||
import numpy as np
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
import time
|
||||
import warnings
|
||||
import datetime
|
||||
from typing import List, Union
|
||||
torch.set_grad_enabled(False)
|
||||
import yaml
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def interpolate_spherical(p0, p1, fract_mixing: float):
|
||||
r"""
|
||||
Helper function to correctly mix two random variables using spherical interpolation.
|
||||
See https://en.wikipedia.org/wiki/Slerp
|
||||
The function will always cast up to float64 for sake of extra 4.
|
||||
Args:
|
||||
p0:
|
||||
First tensor for interpolation
|
||||
p1:
|
||||
Second tensor for interpolation
|
||||
fract_mixing: float
|
||||
Mixing coefficient of interval [0, 1].
|
||||
0 will return in p0
|
||||
1 will return in p1
|
||||
0.x will return a mix between both preserving angular velocity.
|
||||
"""
|
||||
|
||||
if p0.dtype == torch.float16:
|
||||
recast_to = 'fp16'
|
||||
else:
|
||||
recast_to = 'fp32'
|
||||
|
||||
p0 = p0.double()
|
||||
p1 = p1.double()
|
||||
norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
|
||||
epsilon = 1e-7
|
||||
dot = torch.sum(p0 * p1) / norm
|
||||
dot = dot.clamp(-1 + epsilon, 1 - epsilon)
|
||||
|
||||
theta_0 = torch.arccos(dot)
|
||||
sin_theta_0 = torch.sin(theta_0)
|
||||
theta_t = theta_0 * fract_mixing
|
||||
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = torch.sin(theta_t) / sin_theta_0
|
||||
interp = p0 * s0 + p1 * s1
|
||||
|
||||
if recast_to == 'fp16':
|
||||
interp = interp.half()
|
||||
elif recast_to == 'fp32':
|
||||
interp = interp.float()
|
||||
|
||||
return interp
|
||||
|
||||
|
||||
def interpolate_linear(p0, p1, fract_mixing):
|
||||
r"""
|
||||
Helper function to mix two variables using standard linear interpolation.
|
||||
Args:
|
||||
p0:
|
||||
First tensor / np.ndarray for interpolation
|
||||
p1:
|
||||
Second tensor / np.ndarray for interpolation
|
||||
fract_mixing: float
|
||||
Mixing coefficient of interval [0, 1].
|
||||
0 will return in p0
|
||||
1 will return in p1
|
||||
0.x will return a linear mix between both.
|
||||
"""
|
||||
reconvert_uint8 = False
|
||||
if type(p0) is np.ndarray and p0.dtype == 'uint8':
|
||||
reconvert_uint8 = True
|
||||
p0 = p0.astype(np.float64)
|
||||
|
||||
if type(p1) is np.ndarray and p1.dtype == 'uint8':
|
||||
reconvert_uint8 = True
|
||||
p1 = p1.astype(np.float64)
|
||||
|
||||
interp = (1 - fract_mixing) * p0 + fract_mixing * p1
|
||||
|
||||
if reconvert_uint8:
|
||||
interp = np.clip(interp, 0, 255).astype(np.uint8)
|
||||
|
||||
return interp
|
||||
|
||||
|
||||
def add_frames_linear_interp(
|
||||
list_imgs: List[np.ndarray],
|
||||
fps_target: Union[float, int] = None,
|
||||
duration_target: Union[float, int] = None,
|
||||
nmb_frames_target: int = None):
|
||||
r"""
|
||||
Helper function to cheaply increase the number of frames given a list of images,
|
||||
by virtue of standard linear interpolation.
|
||||
The number of inserted frames will be automatically adjusted so that the total of number
|
||||
of frames can be fixed precisely, using a random shuffling technique.
|
||||
The function allows 1:1 comparisons between transitions as videos.
|
||||
|
||||
Args:
|
||||
list_imgs: List[np.ndarray)
|
||||
List of images, between each image new frames will be inserted via linear interpolation.
|
||||
fps_target:
|
||||
OptionA: specify here the desired frames per second.
|
||||
duration_target:
|
||||
OptionA: specify here the desired duration of the transition in seconds.
|
||||
nmb_frames_target:
|
||||
OptionB: directly fix the total number of frames of the output.
|
||||
"""
|
||||
|
||||
# Sanity
|
||||
if nmb_frames_target is not None and fps_target is not None:
|
||||
raise ValueError("You cannot specify both fps_target and nmb_frames_target")
|
||||
if fps_target is None:
|
||||
assert nmb_frames_target is not None, "Either specify nmb_frames_target or nmb_frames_target"
|
||||
if nmb_frames_target is None:
|
||||
assert fps_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
|
||||
assert duration_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
|
||||
nmb_frames_target = fps_target * duration_target
|
||||
|
||||
# Get number of frames that are missing
|
||||
nmb_frames_diff = len(list_imgs) - 1
|
||||
nmb_frames_missing = nmb_frames_target - nmb_frames_diff - 1
|
||||
|
||||
if nmb_frames_missing < 1:
|
||||
return list_imgs
|
||||
|
||||
list_imgs_float = [img.astype(np.float32) for img in list_imgs]
|
||||
# Distribute missing frames, append nmb_frames_to_insert(i) frames for each frame
|
||||
mean_nmb_frames_insert = nmb_frames_missing / nmb_frames_diff
|
||||
constfact = np.floor(mean_nmb_frames_insert)
|
||||
remainder_x = 1 - (mean_nmb_frames_insert - constfact)
|
||||
nmb_iter = 0
|
||||
while True:
|
||||
nmb_frames_to_insert = np.random.rand(nmb_frames_diff)
|
||||
nmb_frames_to_insert[nmb_frames_to_insert <= remainder_x] = 0
|
||||
nmb_frames_to_insert[nmb_frames_to_insert > remainder_x] = 1
|
||||
nmb_frames_to_insert += constfact
|
||||
if np.sum(nmb_frames_to_insert) == nmb_frames_missing:
|
||||
break
|
||||
nmb_iter += 1
|
||||
if nmb_iter > 100000:
|
||||
print("add_frames_linear_interp: issue with inserting the right number of frames")
|
||||
break
|
||||
|
||||
nmb_frames_to_insert = nmb_frames_to_insert.astype(np.int32)
|
||||
list_imgs_interp = []
|
||||
for i in range(len(list_imgs_float) - 1):
|
||||
img0 = list_imgs_float[i]
|
||||
img1 = list_imgs_float[i + 1]
|
||||
list_imgs_interp.append(img0.astype(np.uint8))
|
||||
list_fracts_linblend = np.linspace(0, 1, nmb_frames_to_insert[i] + 2)[1:-1]
|
||||
for fract_linblend in list_fracts_linblend:
|
||||
img_blend = interpolate_linear(img0, img1, fract_linblend).astype(np.uint8)
|
||||
list_imgs_interp.append(img_blend.astype(np.uint8))
|
||||
if i == len(list_imgs_float) - 2:
|
||||
list_imgs_interp.append(img1.astype(np.uint8))
|
||||
|
||||
return list_imgs_interp
|
||||
|
||||
|
||||
def get_spacing(nmb_points: int, scaling: float):
|
||||
"""
|
||||
Helper function for getting nonlinear spacing between 0 and 1, symmetric around 0.5
|
||||
Args:
|
||||
nmb_points: int
|
||||
Number of points between [0, 1]
|
||||
scaling: float
|
||||
Higher values will return higher sampling density around 0.5
|
||||
"""
|
||||
if scaling < 1.7:
|
||||
return np.linspace(0, 1, nmb_points)
|
||||
nmb_points_per_side = nmb_points // 2 + 1
|
||||
if np.mod(nmb_points, 2) != 0: # Uneven case
|
||||
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)
|
||||
right_side = 1 - left_side[::-1][1:]
|
||||
else:
|
||||
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)[0:-1]
|
||||
right_side = 1 - left_side[::-1]
|
||||
all_fracts = np.hstack([left_side, right_side])
|
||||
return all_fracts
|
||||
|
||||
|
||||
def get_time(resolution=None):
|
||||
"""
|
||||
Helper function returning an nicely formatted time string, e.g. 221117_1620
|
||||
"""
|
||||
if resolution is None:
|
||||
resolution = "second"
|
||||
if resolution == "day":
|
||||
t = time.strftime('%y%m%d', time.localtime())
|
||||
elif resolution == "minute":
|
||||
t = time.strftime('%y%m%d_%H%M', time.localtime())
|
||||
elif resolution == "second":
|
||||
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
|
||||
elif resolution == "millisecond":
|
||||
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
|
||||
t += "_"
|
||||
t += str("{:03d}".format(int(int(datetime.utcnow().strftime('%f')) / 1000)))
|
||||
else:
|
||||
raise ValueError("bad resolution provided: %s" % resolution)
|
||||
return t
|
||||
|
||||
|
||||
def compare_dicts(a, b):
|
||||
"""
|
||||
Compares two dictionaries a and b and returns a dictionary c, with all
|
||||
keys,values that have shared keys in a and b but same values in a and b.
|
||||
The values of a and b are stacked together in the output.
|
||||
Example:
|
||||
a = {}; a['bobo'] = 4
|
||||
b = {}; b['bobo'] = 5
|
||||
c = dict_compare(a,b)
|
||||
c = {"bobo",[4,5]}
|
||||
"""
|
||||
c = {}
|
||||
for key in a.keys():
|
||||
if key in b.keys():
|
||||
val_a = a[key]
|
||||
val_b = b[key]
|
||||
if val_a != val_b:
|
||||
c[key] = [val_a, val_b]
|
||||
return c
|
||||
|
||||
|
||||
def yml_load(fp_yml, print_fields=False):
|
||||
"""
|
||||
Helper function for loading yaml files
|
||||
"""
|
||||
with open(fp_yml) as f:
|
||||
data = yaml.load(f, Loader=yaml.loader.SafeLoader)
|
||||
dict_data = dict(data)
|
||||
print("load: loaded {}".format(fp_yml))
|
||||
return dict_data
|
||||
|
||||
|
||||
def yml_save(fp_yml, dict_stuff):
|
||||
"""
|
||||
Helper function for saving yaml files
|
||||
"""
|
||||
with open(fp_yml, 'w') as f:
|
||||
yaml.dump(dict_stuff, f, sort_keys=False, default_flow_style=False)
|
||||
print("yml_save: saved {}".format(fp_yml))
|
Loading…
Reference in New Issue