cleanup
This commit is contained in:
parent
3ed876e0ee
commit
297bb9abe6
|
@ -13,39 +13,31 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os, sys
|
||||
import torch
|
||||
torch.backends.cudnn.benchmark = False
|
||||
import numpy as np
|
||||
torch.set_grad_enabled(False)
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
import warnings
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
from PIL import Image
|
||||
# import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from movie_util import MovieSaver
|
||||
from typing import Callable, List, Optional, Union
|
||||
from latent_blending import LatentBlending, add_frames_linear_interp
|
||||
from latent_blending import LatentBlending
|
||||
from stable_diffusion_holder import StableDiffusionHolder
|
||||
torch.set_grad_enabled(False)
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
||||
#%% First let us spawn a stable diffusion holder
|
||||
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
|
||||
# %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
|
||||
# fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
|
||||
fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.ckpt")
|
||||
sdh = StableDiffusionHolder(fp_ckpt)
|
||||
|
||||
#%% Next let's set up all parameters
|
||||
depth_strength = 0.65 # Specifies how deep (in terms of diffusion iterations the first branching happens)
|
||||
t_compute_max_allowed = 15 # Determines the quality of the transition in terms of compute time you grant it
|
||||
# %% Next let's set up all parameters
|
||||
depth_strength = 0.65 # Specifies how deep (in terms of diffusion iterations the first branching happens)
|
||||
t_compute_max_allowed = 15 # Determines the quality of the transition in terms of compute time you grant it
|
||||
fixed_seeds = [69731932, 504430820]
|
||||
|
||||
|
||||
prompt1 = "photo of a beautiful cherry forest covered in white flowers, ambient light, very detailed, magic"
|
||||
prompt2 = "photo of an golden statue with a funny hat, surrounded by ferns and vines, grainy analog photograph, mystical ambience, incredible detail"
|
||||
|
||||
fp_movie = 'movie_example1.mp4'
|
||||
duration_transition = 12 # In seconds
|
||||
duration_transition = 12 # In seconds
|
||||
|
||||
# Spawn latent blending
|
||||
lb = LatentBlending(sdh)
|
||||
|
@ -54,10 +46,9 @@ lb.set_prompt2(prompt2)
|
|||
|
||||
# Run latent blending
|
||||
lb.run_transition(
|
||||
depth_strength = depth_strength,
|
||||
t_compute_max_allowed = t_compute_max_allowed,
|
||||
fixed_seeds = fixed_seeds
|
||||
)
|
||||
depth_strength=depth_strength,
|
||||
t_compute_max_allowed=t_compute_max_allowed,
|
||||
fixed_seeds=fixed_seeds)
|
||||
|
||||
# Save movie
|
||||
lb.write_movie_transition(fp_movie, duration_transition)
|
||||
lb.write_movie_transition(fp_movie, duration_transition)
|
||||
|
|
|
@ -13,33 +13,26 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os, sys
|
||||
import torch
|
||||
torch.backends.cudnn.benchmark = False
|
||||
import numpy as np
|
||||
torch.set_grad_enabled(False)
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
import warnings
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
from PIL import Image
|
||||
import torch
|
||||
from movie_util import MovieSaver, concatenate_movies
|
||||
from typing import Callable, List, Optional, Union
|
||||
from latent_blending import LatentBlending, add_frames_linear_interp
|
||||
from latent_blending import LatentBlending
|
||||
from stable_diffusion_holder import StableDiffusionHolder
|
||||
torch.set_grad_enabled(False)
|
||||
from movie_util import concatenate_movies
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
#%% First let us spawn a stable diffusion holder
|
||||
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
|
||||
# fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
|
||||
# %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
|
||||
# fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
|
||||
fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.ckpt")
|
||||
sdh = StableDiffusionHolder(fp_ckpt)
|
||||
|
||||
|
||||
#%% Let's setup the multi transition
|
||||
# %% Let's setup the multi transition
|
||||
fps = 30
|
||||
duration_single_trans = 6
|
||||
depth_strength = 0.55 #Specifies how deep (in terms of diffusion iterations the first branching happens)
|
||||
depth_strength = 0.55 # Specifies how deep (in terms of diffusion iterations the first branching happens)
|
||||
|
||||
# Specify a list of prompts below
|
||||
list_prompts = []
|
||||
|
@ -52,36 +45,33 @@ list_prompts.append("statue of an ancient cybernetic messenger annoucing good ne
|
|||
|
||||
# You can optionally specify the seeds
|
||||
list_seeds = [954375479, 332539350, 956051013, 408831845, 250009012, 675588737]
|
||||
t_compute_max_allowed = 12 # per segment
|
||||
t_compute_max_allowed = 12 # per segment
|
||||
fp_movie = 'movie_example2.mp4'
|
||||
lb = LatentBlending(sdh)
|
||||
|
||||
list_movie_parts = [] #
|
||||
for i in range(len(list_prompts)-1):
|
||||
list_movie_parts = []
|
||||
for i in range(len(list_prompts) - 1):
|
||||
# For a multi transition we can save some computation time and recycle the latents
|
||||
if i==0:
|
||||
if i == 0:
|
||||
lb.set_prompt1(list_prompts[i])
|
||||
lb.set_prompt2(list_prompts[i+1])
|
||||
lb.set_prompt2(list_prompts[i + 1])
|
||||
recycle_img1 = False
|
||||
else:
|
||||
lb.swap_forward()
|
||||
lb.set_prompt2(list_prompts[i+1])
|
||||
recycle_img1 = True
|
||||
|
||||
lb.set_prompt2(list_prompts[i + 1])
|
||||
recycle_img1 = True
|
||||
|
||||
fp_movie_part = f"tmp_part_{str(i).zfill(3)}.mp4"
|
||||
|
||||
fixed_seeds = list_seeds[i:i+2]
|
||||
|
||||
fixed_seeds = list_seeds[i:i + 2]
|
||||
# Run latent blending
|
||||
lb.run_transition(
|
||||
depth_strength = depth_strength,
|
||||
t_compute_max_allowed = t_compute_max_allowed,
|
||||
fixed_seeds = fixed_seeds
|
||||
)
|
||||
|
||||
depth_strength=depth_strength,
|
||||
t_compute_max_allowed=t_compute_max_allowed,
|
||||
fixed_seeds=fixed_seeds)
|
||||
|
||||
# Save movie
|
||||
lb.write_movie_transition(fp_movie_part, duration_single_trans)
|
||||
list_movie_parts.append(fp_movie_part)
|
||||
|
||||
# Finally, concatente the result
|
||||
concatenate_movies(fp_movie, list_movie_parts)
|
||||
concatenate_movies(fp_movie, list_movie_parts)
|
||||
|
|
|
@ -13,25 +13,17 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os, sys
|
||||
import torch
|
||||
torch.backends.cudnn.benchmark = False
|
||||
import numpy as np
|
||||
torch.set_grad_enabled(False)
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
import warnings
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
from PIL import Image
|
||||
# import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from movie_util import MovieSaver
|
||||
from typing import Callable, List, Optional, Union
|
||||
from latent_blending import LatentBlending, add_frames_linear_interp
|
||||
from latent_blending import LatentBlending
|
||||
from stable_diffusion_holder import StableDiffusionHolder
|
||||
torch.set_grad_enabled(False)
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
#%% Define vars for low-resoltion pass
|
||||
# %% Define vars for low-resoltion pass
|
||||
prompt1 = "photo of mount vesuvius erupting a terrifying pyroclastic ash cloud"
|
||||
prompt2 = "photo of a inside a building full of ash, fire, death, destruction, explosions"
|
||||
fixed_seeds = [5054613, 1168652]
|
||||
|
@ -41,21 +33,18 @@ height = 384
|
|||
num_inference_steps_lores = 40
|
||||
nmb_max_branches_lores = 10
|
||||
depth_strength_lores = 0.5
|
||||
fp_ckpt_lores = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
|
||||
|
||||
fp_ckpt_lores = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
|
||||
|
||||
#%% Define vars for high-resoltion pass
|
||||
fp_ckpt_hires = "../stable_diffusion_models/ckpt/x4-upscaler-ema.ckpt"
|
||||
# %% Define vars for high-resoltion pass
|
||||
fp_ckpt_hires = hf_hub_download(repo_id="stabilityai/stable-diffusion-x4-upscaler", filename="x4-upscaler-ema.ckpt")
|
||||
depth_strength_hires = 0.65
|
||||
num_inference_steps_hires = 100
|
||||
nmb_branches_final_hires = 6
|
||||
dp_imgs = "tmp_transition" # folder for results and intermediate steps
|
||||
dp_imgs = "tmp_transition" # Folder for results and intermediate steps
|
||||
|
||||
|
||||
#%% Run low-res pass
|
||||
# %% Run low-res pass
|
||||
sdh = StableDiffusionHolder(fp_ckpt_lores)
|
||||
|
||||
#%%
|
||||
lb = LatentBlending(sdh)
|
||||
lb.set_prompt1(prompt1)
|
||||
lb.set_prompt2(prompt2)
|
||||
|
@ -64,14 +53,13 @@ lb.set_height(height)
|
|||
|
||||
# Run latent blending
|
||||
lb.run_transition(
|
||||
depth_strength = depth_strength_lores,
|
||||
nmb_max_branches = nmb_max_branches_lores,
|
||||
fixed_seeds = fixed_seeds
|
||||
)
|
||||
depth_strength=depth_strength_lores,
|
||||
nmb_max_branches=nmb_max_branches_lores,
|
||||
fixed_seeds=fixed_seeds)
|
||||
|
||||
lb.write_imgs_transition(dp_imgs)
|
||||
|
||||
#%% Run high-res pass
|
||||
# %% Run high-res pass
|
||||
sdh = StableDiffusionHolder(fp_ckpt_hires)
|
||||
lb = LatentBlending(sdh)
|
||||
lb = LatentBlending(sdh)
|
||||
lb.run_upscaling(dp_imgs, depth_strength_hires, num_inference_steps_hires, nmb_branches_final_hires)
|
||||
|
|
|
@ -13,25 +13,19 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os, sys
|
||||
import os
|
||||
import torch
|
||||
torch.backends.cudnn.benchmark = False
|
||||
import numpy as np
|
||||
torch.set_grad_enabled(False)
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
import warnings
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
from PIL import Image
|
||||
# import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from movie_util import MovieSaver, concatenate_movies
|
||||
from typing import Callable, List, Optional, Union
|
||||
from latent_blending import LatentBlending, add_frames_linear_interp
|
||||
from latent_blending import LatentBlending
|
||||
from stable_diffusion_holder import StableDiffusionHolder
|
||||
torch.set_grad_enabled(False)
|
||||
from movie_util import concatenate_movies
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
#%% Define vars for low-resoltion pass
|
||||
# %% Define vars for low-resoltion pass
|
||||
list_prompts = []
|
||||
list_prompts.append("surrealistic statue made of glitter and dirt, standing in a lake, atmospheric light, strange glow")
|
||||
list_prompts.append("statue of a mix between a tree and human, made of marble, incredibly detailed")
|
||||
|
@ -50,61 +44,59 @@ num_inference_steps_lores = 40
|
|||
nmb_max_branches_lores = 10
|
||||
depth_strength_lores = 0.5
|
||||
|
||||
fp_ckpt_lores = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
|
||||
fp_ckpt_lores = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
|
||||
|
||||
#%% Define vars for high-resoltion pass
|
||||
fp_ckpt_hires = "../stable_diffusion_models/ckpt/x4-upscaler-ema.ckpt"
|
||||
# %% Define vars for high-resoltion pass
|
||||
fp_ckpt_hires = hf_hub_download(repo_id="stabilityai/stable-diffusion-x4-upscaler", filename="x4-upscaler-ema.ckpt")
|
||||
depth_strength_hires = 0.65
|
||||
num_inference_steps_hires = 100
|
||||
nmb_branches_final_hires = 6
|
||||
#%% Run low-res pass
|
||||
|
||||
# %% Run low-res pass
|
||||
sdh = StableDiffusionHolder(fp_ckpt_lores)
|
||||
t_compute_max_allowed = 12 # per segment
|
||||
t_compute_max_allowed = 12 # Per segment
|
||||
lb = LatentBlending(sdh)
|
||||
|
||||
list_movie_dirs = [] #
|
||||
for i in range(len(list_prompts)-1):
|
||||
list_movie_dirs = []
|
||||
for i in range(len(list_prompts) - 1):
|
||||
# For a multi transition we can save some computation time and recycle the latents
|
||||
if i==0:
|
||||
if i == 0:
|
||||
lb.set_prompt1(list_prompts[i])
|
||||
lb.set_prompt2(list_prompts[i+1])
|
||||
lb.set_prompt2(list_prompts[i + 1])
|
||||
recycle_img1 = False
|
||||
else:
|
||||
lb.swap_forward()
|
||||
lb.set_prompt2(list_prompts[i+1])
|
||||
recycle_img1 = True
|
||||
|
||||
lb.set_prompt2(list_prompts[i + 1])
|
||||
recycle_img1 = True
|
||||
|
||||
dp_movie_part = f"tmp_part_{str(i).zfill(3)}"
|
||||
fp_movie_part = os.path.join(dp_movie_part, "movie_lowres.mp4")
|
||||
os.makedirs(dp_movie_part, exist_ok=True)
|
||||
fixed_seeds = list_seeds[i:i+2]
|
||||
|
||||
fixed_seeds = list_seeds[i:i + 2]
|
||||
|
||||
# Run latent blending
|
||||
lb.run_transition(
|
||||
depth_strength = depth_strength_lores,
|
||||
nmb_max_branches = nmb_max_branches_lores,
|
||||
fixed_seeds = fixed_seeds
|
||||
)
|
||||
|
||||
depth_strength=depth_strength_lores,
|
||||
nmb_max_branches=nmb_max_branches_lores,
|
||||
fixed_seeds=fixed_seeds)
|
||||
|
||||
# Save movie and images (needed for upscaling!)
|
||||
lb.write_movie_transition(fp_movie_part, duration_single_trans)
|
||||
lb.write_imgs_transition(dp_movie_part)
|
||||
list_movie_dirs.append(dp_movie_part)
|
||||
|
||||
|
||||
|
||||
#%% Run high-res pass on each segment
|
||||
# %% Run high-res pass on each segment
|
||||
sdh = StableDiffusionHolder(fp_ckpt_hires)
|
||||
lb = LatentBlending(sdh)
|
||||
lb = LatentBlending(sdh)
|
||||
for dp_part in list_movie_dirs:
|
||||
lb.run_upscaling(dp_part, depth_strength_hires, num_inference_steps_hires, nmb_branches_final_hires)
|
||||
|
||||
#%% concatenate into one long movie
|
||||
# %% concatenate into one long movie
|
||||
list_fp_movies = []
|
||||
for dp_part in list_movie_dirs:
|
||||
fp_movie = os.path.join(dp_part, "movie_highres.mp4")
|
||||
assert os.path.isfile(fp_movie)
|
||||
list_fp_movies.append(fp_movie)
|
||||
|
||||
|
||||
fp_final = "example4.mp4"
|
||||
concatenate_movies(fp_final, list_fp_movies)
|
||||
concatenate_movies(fp_final, list_fp_movies)
|
||||
|
|
446
gradio_ui.py
446
gradio_ui.py
|
@ -13,83 +13,90 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os, sys
|
||||
import os
|
||||
import torch
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.set_grad_enabled(False)
|
||||
import numpy as np
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
import warnings
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
from PIL import Image
|
||||
import torch
|
||||
from movie_util import MovieSaver, concatenate_movies
|
||||
from typing import Callable, List, Optional, Union
|
||||
from latent_blending import get_time, yml_save, LatentBlending, add_frames_linear_interp, compare_dicts
|
||||
from latent_blending import LatentBlending
|
||||
from stable_diffusion_holder import StableDiffusionHolder
|
||||
torch.set_grad_enabled(False)
|
||||
import gradio as gr
|
||||
import copy
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
import shutil
|
||||
import random
|
||||
import time
|
||||
from utils import get_time, add_frames_linear_interp
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
||||
#%%
|
||||
|
||||
class BlendingFrontend():
|
||||
def __init__(self, sdh=None):
|
||||
def __init__(
|
||||
self,
|
||||
sdh,
|
||||
share=False):
|
||||
r"""
|
||||
Gradio Helper Class to collect UI data and start latent blending.
|
||||
Args:
|
||||
sdh:
|
||||
StableDiffusionHolder
|
||||
share: bool
|
||||
Set true to get a shareable gradio link (e.g. for running a remote server)
|
||||
"""
|
||||
self.share = share
|
||||
|
||||
# UI Defaults
|
||||
self.num_inference_steps = 30
|
||||
if sdh is None:
|
||||
self.use_debug = True
|
||||
self.height = 768
|
||||
self.width = 768
|
||||
else:
|
||||
self.use_debug = False
|
||||
self.lb = LatentBlending(sdh)
|
||||
self.lb.sdh.num_inference_steps = self.num_inference_steps
|
||||
self.height = self.lb.sdh.height
|
||||
self.width = self.lb.sdh.width
|
||||
|
||||
self.init_save_dir()
|
||||
self.save_empty_image()
|
||||
self.share = False
|
||||
self.transition_can_be_computed = False
|
||||
self.depth_strength = 0.25
|
||||
self.seed1 = 420
|
||||
self.seed2 = 420
|
||||
self.guidance_scale = 4.0
|
||||
self.guidance_scale_mid_damper = 0.5
|
||||
self.mid_compression_scaler = 1.2
|
||||
self.prompt1 = ""
|
||||
self.prompt2 = ""
|
||||
self.negative_prompt = ""
|
||||
self.state_current = {}
|
||||
self.fps = 30
|
||||
self.duration_video = 8
|
||||
self.t_compute_max_allowed = 10
|
||||
|
||||
self.lb = LatentBlending(sdh)
|
||||
self.lb.sdh.num_inference_steps = self.num_inference_steps
|
||||
self.init_parameters_from_lb()
|
||||
self.init_save_dir()
|
||||
|
||||
# Vars
|
||||
self.list_fp_imgs_current = []
|
||||
self.recycle_img1 = False
|
||||
self.recycle_img2 = False
|
||||
self.list_all_segments = []
|
||||
self.dp_session = ""
|
||||
self.user_id = None
|
||||
|
||||
def init_parameters_from_lb(self):
|
||||
r"""
|
||||
Automatically init parameters from latentblending instance
|
||||
"""
|
||||
self.height = self.lb.sdh.height
|
||||
self.width = self.lb.sdh.width
|
||||
self.guidance_scale = self.lb.guidance_scale
|
||||
self.guidance_scale_mid_damper = self.lb.guidance_scale_mid_damper
|
||||
self.mid_compression_scaler = self.lb.mid_compression_scaler
|
||||
self.branch1_crossfeed_power = self.lb.branch1_crossfeed_power
|
||||
self.branch1_crossfeed_range = self.lb.branch1_crossfeed_range
|
||||
self.branch1_crossfeed_decay = self.lb.branch1_crossfeed_decay
|
||||
self.parental_crossfeed_power = self.lb.parental_crossfeed_power
|
||||
self.parental_crossfeed_range = self.lb.parental_crossfeed_range
|
||||
self.parental_crossfeed_power_decay = self.lb.parental_crossfeed_power_decay
|
||||
self.fps = 30
|
||||
self.duration_video = 10
|
||||
self.t_compute_max_allowed = 10
|
||||
self.list_fp_imgs_current = []
|
||||
self.current_timestamp = None
|
||||
self.recycle_img1 = False
|
||||
self.recycle_img2 = False
|
||||
self.multi_idx_current = -1
|
||||
self.list_imgs_shown_last = 5*[self.fp_img_empty]
|
||||
self.list_all_segments = []
|
||||
self.dp_session = ""
|
||||
self.user_id = None
|
||||
self.block_transition = False
|
||||
|
||||
|
||||
|
||||
def init_save_dir(self):
|
||||
load_dotenv(find_dotenv(), verbose=False)
|
||||
r"""
|
||||
Initializes the directory where stuff is being saved.
|
||||
You can specify this directory in a ".env" file in your latentblending root, setting
|
||||
DIR_OUT='/path/to/saving'
|
||||
"""
|
||||
load_dotenv(find_dotenv(), verbose=False)
|
||||
self.dp_out = os.getenv("DIR_OUT")
|
||||
if self.dp_out is None:
|
||||
self.dp_out = ""
|
||||
|
@ -97,151 +104,151 @@ class BlendingFrontend():
|
|||
os.makedirs(self.dp_imgs, exist_ok=True)
|
||||
self.dp_movies = os.path.join(self.dp_out, "movies")
|
||||
os.makedirs(self.dp_movies, exist_ok=True)
|
||||
|
||||
|
||||
# make dummy image
|
||||
self.save_empty_image()
|
||||
|
||||
def save_empty_image(self):
|
||||
r"""
|
||||
Saves an empty/black dummy image.
|
||||
"""
|
||||
self.fp_img_empty = os.path.join(self.dp_imgs, 'empty.jpg')
|
||||
Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5)
|
||||
|
||||
|
||||
|
||||
def randomize_seed1(self):
|
||||
# Dont randomize seed if we are in a multi concat mode. we don't want to change this one otherwise the movie breaks
|
||||
r"""
|
||||
Randomizes the first seed
|
||||
"""
|
||||
seed = np.random.randint(0, 10000000)
|
||||
self.seed1 = int(seed)
|
||||
print(f"randomize_seed1: new seed = {self.seed1}")
|
||||
return seed
|
||||
|
||||
|
||||
def randomize_seed2(self):
|
||||
r"""
|
||||
Randomizes the second seed
|
||||
"""
|
||||
seed = np.random.randint(0, 10000000)
|
||||
self.seed2 = int(seed)
|
||||
print(f"randomize_seed2: new seed = {self.seed2}")
|
||||
return seed
|
||||
|
||||
|
||||
def setup_lb(self, list_ui_elem):
|
||||
|
||||
def setup_lb(self, list_ui_vals):
|
||||
r"""
|
||||
Sets all parameters from the UI. Since gradio does not support to pass dictionaries,
|
||||
we have to instead pass keys (list_ui_keys, global) and values (list_ui_vals)
|
||||
"""
|
||||
# Collect latent blending variables
|
||||
self.state_current = self.get_state_dict()
|
||||
self.lb.set_width(list_ui_elem[list_ui_keys.index('width')])
|
||||
self.lb.set_height(list_ui_elem[list_ui_keys.index('height')])
|
||||
self.lb.set_prompt1(list_ui_elem[list_ui_keys.index('prompt1')])
|
||||
self.lb.set_prompt2(list_ui_elem[list_ui_keys.index('prompt2')])
|
||||
self.lb.set_negative_prompt(list_ui_elem[list_ui_keys.index('negative_prompt')])
|
||||
self.lb.guidance_scale = list_ui_elem[list_ui_keys.index('guidance_scale')]
|
||||
self.lb.guidance_scale_mid_damper = list_ui_elem[list_ui_keys.index('guidance_scale_mid_damper')]
|
||||
self.t_compute_max_allowed = list_ui_elem[list_ui_keys.index('duration_compute')]
|
||||
self.lb.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
|
||||
self.lb.sdh.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
|
||||
self.duration_video = list_ui_elem[list_ui_keys.index('duration_video')]
|
||||
self.lb.seed1 = list_ui_elem[list_ui_keys.index('seed1')] #seed
|
||||
self.lb.seed2 = list_ui_elem[list_ui_keys.index('seed2')]
|
||||
|
||||
self.lb.branch1_crossfeed_power = list_ui_elem[list_ui_keys.index('branch1_crossfeed_power')]
|
||||
self.lb.branch1_crossfeed_range = list_ui_elem[list_ui_keys.index('branch1_crossfeed_range')]
|
||||
self.lb.branch1_crossfeed_decay = list_ui_elem[list_ui_keys.index('branch1_crossfeed_decay')]
|
||||
self.lb.parental_crossfeed_power = list_ui_elem[list_ui_keys.index('parental_crossfeed_power')]
|
||||
self.lb.parental_crossfeed_range = list_ui_elem[list_ui_keys.index('parental_crossfeed_range')]
|
||||
self.lb.parental_crossfeed_power_decay = list_ui_elem[list_ui_keys.index('parental_crossfeed_power_decay')]
|
||||
self.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
|
||||
self.depth_strength = list_ui_elem[list_ui_keys.index('depth_strength')]
|
||||
|
||||
if len(list_ui_elem[list_ui_keys.index('user_id')]) > 1:
|
||||
self.user_id = list_ui_elem[list_ui_keys.index('user_id')]
|
||||
self.lb.set_width(list_ui_vals[list_ui_keys.index('width')])
|
||||
self.lb.set_height(list_ui_vals[list_ui_keys.index('height')])
|
||||
self.lb.set_prompt1(list_ui_vals[list_ui_keys.index('prompt1')])
|
||||
self.lb.set_prompt2(list_ui_vals[list_ui_keys.index('prompt2')])
|
||||
self.lb.set_negative_prompt(list_ui_vals[list_ui_keys.index('negative_prompt')])
|
||||
self.lb.guidance_scale = list_ui_vals[list_ui_keys.index('guidance_scale')]
|
||||
self.lb.guidance_scale_mid_damper = list_ui_vals[list_ui_keys.index('guidance_scale_mid_damper')]
|
||||
self.t_compute_max_allowed = list_ui_vals[list_ui_keys.index('duration_compute')]
|
||||
self.lb.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
|
||||
self.lb.sdh.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
|
||||
self.duration_video = list_ui_vals[list_ui_keys.index('duration_video')]
|
||||
self.lb.seed1 = list_ui_vals[list_ui_keys.index('seed1')]
|
||||
self.lb.seed2 = list_ui_vals[list_ui_keys.index('seed2')]
|
||||
self.lb.branch1_crossfeed_power = list_ui_vals[list_ui_keys.index('branch1_crossfeed_power')]
|
||||
self.lb.branch1_crossfeed_range = list_ui_vals[list_ui_keys.index('branch1_crossfeed_range')]
|
||||
self.lb.branch1_crossfeed_decay = list_ui_vals[list_ui_keys.index('branch1_crossfeed_decay')]
|
||||
self.lb.parental_crossfeed_power = list_ui_vals[list_ui_keys.index('parental_crossfeed_power')]
|
||||
self.lb.parental_crossfeed_range = list_ui_vals[list_ui_keys.index('parental_crossfeed_range')]
|
||||
self.lb.parental_crossfeed_power_decay = list_ui_vals[list_ui_keys.index('parental_crossfeed_power_decay')]
|
||||
self.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
|
||||
self.depth_strength = list_ui_vals[list_ui_keys.index('depth_strength')]
|
||||
|
||||
if len(list_ui_vals[list_ui_keys.index('user_id')]) > 1:
|
||||
self.user_id = list_ui_vals[list_ui_keys.index('user_id')]
|
||||
else:
|
||||
# generate new user id
|
||||
self.user_id = ''.join((random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ') for i in range(8)))
|
||||
print(f"made new user_id: {self.user_id}")
|
||||
|
||||
print(f"made new user_id: {self.user_id} at {get_time('second')}")
|
||||
|
||||
def save_latents(self, fp_latents, list_latents):
|
||||
r"""
|
||||
Saves a latent trajectory on disk, in npy format.
|
||||
"""
|
||||
list_latents_cpu = [l.cpu().numpy() for l in list_latents]
|
||||
np.save(fp_latents, list_latents_cpu)
|
||||
|
||||
|
||||
|
||||
def load_latents(self, fp_latents):
|
||||
r"""
|
||||
Loads a latent trajectory from disk, converts to torch tensor.
|
||||
"""
|
||||
list_latents_cpu = np.load(fp_latents)
|
||||
list_latents = [torch.from_numpy(l).to(self.lb.device) for l in list_latents_cpu]
|
||||
return list_latents
|
||||
|
||||
|
||||
def compute_img1(self, *args):
|
||||
list_ui_elem = args
|
||||
self.setup_lb(list_ui_elem)
|
||||
r"""
|
||||
Computes the first transition image and returns it for display.
|
||||
Sets all other transition images and last image to empty (as they are obsolete with this operation)
|
||||
"""
|
||||
list_ui_vals = args
|
||||
self.setup_lb(list_ui_vals)
|
||||
fp_img1 = os.path.join(self.dp_imgs, f"img1_{self.user_id}")
|
||||
img1 = Image.fromarray(self.lb.compute_latents1(return_image=True))
|
||||
img1.save(fp_img1+".jpg")
|
||||
self.save_latents(fp_img1+".npy", self.lb.tree_latents[0])
|
||||
|
||||
img1.save(fp_img1 + ".jpg")
|
||||
self.save_latents(fp_img1 + ".npy", self.lb.tree_latents[0])
|
||||
self.recycle_img1 = True
|
||||
self.recycle_img2 = False
|
||||
# fixme save seeds. change filenames?
|
||||
return [fp_img1+".jpg", self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
|
||||
|
||||
return [fp_img1 + ".jpg", self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
|
||||
|
||||
def compute_img2(self, *args):
|
||||
if not os.path.isfile(os.path.join(self.dp_imgs, f"img1_{self.user_id}.jpg")): # don't do anything
|
||||
r"""
|
||||
Computes the last transition image and returns it for display.
|
||||
Sets all other transition images to empty (as they are obsolete with this operation)
|
||||
"""
|
||||
if not os.path.isfile(os.path.join(self.dp_imgs, f"img1_{self.user_id}.jpg")): # don't do anything
|
||||
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
|
||||
list_ui_elem = args
|
||||
self.setup_lb(list_ui_elem)
|
||||
list_ui_vals = args
|
||||
self.setup_lb(list_ui_vals)
|
||||
|
||||
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
|
||||
fp_img2 = os.path.join(self.dp_imgs, f"img2_{self.user_id}")
|
||||
img2 = Image.fromarray(self.lb.compute_latents2(return_image=True))
|
||||
img2.save(fp_img2+'.jpg')
|
||||
self.save_latents(fp_img2+".npy", self.lb.tree_latents[-1])
|
||||
img2.save(fp_img2 + '.jpg')
|
||||
self.save_latents(fp_img2 + ".npy", self.lb.tree_latents[-1])
|
||||
self.recycle_img2 = True
|
||||
self.transition_can_be_computed = True
|
||||
# fixme save seeds. change filenames?
|
||||
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2+".jpg", self.user_id]
|
||||
|
||||
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2 + ".jpg", self.user_id]
|
||||
|
||||
def compute_transition(self, *args):
|
||||
if not self.transition_can_be_computed:
|
||||
list_return = [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
|
||||
return list_return
|
||||
|
||||
list_ui_elem = args
|
||||
self.setup_lb(list_ui_elem)
|
||||
r"""
|
||||
Computes transition images and movie.
|
||||
"""
|
||||
list_ui_vals = args
|
||||
self.setup_lb(list_ui_vals)
|
||||
print("STARTING TRANSITION...")
|
||||
|
||||
fixed_seeds = [self.seed1, self.seed2]
|
||||
|
||||
# Run Latent Blending
|
||||
# Check if another user is blocking this... otherwise everything will become mixed.
|
||||
# t_now = time.time()
|
||||
# if self.block_transition:
|
||||
# while True:
|
||||
# time.sleep(1)
|
||||
# if not self.block_transition:
|
||||
# break
|
||||
# if time.time() - t_now > 1000:
|
||||
# return
|
||||
|
||||
self.block_transition = True
|
||||
# Inject loaded latents (other user interference)
|
||||
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
|
||||
self.lb.tree_latents[-1] = self.load_latents(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"))
|
||||
imgs_transition = self.lb.run_transition(
|
||||
recycle_img1=self.recycle_img1,
|
||||
recycle_img2=self.recycle_img2,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
depth_strength=self.depth_strength,
|
||||
recycle_img1=self.recycle_img1,
|
||||
recycle_img2=self.recycle_img2,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
depth_strength=self.depth_strength,
|
||||
t_compute_max_allowed=self.t_compute_max_allowed,
|
||||
fixed_seeds=fixed_seeds
|
||||
)
|
||||
print(f"Latent Blending pass finished. Resulted in {len(imgs_transition)} images")
|
||||
|
||||
fixed_seeds=fixed_seeds)
|
||||
print(f"Latent Blending pass finished ({get_time('second')}). Resulted in {len(imgs_transition)} images")
|
||||
|
||||
# Subselect three preview images
|
||||
idx_img_prev = np.round(np.linspace(0, len(imgs_transition)-1, 5)[1:-1]).astype(np.int32)
|
||||
idx_img_prev = np.round(np.linspace(0, len(imgs_transition) - 1, 5)[1:-1]).astype(np.int32)
|
||||
|
||||
list_imgs_preview = []
|
||||
for j in idx_img_prev:
|
||||
list_imgs_preview.append(Image.fromarray(imgs_transition[j]))
|
||||
|
||||
|
||||
# Save the preview imgs as jpgs on disk so we are not sending umcompressed data around
|
||||
self.current_timestamp = get_time('second')
|
||||
current_timestamp = get_time('second')
|
||||
self.list_fp_imgs_current = []
|
||||
for i in range(len(list_imgs_preview)):
|
||||
fp_img = os.path.join(self.dp_imgs, f"img_preview_{i}_{self.current_timestamp}.jpg")
|
||||
fp_img = os.path.join(self.dp_imgs, f"img_preview_{i}_{current_timestamp}.jpg")
|
||||
list_imgs_preview[i].save(fp_img)
|
||||
self.list_fp_imgs_current.append(fp_img)
|
||||
self.block_transition = False
|
||||
# Insert cheap frames for the movie
|
||||
imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration_video, self.fps)
|
||||
|
||||
|
@ -254,44 +261,43 @@ class BlendingFrontend():
|
|||
ms.write_frame(img)
|
||||
ms.finalize()
|
||||
print("DONE SAVING MOVIE! SENDING BACK...")
|
||||
|
||||
|
||||
# Assemble Output, updating the preview images and le movie
|
||||
list_return = self.list_fp_imgs_current + [self.fp_movie]
|
||||
return list_return
|
||||
|
||||
|
||||
def stack_forward(self, prompt2, seed2):
|
||||
r"""
|
||||
Allows to generate multi-segment movies. Sets last image -> first image with all
|
||||
relevant parameters.
|
||||
"""
|
||||
# Save preview images, prompts and seeds into dictionary for stacking
|
||||
if len(self.list_all_segments) == 0:
|
||||
timestamp_session = get_time('second')
|
||||
self.dp_session = os.path.join(self.dp_out, f"session_{timestamp_session}")
|
||||
os.makedirs(self.dp_session)
|
||||
|
||||
self.transition_can_be_computed = False
|
||||
|
||||
idx_segment = len(self.list_all_segments)
|
||||
idx_segment = len(self.list_all_segments)
|
||||
dp_segment = os.path.join(self.dp_session, f"segment_{str(idx_segment).zfill(3)}")
|
||||
|
||||
|
||||
self.list_all_segments.append(dp_segment)
|
||||
self.lb.write_imgs_transition(dp_segment)
|
||||
|
||||
|
||||
fp_movie_last = self.get_fp_video_last()
|
||||
fp_movie_next = self.get_fp_video_next()
|
||||
|
||||
|
||||
shutil.copyfile(fp_movie_last, fp_movie_next)
|
||||
|
||||
|
||||
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
|
||||
self.lb.tree_latents[-1] = self.load_latents(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"))
|
||||
self.lb.swap_forward()
|
||||
|
||||
|
||||
shutil.copyfile(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"), os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
|
||||
|
||||
|
||||
fp_multi = self.multi_concat()
|
||||
list_out = [fp_multi]
|
||||
|
||||
|
||||
list_out.extend([os.path.join(self.dp_imgs, f"img2_{self.user_id}.jpg")])
|
||||
list_out.extend([self.fp_img_empty]*4)
|
||||
list_out.extend([self.fp_img_empty] * 4)
|
||||
list_out.append(gr.update(interactive=False, value=prompt2))
|
||||
list_out.append(gr.update(interactive=False, value=seed2))
|
||||
list_out.append("")
|
||||
|
@ -299,25 +305,31 @@ class BlendingFrontend():
|
|||
print(f"stack_forward: fp_multi {fp_multi}")
|
||||
return list_out
|
||||
|
||||
|
||||
def multi_concat(self):
|
||||
r"""
|
||||
Concatentates all stacked segments into one long movie.
|
||||
"""
|
||||
list_fp_movies = self.get_fp_video_all()
|
||||
# Concatenate movies and save
|
||||
fp_final = os.path.join(self.dp_session, f"concat_{self.user_id}.mp4")
|
||||
concatenate_movies(fp_final, list_fp_movies)
|
||||
return fp_final
|
||||
|
||||
|
||||
def get_fp_video_all(self):
|
||||
r"""
|
||||
Collects all stacked movie segments.
|
||||
"""
|
||||
list_all = os.listdir(self.dp_movies)
|
||||
str_beg = f"movie_{self.user_id}_"
|
||||
list_user = [l for l in list_all if str_beg in l]
|
||||
list_user.sort()
|
||||
list_user = [os.path.join(self.dp_movies, l) for l in list_user]
|
||||
return list_user
|
||||
|
||||
|
||||
|
||||
def get_fp_video_next(self):
|
||||
r"""
|
||||
Gets the filepath of the next movie segment.
|
||||
"""
|
||||
list_videos = self.get_fp_video_all()
|
||||
if len(list_videos) == 0:
|
||||
idx_next = 0
|
||||
|
@ -325,93 +337,81 @@ class BlendingFrontend():
|
|||
idx_next = len(list_videos)
|
||||
fp_video_next = os.path.join(self.dp_movies, f"movie_{self.user_id}_{str(idx_next).zfill(3)}.mp4")
|
||||
return fp_video_next
|
||||
|
||||
|
||||
def get_fp_video_last(self):
|
||||
r"""
|
||||
Gets the current video that was saved.
|
||||
"""
|
||||
fp_video_last = os.path.join(self.dp_movies, f"last_{self.user_id}.mp4")
|
||||
return fp_video_last
|
||||
|
||||
|
||||
def get_state_dict(self):
|
||||
state_dict = {}
|
||||
grab_vars = ['prompt1', 'prompt2', 'seed1', 'seed2', 'height', 'width',
|
||||
'num_inference_steps', 'depth_strength', 'guidance_scale',
|
||||
'guidance_scale_mid_damper', 'mid_compression_scaler']
|
||||
|
||||
for v in grab_vars:
|
||||
state_dict[v] = getattr(self, v)
|
||||
return state_dict
|
||||
if __name__ == "__main__":
|
||||
fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
|
||||
# fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.ckpt")
|
||||
bf = BlendingFrontend(StableDiffusionHolder(fp_ckpt))
|
||||
# self = BlendingFrontend(None)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
|
||||
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
|
||||
bf = BlendingFrontend(StableDiffusionHolder(fp_ckpt))
|
||||
# self = BlendingFrontend(None)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Row():
|
||||
prompt1 = gr.Textbox(label="prompt 1")
|
||||
prompt2 = gr.Textbox(label="prompt 2")
|
||||
|
||||
|
||||
with gr.Row():
|
||||
duration_compute = gr.Slider(5, 200, bf.t_compute_max_allowed, step=1, label='compute budget', interactive=True)
|
||||
duration_video = gr.Slider(1, 100, bf.duration_video, step=0.1, label='video duration', interactive=True)
|
||||
duration_compute = gr.Slider(5, 200, bf.t_compute_max_allowed, step=1, label='compute budget', interactive=True)
|
||||
duration_video = gr.Slider(1, 100, bf.duration_video, step=0.1, label='video duration', interactive=True)
|
||||
height = gr.Slider(256, 2048, bf.height, step=128, label='height', interactive=True)
|
||||
width = gr.Slider(256, 2048, bf.width, step=128, label='width', interactive=True)
|
||||
|
||||
width = gr.Slider(256, 2048, bf.width, step=128, label='width', interactive=True)
|
||||
|
||||
with gr.Accordion("Advanced Settings (click to expand)", open=False):
|
||||
|
||||
with gr.Accordion("Diffusion settings", open=True):
|
||||
with gr.Row():
|
||||
num_inference_steps = gr.Slider(5, 100, bf.num_inference_steps, step=1, label='num_inference_steps', interactive=True)
|
||||
guidance_scale = gr.Slider(1, 25, bf.guidance_scale, step=0.1, label='guidance_scale', interactive=True)
|
||||
negative_prompt = gr.Textbox(label="negative prompt")
|
||||
|
||||
guidance_scale = gr.Slider(1, 25, bf.guidance_scale, step=0.1, label='guidance_scale', interactive=True)
|
||||
negative_prompt = gr.Textbox(label="negative prompt")
|
||||
|
||||
with gr.Accordion("Seed control: adjust seeds for first and last images", open=True):
|
||||
with gr.Row():
|
||||
b_newseed1 = gr.Button("randomize seed 1", variant='secondary')
|
||||
seed1 = gr.Number(bf.seed1, label="seed 1", interactive=True)
|
||||
seed2 = gr.Number(bf.seed2, label="seed 2", interactive=True)
|
||||
b_newseed2 = gr.Button("randomize seed 2", variant='secondary')
|
||||
|
||||
|
||||
with gr.Accordion("Last image crossfeeding.", open=True):
|
||||
with gr.Row():
|
||||
branch1_crossfeed_power = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_power, step=0.01, label='branch1 crossfeed power', interactive=True)
|
||||
branch1_crossfeed_range = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_range, step=0.01, label='branch1 crossfeed range', interactive=True)
|
||||
branch1_crossfeed_decay = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_decay, step=0.01, label='branch1 crossfeed decay', interactive=True)
|
||||
branch1_crossfeed_power = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_power, step=0.01, label='branch1 crossfeed power', interactive=True)
|
||||
branch1_crossfeed_range = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_range, step=0.01, label='branch1 crossfeed range', interactive=True)
|
||||
branch1_crossfeed_decay = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_decay, step=0.01, label='branch1 crossfeed decay', interactive=True)
|
||||
|
||||
with gr.Accordion("Transition settings", open=True):
|
||||
with gr.Row():
|
||||
parental_crossfeed_power = gr.Slider(0.0, 1.0, bf.parental_crossfeed_power, step=0.01, label='parental crossfeed power', interactive=True)
|
||||
parental_crossfeed_range = gr.Slider(0.0, 1.0, bf.parental_crossfeed_range, step=0.01, label='parental crossfeed range', interactive=True)
|
||||
parental_crossfeed_power_decay = gr.Slider(0.0, 1.0, bf.parental_crossfeed_power_decay, step=0.01, label='parental crossfeed decay', interactive=True)
|
||||
parental_crossfeed_power = gr.Slider(0.0, 1.0, bf.parental_crossfeed_power, step=0.01, label='parental crossfeed power', interactive=True)
|
||||
parental_crossfeed_range = gr.Slider(0.0, 1.0, bf.parental_crossfeed_range, step=0.01, label='parental crossfeed range', interactive=True)
|
||||
parental_crossfeed_power_decay = gr.Slider(0.0, 1.0, bf.parental_crossfeed_power_decay, step=0.01, label='parental crossfeed decay', interactive=True)
|
||||
with gr.Row():
|
||||
depth_strength = gr.Slider(0.01, 0.99, bf.depth_strength, step=0.01, label='depth_strength', interactive=True)
|
||||
guidance_scale_mid_damper = gr.Slider(0.01, 2.0, bf.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
|
||||
|
||||
|
||||
depth_strength = gr.Slider(0.01, 0.99, bf.depth_strength, step=0.01, label='depth_strength', interactive=True)
|
||||
guidance_scale_mid_damper = gr.Slider(0.01, 2.0, bf.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
|
||||
|
||||
with gr.Row():
|
||||
b_compute1 = gr.Button('compute first image', variant='primary')
|
||||
b_compute_transition = gr.Button('compute transition', variant='primary')
|
||||
b_compute2 = gr.Button('compute last image', variant='primary')
|
||||
|
||||
|
||||
with gr.Row():
|
||||
img1 = gr.Image(label="1/5")
|
||||
img2 = gr.Image(label="2/5", show_progress=False)
|
||||
img3 = gr.Image(label="3/5", show_progress=False)
|
||||
img4 = gr.Image(label="4/5", show_progress=False)
|
||||
img5 = gr.Image(label="5/5")
|
||||
|
||||
|
||||
with gr.Row():
|
||||
vid_single = gr.Video(label="single trans")
|
||||
vid_multi = gr.Video(label="multi trans")
|
||||
|
||||
vid_single = gr.Video(label="current single trans")
|
||||
vid_multi = gr.Video(label="concatented multi trans")
|
||||
|
||||
with gr.Row():
|
||||
# b_restart = gr.Button("RESTART EVERYTHING")
|
||||
b_stackforward = gr.Button('append last movie segment (left) to multi movie (right)', variant='primary')
|
||||
|
||||
|
||||
with gr.Row():
|
||||
gr.Markdown(
|
||||
"""
|
||||
|
@ -420,75 +420,73 @@ if __name__ == "__main__":
|
|||
- compute budget: set your waiting time for the transition. high values = better quality
|
||||
- video duration: seconds per segment
|
||||
- height/width: in pixels
|
||||
|
||||
|
||||
## Diffusion settings
|
||||
- num_inference_steps: number of diffusion steps
|
||||
- guidance_scale: latent blending seems to prefer lower values here
|
||||
- negative prompt: enter negative prompt here, applied for all images
|
||||
|
||||
|
||||
## Last image crossfeeding
|
||||
- branch1_crossfeed_power: Controls the level of cross-feeding between the first and last image branch. For preserving structures.
|
||||
- branch1_crossfeed_range: Sets the duration of active crossfeed during development. High values enforce strong structural similarity.
|
||||
- branch1_crossfeed_decay: Sets decay for branch1_crossfeed_power. Lower values make the decay stronger across the range.
|
||||
|
||||
|
||||
## Transition settings
|
||||
- parental_crossfeed_power: Similar to branch1_crossfeed_power, however applied for the images withinin the transition.
|
||||
- parental_crossfeed_range: Similar to branch1_crossfeed_range, however applied for the images withinin the transition.
|
||||
- parental_crossfeed_power_decay: Similar to branch1_crossfeed_decay, however applied for the images withinin the transition.
|
||||
- depth_strength: Determines when the blending process will begin in terms of diffusion steps. Low values more inventive but can cause motion.
|
||||
- guidance_scale_mid_damper: Decreases the guidance scale in the middle of a transition.
|
||||
"""
|
||||
)
|
||||
|
||||
""")
|
||||
|
||||
with gr.Row():
|
||||
user_id = gr.Textbox(label="user id", interactive=False)
|
||||
|
||||
|
||||
# Collect all UI elemts in list to easily pass as inputs in gradio
|
||||
dict_ui_elem = {}
|
||||
dict_ui_elem["prompt1"] = prompt1
|
||||
dict_ui_elem["negative_prompt"] = negative_prompt
|
||||
dict_ui_elem["prompt2"] = prompt2
|
||||
|
||||
|
||||
dict_ui_elem["duration_compute"] = duration_compute
|
||||
dict_ui_elem["duration_video"] = duration_video
|
||||
dict_ui_elem["height"] = height
|
||||
dict_ui_elem["width"] = width
|
||||
|
||||
|
||||
dict_ui_elem["depth_strength"] = depth_strength
|
||||
dict_ui_elem["branch1_crossfeed_power"] = branch1_crossfeed_power
|
||||
dict_ui_elem["branch1_crossfeed_range"] = branch1_crossfeed_range
|
||||
dict_ui_elem["branch1_crossfeed_decay"] = branch1_crossfeed_decay
|
||||
|
||||
|
||||
dict_ui_elem["num_inference_steps"] = num_inference_steps
|
||||
dict_ui_elem["guidance_scale"] = guidance_scale
|
||||
dict_ui_elem["guidance_scale_mid_damper"] = guidance_scale_mid_damper
|
||||
dict_ui_elem["seed1"] = seed1
|
||||
dict_ui_elem["seed2"] = seed2
|
||||
|
||||
|
||||
dict_ui_elem["parental_crossfeed_range"] = parental_crossfeed_range
|
||||
dict_ui_elem["parental_crossfeed_power"] = parental_crossfeed_power
|
||||
dict_ui_elem["parental_crossfeed_power_decay"] = parental_crossfeed_power_decay
|
||||
dict_ui_elem["user_id"] = user_id
|
||||
|
||||
|
||||
# Convert to list, as gradio doesn't seem to accept dicts
|
||||
list_ui_elem = []
|
||||
list_ui_vals = []
|
||||
list_ui_keys = []
|
||||
for k in dict_ui_elem.keys():
|
||||
list_ui_elem.append(dict_ui_elem[k])
|
||||
list_ui_vals.append(dict_ui_elem[k])
|
||||
list_ui_keys.append(k)
|
||||
bf.list_ui_keys = list_ui_keys
|
||||
|
||||
|
||||
b_newseed1.click(bf.randomize_seed1, outputs=seed1)
|
||||
b_newseed2.click(bf.randomize_seed2, outputs=seed2)
|
||||
b_compute1.click(bf.compute_img1, inputs=list_ui_elem, outputs=[img1, img2, img3, img4, img5, user_id])
|
||||
b_compute2.click(bf.compute_img2, inputs=list_ui_elem, outputs=[img2, img3, img4, img5, user_id])
|
||||
b_compute_transition.click(bf.compute_transition,
|
||||
inputs=list_ui_elem,
|
||||
outputs=[img2, img3, img4, vid_single])
|
||||
|
||||
b_stackforward.click(bf.stack_forward,
|
||||
inputs=[prompt2, seed2],
|
||||
outputs=[vid_multi, img1, img2, img3, img4, img5, prompt1, seed1, prompt2])
|
||||
b_compute1.click(bf.compute_img1, inputs=list_ui_vals, outputs=[img1, img2, img3, img4, img5, user_id])
|
||||
b_compute2.click(bf.compute_img2, inputs=list_ui_vals, outputs=[img2, img3, img4, img5, user_id])
|
||||
b_compute_transition.click(bf.compute_transition,
|
||||
inputs=list_ui_vals,
|
||||
outputs=[img2, img3, img4, vid_single])
|
||||
|
||||
b_stackforward.click(bf.stack_forward,
|
||||
inputs=[prompt2, seed2],
|
||||
outputs=[vid_multi, img1, img2, img3, img4, img5, prompt1, seed1, prompt2])
|
||||
|
||||
|
||||
demo.launch(share=bf.share, inbrowser=True, inline=False)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
100
movie_util.py
100
movie_util.py
|
@ -1,5 +1,6 @@
|
|||
# Copyright 2022 Lunar Ring. All rights reserved.
|
||||
#
|
||||
# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
@ -17,26 +18,24 @@ import os
|
|||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import cv2
|
||||
from typing import Callable, List, Optional, Union
|
||||
import ffmpeg # pip install ffmpeg-python. if error with broken pipe: conda update ffmpeg
|
||||
from typing import List
|
||||
import ffmpeg # pip install ffmpeg-python. if error with broken pipe: conda update ffmpeg
|
||||
|
||||
|
||||
#%%
|
||||
|
||||
class MovieSaver():
|
||||
def __init__(
|
||||
self,
|
||||
fp_out: str,
|
||||
fps: int = 24,
|
||||
self,
|
||||
fp_out: str,
|
||||
fps: int = 24,
|
||||
shape_hw: List[int] = None,
|
||||
crf: int = 24,
|
||||
codec: str = 'libx264',
|
||||
preset: str ='fast',
|
||||
pix_fmt: str = 'yuv420p',
|
||||
silent_ffmpeg: bool = True
|
||||
):
|
||||
preset: str = 'fast',
|
||||
pix_fmt: str = 'yuv420p',
|
||||
silent_ffmpeg: bool = True):
|
||||
r"""
|
||||
Initializes movie saver class - a human friendly ffmpeg wrapper.
|
||||
After you init the class, you can dump numpy arrays x into moviesaver.write_frame(x).
|
||||
After you init the class, you can dump numpy arrays x into moviesaver.write_frame(x).
|
||||
Don't forget toi finalize movie file with moviesaver.finalize().
|
||||
Args:
|
||||
fp_out: str
|
||||
|
@ -47,22 +46,22 @@ class MovieSaver():
|
|||
Output shape, optional argument. Can be initialized automatically when first frame is written.
|
||||
crf: int
|
||||
ffmpeg doc: the range of the CRF scale is 0–51, where 0 is lossless
|
||||
(for 8 bit only, for 10 bit use -qp 0), 23 is the default, and 51 is worst quality possible.
|
||||
A lower value generally leads to higher quality, and a subjectively sane range is 17–28.
|
||||
Consider 17 or 18 to be visually lossless or nearly so;
|
||||
it should look the same or nearly the same as the input but it isn't technically lossless.
|
||||
The range is exponential, so increasing the CRF value +6 results in
|
||||
roughly half the bitrate / file size, while -6 leads to roughly twice the bitrate.
|
||||
(for 8 bit only, for 10 bit use -qp 0), 23 is the default, and 51 is worst quality possible.
|
||||
A lower value generally leads to higher quality, and a subjectively sane range is 17–28.
|
||||
Consider 17 or 18 to be visually lossless or nearly so;
|
||||
it should look the same or nearly the same as the input but it isn't technically lossless.
|
||||
The range is exponential, so increasing the CRF value +6 results in
|
||||
roughly half the bitrate / file size, while -6 leads to roughly twice the bitrate.
|
||||
codec: int
|
||||
Number of diffusion steps. Larger values will take more compute time.
|
||||
preset: str
|
||||
Choose between ultrafast, superfast, veryfast, faster, fast, medium, slow, slower, veryslow.
|
||||
ffmpeg doc: A preset is a collection of options that will provide a certain encoding speed
|
||||
to compression ratio. A slower preset will provide better compression
|
||||
(compression is quality per filesize).
|
||||
This means that, for example, if you target a certain file size or constant bit rate,
|
||||
ffmpeg doc: A preset is a collection of options that will provide a certain encoding speed
|
||||
to compression ratio. A slower preset will provide better compression
|
||||
(compression is quality per filesize).
|
||||
This means that, for example, if you target a certain file size or constant bit rate,
|
||||
you will achieve better quality with a slower preset. Similarly, for constant quality encoding,
|
||||
you will simply save bitrate by choosing a slower preset.
|
||||
you will simply save bitrate by choosing a slower preset.
|
||||
pix_fmt: str
|
||||
Pixel format. Run 'ffmpeg -pix_fmts' in your shell to see all options.
|
||||
silent_ffmpeg: bool
|
||||
|
@ -70,7 +69,7 @@ class MovieSaver():
|
|||
"""
|
||||
if len(os.path.split(fp_out)[0]) > 0:
|
||||
assert os.path.isdir(os.path.split(fp_out)[0]), "Directory does not exist!"
|
||||
|
||||
|
||||
self.fp_out = fp_out
|
||||
self.fps = fps
|
||||
self.crf = crf
|
||||
|
@ -78,10 +77,10 @@ class MovieSaver():
|
|||
self.codec = codec
|
||||
self.preset = preset
|
||||
self.silent_ffmpeg = silent_ffmpeg
|
||||
|
||||
|
||||
if os.path.isfile(fp_out):
|
||||
os.remove(fp_out)
|
||||
|
||||
|
||||
self.init_done = False
|
||||
self.nmb_frames = 0
|
||||
if shape_hw is None:
|
||||
|
@ -91,11 +90,9 @@ class MovieSaver():
|
|||
shape_hw.append(3)
|
||||
self.shape_hw = shape_hw
|
||||
self.initialize()
|
||||
|
||||
|
||||
|
||||
print(f"MovieSaver initialized. fps={fps} crf={crf} pix_fmt={pix_fmt} codec={codec} preset={preset}")
|
||||
|
||||
|
||||
|
||||
def initialize(self):
|
||||
args = (
|
||||
ffmpeg
|
||||
|
@ -111,8 +108,7 @@ class MovieSaver():
|
|||
self.init_done = True
|
||||
self.shape_hw = tuple(self.shape_hw)
|
||||
print(f"Initialization done. Movie shape: {self.shape_hw}")
|
||||
|
||||
|
||||
|
||||
def write_frame(self, out_frame: np.ndarray):
|
||||
r"""
|
||||
Function to dump a numpy array as frame of a movie.
|
||||
|
@ -123,18 +119,17 @@ class MovieSaver():
|
|||
Dim 1: x
|
||||
Dim 2: RGB
|
||||
"""
|
||||
|
||||
assert out_frame.dtype == np.uint8, "Convert to np.uint8 before"
|
||||
assert len(out_frame.shape) == 3, "out_frame needs to be three dimensional, Y X C"
|
||||
assert out_frame.shape[2] == 3, f"need three color channels, but you provided {out_frame.shape[2]}."
|
||||
|
||||
|
||||
if not self.init_done:
|
||||
self.shape_hw = out_frame.shape
|
||||
self.initialize()
|
||||
|
||||
|
||||
assert self.shape_hw == out_frame.shape, f"You cannot change the image size after init. Initialized with {self.shape_hw}, out_frame {out_frame.shape}"
|
||||
|
||||
# write frame
|
||||
# write frame
|
||||
self.ffmpg_process.stdin.write(
|
||||
out_frame
|
||||
.astype(np.uint8)
|
||||
|
@ -142,8 +137,7 @@ class MovieSaver():
|
|||
)
|
||||
|
||||
self.nmb_frames += 1
|
||||
|
||||
|
||||
|
||||
def finalize(self):
|
||||
r"""
|
||||
Call this function to finalize the movie. If you forget to call it your movie will be garbage.
|
||||
|
@ -157,7 +151,6 @@ class MovieSaver():
|
|||
print(f"Movie saved, {duration}s playtime, watch here: \n{self.fp_out}")
|
||||
|
||||
|
||||
|
||||
def concatenate_movies(fp_final: str, list_fp_movies: List[str]):
|
||||
r"""
|
||||
Concatenate multiple movie segments into one long movie, using ffmpeg.
|
||||
|
@ -167,13 +160,13 @@ def concatenate_movies(fp_final: str, list_fp_movies: List[str]):
|
|||
fp_final : str
|
||||
Full path of the final movie file. Should end with .mp4
|
||||
list_fp_movies : list[str]
|
||||
List of full paths of movie segments.
|
||||
List of full paths of movie segments.
|
||||
"""
|
||||
assert fp_final[-4] == ".", "fp_final seems to miss file extension: {fp_final}"
|
||||
for fp in list_fp_movies:
|
||||
assert os.path.isfile(fp), f"Input movie does not exist: {fp}"
|
||||
assert os.path.getsize(fp) > 100, f"Input movie seems empty: {fp}"
|
||||
|
||||
|
||||
if os.path.isfile(fp_final):
|
||||
os.remove(fp_final)
|
||||
|
||||
|
@ -181,32 +174,32 @@ def concatenate_movies(fp_final: str, list_fp_movies: List[str]):
|
|||
list_concat = []
|
||||
for fp_part in list_fp_movies:
|
||||
list_concat.append(f"""file '{fp_part}'""")
|
||||
|
||||
|
||||
# save this list
|
||||
fp_list = "tmp_move.txt"
|
||||
with open(fp_list, "w") as fa:
|
||||
for item in list_concat:
|
||||
fa.write("%s\n" % item)
|
||||
|
||||
|
||||
cmd = f'ffmpeg -f concat -safe 0 -i {fp_list} -c copy {fp_final}'
|
||||
dp_movie = os.path.split(fp_final)[0]
|
||||
subprocess.call(cmd, shell=True)
|
||||
os.remove(fp_list)
|
||||
if os.path.isfile(fp_final):
|
||||
print(f"concatenate_movies: success! Watch here: {fp_final}")
|
||||
|
||||
|
||||
|
||||
class MovieReader():
|
||||
r"""
|
||||
Class to read in a movie.
|
||||
"""
|
||||
|
||||
def __init__(self, fp_movie):
|
||||
self.video_player_object = cv2.VideoCapture(fp_movie)
|
||||
self.nmb_frames = int(self.video_player_object.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
self.fps_movie = int(self.video_player_object.get(cv2.CAP_PROP_FPS))
|
||||
self.shape = [100,100,3]
|
||||
self.shape = [100, 100, 3]
|
||||
self.shape_is_set = False
|
||||
|
||||
|
||||
def get_next_frame(self):
|
||||
success, image = self.video_player_object.read()
|
||||
if success:
|
||||
|
@ -217,19 +210,18 @@ class MovieReader():
|
|||
else:
|
||||
return np.zeros(self.shape)
|
||||
|
||||
#%%
|
||||
if __name__ == "__main__":
|
||||
fps=2
|
||||
|
||||
if __name__ == "__main__":
|
||||
fps = 2
|
||||
list_fp_movies = []
|
||||
for k in range(4):
|
||||
fp_movie = f"/tmp/my_random_movie_{k}.mp4"
|
||||
list_fp_movies.append(fp_movie)
|
||||
ms = MovieSaver(fp_movie, fps=fps)
|
||||
for fn in tqdm(range(30)):
|
||||
img = (np.random.rand(512, 1024, 3)*255).astype(np.uint8)
|
||||
img = (np.random.rand(512, 1024, 3) * 255).astype(np.uint8)
|
||||
ms.write_frame(img)
|
||||
ms.finalize()
|
||||
|
||||
|
||||
fp_final = "/tmp/my_concatenated_movie.mp4"
|
||||
concatenate_movies(fp_final, list_fp_movies)
|
||||
|
||||
|
|
|
@ -13,36 +13,25 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os, sys
|
||||
dp_git = "/home/lugo/git/"
|
||||
sys.path.append(os.path.join(dp_git,'garden4'))
|
||||
sys.path.append('util')
|
||||
import os
|
||||
import torch
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.set_grad_enabled(False)
|
||||
import numpy as np
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
import time
|
||||
import subprocess
|
||||
import warnings
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
from PIL import Image
|
||||
# import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from movie_util import MovieSaver
|
||||
import datetime
|
||||
from typing import Callable, List, Optional, Union
|
||||
import inspect
|
||||
from threading import Thread
|
||||
torch.set_grad_enabled(False)
|
||||
from typing import Optional
|
||||
from omegaconf import OmegaConf
|
||||
from torch import autocast
|
||||
from contextlib import nullcontext
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from einops import repeat, rearrange
|
||||
#%%
|
||||
from utils import interpolate_spherical
|
||||
|
||||
|
||||
def pad_image(input_image):
|
||||
|
@ -53,41 +42,11 @@ def pad_image(input_image):
|
|||
return im_padded
|
||||
|
||||
|
||||
|
||||
def make_batch_inpaint(
|
||||
image,
|
||||
mask,
|
||||
txt,
|
||||
device,
|
||||
num_samples=1):
|
||||
image = np.array(image.convert("RGB"))
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
mask = np.array(mask.convert("L"))
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
mask = mask[None, None]
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
masked_image = image * (mask < 0.5)
|
||||
|
||||
batch = {
|
||||
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||
"txt": num_samples * [txt],
|
||||
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
def make_batch_superres(
|
||||
image,
|
||||
txt,
|
||||
device,
|
||||
num_samples=1,
|
||||
):
|
||||
num_samples=1):
|
||||
image = np.array(image.convert("RGB"))
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
batch = {
|
||||
|
@ -107,14 +66,14 @@ def make_noise_augmentation(model, batch, noise_level=None):
|
|||
|
||||
|
||||
class StableDiffusionHolder:
|
||||
def __init__(self,
|
||||
fp_ckpt: str = None,
|
||||
def __init__(self,
|
||||
fp_ckpt: str = None,
|
||||
fp_config: str = None,
|
||||
num_inference_steps: int = 30,
|
||||
num_inference_steps: int = 30,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
device: str = None,
|
||||
precision: str='autocast',
|
||||
precision: str = 'autocast',
|
||||
):
|
||||
r"""
|
||||
Initializes the stable diffusion holder, which contains the models and sampler.
|
||||
|
@ -122,26 +81,26 @@ class StableDiffusionHolder:
|
|||
fp_ckpt: File pointer to the .ckpt model file
|
||||
fp_config: File pointer to the .yaml config file
|
||||
num_inference_steps: Number of diffusion iterations. Will be overwritten by latent blending.
|
||||
height: Height of the resulting image.
|
||||
width: Width of the resulting image.
|
||||
height: Height of the resulting image.
|
||||
width: Width of the resulting image.
|
||||
device: Device to run the model on.
|
||||
precision: Precision to run the model on.
|
||||
"""
|
||||
self.seed = 42
|
||||
self.guidance_scale = 5.0
|
||||
|
||||
|
||||
if device is None:
|
||||
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
else:
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.init_model(fp_ckpt, fp_config)
|
||||
|
||||
self.f = 8 #downsampling factor, most often 8 or 16",
|
||||
|
||||
self.f = 8 # downsampling factor, most often 8 or 16"
|
||||
self.C = 4
|
||||
self.ddim_eta = 0
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
|
||||
if height is None and width is None:
|
||||
self.init_auto_res()
|
||||
else:
|
||||
|
@ -149,53 +108,44 @@ class StableDiffusionHolder:
|
|||
assert width is not None, "specify both width and height"
|
||||
self.height = height
|
||||
self.width = width
|
||||
|
||||
# Inpainting inits
|
||||
self.mask_empty = Image.fromarray(255*np.ones([self.width, self.height], dtype=np.uint8))
|
||||
self.image_empty = Image.fromarray(np.zeros([self.width, self.height, 3], dtype=np.uint8))
|
||||
|
||||
|
||||
self.negative_prompt = [""]
|
||||
|
||||
|
||||
|
||||
def init_model(self, fp_ckpt, fp_config):
|
||||
r"""Loads the models and sampler.
|
||||
"""
|
||||
|
||||
assert os.path.isfile(fp_ckpt), f"Your model checkpoint file does not exist: {fp_ckpt}"
|
||||
self.fp_ckpt = fp_ckpt
|
||||
|
||||
|
||||
# Auto init the config?
|
||||
if fp_config is None:
|
||||
fn_ckpt = os.path.basename(fp_ckpt)
|
||||
if 'depth' in fn_ckpt:
|
||||
fp_config = 'configs/v2-midas-inference.yaml'
|
||||
elif 'inpain' in fn_ckpt:
|
||||
fp_config = 'configs/v2-inpainting-inference.yaml'
|
||||
elif 'upscaler' in fn_ckpt:
|
||||
fp_config = 'configs/x4-upscaling.yaml'
|
||||
fp_config = 'configs/x4-upscaling.yaml'
|
||||
elif '512' in fn_ckpt:
|
||||
fp_config = 'configs/v2-inference.yaml'
|
||||
elif '768'in fn_ckpt:
|
||||
fp_config = 'configs/v2-inference-v.yaml'
|
||||
fp_config = 'configs/v2-inference.yaml'
|
||||
elif '768' in fn_ckpt:
|
||||
fp_config = 'configs/v2-inference-v.yaml'
|
||||
elif 'v1-5' in fn_ckpt:
|
||||
fp_config = 'configs/v1-inference.yaml'
|
||||
fp_config = 'configs/v1-inference.yaml'
|
||||
else:
|
||||
raise ValueError("auto detect of config failed. please specify fp_config manually!")
|
||||
|
||||
|
||||
assert os.path.isfile(fp_config), "Auto-init of the config file failed. Please specify manually."
|
||||
|
||||
|
||||
assert os.path.isfile(fp_config), f"Your config file does not exist: {fp_config}"
|
||||
|
||||
|
||||
config = OmegaConf.load(fp_config)
|
||||
|
||||
|
||||
self.model = instantiate_from_config(config.model)
|
||||
self.model.load_state_dict(torch.load(fp_ckpt)["state_dict"], strict=False)
|
||||
|
||||
self.model = self.model.to(self.device)
|
||||
self.sampler = DDIMSampler(self.model)
|
||||
|
||||
|
||||
|
||||
def init_auto_res(self):
|
||||
r"""Automatically set the resolution to the one used in training.
|
||||
"""
|
||||
|
@ -205,7 +155,7 @@ class StableDiffusionHolder:
|
|||
else:
|
||||
self.height = 512
|
||||
self.width = 512
|
||||
|
||||
|
||||
def set_negative_prompt(self, negative_prompt):
|
||||
r"""Set the negative prompt. Currenty only one negative prompt is supported
|
||||
"""
|
||||
|
@ -214,51 +164,46 @@ class StableDiffusionHolder:
|
|||
self.negative_prompt = [negative_prompt]
|
||||
else:
|
||||
self.negative_prompt = negative_prompt
|
||||
|
||||
|
||||
if len(self.negative_prompt) > 1:
|
||||
self.negative_prompt = [self.negative_prompt[0]]
|
||||
|
||||
|
||||
def get_text_embedding(self, prompt):
|
||||
c = self.model.get_learned_conditioning(prompt)
|
||||
return c
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_cond_upscaling(self, image, text_embedding, noise_level):
|
||||
r"""
|
||||
Initializes the conditioning for the x4 upscaling model.
|
||||
"""
|
||||
|
||||
image = pad_image(image) # resize to integer multiple of 32
|
||||
w, h = image.size
|
||||
noise_level = torch.Tensor(1 * [noise_level]).to(self.sampler.model.device).long()
|
||||
batch = make_batch_superres(image, txt="placeholder", device=self.device, num_samples=1)
|
||||
|
||||
x_augment, noise_level = make_noise_augmentation(self.model, batch, noise_level)
|
||||
|
||||
|
||||
cond = {"c_concat": [x_augment], "c_crossattn": [text_embedding], "c_adm": noise_level}
|
||||
# uncond cond
|
||||
uc_cross = self.model.get_unconditional_conditioning(1, "")
|
||||
uc_full = {"c_concat": [x_augment], "c_crossattn": [uc_cross], "c_adm": noise_level}
|
||||
|
||||
return cond, uc_full
|
||||
|
||||
@torch.no_grad()
|
||||
def run_diffusion_standard(
|
||||
self,
|
||||
text_embeddings: torch.FloatTensor,
|
||||
self,
|
||||
text_embeddings: torch.FloatTensor,
|
||||
latents_start: torch.FloatTensor,
|
||||
idx_start: int = 0,
|
||||
list_latents_mixing = None,
|
||||
mixing_coeffs = 0.0,
|
||||
spatial_mask = None,
|
||||
return_image: Optional[bool] = False,
|
||||
):
|
||||
idx_start: int = 0,
|
||||
list_latents_mixing=None,
|
||||
mixing_coeffs=0.0,
|
||||
spatial_mask=None,
|
||||
return_image: Optional[bool] = False):
|
||||
r"""
|
||||
Diffusion standard version.
|
||||
|
||||
Diffusion standard version.
|
||||
Args:
|
||||
text_embeddings: torch.FloatTensor
|
||||
text_embeddings: torch.FloatTensor
|
||||
Text embeddings used for diffusion
|
||||
latents_for_injection: torch.FloatTensor or list
|
||||
Latents that are used for injection
|
||||
|
@ -270,41 +215,32 @@ class StableDiffusionHolder:
|
|||
experimental feature for enforcing pixels from list_latents_mixing
|
||||
return_image: Optional[bool]
|
||||
Optionally return image directly
|
||||
|
||||
"""
|
||||
|
||||
# Asserts
|
||||
if type(mixing_coeffs) == float:
|
||||
list_mixing_coeffs = self.num_inference_steps*[mixing_coeffs]
|
||||
list_mixing_coeffs = self.num_inference_steps * [mixing_coeffs]
|
||||
elif type(mixing_coeffs) == list:
|
||||
assert len(mixing_coeffs) == self.num_inference_steps
|
||||
list_mixing_coeffs = mixing_coeffs
|
||||
else:
|
||||
raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps")
|
||||
|
||||
|
||||
if np.sum(list_mixing_coeffs) > 0:
|
||||
assert len(list_latents_mixing) == self.num_inference_steps
|
||||
|
||||
|
||||
|
||||
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
||||
|
||||
with precision_scope("cuda"):
|
||||
with self.model.ema_scope():
|
||||
if self.guidance_scale != 1.0:
|
||||
uc = self.model.get_learned_conditioning(self.negative_prompt)
|
||||
else:
|
||||
uc = None
|
||||
|
||||
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
|
||||
|
||||
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps - 1, ddim_eta=self.ddim_eta, verbose=False)
|
||||
latents = latents_start.clone()
|
||||
|
||||
timesteps = self.sampler.ddim_timesteps
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
|
||||
# collect latents
|
||||
# Collect latents
|
||||
list_latents_out = []
|
||||
for i, step in enumerate(time_range):
|
||||
# Set the right starting latents
|
||||
|
@ -313,83 +249,71 @@ class StableDiffusionHolder:
|
|||
continue
|
||||
elif i == idx_start:
|
||||
latents = latents_start.clone()
|
||||
|
||||
# Mix the latents.
|
||||
if i > 0 and list_mixing_coeffs[i]>0:
|
||||
latents_mixtarget = list_latents_mixing[i-1].clone()
|
||||
# Mix latents
|
||||
if i > 0 and list_mixing_coeffs[i] > 0:
|
||||
latents_mixtarget = list_latents_mixing[i - 1].clone()
|
||||
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
|
||||
|
||||
|
||||
if spatial_mask is not None and list_latents_mixing is not None:
|
||||
latents = interpolate_spherical(latents, list_latents_mixing[i-1], 1-spatial_mask)
|
||||
# latents[:,:,-15:,:] = latents_mixtarget[:,:,-15:,:]
|
||||
|
||||
latents = interpolate_spherical(latents, list_latents_mixing[i - 1], 1 - spatial_mask)
|
||||
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((1,), step, device=self.device, dtype=torch.long)
|
||||
outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False,
|
||||
quantize_denoised=False, temperature=1.0,
|
||||
noise_dropout=0.0, score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=self.guidance_scale,
|
||||
unconditional_conditioning=uc,
|
||||
dynamic_threshold=None)
|
||||
quantize_denoised=False, temperature=1.0,
|
||||
noise_dropout=0.0, score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=self.guidance_scale,
|
||||
unconditional_conditioning=uc,
|
||||
dynamic_threshold=None)
|
||||
latents, pred_x0 = outs
|
||||
list_latents_out.append(latents.clone())
|
||||
|
||||
if return_image:
|
||||
if return_image:
|
||||
return self.latent2image(latents)
|
||||
else:
|
||||
return list_latents_out
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def run_diffusion_upscaling(
|
||||
self,
|
||||
self,
|
||||
cond,
|
||||
uc_full,
|
||||
latents_start: torch.FloatTensor,
|
||||
idx_start: int = -1,
|
||||
list_latents_mixing = None,
|
||||
mixing_coeffs = 0.0,
|
||||
return_image: Optional[bool] = False
|
||||
):
|
||||
latents_start: torch.FloatTensor,
|
||||
idx_start: int = -1,
|
||||
list_latents_mixing: list = None,
|
||||
mixing_coeffs: float = 0.0,
|
||||
return_image: Optional[bool] = False):
|
||||
r"""
|
||||
Diffusion upscaling version.
|
||||
Diffusion upscaling version.
|
||||
"""
|
||||
|
||||
|
||||
# Asserts
|
||||
if type(mixing_coeffs) == float:
|
||||
list_mixing_coeffs = self.num_inference_steps*[mixing_coeffs]
|
||||
list_mixing_coeffs = self.num_inference_steps * [mixing_coeffs]
|
||||
elif type(mixing_coeffs) == list:
|
||||
assert len(mixing_coeffs) == self.num_inference_steps
|
||||
list_mixing_coeffs = mixing_coeffs
|
||||
else:
|
||||
raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps")
|
||||
|
||||
|
||||
if np.sum(list_mixing_coeffs) > 0:
|
||||
assert len(list_latents_mixing) == self.num_inference_steps
|
||||
|
||||
|
||||
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
||||
|
||||
h = uc_full['c_concat'][0].shape[2]
|
||||
w = uc_full['c_concat'][0].shape[3]
|
||||
|
||||
h = uc_full['c_concat'][0].shape[2]
|
||||
w = uc_full['c_concat'][0].shape[3]
|
||||
with precision_scope("cuda"):
|
||||
with self.model.ema_scope():
|
||||
|
||||
shape_latents = [self.model.channels, h, w]
|
||||
|
||||
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
|
||||
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps - 1, ddim_eta=self.ddim_eta, verbose=False)
|
||||
C, H, W = shape_latents
|
||||
size = (1, C, H, W)
|
||||
b = size[0]
|
||||
|
||||
latents = latents_start.clone()
|
||||
|
||||
timesteps = self.sampler.ddim_timesteps
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
|
||||
# collect latents
|
||||
list_latents_out = []
|
||||
for i, step in enumerate(time_range):
|
||||
|
@ -399,232 +323,40 @@ class StableDiffusionHolder:
|
|||
continue
|
||||
elif i == idx_start:
|
||||
latents = latents_start.clone()
|
||||
|
||||
# Mix the latents.
|
||||
if i > 0 and list_mixing_coeffs[i]>0:
|
||||
latents_mixtarget = list_latents_mixing[i-1].clone()
|
||||
# Mix the latents.
|
||||
if i > 0 and list_mixing_coeffs[i] > 0:
|
||||
latents_mixtarget = list_latents_mixing[i - 1].clone()
|
||||
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
|
||||
|
||||
# print(f"diffusion iter {i}")
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=self.device, dtype=torch.long)
|
||||
outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
|
||||
quantize_denoised=False, temperature=1.0,
|
||||
noise_dropout=0.0, score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=self.guidance_scale,
|
||||
unconditional_conditioning=uc_full,
|
||||
dynamic_threshold=None)
|
||||
quantize_denoised=False, temperature=1.0,
|
||||
noise_dropout=0.0, score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=self.guidance_scale,
|
||||
unconditional_conditioning=uc_full,
|
||||
dynamic_threshold=None)
|
||||
latents, pred_x0 = outs
|
||||
list_latents_out.append(latents.clone())
|
||||
|
||||
if return_image:
|
||||
return self.latent2image(latents)
|
||||
else:
|
||||
return list_latents_out
|
||||
|
||||
@torch.no_grad()
|
||||
def run_diffusion_inpaint(
|
||||
self,
|
||||
text_embeddings: torch.FloatTensor,
|
||||
latents_for_injection: torch.FloatTensor = None,
|
||||
idx_start: int = -1,
|
||||
idx_stop: int = -1,
|
||||
return_image: Optional[bool] = False
|
||||
):
|
||||
r"""
|
||||
Runs inpaint-based diffusion. Returns a list of latents that were computed.
|
||||
Adaptations allow to supply
|
||||
a) starting index for diffusion
|
||||
b) stopping index for diffusion
|
||||
c) latent representations that are injected at the starting index
|
||||
Furthermore the intermittent latents are collected and returned.
|
||||
|
||||
Adapted from diffusers (https://github.com/huggingface/diffusers)
|
||||
Args:
|
||||
text_embeddings: torch.FloatTensor
|
||||
Text embeddings used for diffusion
|
||||
latents_for_injection: torch.FloatTensor
|
||||
Latents that are used for injection
|
||||
idx_start: int
|
||||
Index of the diffusion process start and where the latents_for_injection are injected
|
||||
idx_stop: int
|
||||
Index of the diffusion process end.
|
||||
return_image: Optional[bool]
|
||||
Optionally return image directly
|
||||
|
||||
"""
|
||||
|
||||
if latents_for_injection is None:
|
||||
do_inject_latents = False
|
||||
else:
|
||||
do_inject_latents = True
|
||||
|
||||
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
||||
generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
|
||||
|
||||
with precision_scope("cuda"):
|
||||
with self.model.ema_scope():
|
||||
|
||||
batch = make_batch_inpaint(self.image_source, self.mask_image, txt="willbereplaced", device=self.device, num_samples=1)
|
||||
c = text_embeddings
|
||||
c_cat = list()
|
||||
for ck in self.model.concat_keys:
|
||||
cc = batch[ck].float()
|
||||
if ck != self.model.masked_image_key:
|
||||
bchw = [1, 4, self.height // 8, self.width // 8]
|
||||
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
||||
else:
|
||||
cc = self.model.get_first_stage_encoding(self.model.encode_first_stage(cc))
|
||||
c_cat.append(cc)
|
||||
c_cat = torch.cat(c_cat, dim=1)
|
||||
|
||||
# cond
|
||||
cond = {"c_concat": [c_cat], "c_crossattn": [c]}
|
||||
|
||||
# uncond cond
|
||||
uc_cross = self.model.get_unconditional_conditioning(1, "")
|
||||
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
||||
|
||||
shape_latents = [self.model.channels, self.height // 8, self.width // 8]
|
||||
|
||||
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=0., verbose=False)
|
||||
# sampling
|
||||
C, H, W = shape_latents
|
||||
size = (1, C, H, W)
|
||||
|
||||
device = self.model.betas.device
|
||||
b = size[0]
|
||||
latents = torch.randn(size, generator=generator, device=device)
|
||||
|
||||
timesteps = self.sampler.ddim_timesteps
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
|
||||
# collect latents
|
||||
list_latents_out = []
|
||||
for i, step in enumerate(time_range):
|
||||
if do_inject_latents:
|
||||
# Inject latent at right place
|
||||
if i < idx_start:
|
||||
continue
|
||||
elif i == idx_start:
|
||||
latents = latents_for_injection.clone()
|
||||
|
||||
if i == idx_stop:
|
||||
return list_latents_out
|
||||
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
|
||||
outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
|
||||
quantize_denoised=False, temperature=1.0,
|
||||
noise_dropout=0.0, score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=self.guidance_scale,
|
||||
unconditional_conditioning=uc_full,
|
||||
dynamic_threshold=None)
|
||||
latents, pred_x0 = outs
|
||||
list_latents_out.append(latents.clone())
|
||||
|
||||
if return_image:
|
||||
if return_image:
|
||||
return self.latent2image(latents)
|
||||
else:
|
||||
return list_latents_out
|
||||
|
||||
@torch.no_grad()
|
||||
def latent2image(
|
||||
self,
|
||||
latents: torch.FloatTensor
|
||||
):
|
||||
self,
|
||||
latents: torch.FloatTensor):
|
||||
r"""
|
||||
Returns an image provided a latent representation from diffusion.
|
||||
Args:
|
||||
latents: torch.FloatTensor
|
||||
Result of the diffusion process.
|
||||
Result of the diffusion process.
|
||||
"""
|
||||
x_sample = self.model.decode_first_stage(latents)
|
||||
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255 * x_sample[0,:,:].permute([1,2,0]).cpu().numpy()
|
||||
x_sample = 255 * x_sample[0, :, :].permute([1, 2, 0]).cpu().numpy()
|
||||
image = x_sample.astype(np.uint8)
|
||||
return image
|
||||
|
||||
@torch.no_grad()
|
||||
def interpolate_spherical(p0, p1, fract_mixing: float):
|
||||
r"""
|
||||
Helper function to correctly mix two random variables using spherical interpolation.
|
||||
See https://en.wikipedia.org/wiki/Slerp
|
||||
The function will always cast up to float64 for sake of extra 4.
|
||||
Args:
|
||||
p0:
|
||||
First tensor for interpolation
|
||||
p1:
|
||||
Second tensor for interpolation
|
||||
fract_mixing: float
|
||||
Mixing coefficient of interval [0, 1].
|
||||
0 will return in p0
|
||||
1 will return in p1
|
||||
0.x will return a mix between both preserving angular velocity.
|
||||
"""
|
||||
|
||||
if p0.dtype == torch.float16:
|
||||
recast_to = 'fp16'
|
||||
else:
|
||||
recast_to = 'fp32'
|
||||
|
||||
p0 = p0.double()
|
||||
p1 = p1.double()
|
||||
norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
|
||||
epsilon = 1e-7
|
||||
dot = torch.sum(p0 * p1) / norm
|
||||
dot = dot.clamp(-1+epsilon, 1-epsilon)
|
||||
|
||||
theta_0 = torch.arccos(dot)
|
||||
sin_theta_0 = torch.sin(theta_0)
|
||||
theta_t = theta_0 * fract_mixing
|
||||
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = torch.sin(theta_t) / sin_theta_0
|
||||
interp = p0*s0 + p1*s1
|
||||
|
||||
if recast_to == 'fp16':
|
||||
interp = interp.half()
|
||||
elif recast_to == 'fp32':
|
||||
interp = interp.float()
|
||||
|
||||
return interp
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
num_inference_steps = 20 # Number of diffusion interations
|
||||
|
||||
# fp_ckpt = "../stable_diffusion_models/ckpt/768-v-ema.ckpt"
|
||||
# fp_config = '../stablediffusion/configs/stable-diffusion/v2-inference-v.yaml'
|
||||
|
||||
# fp_ckpt= "../stable_diffusion_models/ckpt/512-inpainting-ema.ckpt"
|
||||
# fp_config = '../stablediffusion/configs//stable-diffusion/v2-inpainting-inference.yaml'
|
||||
|
||||
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
|
||||
# fp_config = 'configs/v2-inference-v.yaml'
|
||||
|
||||
|
||||
self = StableDiffusionHolder(fp_ckpt, num_inference_steps=num_inference_steps)
|
||||
|
||||
xxx
|
||||
|
||||
#%%
|
||||
self.width = 1536
|
||||
self.height = 768
|
||||
prompt = "360 degree equirectangular, a huge rocky hill full of pianos and keyboards, musical instruments, cinematic, masterpiece 8 k, artstation"
|
||||
self.set_negative_prompt("out of frame, faces, rendering, blurry")
|
||||
te = self.get_text_embedding(prompt)
|
||||
|
||||
img = self.run_diffusion_standard(te, return_image=True)
|
||||
Image.fromarray(img).show()
|
||||
|
||||
|
|
|
@ -0,0 +1,260 @@
|
|||
# Copyright 2022 Lunar Ring. All rights reserved.
|
||||
# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
torch.backends.cudnn.benchmark = False
|
||||
import numpy as np
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
import time
|
||||
import warnings
|
||||
import datetime
|
||||
from typing import List, Union
|
||||
torch.set_grad_enabled(False)
|
||||
import yaml
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def interpolate_spherical(p0, p1, fract_mixing: float):
|
||||
r"""
|
||||
Helper function to correctly mix two random variables using spherical interpolation.
|
||||
See https://en.wikipedia.org/wiki/Slerp
|
||||
The function will always cast up to float64 for sake of extra 4.
|
||||
Args:
|
||||
p0:
|
||||
First tensor for interpolation
|
||||
p1:
|
||||
Second tensor for interpolation
|
||||
fract_mixing: float
|
||||
Mixing coefficient of interval [0, 1].
|
||||
0 will return in p0
|
||||
1 will return in p1
|
||||
0.x will return a mix between both preserving angular velocity.
|
||||
"""
|
||||
|
||||
if p0.dtype == torch.float16:
|
||||
recast_to = 'fp16'
|
||||
else:
|
||||
recast_to = 'fp32'
|
||||
|
||||
p0 = p0.double()
|
||||
p1 = p1.double()
|
||||
norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
|
||||
epsilon = 1e-7
|
||||
dot = torch.sum(p0 * p1) / norm
|
||||
dot = dot.clamp(-1 + epsilon, 1 - epsilon)
|
||||
|
||||
theta_0 = torch.arccos(dot)
|
||||
sin_theta_0 = torch.sin(theta_0)
|
||||
theta_t = theta_0 * fract_mixing
|
||||
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = torch.sin(theta_t) / sin_theta_0
|
||||
interp = p0 * s0 + p1 * s1
|
||||
|
||||
if recast_to == 'fp16':
|
||||
interp = interp.half()
|
||||
elif recast_to == 'fp32':
|
||||
interp = interp.float()
|
||||
|
||||
return interp
|
||||
|
||||
|
||||
def interpolate_linear(p0, p1, fract_mixing):
|
||||
r"""
|
||||
Helper function to mix two variables using standard linear interpolation.
|
||||
Args:
|
||||
p0:
|
||||
First tensor / np.ndarray for interpolation
|
||||
p1:
|
||||
Second tensor / np.ndarray for interpolation
|
||||
fract_mixing: float
|
||||
Mixing coefficient of interval [0, 1].
|
||||
0 will return in p0
|
||||
1 will return in p1
|
||||
0.x will return a linear mix between both.
|
||||
"""
|
||||
reconvert_uint8 = False
|
||||
if type(p0) is np.ndarray and p0.dtype == 'uint8':
|
||||
reconvert_uint8 = True
|
||||
p0 = p0.astype(np.float64)
|
||||
|
||||
if type(p1) is np.ndarray and p1.dtype == 'uint8':
|
||||
reconvert_uint8 = True
|
||||
p1 = p1.astype(np.float64)
|
||||
|
||||
interp = (1 - fract_mixing) * p0 + fract_mixing * p1
|
||||
|
||||
if reconvert_uint8:
|
||||
interp = np.clip(interp, 0, 255).astype(np.uint8)
|
||||
|
||||
return interp
|
||||
|
||||
|
||||
def add_frames_linear_interp(
|
||||
list_imgs: List[np.ndarray],
|
||||
fps_target: Union[float, int] = None,
|
||||
duration_target: Union[float, int] = None,
|
||||
nmb_frames_target: int = None):
|
||||
r"""
|
||||
Helper function to cheaply increase the number of frames given a list of images,
|
||||
by virtue of standard linear interpolation.
|
||||
The number of inserted frames will be automatically adjusted so that the total of number
|
||||
of frames can be fixed precisely, using a random shuffling technique.
|
||||
The function allows 1:1 comparisons between transitions as videos.
|
||||
|
||||
Args:
|
||||
list_imgs: List[np.ndarray)
|
||||
List of images, between each image new frames will be inserted via linear interpolation.
|
||||
fps_target:
|
||||
OptionA: specify here the desired frames per second.
|
||||
duration_target:
|
||||
OptionA: specify here the desired duration of the transition in seconds.
|
||||
nmb_frames_target:
|
||||
OptionB: directly fix the total number of frames of the output.
|
||||
"""
|
||||
|
||||
# Sanity
|
||||
if nmb_frames_target is not None and fps_target is not None:
|
||||
raise ValueError("You cannot specify both fps_target and nmb_frames_target")
|
||||
if fps_target is None:
|
||||
assert nmb_frames_target is not None, "Either specify nmb_frames_target or nmb_frames_target"
|
||||
if nmb_frames_target is None:
|
||||
assert fps_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
|
||||
assert duration_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
|
||||
nmb_frames_target = fps_target * duration_target
|
||||
|
||||
# Get number of frames that are missing
|
||||
nmb_frames_diff = len(list_imgs) - 1
|
||||
nmb_frames_missing = nmb_frames_target - nmb_frames_diff - 1
|
||||
|
||||
if nmb_frames_missing < 1:
|
||||
return list_imgs
|
||||
|
||||
list_imgs_float = [img.astype(np.float32) for img in list_imgs]
|
||||
# Distribute missing frames, append nmb_frames_to_insert(i) frames for each frame
|
||||
mean_nmb_frames_insert = nmb_frames_missing / nmb_frames_diff
|
||||
constfact = np.floor(mean_nmb_frames_insert)
|
||||
remainder_x = 1 - (mean_nmb_frames_insert - constfact)
|
||||
nmb_iter = 0
|
||||
while True:
|
||||
nmb_frames_to_insert = np.random.rand(nmb_frames_diff)
|
||||
nmb_frames_to_insert[nmb_frames_to_insert <= remainder_x] = 0
|
||||
nmb_frames_to_insert[nmb_frames_to_insert > remainder_x] = 1
|
||||
nmb_frames_to_insert += constfact
|
||||
if np.sum(nmb_frames_to_insert) == nmb_frames_missing:
|
||||
break
|
||||
nmb_iter += 1
|
||||
if nmb_iter > 100000:
|
||||
print("add_frames_linear_interp: issue with inserting the right number of frames")
|
||||
break
|
||||
|
||||
nmb_frames_to_insert = nmb_frames_to_insert.astype(np.int32)
|
||||
list_imgs_interp = []
|
||||
for i in range(len(list_imgs_float) - 1):
|
||||
img0 = list_imgs_float[i]
|
||||
img1 = list_imgs_float[i + 1]
|
||||
list_imgs_interp.append(img0.astype(np.uint8))
|
||||
list_fracts_linblend = np.linspace(0, 1, nmb_frames_to_insert[i] + 2)[1:-1]
|
||||
for fract_linblend in list_fracts_linblend:
|
||||
img_blend = interpolate_linear(img0, img1, fract_linblend).astype(np.uint8)
|
||||
list_imgs_interp.append(img_blend.astype(np.uint8))
|
||||
if i == len(list_imgs_float) - 2:
|
||||
list_imgs_interp.append(img1.astype(np.uint8))
|
||||
|
||||
return list_imgs_interp
|
||||
|
||||
|
||||
def get_spacing(nmb_points: int, scaling: float):
|
||||
"""
|
||||
Helper function for getting nonlinear spacing between 0 and 1, symmetric around 0.5
|
||||
Args:
|
||||
nmb_points: int
|
||||
Number of points between [0, 1]
|
||||
scaling: float
|
||||
Higher values will return higher sampling density around 0.5
|
||||
"""
|
||||
if scaling < 1.7:
|
||||
return np.linspace(0, 1, nmb_points)
|
||||
nmb_points_per_side = nmb_points // 2 + 1
|
||||
if np.mod(nmb_points, 2) != 0: # Uneven case
|
||||
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)
|
||||
right_side = 1 - left_side[::-1][1:]
|
||||
else:
|
||||
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)[0:-1]
|
||||
right_side = 1 - left_side[::-1]
|
||||
all_fracts = np.hstack([left_side, right_side])
|
||||
return all_fracts
|
||||
|
||||
|
||||
def get_time(resolution=None):
|
||||
"""
|
||||
Helper function returning an nicely formatted time string, e.g. 221117_1620
|
||||
"""
|
||||
if resolution is None:
|
||||
resolution = "second"
|
||||
if resolution == "day":
|
||||
t = time.strftime('%y%m%d', time.localtime())
|
||||
elif resolution == "minute":
|
||||
t = time.strftime('%y%m%d_%H%M', time.localtime())
|
||||
elif resolution == "second":
|
||||
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
|
||||
elif resolution == "millisecond":
|
||||
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
|
||||
t += "_"
|
||||
t += str("{:03d}".format(int(int(datetime.utcnow().strftime('%f')) / 1000)))
|
||||
else:
|
||||
raise ValueError("bad resolution provided: %s" % resolution)
|
||||
return t
|
||||
|
||||
|
||||
def compare_dicts(a, b):
|
||||
"""
|
||||
Compares two dictionaries a and b and returns a dictionary c, with all
|
||||
keys,values that have shared keys in a and b but same values in a and b.
|
||||
The values of a and b are stacked together in the output.
|
||||
Example:
|
||||
a = {}; a['bobo'] = 4
|
||||
b = {}; b['bobo'] = 5
|
||||
c = dict_compare(a,b)
|
||||
c = {"bobo",[4,5]}
|
||||
"""
|
||||
c = {}
|
||||
for key in a.keys():
|
||||
if key in b.keys():
|
||||
val_a = a[key]
|
||||
val_b = b[key]
|
||||
if val_a != val_b:
|
||||
c[key] = [val_a, val_b]
|
||||
return c
|
||||
|
||||
|
||||
def yml_load(fp_yml, print_fields=False):
|
||||
"""
|
||||
Helper function for loading yaml files
|
||||
"""
|
||||
with open(fp_yml) as f:
|
||||
data = yaml.load(f, Loader=yaml.loader.SafeLoader)
|
||||
dict_data = dict(data)
|
||||
print("load: loaded {}".format(fp_yml))
|
||||
return dict_data
|
||||
|
||||
|
||||
def yml_save(fp_yml, dict_stuff):
|
||||
"""
|
||||
Helper function for saving yaml files
|
||||
"""
|
||||
with open(fp_yml, 'w') as f:
|
||||
yaml.dump(dict_stuff, f, sort_keys=False, default_flow_style=False)
|
||||
print("yml_save: saved {}".format(fp_yml))
|
Loading…
Reference in New Issue