cleanup
This commit is contained in:
parent
3ed876e0ee
commit
297bb9abe6
|
@ -13,39 +13,31 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os, sys
|
|
||||||
import torch
|
import torch
|
||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
import numpy as np
|
torch.set_grad_enabled(False)
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
import warnings
|
import warnings
|
||||||
import torch
|
from latent_blending import LatentBlending
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from PIL import Image
|
|
||||||
# import matplotlib.pyplot as plt
|
|
||||||
import torch
|
|
||||||
from movie_util import MovieSaver
|
|
||||||
from typing import Callable, List, Optional, Union
|
|
||||||
from latent_blending import LatentBlending, add_frames_linear_interp
|
|
||||||
from stable_diffusion_holder import StableDiffusionHolder
|
from stable_diffusion_holder import StableDiffusionHolder
|
||||||
torch.set_grad_enabled(False)
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
|
||||||
#%% First let us spawn a stable diffusion holder
|
# %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
|
||||||
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
|
# fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
|
||||||
|
fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.ckpt")
|
||||||
sdh = StableDiffusionHolder(fp_ckpt)
|
sdh = StableDiffusionHolder(fp_ckpt)
|
||||||
|
# %% Next let's set up all parameters
|
||||||
#%% Next let's set up all parameters
|
depth_strength = 0.65 # Specifies how deep (in terms of diffusion iterations the first branching happens)
|
||||||
depth_strength = 0.65 # Specifies how deep (in terms of diffusion iterations the first branching happens)
|
t_compute_max_allowed = 15 # Determines the quality of the transition in terms of compute time you grant it
|
||||||
t_compute_max_allowed = 15 # Determines the quality of the transition in terms of compute time you grant it
|
|
||||||
fixed_seeds = [69731932, 504430820]
|
fixed_seeds = [69731932, 504430820]
|
||||||
|
|
||||||
prompt1 = "photo of a beautiful cherry forest covered in white flowers, ambient light, very detailed, magic"
|
prompt1 = "photo of a beautiful cherry forest covered in white flowers, ambient light, very detailed, magic"
|
||||||
prompt2 = "photo of an golden statue with a funny hat, surrounded by ferns and vines, grainy analog photograph, mystical ambience, incredible detail"
|
prompt2 = "photo of an golden statue with a funny hat, surrounded by ferns and vines, grainy analog photograph, mystical ambience, incredible detail"
|
||||||
|
|
||||||
fp_movie = 'movie_example1.mp4'
|
fp_movie = 'movie_example1.mp4'
|
||||||
duration_transition = 12 # In seconds
|
duration_transition = 12 # In seconds
|
||||||
|
|
||||||
# Spawn latent blending
|
# Spawn latent blending
|
||||||
lb = LatentBlending(sdh)
|
lb = LatentBlending(sdh)
|
||||||
|
@ -54,10 +46,9 @@ lb.set_prompt2(prompt2)
|
||||||
|
|
||||||
# Run latent blending
|
# Run latent blending
|
||||||
lb.run_transition(
|
lb.run_transition(
|
||||||
depth_strength = depth_strength,
|
depth_strength=depth_strength,
|
||||||
t_compute_max_allowed = t_compute_max_allowed,
|
t_compute_max_allowed=t_compute_max_allowed,
|
||||||
fixed_seeds = fixed_seeds
|
fixed_seeds=fixed_seeds)
|
||||||
)
|
|
||||||
|
|
||||||
# Save movie
|
# Save movie
|
||||||
lb.write_movie_transition(fp_movie, duration_transition)
|
lb.write_movie_transition(fp_movie, duration_transition)
|
|
@ -13,33 +13,26 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os, sys
|
|
||||||
import torch
|
import torch
|
||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
import numpy as np
|
torch.set_grad_enabled(False)
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
import warnings
|
import warnings
|
||||||
import torch
|
from latent_blending import LatentBlending
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
from movie_util import MovieSaver, concatenate_movies
|
|
||||||
from typing import Callable, List, Optional, Union
|
|
||||||
from latent_blending import LatentBlending, add_frames_linear_interp
|
|
||||||
from stable_diffusion_holder import StableDiffusionHolder
|
from stable_diffusion_holder import StableDiffusionHolder
|
||||||
torch.set_grad_enabled(False)
|
from movie_util import concatenate_movies
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
#%% First let us spawn a stable diffusion holder
|
# %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
|
||||||
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
|
# fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
|
||||||
# fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
|
fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.ckpt")
|
||||||
sdh = StableDiffusionHolder(fp_ckpt)
|
sdh = StableDiffusionHolder(fp_ckpt)
|
||||||
|
|
||||||
|
# %% Let's setup the multi transition
|
||||||
#%% Let's setup the multi transition
|
|
||||||
fps = 30
|
fps = 30
|
||||||
duration_single_trans = 6
|
duration_single_trans = 6
|
||||||
depth_strength = 0.55 #Specifies how deep (in terms of diffusion iterations the first branching happens)
|
depth_strength = 0.55 # Specifies how deep (in terms of diffusion iterations the first branching happens)
|
||||||
|
|
||||||
# Specify a list of prompts below
|
# Specify a list of prompts below
|
||||||
list_prompts = []
|
list_prompts = []
|
||||||
|
@ -52,32 +45,29 @@ list_prompts.append("statue of an ancient cybernetic messenger annoucing good ne
|
||||||
|
|
||||||
# You can optionally specify the seeds
|
# You can optionally specify the seeds
|
||||||
list_seeds = [954375479, 332539350, 956051013, 408831845, 250009012, 675588737]
|
list_seeds = [954375479, 332539350, 956051013, 408831845, 250009012, 675588737]
|
||||||
t_compute_max_allowed = 12 # per segment
|
t_compute_max_allowed = 12 # per segment
|
||||||
fp_movie = 'movie_example2.mp4'
|
fp_movie = 'movie_example2.mp4'
|
||||||
lb = LatentBlending(sdh)
|
lb = LatentBlending(sdh)
|
||||||
|
|
||||||
list_movie_parts = [] #
|
list_movie_parts = []
|
||||||
for i in range(len(list_prompts)-1):
|
for i in range(len(list_prompts) - 1):
|
||||||
# For a multi transition we can save some computation time and recycle the latents
|
# For a multi transition we can save some computation time and recycle the latents
|
||||||
if i==0:
|
if i == 0:
|
||||||
lb.set_prompt1(list_prompts[i])
|
lb.set_prompt1(list_prompts[i])
|
||||||
lb.set_prompt2(list_prompts[i+1])
|
lb.set_prompt2(list_prompts[i + 1])
|
||||||
recycle_img1 = False
|
recycle_img1 = False
|
||||||
else:
|
else:
|
||||||
lb.swap_forward()
|
lb.swap_forward()
|
||||||
lb.set_prompt2(list_prompts[i+1])
|
lb.set_prompt2(list_prompts[i + 1])
|
||||||
recycle_img1 = True
|
recycle_img1 = True
|
||||||
|
|
||||||
fp_movie_part = f"tmp_part_{str(i).zfill(3)}.mp4"
|
fp_movie_part = f"tmp_part_{str(i).zfill(3)}.mp4"
|
||||||
|
fixed_seeds = list_seeds[i:i + 2]
|
||||||
fixed_seeds = list_seeds[i:i+2]
|
|
||||||
|
|
||||||
# Run latent blending
|
# Run latent blending
|
||||||
lb.run_transition(
|
lb.run_transition(
|
||||||
depth_strength = depth_strength,
|
depth_strength=depth_strength,
|
||||||
t_compute_max_allowed = t_compute_max_allowed,
|
t_compute_max_allowed=t_compute_max_allowed,
|
||||||
fixed_seeds = fixed_seeds
|
fixed_seeds=fixed_seeds)
|
||||||
)
|
|
||||||
|
|
||||||
# Save movie
|
# Save movie
|
||||||
lb.write_movie_transition(fp_movie_part, duration_single_trans)
|
lb.write_movie_transition(fp_movie_part, duration_single_trans)
|
||||||
|
|
|
@ -13,25 +13,17 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os, sys
|
|
||||||
import torch
|
import torch
|
||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
import numpy as np
|
torch.set_grad_enabled(False)
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
import warnings
|
import warnings
|
||||||
import torch
|
from latent_blending import LatentBlending
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from PIL import Image
|
|
||||||
# import matplotlib.pyplot as plt
|
|
||||||
import torch
|
|
||||||
from movie_util import MovieSaver
|
|
||||||
from typing import Callable, List, Optional, Union
|
|
||||||
from latent_blending import LatentBlending, add_frames_linear_interp
|
|
||||||
from stable_diffusion_holder import StableDiffusionHolder
|
from stable_diffusion_holder import StableDiffusionHolder
|
||||||
torch.set_grad_enabled(False)
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
#%% Define vars for low-resoltion pass
|
# %% Define vars for low-resoltion pass
|
||||||
prompt1 = "photo of mount vesuvius erupting a terrifying pyroclastic ash cloud"
|
prompt1 = "photo of mount vesuvius erupting a terrifying pyroclastic ash cloud"
|
||||||
prompt2 = "photo of a inside a building full of ash, fire, death, destruction, explosions"
|
prompt2 = "photo of a inside a building full of ash, fire, death, destruction, explosions"
|
||||||
fixed_seeds = [5054613, 1168652]
|
fixed_seeds = [5054613, 1168652]
|
||||||
|
@ -41,21 +33,18 @@ height = 384
|
||||||
num_inference_steps_lores = 40
|
num_inference_steps_lores = 40
|
||||||
nmb_max_branches_lores = 10
|
nmb_max_branches_lores = 10
|
||||||
depth_strength_lores = 0.5
|
depth_strength_lores = 0.5
|
||||||
|
fp_ckpt_lores = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
|
||||||
|
|
||||||
fp_ckpt_lores = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
|
# %% Define vars for high-resoltion pass
|
||||||
|
fp_ckpt_hires = hf_hub_download(repo_id="stabilityai/stable-diffusion-x4-upscaler", filename="x4-upscaler-ema.ckpt")
|
||||||
#%% Define vars for high-resoltion pass
|
|
||||||
fp_ckpt_hires = "../stable_diffusion_models/ckpt/x4-upscaler-ema.ckpt"
|
|
||||||
depth_strength_hires = 0.65
|
depth_strength_hires = 0.65
|
||||||
num_inference_steps_hires = 100
|
num_inference_steps_hires = 100
|
||||||
nmb_branches_final_hires = 6
|
nmb_branches_final_hires = 6
|
||||||
dp_imgs = "tmp_transition" # folder for results and intermediate steps
|
dp_imgs = "tmp_transition" # Folder for results and intermediate steps
|
||||||
|
|
||||||
|
|
||||||
#%% Run low-res pass
|
# %% Run low-res pass
|
||||||
sdh = StableDiffusionHolder(fp_ckpt_lores)
|
sdh = StableDiffusionHolder(fp_ckpt_lores)
|
||||||
|
|
||||||
#%%
|
|
||||||
lb = LatentBlending(sdh)
|
lb = LatentBlending(sdh)
|
||||||
lb.set_prompt1(prompt1)
|
lb.set_prompt1(prompt1)
|
||||||
lb.set_prompt2(prompt2)
|
lb.set_prompt2(prompt2)
|
||||||
|
@ -64,14 +53,13 @@ lb.set_height(height)
|
||||||
|
|
||||||
# Run latent blending
|
# Run latent blending
|
||||||
lb.run_transition(
|
lb.run_transition(
|
||||||
depth_strength = depth_strength_lores,
|
depth_strength=depth_strength_lores,
|
||||||
nmb_max_branches = nmb_max_branches_lores,
|
nmb_max_branches=nmb_max_branches_lores,
|
||||||
fixed_seeds = fixed_seeds
|
fixed_seeds=fixed_seeds)
|
||||||
)
|
|
||||||
|
|
||||||
lb.write_imgs_transition(dp_imgs)
|
lb.write_imgs_transition(dp_imgs)
|
||||||
|
|
||||||
#%% Run high-res pass
|
# %% Run high-res pass
|
||||||
sdh = StableDiffusionHolder(fp_ckpt_hires)
|
sdh = StableDiffusionHolder(fp_ckpt_hires)
|
||||||
lb = LatentBlending(sdh)
|
lb = LatentBlending(sdh)
|
||||||
lb.run_upscaling(dp_imgs, depth_strength_hires, num_inference_steps_hires, nmb_branches_final_hires)
|
lb.run_upscaling(dp_imgs, depth_strength_hires, num_inference_steps_hires, nmb_branches_final_hires)
|
||||||
|
|
|
@ -13,25 +13,19 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os, sys
|
import os
|
||||||
import torch
|
import torch
|
||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
import numpy as np
|
torch.set_grad_enabled(False)
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
import warnings
|
import warnings
|
||||||
import torch
|
from latent_blending import LatentBlending
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from PIL import Image
|
|
||||||
# import matplotlib.pyplot as plt
|
|
||||||
import torch
|
|
||||||
from movie_util import MovieSaver, concatenate_movies
|
|
||||||
from typing import Callable, List, Optional, Union
|
|
||||||
from latent_blending import LatentBlending, add_frames_linear_interp
|
|
||||||
from stable_diffusion_holder import StableDiffusionHolder
|
from stable_diffusion_holder import StableDiffusionHolder
|
||||||
torch.set_grad_enabled(False)
|
from movie_util import concatenate_movies
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
#%% Define vars for low-resoltion pass
|
# %% Define vars for low-resoltion pass
|
||||||
list_prompts = []
|
list_prompts = []
|
||||||
list_prompts.append("surrealistic statue made of glitter and dirt, standing in a lake, atmospheric light, strange glow")
|
list_prompts.append("surrealistic statue made of glitter and dirt, standing in a lake, atmospheric light, strange glow")
|
||||||
list_prompts.append("statue of a mix between a tree and human, made of marble, incredibly detailed")
|
list_prompts.append("statue of a mix between a tree and human, made of marble, incredibly detailed")
|
||||||
|
@ -50,56 +44,54 @@ num_inference_steps_lores = 40
|
||||||
nmb_max_branches_lores = 10
|
nmb_max_branches_lores = 10
|
||||||
depth_strength_lores = 0.5
|
depth_strength_lores = 0.5
|
||||||
|
|
||||||
fp_ckpt_lores = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
|
fp_ckpt_lores = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
|
||||||
|
|
||||||
#%% Define vars for high-resoltion pass
|
# %% Define vars for high-resoltion pass
|
||||||
fp_ckpt_hires = "../stable_diffusion_models/ckpt/x4-upscaler-ema.ckpt"
|
fp_ckpt_hires = hf_hub_download(repo_id="stabilityai/stable-diffusion-x4-upscaler", filename="x4-upscaler-ema.ckpt")
|
||||||
depth_strength_hires = 0.65
|
depth_strength_hires = 0.65
|
||||||
num_inference_steps_hires = 100
|
num_inference_steps_hires = 100
|
||||||
nmb_branches_final_hires = 6
|
nmb_branches_final_hires = 6
|
||||||
#%% Run low-res pass
|
|
||||||
|
# %% Run low-res pass
|
||||||
sdh = StableDiffusionHolder(fp_ckpt_lores)
|
sdh = StableDiffusionHolder(fp_ckpt_lores)
|
||||||
t_compute_max_allowed = 12 # per segment
|
t_compute_max_allowed = 12 # Per segment
|
||||||
lb = LatentBlending(sdh)
|
lb = LatentBlending(sdh)
|
||||||
|
|
||||||
list_movie_dirs = [] #
|
list_movie_dirs = []
|
||||||
for i in range(len(list_prompts)-1):
|
for i in range(len(list_prompts) - 1):
|
||||||
# For a multi transition we can save some computation time and recycle the latents
|
# For a multi transition we can save some computation time and recycle the latents
|
||||||
if i==0:
|
if i == 0:
|
||||||
lb.set_prompt1(list_prompts[i])
|
lb.set_prompt1(list_prompts[i])
|
||||||
lb.set_prompt2(list_prompts[i+1])
|
lb.set_prompt2(list_prompts[i + 1])
|
||||||
recycle_img1 = False
|
recycle_img1 = False
|
||||||
else:
|
else:
|
||||||
lb.swap_forward()
|
lb.swap_forward()
|
||||||
lb.set_prompt2(list_prompts[i+1])
|
lb.set_prompt2(list_prompts[i + 1])
|
||||||
recycle_img1 = True
|
recycle_img1 = True
|
||||||
|
|
||||||
dp_movie_part = f"tmp_part_{str(i).zfill(3)}"
|
dp_movie_part = f"tmp_part_{str(i).zfill(3)}"
|
||||||
fp_movie_part = os.path.join(dp_movie_part, "movie_lowres.mp4")
|
fp_movie_part = os.path.join(dp_movie_part, "movie_lowres.mp4")
|
||||||
os.makedirs(dp_movie_part, exist_ok=True)
|
os.makedirs(dp_movie_part, exist_ok=True)
|
||||||
fixed_seeds = list_seeds[i:i+2]
|
fixed_seeds = list_seeds[i:i + 2]
|
||||||
|
|
||||||
# Run latent blending
|
# Run latent blending
|
||||||
lb.run_transition(
|
lb.run_transition(
|
||||||
depth_strength = depth_strength_lores,
|
depth_strength=depth_strength_lores,
|
||||||
nmb_max_branches = nmb_max_branches_lores,
|
nmb_max_branches=nmb_max_branches_lores,
|
||||||
fixed_seeds = fixed_seeds
|
fixed_seeds=fixed_seeds)
|
||||||
)
|
|
||||||
|
|
||||||
# Save movie and images (needed for upscaling!)
|
# Save movie and images (needed for upscaling!)
|
||||||
lb.write_movie_transition(fp_movie_part, duration_single_trans)
|
lb.write_movie_transition(fp_movie_part, duration_single_trans)
|
||||||
lb.write_imgs_transition(dp_movie_part)
|
lb.write_imgs_transition(dp_movie_part)
|
||||||
list_movie_dirs.append(dp_movie_part)
|
list_movie_dirs.append(dp_movie_part)
|
||||||
|
|
||||||
|
# %% Run high-res pass on each segment
|
||||||
|
|
||||||
#%% Run high-res pass on each segment
|
|
||||||
sdh = StableDiffusionHolder(fp_ckpt_hires)
|
sdh = StableDiffusionHolder(fp_ckpt_hires)
|
||||||
lb = LatentBlending(sdh)
|
lb = LatentBlending(sdh)
|
||||||
for dp_part in list_movie_dirs:
|
for dp_part in list_movie_dirs:
|
||||||
lb.run_upscaling(dp_part, depth_strength_hires, num_inference_steps_hires, nmb_branches_final_hires)
|
lb.run_upscaling(dp_part, depth_strength_hires, num_inference_steps_hires, nmb_branches_final_hires)
|
||||||
|
|
||||||
#%% concatenate into one long movie
|
# %% concatenate into one long movie
|
||||||
list_fp_movies = []
|
list_fp_movies = []
|
||||||
for dp_part in list_movie_dirs:
|
for dp_part in list_movie_dirs:
|
||||||
fp_movie = os.path.join(dp_part, "movie_highres.mp4")
|
fp_movie = os.path.join(dp_part, "movie_highres.mp4")
|
||||||
|
|
314
gradio_ui.py
314
gradio_ui.py
|
@ -13,82 +13,89 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os, sys
|
import os
|
||||||
import torch
|
import torch
|
||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
|
torch.set_grad_enabled(False)
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
import warnings
|
import warnings
|
||||||
import torch
|
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import torch
|
|
||||||
from movie_util import MovieSaver, concatenate_movies
|
from movie_util import MovieSaver, concatenate_movies
|
||||||
from typing import Callable, List, Optional, Union
|
from latent_blending import LatentBlending
|
||||||
from latent_blending import get_time, yml_save, LatentBlending, add_frames_linear_interp, compare_dicts
|
|
||||||
from stable_diffusion_holder import StableDiffusionHolder
|
from stable_diffusion_holder import StableDiffusionHolder
|
||||||
torch.set_grad_enabled(False)
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import copy
|
|
||||||
from dotenv import find_dotenv, load_dotenv
|
from dotenv import find_dotenv, load_dotenv
|
||||||
import shutil
|
import shutil
|
||||||
import random
|
import random
|
||||||
import time
|
from utils import get_time, add_frames_linear_interp
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
|
||||||
#%%
|
|
||||||
|
|
||||||
class BlendingFrontend():
|
class BlendingFrontend():
|
||||||
def __init__(self, sdh=None):
|
def __init__(
|
||||||
self.num_inference_steps = 30
|
self,
|
||||||
if sdh is None:
|
sdh,
|
||||||
self.use_debug = True
|
share=False):
|
||||||
self.height = 768
|
r"""
|
||||||
self.width = 768
|
Gradio Helper Class to collect UI data and start latent blending.
|
||||||
else:
|
Args:
|
||||||
self.use_debug = False
|
sdh:
|
||||||
self.lb = LatentBlending(sdh)
|
StableDiffusionHolder
|
||||||
self.lb.sdh.num_inference_steps = self.num_inference_steps
|
share: bool
|
||||||
self.height = self.lb.sdh.height
|
Set true to get a shareable gradio link (e.g. for running a remote server)
|
||||||
self.width = self.lb.sdh.width
|
"""
|
||||||
|
self.share = share
|
||||||
|
|
||||||
self.init_save_dir()
|
# UI Defaults
|
||||||
self.save_empty_image()
|
self.num_inference_steps = 30
|
||||||
self.share = False
|
|
||||||
self.transition_can_be_computed = False
|
|
||||||
self.depth_strength = 0.25
|
self.depth_strength = 0.25
|
||||||
self.seed1 = 420
|
self.seed1 = 420
|
||||||
self.seed2 = 420
|
self.seed2 = 420
|
||||||
self.guidance_scale = 4.0
|
|
||||||
self.guidance_scale_mid_damper = 0.5
|
|
||||||
self.mid_compression_scaler = 1.2
|
|
||||||
self.prompt1 = ""
|
self.prompt1 = ""
|
||||||
self.prompt2 = ""
|
self.prompt2 = ""
|
||||||
self.negative_prompt = ""
|
self.negative_prompt = ""
|
||||||
self.state_current = {}
|
self.fps = 30
|
||||||
|
self.duration_video = 8
|
||||||
|
self.t_compute_max_allowed = 10
|
||||||
|
|
||||||
|
self.lb = LatentBlending(sdh)
|
||||||
|
self.lb.sdh.num_inference_steps = self.num_inference_steps
|
||||||
|
self.init_parameters_from_lb()
|
||||||
|
self.init_save_dir()
|
||||||
|
|
||||||
|
# Vars
|
||||||
|
self.list_fp_imgs_current = []
|
||||||
|
self.recycle_img1 = False
|
||||||
|
self.recycle_img2 = False
|
||||||
|
self.list_all_segments = []
|
||||||
|
self.dp_session = ""
|
||||||
|
self.user_id = None
|
||||||
|
|
||||||
|
def init_parameters_from_lb(self):
|
||||||
|
r"""
|
||||||
|
Automatically init parameters from latentblending instance
|
||||||
|
"""
|
||||||
|
self.height = self.lb.sdh.height
|
||||||
|
self.width = self.lb.sdh.width
|
||||||
|
self.guidance_scale = self.lb.guidance_scale
|
||||||
|
self.guidance_scale_mid_damper = self.lb.guidance_scale_mid_damper
|
||||||
|
self.mid_compression_scaler = self.lb.mid_compression_scaler
|
||||||
self.branch1_crossfeed_power = self.lb.branch1_crossfeed_power
|
self.branch1_crossfeed_power = self.lb.branch1_crossfeed_power
|
||||||
self.branch1_crossfeed_range = self.lb.branch1_crossfeed_range
|
self.branch1_crossfeed_range = self.lb.branch1_crossfeed_range
|
||||||
self.branch1_crossfeed_decay = self.lb.branch1_crossfeed_decay
|
self.branch1_crossfeed_decay = self.lb.branch1_crossfeed_decay
|
||||||
self.parental_crossfeed_power = self.lb.parental_crossfeed_power
|
self.parental_crossfeed_power = self.lb.parental_crossfeed_power
|
||||||
self.parental_crossfeed_range = self.lb.parental_crossfeed_range
|
self.parental_crossfeed_range = self.lb.parental_crossfeed_range
|
||||||
self.parental_crossfeed_power_decay = self.lb.parental_crossfeed_power_decay
|
self.parental_crossfeed_power_decay = self.lb.parental_crossfeed_power_decay
|
||||||
self.fps = 30
|
|
||||||
self.duration_video = 10
|
|
||||||
self.t_compute_max_allowed = 10
|
|
||||||
self.list_fp_imgs_current = []
|
|
||||||
self.current_timestamp = None
|
|
||||||
self.recycle_img1 = False
|
|
||||||
self.recycle_img2 = False
|
|
||||||
self.multi_idx_current = -1
|
|
||||||
self.list_imgs_shown_last = 5*[self.fp_img_empty]
|
|
||||||
self.list_all_segments = []
|
|
||||||
self.dp_session = ""
|
|
||||||
self.user_id = None
|
|
||||||
self.block_transition = False
|
|
||||||
|
|
||||||
|
|
||||||
def init_save_dir(self):
|
def init_save_dir(self):
|
||||||
|
r"""
|
||||||
|
Initializes the directory where stuff is being saved.
|
||||||
|
You can specify this directory in a ".env" file in your latentblending root, setting
|
||||||
|
DIR_OUT='/path/to/saving'
|
||||||
|
"""
|
||||||
load_dotenv(find_dotenv(), verbose=False)
|
load_dotenv(find_dotenv(), verbose=False)
|
||||||
self.dp_out = os.getenv("DIR_OUT")
|
self.dp_out = os.getenv("DIR_OUT")
|
||||||
if self.dp_out is None:
|
if self.dp_out is None:
|
||||||
|
@ -97,124 +104,125 @@ class BlendingFrontend():
|
||||||
os.makedirs(self.dp_imgs, exist_ok=True)
|
os.makedirs(self.dp_imgs, exist_ok=True)
|
||||||
self.dp_movies = os.path.join(self.dp_out, "movies")
|
self.dp_movies = os.path.join(self.dp_out, "movies")
|
||||||
os.makedirs(self.dp_movies, exist_ok=True)
|
os.makedirs(self.dp_movies, exist_ok=True)
|
||||||
|
self.save_empty_image()
|
||||||
|
|
||||||
|
|
||||||
# make dummy image
|
|
||||||
def save_empty_image(self):
|
def save_empty_image(self):
|
||||||
|
r"""
|
||||||
|
Saves an empty/black dummy image.
|
||||||
|
"""
|
||||||
self.fp_img_empty = os.path.join(self.dp_imgs, 'empty.jpg')
|
self.fp_img_empty = os.path.join(self.dp_imgs, 'empty.jpg')
|
||||||
Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5)
|
Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5)
|
||||||
|
|
||||||
|
|
||||||
def randomize_seed1(self):
|
def randomize_seed1(self):
|
||||||
# Dont randomize seed if we are in a multi concat mode. we don't want to change this one otherwise the movie breaks
|
r"""
|
||||||
|
Randomizes the first seed
|
||||||
|
"""
|
||||||
seed = np.random.randint(0, 10000000)
|
seed = np.random.randint(0, 10000000)
|
||||||
self.seed1 = int(seed)
|
self.seed1 = int(seed)
|
||||||
print(f"randomize_seed1: new seed = {self.seed1}")
|
print(f"randomize_seed1: new seed = {self.seed1}")
|
||||||
return seed
|
return seed
|
||||||
|
|
||||||
def randomize_seed2(self):
|
def randomize_seed2(self):
|
||||||
|
r"""
|
||||||
|
Randomizes the second seed
|
||||||
|
"""
|
||||||
seed = np.random.randint(0, 10000000)
|
seed = np.random.randint(0, 10000000)
|
||||||
self.seed2 = int(seed)
|
self.seed2 = int(seed)
|
||||||
print(f"randomize_seed2: new seed = {self.seed2}")
|
print(f"randomize_seed2: new seed = {self.seed2}")
|
||||||
return seed
|
return seed
|
||||||
|
|
||||||
|
def setup_lb(self, list_ui_vals):
|
||||||
def setup_lb(self, list_ui_elem):
|
r"""
|
||||||
|
Sets all parameters from the UI. Since gradio does not support to pass dictionaries,
|
||||||
|
we have to instead pass keys (list_ui_keys, global) and values (list_ui_vals)
|
||||||
|
"""
|
||||||
# Collect latent blending variables
|
# Collect latent blending variables
|
||||||
self.state_current = self.get_state_dict()
|
self.lb.set_width(list_ui_vals[list_ui_keys.index('width')])
|
||||||
self.lb.set_width(list_ui_elem[list_ui_keys.index('width')])
|
self.lb.set_height(list_ui_vals[list_ui_keys.index('height')])
|
||||||
self.lb.set_height(list_ui_elem[list_ui_keys.index('height')])
|
self.lb.set_prompt1(list_ui_vals[list_ui_keys.index('prompt1')])
|
||||||
self.lb.set_prompt1(list_ui_elem[list_ui_keys.index('prompt1')])
|
self.lb.set_prompt2(list_ui_vals[list_ui_keys.index('prompt2')])
|
||||||
self.lb.set_prompt2(list_ui_elem[list_ui_keys.index('prompt2')])
|
self.lb.set_negative_prompt(list_ui_vals[list_ui_keys.index('negative_prompt')])
|
||||||
self.lb.set_negative_prompt(list_ui_elem[list_ui_keys.index('negative_prompt')])
|
self.lb.guidance_scale = list_ui_vals[list_ui_keys.index('guidance_scale')]
|
||||||
self.lb.guidance_scale = list_ui_elem[list_ui_keys.index('guidance_scale')]
|
self.lb.guidance_scale_mid_damper = list_ui_vals[list_ui_keys.index('guidance_scale_mid_damper')]
|
||||||
self.lb.guidance_scale_mid_damper = list_ui_elem[list_ui_keys.index('guidance_scale_mid_damper')]
|
self.t_compute_max_allowed = list_ui_vals[list_ui_keys.index('duration_compute')]
|
||||||
self.t_compute_max_allowed = list_ui_elem[list_ui_keys.index('duration_compute')]
|
self.lb.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
|
||||||
self.lb.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
|
self.lb.sdh.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
|
||||||
self.lb.sdh.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
|
self.duration_video = list_ui_vals[list_ui_keys.index('duration_video')]
|
||||||
self.duration_video = list_ui_elem[list_ui_keys.index('duration_video')]
|
self.lb.seed1 = list_ui_vals[list_ui_keys.index('seed1')]
|
||||||
self.lb.seed1 = list_ui_elem[list_ui_keys.index('seed1')] #seed
|
self.lb.seed2 = list_ui_vals[list_ui_keys.index('seed2')]
|
||||||
self.lb.seed2 = list_ui_elem[list_ui_keys.index('seed2')]
|
self.lb.branch1_crossfeed_power = list_ui_vals[list_ui_keys.index('branch1_crossfeed_power')]
|
||||||
|
self.lb.branch1_crossfeed_range = list_ui_vals[list_ui_keys.index('branch1_crossfeed_range')]
|
||||||
|
self.lb.branch1_crossfeed_decay = list_ui_vals[list_ui_keys.index('branch1_crossfeed_decay')]
|
||||||
|
self.lb.parental_crossfeed_power = list_ui_vals[list_ui_keys.index('parental_crossfeed_power')]
|
||||||
|
self.lb.parental_crossfeed_range = list_ui_vals[list_ui_keys.index('parental_crossfeed_range')]
|
||||||
|
self.lb.parental_crossfeed_power_decay = list_ui_vals[list_ui_keys.index('parental_crossfeed_power_decay')]
|
||||||
|
self.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
|
||||||
|
self.depth_strength = list_ui_vals[list_ui_keys.index('depth_strength')]
|
||||||
|
|
||||||
self.lb.branch1_crossfeed_power = list_ui_elem[list_ui_keys.index('branch1_crossfeed_power')]
|
if len(list_ui_vals[list_ui_keys.index('user_id')]) > 1:
|
||||||
self.lb.branch1_crossfeed_range = list_ui_elem[list_ui_keys.index('branch1_crossfeed_range')]
|
self.user_id = list_ui_vals[list_ui_keys.index('user_id')]
|
||||||
self.lb.branch1_crossfeed_decay = list_ui_elem[list_ui_keys.index('branch1_crossfeed_decay')]
|
|
||||||
self.lb.parental_crossfeed_power = list_ui_elem[list_ui_keys.index('parental_crossfeed_power')]
|
|
||||||
self.lb.parental_crossfeed_range = list_ui_elem[list_ui_keys.index('parental_crossfeed_range')]
|
|
||||||
self.lb.parental_crossfeed_power_decay = list_ui_elem[list_ui_keys.index('parental_crossfeed_power_decay')]
|
|
||||||
self.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
|
|
||||||
self.depth_strength = list_ui_elem[list_ui_keys.index('depth_strength')]
|
|
||||||
|
|
||||||
if len(list_ui_elem[list_ui_keys.index('user_id')]) > 1:
|
|
||||||
self.user_id = list_ui_elem[list_ui_keys.index('user_id')]
|
|
||||||
else:
|
else:
|
||||||
# generate new user id
|
# generate new user id
|
||||||
self.user_id = ''.join((random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ') for i in range(8)))
|
self.user_id = ''.join((random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ') for i in range(8)))
|
||||||
print(f"made new user_id: {self.user_id}")
|
print(f"made new user_id: {self.user_id} at {get_time('second')}")
|
||||||
|
|
||||||
def save_latents(self, fp_latents, list_latents):
|
def save_latents(self, fp_latents, list_latents):
|
||||||
|
r"""
|
||||||
|
Saves a latent trajectory on disk, in npy format.
|
||||||
|
"""
|
||||||
list_latents_cpu = [l.cpu().numpy() for l in list_latents]
|
list_latents_cpu = [l.cpu().numpy() for l in list_latents]
|
||||||
np.save(fp_latents, list_latents_cpu)
|
np.save(fp_latents, list_latents_cpu)
|
||||||
|
|
||||||
|
|
||||||
def load_latents(self, fp_latents):
|
def load_latents(self, fp_latents):
|
||||||
|
r"""
|
||||||
|
Loads a latent trajectory from disk, converts to torch tensor.
|
||||||
|
"""
|
||||||
list_latents_cpu = np.load(fp_latents)
|
list_latents_cpu = np.load(fp_latents)
|
||||||
list_latents = [torch.from_numpy(l).to(self.lb.device) for l in list_latents_cpu]
|
list_latents = [torch.from_numpy(l).to(self.lb.device) for l in list_latents_cpu]
|
||||||
return list_latents
|
return list_latents
|
||||||
|
|
||||||
def compute_img1(self, *args):
|
def compute_img1(self, *args):
|
||||||
list_ui_elem = args
|
r"""
|
||||||
self.setup_lb(list_ui_elem)
|
Computes the first transition image and returns it for display.
|
||||||
|
Sets all other transition images and last image to empty (as they are obsolete with this operation)
|
||||||
|
"""
|
||||||
|
list_ui_vals = args
|
||||||
|
self.setup_lb(list_ui_vals)
|
||||||
fp_img1 = os.path.join(self.dp_imgs, f"img1_{self.user_id}")
|
fp_img1 = os.path.join(self.dp_imgs, f"img1_{self.user_id}")
|
||||||
img1 = Image.fromarray(self.lb.compute_latents1(return_image=True))
|
img1 = Image.fromarray(self.lb.compute_latents1(return_image=True))
|
||||||
img1.save(fp_img1+".jpg")
|
img1.save(fp_img1 + ".jpg")
|
||||||
self.save_latents(fp_img1+".npy", self.lb.tree_latents[0])
|
self.save_latents(fp_img1 + ".npy", self.lb.tree_latents[0])
|
||||||
|
|
||||||
self.recycle_img1 = True
|
self.recycle_img1 = True
|
||||||
self.recycle_img2 = False
|
self.recycle_img2 = False
|
||||||
# fixme save seeds. change filenames?
|
return [fp_img1 + ".jpg", self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
|
||||||
return [fp_img1+".jpg", self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
|
|
||||||
|
|
||||||
def compute_img2(self, *args):
|
def compute_img2(self, *args):
|
||||||
if not os.path.isfile(os.path.join(self.dp_imgs, f"img1_{self.user_id}.jpg")): # don't do anything
|
r"""
|
||||||
|
Computes the last transition image and returns it for display.
|
||||||
|
Sets all other transition images to empty (as they are obsolete with this operation)
|
||||||
|
"""
|
||||||
|
if not os.path.isfile(os.path.join(self.dp_imgs, f"img1_{self.user_id}.jpg")): # don't do anything
|
||||||
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
|
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
|
||||||
list_ui_elem = args
|
list_ui_vals = args
|
||||||
self.setup_lb(list_ui_elem)
|
self.setup_lb(list_ui_vals)
|
||||||
|
|
||||||
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
|
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
|
||||||
fp_img2 = os.path.join(self.dp_imgs, f"img2_{self.user_id}")
|
fp_img2 = os.path.join(self.dp_imgs, f"img2_{self.user_id}")
|
||||||
img2 = Image.fromarray(self.lb.compute_latents2(return_image=True))
|
img2 = Image.fromarray(self.lb.compute_latents2(return_image=True))
|
||||||
img2.save(fp_img2+'.jpg')
|
img2.save(fp_img2 + '.jpg')
|
||||||
self.save_latents(fp_img2+".npy", self.lb.tree_latents[-1])
|
self.save_latents(fp_img2 + ".npy", self.lb.tree_latents[-1])
|
||||||
self.recycle_img2 = True
|
self.recycle_img2 = True
|
||||||
self.transition_can_be_computed = True
|
|
||||||
# fixme save seeds. change filenames?
|
# fixme save seeds. change filenames?
|
||||||
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2+".jpg", self.user_id]
|
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2 + ".jpg", self.user_id]
|
||||||
|
|
||||||
|
|
||||||
def compute_transition(self, *args):
|
def compute_transition(self, *args):
|
||||||
if not self.transition_can_be_computed:
|
r"""
|
||||||
list_return = [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
|
Computes transition images and movie.
|
||||||
return list_return
|
"""
|
||||||
|
list_ui_vals = args
|
||||||
list_ui_elem = args
|
self.setup_lb(list_ui_vals)
|
||||||
self.setup_lb(list_ui_elem)
|
|
||||||
print("STARTING TRANSITION...")
|
print("STARTING TRANSITION...")
|
||||||
|
|
||||||
fixed_seeds = [self.seed1, self.seed2]
|
fixed_seeds = [self.seed1, self.seed2]
|
||||||
|
|
||||||
# Run Latent Blending
|
|
||||||
# Check if another user is blocking this... otherwise everything will become mixed.
|
|
||||||
# t_now = time.time()
|
|
||||||
# if self.block_transition:
|
|
||||||
# while True:
|
|
||||||
# time.sleep(1)
|
|
||||||
# if not self.block_transition:
|
|
||||||
# break
|
|
||||||
# if time.time() - t_now > 1000:
|
|
||||||
# return
|
|
||||||
|
|
||||||
self.block_transition = True
|
|
||||||
# Inject loaded latents (other user interference)
|
# Inject loaded latents (other user interference)
|
||||||
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
|
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
|
||||||
self.lb.tree_latents[-1] = self.load_latents(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"))
|
self.lb.tree_latents[-1] = self.load_latents(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"))
|
||||||
|
@ -224,24 +232,23 @@ class BlendingFrontend():
|
||||||
num_inference_steps=self.num_inference_steps,
|
num_inference_steps=self.num_inference_steps,
|
||||||
depth_strength=self.depth_strength,
|
depth_strength=self.depth_strength,
|
||||||
t_compute_max_allowed=self.t_compute_max_allowed,
|
t_compute_max_allowed=self.t_compute_max_allowed,
|
||||||
fixed_seeds=fixed_seeds
|
fixed_seeds=fixed_seeds)
|
||||||
)
|
print(f"Latent Blending pass finished ({get_time('second')}). Resulted in {len(imgs_transition)} images")
|
||||||
print(f"Latent Blending pass finished. Resulted in {len(imgs_transition)} images")
|
|
||||||
|
|
||||||
# Subselect three preview images
|
# Subselect three preview images
|
||||||
idx_img_prev = np.round(np.linspace(0, len(imgs_transition)-1, 5)[1:-1]).astype(np.int32)
|
idx_img_prev = np.round(np.linspace(0, len(imgs_transition) - 1, 5)[1:-1]).astype(np.int32)
|
||||||
|
|
||||||
list_imgs_preview = []
|
list_imgs_preview = []
|
||||||
for j in idx_img_prev:
|
for j in idx_img_prev:
|
||||||
list_imgs_preview.append(Image.fromarray(imgs_transition[j]))
|
list_imgs_preview.append(Image.fromarray(imgs_transition[j]))
|
||||||
|
|
||||||
# Save the preview imgs as jpgs on disk so we are not sending umcompressed data around
|
# Save the preview imgs as jpgs on disk so we are not sending umcompressed data around
|
||||||
self.current_timestamp = get_time('second')
|
current_timestamp = get_time('second')
|
||||||
self.list_fp_imgs_current = []
|
self.list_fp_imgs_current = []
|
||||||
for i in range(len(list_imgs_preview)):
|
for i in range(len(list_imgs_preview)):
|
||||||
fp_img = os.path.join(self.dp_imgs, f"img_preview_{i}_{self.current_timestamp}.jpg")
|
fp_img = os.path.join(self.dp_imgs, f"img_preview_{i}_{current_timestamp}.jpg")
|
||||||
list_imgs_preview[i].save(fp_img)
|
list_imgs_preview[i].save(fp_img)
|
||||||
self.list_fp_imgs_current.append(fp_img)
|
self.list_fp_imgs_current.append(fp_img)
|
||||||
self.block_transition = False
|
|
||||||
# Insert cheap frames for the movie
|
# Insert cheap frames for the movie
|
||||||
imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration_video, self.fps)
|
imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration_video, self.fps)
|
||||||
|
|
||||||
|
@ -259,16 +266,17 @@ class BlendingFrontend():
|
||||||
list_return = self.list_fp_imgs_current + [self.fp_movie]
|
list_return = self.list_fp_imgs_current + [self.fp_movie]
|
||||||
return list_return
|
return list_return
|
||||||
|
|
||||||
|
|
||||||
def stack_forward(self, prompt2, seed2):
|
def stack_forward(self, prompt2, seed2):
|
||||||
|
r"""
|
||||||
|
Allows to generate multi-segment movies. Sets last image -> first image with all
|
||||||
|
relevant parameters.
|
||||||
|
"""
|
||||||
# Save preview images, prompts and seeds into dictionary for stacking
|
# Save preview images, prompts and seeds into dictionary for stacking
|
||||||
if len(self.list_all_segments) == 0:
|
if len(self.list_all_segments) == 0:
|
||||||
timestamp_session = get_time('second')
|
timestamp_session = get_time('second')
|
||||||
self.dp_session = os.path.join(self.dp_out, f"session_{timestamp_session}")
|
self.dp_session = os.path.join(self.dp_out, f"session_{timestamp_session}")
|
||||||
os.makedirs(self.dp_session)
|
os.makedirs(self.dp_session)
|
||||||
|
|
||||||
self.transition_can_be_computed = False
|
|
||||||
|
|
||||||
idx_segment = len(self.list_all_segments)
|
idx_segment = len(self.list_all_segments)
|
||||||
dp_segment = os.path.join(self.dp_session, f"segment_{str(idx_segment).zfill(3)}")
|
dp_segment = os.path.join(self.dp_session, f"segment_{str(idx_segment).zfill(3)}")
|
||||||
|
|
||||||
|
@ -285,13 +293,11 @@ class BlendingFrontend():
|
||||||
self.lb.swap_forward()
|
self.lb.swap_forward()
|
||||||
|
|
||||||
shutil.copyfile(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"), os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
|
shutil.copyfile(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"), os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
|
||||||
|
|
||||||
|
|
||||||
fp_multi = self.multi_concat()
|
fp_multi = self.multi_concat()
|
||||||
list_out = [fp_multi]
|
list_out = [fp_multi]
|
||||||
|
|
||||||
list_out.extend([os.path.join(self.dp_imgs, f"img2_{self.user_id}.jpg")])
|
list_out.extend([os.path.join(self.dp_imgs, f"img2_{self.user_id}.jpg")])
|
||||||
list_out.extend([self.fp_img_empty]*4)
|
list_out.extend([self.fp_img_empty] * 4)
|
||||||
list_out.append(gr.update(interactive=False, value=prompt2))
|
list_out.append(gr.update(interactive=False, value=prompt2))
|
||||||
list_out.append(gr.update(interactive=False, value=seed2))
|
list_out.append(gr.update(interactive=False, value=seed2))
|
||||||
list_out.append("")
|
list_out.append("")
|
||||||
|
@ -299,16 +305,20 @@ class BlendingFrontend():
|
||||||
print(f"stack_forward: fp_multi {fp_multi}")
|
print(f"stack_forward: fp_multi {fp_multi}")
|
||||||
return list_out
|
return list_out
|
||||||
|
|
||||||
|
|
||||||
def multi_concat(self):
|
def multi_concat(self):
|
||||||
|
r"""
|
||||||
|
Concatentates all stacked segments into one long movie.
|
||||||
|
"""
|
||||||
list_fp_movies = self.get_fp_video_all()
|
list_fp_movies = self.get_fp_video_all()
|
||||||
# Concatenate movies and save
|
# Concatenate movies and save
|
||||||
fp_final = os.path.join(self.dp_session, f"concat_{self.user_id}.mp4")
|
fp_final = os.path.join(self.dp_session, f"concat_{self.user_id}.mp4")
|
||||||
concatenate_movies(fp_final, list_fp_movies)
|
concatenate_movies(fp_final, list_fp_movies)
|
||||||
return fp_final
|
return fp_final
|
||||||
|
|
||||||
|
|
||||||
def get_fp_video_all(self):
|
def get_fp_video_all(self):
|
||||||
|
r"""
|
||||||
|
Collects all stacked movie segments.
|
||||||
|
"""
|
||||||
list_all = os.listdir(self.dp_movies)
|
list_all = os.listdir(self.dp_movies)
|
||||||
str_beg = f"movie_{self.user_id}_"
|
str_beg = f"movie_{self.user_id}_"
|
||||||
list_user = [l for l in list_all if str_beg in l]
|
list_user = [l for l in list_all if str_beg in l]
|
||||||
|
@ -316,8 +326,10 @@ class BlendingFrontend():
|
||||||
list_user = [os.path.join(self.dp_movies, l) for l in list_user]
|
list_user = [os.path.join(self.dp_movies, l) for l in list_user]
|
||||||
return list_user
|
return list_user
|
||||||
|
|
||||||
|
|
||||||
def get_fp_video_next(self):
|
def get_fp_video_next(self):
|
||||||
|
r"""
|
||||||
|
Gets the filepath of the next movie segment.
|
||||||
|
"""
|
||||||
list_videos = self.get_fp_video_all()
|
list_videos = self.get_fp_video_all()
|
||||||
if len(list_videos) == 0:
|
if len(list_videos) == 0:
|
||||||
idx_next = 0
|
idx_next = 0
|
||||||
|
@ -327,26 +339,16 @@ class BlendingFrontend():
|
||||||
return fp_video_next
|
return fp_video_next
|
||||||
|
|
||||||
def get_fp_video_last(self):
|
def get_fp_video_last(self):
|
||||||
|
r"""
|
||||||
|
Gets the current video that was saved.
|
||||||
|
"""
|
||||||
fp_video_last = os.path.join(self.dp_movies, f"last_{self.user_id}.mp4")
|
fp_video_last = os.path.join(self.dp_movies, f"last_{self.user_id}.mp4")
|
||||||
return fp_video_last
|
return fp_video_last
|
||||||
|
|
||||||
|
|
||||||
def get_state_dict(self):
|
|
||||||
state_dict = {}
|
|
||||||
grab_vars = ['prompt1', 'prompt2', 'seed1', 'seed2', 'height', 'width',
|
|
||||||
'num_inference_steps', 'depth_strength', 'guidance_scale',
|
|
||||||
'guidance_scale_mid_damper', 'mid_compression_scaler']
|
|
||||||
|
|
||||||
for v in grab_vars:
|
|
||||||
state_dict[v] = getattr(self, v)
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
|
||||||
# fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
|
# fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.ckpt")
|
||||||
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
|
|
||||||
bf = BlendingFrontend(StableDiffusionHolder(fp_ckpt))
|
bf = BlendingFrontend(StableDiffusionHolder(fp_ckpt))
|
||||||
# self = BlendingFrontend(None)
|
# self = BlendingFrontend(None)
|
||||||
|
|
||||||
|
@ -391,7 +393,6 @@ if __name__ == "__main__":
|
||||||
depth_strength = gr.Slider(0.01, 0.99, bf.depth_strength, step=0.01, label='depth_strength', interactive=True)
|
depth_strength = gr.Slider(0.01, 0.99, bf.depth_strength, step=0.01, label='depth_strength', interactive=True)
|
||||||
guidance_scale_mid_damper = gr.Slider(0.01, 2.0, bf.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
|
guidance_scale_mid_damper = gr.Slider(0.01, 2.0, bf.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
|
||||||
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
b_compute1 = gr.Button('compute first image', variant='primary')
|
b_compute1 = gr.Button('compute first image', variant='primary')
|
||||||
b_compute_transition = gr.Button('compute transition', variant='primary')
|
b_compute_transition = gr.Button('compute transition', variant='primary')
|
||||||
|
@ -405,11 +406,10 @@ if __name__ == "__main__":
|
||||||
img5 = gr.Image(label="5/5")
|
img5 = gr.Image(label="5/5")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
vid_single = gr.Video(label="single trans")
|
vid_single = gr.Video(label="current single trans")
|
||||||
vid_multi = gr.Video(label="multi trans")
|
vid_multi = gr.Video(label="concatented multi trans")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
# b_restart = gr.Button("RESTART EVERYTHING")
|
|
||||||
b_stackforward = gr.Button('append last movie segment (left) to multi movie (right)', variant='primary')
|
b_stackforward = gr.Button('append last movie segment (left) to multi movie (right)', variant='primary')
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -437,8 +437,7 @@ if __name__ == "__main__":
|
||||||
- parental_crossfeed_power_decay: Similar to branch1_crossfeed_decay, however applied for the images withinin the transition.
|
- parental_crossfeed_power_decay: Similar to branch1_crossfeed_decay, however applied for the images withinin the transition.
|
||||||
- depth_strength: Determines when the blending process will begin in terms of diffusion steps. Low values more inventive but can cause motion.
|
- depth_strength: Determines when the blending process will begin in terms of diffusion steps. Low values more inventive but can cause motion.
|
||||||
- guidance_scale_mid_damper: Decreases the guidance scale in the middle of a transition.
|
- guidance_scale_mid_damper: Decreases the guidance scale in the middle of a transition.
|
||||||
"""
|
""")
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
user_id = gr.Textbox(label="user id", interactive=False)
|
user_id = gr.Textbox(label="user id", interactive=False)
|
||||||
|
@ -471,24 +470,23 @@ if __name__ == "__main__":
|
||||||
dict_ui_elem["user_id"] = user_id
|
dict_ui_elem["user_id"] = user_id
|
||||||
|
|
||||||
# Convert to list, as gradio doesn't seem to accept dicts
|
# Convert to list, as gradio doesn't seem to accept dicts
|
||||||
list_ui_elem = []
|
list_ui_vals = []
|
||||||
list_ui_keys = []
|
list_ui_keys = []
|
||||||
for k in dict_ui_elem.keys():
|
for k in dict_ui_elem.keys():
|
||||||
list_ui_elem.append(dict_ui_elem[k])
|
list_ui_vals.append(dict_ui_elem[k])
|
||||||
list_ui_keys.append(k)
|
list_ui_keys.append(k)
|
||||||
bf.list_ui_keys = list_ui_keys
|
bf.list_ui_keys = list_ui_keys
|
||||||
|
|
||||||
b_newseed1.click(bf.randomize_seed1, outputs=seed1)
|
b_newseed1.click(bf.randomize_seed1, outputs=seed1)
|
||||||
b_newseed2.click(bf.randomize_seed2, outputs=seed2)
|
b_newseed2.click(bf.randomize_seed2, outputs=seed2)
|
||||||
b_compute1.click(bf.compute_img1, inputs=list_ui_elem, outputs=[img1, img2, img3, img4, img5, user_id])
|
b_compute1.click(bf.compute_img1, inputs=list_ui_vals, outputs=[img1, img2, img3, img4, img5, user_id])
|
||||||
b_compute2.click(bf.compute_img2, inputs=list_ui_elem, outputs=[img2, img3, img4, img5, user_id])
|
b_compute2.click(bf.compute_img2, inputs=list_ui_vals, outputs=[img2, img3, img4, img5, user_id])
|
||||||
b_compute_transition.click(bf.compute_transition,
|
b_compute_transition.click(bf.compute_transition,
|
||||||
inputs=list_ui_elem,
|
inputs=list_ui_vals,
|
||||||
outputs=[img2, img3, img4, vid_single])
|
outputs=[img2, img3, img4, vid_single])
|
||||||
|
|
||||||
b_stackforward.click(bf.stack_forward,
|
b_stackforward.click(bf.stack_forward,
|
||||||
inputs=[prompt2, seed2],
|
inputs=[prompt2, seed2],
|
||||||
outputs=[vid_multi, img1, img2, img3, img4, img5, prompt1, seed1, prompt2])
|
outputs=[vid_multi, img1, img2, img3, img4, img5, prompt1, seed1, prompt2])
|
||||||
|
|
||||||
|
|
||||||
demo.launch(share=bf.share, inbrowser=True, inline=False)
|
demo.launch(share=bf.share, inbrowser=True, inline=False)
|
||||||
|
|
|
@ -13,41 +13,31 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os, sys
|
import os
|
||||||
import torch
|
import torch
|
||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
|
torch.set_grad_enabled(False)
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
import time
|
import time
|
||||||
import subprocess
|
|
||||||
import warnings
|
import warnings
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
# import matplotlib.pyplot as plt
|
|
||||||
from movie_util import MovieSaver
|
from movie_util import MovieSaver
|
||||||
import datetime
|
from typing import List, Optional
|
||||||
from typing import Callable, List, Optional, Union
|
|
||||||
import inspect
|
|
||||||
from threading import Thread
|
|
||||||
torch.set_grad_enabled(False)
|
|
||||||
from contextlib import nullcontext
|
|
||||||
|
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
|
||||||
from ldm.util import instantiate_from_config
|
|
||||||
from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentInpaintDiffusion
|
from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentInpaintDiffusion
|
||||||
from stable_diffusion_holder import StableDiffusionHolder
|
|
||||||
import yaml
|
|
||||||
import lpips
|
import lpips
|
||||||
#%%
|
from utils import interpolate_spherical, interpolate_linear, add_frames_linear_interp, yml_load, yml_save
|
||||||
|
|
||||||
|
|
||||||
class LatentBlending():
|
class LatentBlending():
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
sdh: None,
|
sdh: None,
|
||||||
guidance_scale: float = 4,
|
guidance_scale: float = 4,
|
||||||
guidance_scale_mid_damper: float = 0.5,
|
guidance_scale_mid_damper: float = 0.5,
|
||||||
mid_compression_scaler: float = 1.2,
|
mid_compression_scaler: float = 1.2):
|
||||||
):
|
|
||||||
r"""
|
r"""
|
||||||
Initializes the latent blending class.
|
Initializes the latent blending class.
|
||||||
Args:
|
Args:
|
||||||
|
@ -64,9 +54,10 @@ class LatentBlending():
|
||||||
Increases the sampling density in the middle (where most changes happen). Higher value
|
Increases the sampling density in the middle (where most changes happen). Higher value
|
||||||
imply more values in the middle. However the inflection point can occur outside the middle,
|
imply more values in the middle. However the inflection point can occur outside the middle,
|
||||||
thus high values can give rough transitions. Values around 2 should be fine.
|
thus high values can give rough transitions. Values around 2 should be fine.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
assert guidance_scale_mid_damper>0 and guidance_scale_mid_damper<=1.0, f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}"
|
assert guidance_scale_mid_damper > 0 \
|
||||||
|
and guidance_scale_mid_damper <= 1.0, \
|
||||||
|
f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}"
|
||||||
|
|
||||||
self.sdh = sdh
|
self.sdh = sdh
|
||||||
self.device = self.sdh.device
|
self.device = self.sdh.device
|
||||||
|
@ -115,10 +106,8 @@ class LatentBlending():
|
||||||
self.multi_transition_img_last = None
|
self.multi_transition_img_last = None
|
||||||
self.dt_per_diff = 0
|
self.dt_per_diff = 0
|
||||||
self.spatial_mask = None
|
self.spatial_mask = None
|
||||||
|
|
||||||
self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
|
self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
|
||||||
|
|
||||||
|
|
||||||
def init_mode(self):
|
def init_mode(self):
|
||||||
r"""
|
r"""
|
||||||
Sets the operational mode. Currently supported are standard, inpainting and x4 upscaling.
|
Sets the operational mode. Currently supported are standard, inpainting and x4 upscaling.
|
||||||
|
@ -151,13 +140,12 @@ class LatentBlending():
|
||||||
Tunes the guidance scale down as a linear function of fract_mixing,
|
Tunes the guidance scale down as a linear function of fract_mixing,
|
||||||
towards 0.5 the minimum will be reached.
|
towards 0.5 the minimum will be reached.
|
||||||
"""
|
"""
|
||||||
mid_factor = 1 - np.abs(fract_mixing - 0.5)/ 0.5
|
mid_factor = 1 - np.abs(fract_mixing - 0.5) / 0.5
|
||||||
max_guidance_reduction = self.guidance_scale_base * (1-self.guidance_scale_mid_damper) - 1
|
max_guidance_reduction = self.guidance_scale_base * (1 - self.guidance_scale_mid_damper) - 1
|
||||||
guidance_scale_effective = self.guidance_scale_base - max_guidance_reduction*mid_factor
|
guidance_scale_effective = self.guidance_scale_base - max_guidance_reduction * mid_factor
|
||||||
self.guidance_scale = guidance_scale_effective
|
self.guidance_scale = guidance_scale_effective
|
||||||
self.sdh.guidance_scale = guidance_scale_effective
|
self.sdh.guidance_scale = guidance_scale_effective
|
||||||
|
|
||||||
|
|
||||||
def set_branch1_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay):
|
def set_branch1_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay):
|
||||||
r"""
|
r"""
|
||||||
Sets the crossfeed parameters for the first branch to the last branch.
|
Sets the crossfeed parameters for the first branch to the last branch.
|
||||||
|
@ -173,7 +161,6 @@ class LatentBlending():
|
||||||
self.branch1_crossfeed_range = np.clip(crossfeed_range, 0, 1)
|
self.branch1_crossfeed_range = np.clip(crossfeed_range, 0, 1)
|
||||||
self.branch1_crossfeed_decay = np.clip(crossfeed_decay, 0, 1)
|
self.branch1_crossfeed_decay = np.clip(crossfeed_decay, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
def set_parental_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay):
|
def set_parental_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay):
|
||||||
r"""
|
r"""
|
||||||
Sets the crossfeed parameters for all transition images (within the first and last branch).
|
Sets the crossfeed parameters for all transition images (within the first and last branch).
|
||||||
|
@ -189,7 +176,6 @@ class LatentBlending():
|
||||||
self.parental_crossfeed_range = np.clip(crossfeed_range, 0, 1)
|
self.parental_crossfeed_range = np.clip(crossfeed_range, 0, 1)
|
||||||
self.parental_crossfeed_power_decay = np.clip(crossfeed_decay, 0, 1)
|
self.parental_crossfeed_power_decay = np.clip(crossfeed_decay, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
def set_prompt1(self, prompt: str):
|
def set_prompt1(self, prompt: str):
|
||||||
r"""
|
r"""
|
||||||
Sets the first prompt (for the first keyframe) including text embeddings.
|
Sets the first prompt (for the first keyframe) including text embeddings.
|
||||||
|
@ -201,7 +187,6 @@ class LatentBlending():
|
||||||
self.prompt1 = prompt
|
self.prompt1 = prompt
|
||||||
self.text_embedding1 = self.get_text_embeddings(self.prompt1)
|
self.text_embedding1 = self.get_text_embeddings(self.prompt1)
|
||||||
|
|
||||||
|
|
||||||
def set_prompt2(self, prompt: str):
|
def set_prompt2(self, prompt: str):
|
||||||
r"""
|
r"""
|
||||||
Sets the second prompt (for the second keyframe) including text embeddings.
|
Sets the second prompt (for the second keyframe) including text embeddings.
|
||||||
|
@ -237,8 +222,7 @@ class LatentBlending():
|
||||||
depth_strength: Optional[float] = 0.3,
|
depth_strength: Optional[float] = 0.3,
|
||||||
t_compute_max_allowed: Optional[float] = None,
|
t_compute_max_allowed: Optional[float] = None,
|
||||||
nmb_max_branches: Optional[int] = None,
|
nmb_max_branches: Optional[int] = None,
|
||||||
fixed_seeds: Optional[List[int]] = None,
|
fixed_seeds: Optional[List[int]] = None):
|
||||||
):
|
|
||||||
r"""
|
r"""
|
||||||
Function for computing transitions.
|
Function for computing transitions.
|
||||||
Returns a list of transition images using spherical latent blending.
|
Returns a list of transition images using spherical latent blending.
|
||||||
|
@ -263,7 +247,6 @@ class LatentBlending():
|
||||||
fixed_seeds: Optional[List[int)]:
|
fixed_seeds: Optional[List[int)]:
|
||||||
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
|
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
|
||||||
Otherwise random seeds will be taken.
|
Otherwise random seeds will be taken.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Sanity checks first
|
# Sanity checks first
|
||||||
|
@ -275,7 +258,7 @@ class LatentBlending():
|
||||||
if fixed_seeds == 'randomize':
|
if fixed_seeds == 'randomize':
|
||||||
fixed_seeds = list(np.random.randint(0, 1000000, 2).astype(np.int32))
|
fixed_seeds = list(np.random.randint(0, 1000000, 2).astype(np.int32))
|
||||||
else:
|
else:
|
||||||
assert len(fixed_seeds)==2, "Supply a list with len = 2"
|
assert len(fixed_seeds) == 2, "Supply a list with len = 2"
|
||||||
|
|
||||||
self.seed1 = fixed_seeds[0]
|
self.seed1 = fixed_seeds[0]
|
||||||
self.seed2 = fixed_seeds[1]
|
self.seed2 = fixed_seeds[1]
|
||||||
|
@ -323,7 +306,6 @@ class LatentBlending():
|
||||||
|
|
||||||
return self.tree_final_imgs
|
return self.tree_final_imgs
|
||||||
|
|
||||||
|
|
||||||
def compute_latents1(self, return_image=False):
|
def compute_latents1(self, return_image=False):
|
||||||
r"""
|
r"""
|
||||||
Runs a diffusion trajectory for the first image
|
Runs a diffusion trajectory for the first image
|
||||||
|
@ -337,11 +319,10 @@ class LatentBlending():
|
||||||
latents_start = self.get_noise(self.seed1)
|
latents_start = self.get_noise(self.seed1)
|
||||||
list_latents1 = self.run_diffusion(
|
list_latents1 = self.run_diffusion(
|
||||||
list_conditionings,
|
list_conditionings,
|
||||||
latents_start = latents_start,
|
latents_start=latents_start,
|
||||||
idx_start = 0
|
idx_start=0)
|
||||||
)
|
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
self.dt_per_diff = (t1-t0) / self.num_inference_steps
|
self.dt_per_diff = (t1 - t0) / self.num_inference_steps
|
||||||
self.tree_latents[0] = list_latents1
|
self.tree_latents[0] = list_latents1
|
||||||
if return_image:
|
if return_image:
|
||||||
return self.sdh.latent2image(list_latents1[-1])
|
return self.sdh.latent2image(list_latents1[-1])
|
||||||
|
@ -361,17 +342,16 @@ class LatentBlending():
|
||||||
# Influence from branch1
|
# Influence from branch1
|
||||||
if self.branch1_crossfeed_power > 0.0:
|
if self.branch1_crossfeed_power > 0.0:
|
||||||
# Set up the mixing_coeffs
|
# Set up the mixing_coeffs
|
||||||
idx_mixing_stop = int(round(self.num_inference_steps*self.branch1_crossfeed_range))
|
idx_mixing_stop = int(round(self.num_inference_steps * self.branch1_crossfeed_range))
|
||||||
mixing_coeffs = list(np.linspace(self.branch1_crossfeed_power, self.branch1_crossfeed_power*self.branch1_crossfeed_decay, idx_mixing_stop))
|
mixing_coeffs = list(np.linspace(self.branch1_crossfeed_power, self.branch1_crossfeed_power * self.branch1_crossfeed_decay, idx_mixing_stop))
|
||||||
mixing_coeffs.extend((self.num_inference_steps-idx_mixing_stop)*[0])
|
mixing_coeffs.extend((self.num_inference_steps - idx_mixing_stop) * [0])
|
||||||
list_latents_mixing = self.tree_latents[0]
|
list_latents_mixing = self.tree_latents[0]
|
||||||
list_latents2 = self.run_diffusion(
|
list_latents2 = self.run_diffusion(
|
||||||
list_conditionings,
|
list_conditionings,
|
||||||
latents_start = latents_start,
|
latents_start=latents_start,
|
||||||
idx_start = 0,
|
idx_start=0,
|
||||||
list_latents_mixing = list_latents_mixing,
|
list_latents_mixing=list_latents_mixing,
|
||||||
mixing_coeffs = mixing_coeffs
|
mixing_coeffs=mixing_coeffs)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
list_latents2 = self.run_diffusion(list_conditionings, latents_start)
|
list_latents2 = self.run_diffusion(list_conditionings, latents_start)
|
||||||
self.tree_latents[-1] = list_latents2
|
self.tree_latents[-1] = list_latents2
|
||||||
|
@ -381,7 +361,6 @@ class LatentBlending():
|
||||||
else:
|
else:
|
||||||
return list_latents2
|
return list_latents2
|
||||||
|
|
||||||
|
|
||||||
def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection):
|
def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection):
|
||||||
r"""
|
r"""
|
||||||
Runs a diffusion trajectory, using the latents from the respective parents
|
Runs a diffusion trajectory, using the latents from the respective parents
|
||||||
|
@ -409,22 +388,19 @@ class LatentBlending():
|
||||||
latents_parental = interpolate_spherical(latents_p1, latents_p2, fract_mixing_parental)
|
latents_parental = interpolate_spherical(latents_p1, latents_p2, fract_mixing_parental)
|
||||||
list_latents_parental_mix.append(latents_parental)
|
list_latents_parental_mix.append(latents_parental)
|
||||||
|
|
||||||
idx_mixing_stop = int(round(self.num_inference_steps*self.parental_crossfeed_range))
|
idx_mixing_stop = int(round(self.num_inference_steps * self.parental_crossfeed_range))
|
||||||
mixing_coeffs = idx_injection*[self.parental_crossfeed_power]
|
mixing_coeffs = idx_injection * [self.parental_crossfeed_power]
|
||||||
nmb_mixing = idx_mixing_stop - idx_injection
|
nmb_mixing = idx_mixing_stop - idx_injection
|
||||||
if nmb_mixing > 0:
|
if nmb_mixing > 0:
|
||||||
mixing_coeffs.extend(list(np.linspace(self.parental_crossfeed_power, self.parental_crossfeed_power*self.parental_crossfeed_power_decay, nmb_mixing)))
|
mixing_coeffs.extend(list(np.linspace(self.parental_crossfeed_power, self.parental_crossfeed_power * self.parental_crossfeed_power_decay, nmb_mixing)))
|
||||||
mixing_coeffs.extend((self.num_inference_steps-len(mixing_coeffs))*[0])
|
mixing_coeffs.extend((self.num_inference_steps - len(mixing_coeffs)) * [0])
|
||||||
|
latents_start = list_latents_parental_mix[idx_injection - 1]
|
||||||
latents_start = list_latents_parental_mix[idx_injection-1]
|
|
||||||
list_latents = self.run_diffusion(
|
list_latents = self.run_diffusion(
|
||||||
list_conditionings,
|
list_conditionings,
|
||||||
latents_start = latents_start,
|
latents_start=latents_start,
|
||||||
idx_start = idx_injection,
|
idx_start=idx_injection,
|
||||||
list_latents_mixing = list_latents_parental_mix,
|
list_latents_mixing=list_latents_parental_mix,
|
||||||
mixing_coeffs = mixing_coeffs
|
mixing_coeffs=mixing_coeffs)
|
||||||
)
|
|
||||||
|
|
||||||
return list_latents
|
return list_latents
|
||||||
|
|
||||||
def get_time_based_branching(self, depth_strength, t_compute_max_allowed=None, nmb_max_branches=None):
|
def get_time_based_branching(self, depth_strength, t_compute_max_allowed=None, nmb_max_branches=None):
|
||||||
|
@ -445,8 +421,8 @@ class LatentBlending():
|
||||||
results. Use this if you want to have controllable results independent
|
results. Use this if you want to have controllable results independent
|
||||||
of your computer.
|
of your computer.
|
||||||
"""
|
"""
|
||||||
idx_injection_base = int(round(self.num_inference_steps*depth_strength))
|
idx_injection_base = int(round(self.num_inference_steps * depth_strength))
|
||||||
list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps-1, 3)
|
list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps - 1, 3)
|
||||||
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
|
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
|
||||||
t_compute = 0
|
t_compute = 0
|
||||||
|
|
||||||
|
@ -456,20 +432,18 @@ class LatentBlending():
|
||||||
elif t_compute_max_allowed is None:
|
elif t_compute_max_allowed is None:
|
||||||
assert nmb_max_branches is not None, "Either specify t_compute_max_allowed or nmb_max_branches"
|
assert nmb_max_branches is not None, "Either specify t_compute_max_allowed or nmb_max_branches"
|
||||||
stop_criterion = "nmb_max_branches"
|
stop_criterion = "nmb_max_branches"
|
||||||
nmb_max_branches -= 2 # discounting the outer frames
|
nmb_max_branches -= 2 # Discounting the outer frames
|
||||||
else:
|
else:
|
||||||
raise ValueError("Either specify t_compute_max_allowed or nmb_max_branches")
|
raise ValueError("Either specify t_compute_max_allowed or nmb_max_branches")
|
||||||
|
|
||||||
stop_criterion_reached = False
|
stop_criterion_reached = False
|
||||||
is_first_iteration = True
|
is_first_iteration = True
|
||||||
|
|
||||||
while not stop_criterion_reached:
|
while not stop_criterion_reached:
|
||||||
list_compute_steps = self.num_inference_steps - list_idx_injection
|
list_compute_steps = self.num_inference_steps - list_idx_injection
|
||||||
list_compute_steps *= list_nmb_stems
|
list_compute_steps *= list_nmb_stems
|
||||||
t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15*np.sum(list_nmb_stems)
|
t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15 * np.sum(list_nmb_stems)
|
||||||
increase_done = False
|
increase_done = False
|
||||||
for s_idx in range(len(list_nmb_stems)-1):
|
for s_idx in range(len(list_nmb_stems) - 1):
|
||||||
if list_nmb_stems[s_idx+1] / list_nmb_stems[s_idx] >= 2:
|
if list_nmb_stems[s_idx + 1] / list_nmb_stems[s_idx] >= 2:
|
||||||
list_nmb_stems[s_idx] += 1
|
list_nmb_stems[s_idx] += 1
|
||||||
increase_done = True
|
increase_done = True
|
||||||
break
|
break
|
||||||
|
@ -501,10 +475,10 @@ class LatentBlending():
|
||||||
"""
|
"""
|
||||||
# get_lpips_similarity
|
# get_lpips_similarity
|
||||||
similarities = []
|
similarities = []
|
||||||
for i in range(len(self.tree_final_imgs)-1):
|
for i in range(len(self.tree_final_imgs) - 1):
|
||||||
similarities.append(self.get_lpips_similarity(self.tree_final_imgs[i], self.tree_final_imgs[i+1]))
|
similarities.append(self.get_lpips_similarity(self.tree_final_imgs[i], self.tree_final_imgs[i + 1]))
|
||||||
b_closest1 = np.argmax(similarities)
|
b_closest1 = np.argmax(similarities)
|
||||||
b_closest2 = b_closest1+1
|
b_closest2 = b_closest1 + 1
|
||||||
fract_closest1 = self.tree_fracts[b_closest1]
|
fract_closest1 = self.tree_fracts[b_closest1]
|
||||||
fract_closest2 = self.tree_fracts[b_closest2]
|
fract_closest2 = self.tree_fracts[b_closest2]
|
||||||
|
|
||||||
|
@ -515,23 +489,15 @@ class LatentBlending():
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
b_parent1 -= 1
|
b_parent1 -= 1
|
||||||
|
|
||||||
b_parent2 = b_closest2
|
b_parent2 = b_closest2
|
||||||
while True:
|
while True:
|
||||||
if self.tree_idx_injection[b_parent2] < idx_injection:
|
if self.tree_idx_injection[b_parent2] < idx_injection:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
b_parent2 += 1
|
b_parent2 += 1
|
||||||
|
fract_mixing = (fract_closest1 + fract_closest2) / 2
|
||||||
# print(f"\n\nb_closest: {b_closest1} {b_closest2} fract_closest1 {fract_closest1} fract_closest2 {fract_closest2}")
|
|
||||||
# print(f"b_parent: {b_parent1} {b_parent2}")
|
|
||||||
# print(f"similarities {similarities}")
|
|
||||||
# print(f"idx_injection {idx_injection} tree_idx_injection {self.tree_idx_injection}")
|
|
||||||
|
|
||||||
fract_mixing = (fract_closest1 + fract_closest2) /2
|
|
||||||
return fract_mixing, b_parent1, b_parent2
|
return fract_mixing, b_parent1, b_parent2
|
||||||
|
|
||||||
|
|
||||||
def insert_into_tree(self, fract_mixing, idx_injection, list_latents):
|
def insert_into_tree(self, fract_mixing, idx_injection, list_latents):
|
||||||
r"""
|
r"""
|
||||||
Inserts all necessary parameters into the trajectory tree.
|
Inserts all necessary parameters into the trajectory tree.
|
||||||
|
@ -543,12 +509,11 @@ class LatentBlending():
|
||||||
list_latents: list
|
list_latents: list
|
||||||
list of the latents to be inserted
|
list of the latents to be inserted
|
||||||
"""
|
"""
|
||||||
b_parent1, b_parent2 = get_closest_idx(fract_mixing, self.tree_fracts)
|
b_parent1, b_parent2 = self.get_closest_idx(fract_mixing)
|
||||||
self.tree_latents.insert(b_parent1+1, list_latents)
|
self.tree_latents.insert(b_parent1 + 1, list_latents)
|
||||||
self.tree_final_imgs.insert(b_parent1+1, self.sdh.latent2image(list_latents[-1]))
|
self.tree_final_imgs.insert(b_parent1 + 1, self.sdh.latent2image(list_latents[-1]))
|
||||||
self.tree_fracts.insert(b_parent1+1, fract_mixing)
|
self.tree_fracts.insert(b_parent1 + 1, fract_mixing)
|
||||||
self.tree_idx_injection.insert(b_parent1+1, idx_injection)
|
self.tree_idx_injection.insert(b_parent1 + 1, idx_injection)
|
||||||
|
|
||||||
|
|
||||||
def get_spatial_mask_template(self):
|
def get_spatial_mask_template(self):
|
||||||
r"""
|
r"""
|
||||||
|
@ -565,9 +530,7 @@ class LatentBlending():
|
||||||
Args:
|
Args:
|
||||||
img_mask:
|
img_mask:
|
||||||
mask image [0,1]. You can get a template using get_spatial_mask_template
|
mask image [0,1]. You can get a template using get_spatial_mask_template
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
|
shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
|
||||||
C, H, W = shape_latents
|
C, H, W = shape_latents
|
||||||
img_mask = np.asarray(img_mask)
|
img_mask = np.asarray(img_mask)
|
||||||
|
@ -577,18 +540,15 @@ class LatentBlending():
|
||||||
assert img_mask.shape[1] == W, f"Your mask needs to be of dimension {H} x {W}"
|
assert img_mask.shape[1] == W, f"Your mask needs to be of dimension {H} x {W}"
|
||||||
spatial_mask = torch.from_numpy(img_mask).to(device=self.device)
|
spatial_mask = torch.from_numpy(img_mask).to(device=self.device)
|
||||||
spatial_mask = torch.unsqueeze(spatial_mask, 0)
|
spatial_mask = torch.unsqueeze(spatial_mask, 0)
|
||||||
spatial_mask = spatial_mask.repeat((C,1,1))
|
spatial_mask = spatial_mask.repeat((C, 1, 1))
|
||||||
spatial_mask = torch.unsqueeze(spatial_mask, 0)
|
spatial_mask = torch.unsqueeze(spatial_mask, 0)
|
||||||
|
|
||||||
self.spatial_mask = spatial_mask
|
self.spatial_mask = spatial_mask
|
||||||
|
|
||||||
|
|
||||||
def get_noise(self, seed):
|
def get_noise(self, seed):
|
||||||
r"""
|
r"""
|
||||||
Helper function to get noise given seed.
|
Helper function to get noise given seed.
|
||||||
Args:
|
Args:
|
||||||
seed: int
|
seed: int
|
||||||
|
|
||||||
"""
|
"""
|
||||||
generator = torch.Generator(device=self.sdh.device).manual_seed(int(seed))
|
generator = torch.Generator(device=self.sdh.device).manual_seed(int(seed))
|
||||||
if self.mode == 'standard':
|
if self.mode == 'standard':
|
||||||
|
@ -599,21 +559,17 @@ class LatentBlending():
|
||||||
h = self.image1_lowres.size[1]
|
h = self.image1_lowres.size[1]
|
||||||
shape_latents = [self.sdh.model.channels, h, w]
|
shape_latents = [self.sdh.model.channels, h, w]
|
||||||
C, H, W = shape_latents
|
C, H, W = shape_latents
|
||||||
|
|
||||||
return torch.randn((1, C, H, W), generator=generator, device=self.sdh.device)
|
return torch.randn((1, C, H, W), generator=generator, device=self.sdh.device)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def run_diffusion(
|
def run_diffusion(
|
||||||
self,
|
self,
|
||||||
list_conditionings,
|
list_conditionings,
|
||||||
latents_start: torch.FloatTensor = None,
|
latents_start: torch.FloatTensor = None,
|
||||||
idx_start: int = 0,
|
idx_start: int = 0,
|
||||||
list_latents_mixing = None,
|
list_latents_mixing=None,
|
||||||
mixing_coeffs = 0.0,
|
mixing_coeffs=0.0,
|
||||||
return_image: Optional[bool] = False
|
return_image: Optional[bool] = False):
|
||||||
):
|
|
||||||
|
|
||||||
r"""
|
r"""
|
||||||
Wrapper function for diffusion runners.
|
Wrapper function for diffusion runners.
|
||||||
Depending on the mode, the correct one will be executed.
|
Depending on the mode, the correct one will be executed.
|
||||||
|
@ -640,14 +596,13 @@ class LatentBlending():
|
||||||
if self.mode == 'standard':
|
if self.mode == 'standard':
|
||||||
text_embeddings = list_conditionings[0]
|
text_embeddings = list_conditionings[0]
|
||||||
return self.sdh.run_diffusion_standard(
|
return self.sdh.run_diffusion_standard(
|
||||||
text_embeddings = text_embeddings,
|
text_embeddings=text_embeddings,
|
||||||
latents_start = latents_start,
|
latents_start=latents_start,
|
||||||
idx_start = idx_start,
|
idx_start=idx_start,
|
||||||
list_latents_mixing = list_latents_mixing,
|
list_latents_mixing=list_latents_mixing,
|
||||||
mixing_coeffs = mixing_coeffs,
|
mixing_coeffs=mixing_coeffs,
|
||||||
spatial_mask = self.spatial_mask,
|
spatial_mask=self.spatial_mask,
|
||||||
return_image = return_image,
|
return_image=return_image)
|
||||||
)
|
|
||||||
|
|
||||||
elif self.mode == 'upscale':
|
elif self.mode == 'upscale':
|
||||||
cond = list_conditionings[0]
|
cond = list_conditionings[0]
|
||||||
|
@ -657,11 +612,10 @@ class LatentBlending():
|
||||||
uc_full,
|
uc_full,
|
||||||
latents_start=latents_start,
|
latents_start=latents_start,
|
||||||
idx_start=idx_start,
|
idx_start=idx_start,
|
||||||
list_latents_mixing = list_latents_mixing,
|
list_latents_mixing=list_latents_mixing,
|
||||||
mixing_coeffs = mixing_coeffs,
|
mixing_coeffs=mixing_coeffs,
|
||||||
return_image=return_image)
|
return_image=return_image)
|
||||||
|
|
||||||
|
|
||||||
def run_upscaling(
|
def run_upscaling(
|
||||||
self,
|
self,
|
||||||
dp_img: str,
|
dp_img: str,
|
||||||
|
@ -669,9 +623,9 @@ class LatentBlending():
|
||||||
num_inference_steps: int = 100,
|
num_inference_steps: int = 100,
|
||||||
nmb_max_branches_highres: int = 5,
|
nmb_max_branches_highres: int = 5,
|
||||||
nmb_max_branches_lowres: int = 6,
|
nmb_max_branches_lowres: int = 6,
|
||||||
duration_single_segment = 3,
|
duration_single_segment=3,
|
||||||
fixed_seeds: Optional[List[int]] = None,
|
fps=24,
|
||||||
):
|
fixed_seeds: Optional[List[int]] = None):
|
||||||
r"""
|
r"""
|
||||||
Runs upscaling with the x4 model. Requires that you run a transition before with a low-res model and save the results using write_imgs_transition.
|
Runs upscaling with the x4 model. Requires that you run a transition before with a low-res model and save the results using write_imgs_transition.
|
||||||
|
|
||||||
|
@ -692,13 +646,14 @@ class LatentBlending():
|
||||||
Setting this number lower (e.g. 6) will decrease the compute time but not affect the results too much.
|
Setting this number lower (e.g. 6) will decrease the compute time but not affect the results too much.
|
||||||
duration_single_segment: float
|
duration_single_segment: float
|
||||||
The duration of each high-res movie segment. You will have nmb_max_branches_lowres-1 segments in total.
|
The duration of each high-res movie segment. You will have nmb_max_branches_lowres-1 segments in total.
|
||||||
|
fps: float
|
||||||
|
frames per second of movie
|
||||||
fixed_seeds: Optional[List[int)]:
|
fixed_seeds: Optional[List[int)]:
|
||||||
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
|
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
|
||||||
Otherwise random seeds will be taken.
|
Otherwise random seeds will be taken.
|
||||||
"""
|
"""
|
||||||
fp_yml = os.path.join(dp_img, "lowres.yaml")
|
fp_yml = os.path.join(dp_img, "lowres.yaml")
|
||||||
fp_movie = os.path.join(dp_img, "movie_highres.mp4")
|
fp_movie = os.path.join(dp_img, "movie_highres.mp4")
|
||||||
fps = 24
|
|
||||||
ms = MovieSaver(fp_movie, fps=fps)
|
ms = MovieSaver(fp_movie, fps=fps)
|
||||||
assert os.path.isfile(fp_yml), "lowres.yaml does not exist. did you forget run_upscaling_step1?"
|
assert os.path.isfile(fp_yml), "lowres.yaml does not exist. did you forget run_upscaling_step1?"
|
||||||
dict_stuff = yml_load(fp_yml)
|
dict_stuff = yml_load(fp_yml)
|
||||||
|
@ -707,53 +662,43 @@ class LatentBlending():
|
||||||
nmb_images_lowres = dict_stuff['nmb_images']
|
nmb_images_lowres = dict_stuff['nmb_images']
|
||||||
prompt1 = dict_stuff['prompt1']
|
prompt1 = dict_stuff['prompt1']
|
||||||
prompt2 = dict_stuff['prompt2']
|
prompt2 = dict_stuff['prompt2']
|
||||||
idx_img_lowres = np.round(np.linspace(0, nmb_images_lowres-1, nmb_max_branches_lowres)).astype(np.int32)
|
idx_img_lowres = np.round(np.linspace(0, nmb_images_lowres - 1, nmb_max_branches_lowres)).astype(np.int32)
|
||||||
imgs_lowres = []
|
imgs_lowres = []
|
||||||
for i in idx_img_lowres:
|
for i in idx_img_lowres:
|
||||||
fp_img_lowres = os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg")
|
fp_img_lowres = os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg")
|
||||||
assert os.path.isfile(fp_img_lowres), f"{fp_img_lowres} does not exist. did you forget run_upscaling_step1?"
|
assert os.path.isfile(fp_img_lowres), f"{fp_img_lowres} does not exist. did you forget run_upscaling_step1?"
|
||||||
imgs_lowres.append(Image.open(fp_img_lowres))
|
imgs_lowres.append(Image.open(fp_img_lowres))
|
||||||
|
|
||||||
|
|
||||||
# set up upscaling
|
# set up upscaling
|
||||||
text_embeddingA = self.sdh.get_text_embedding(prompt1)
|
text_embeddingA = self.sdh.get_text_embedding(prompt1)
|
||||||
text_embeddingB = self.sdh.get_text_embedding(prompt2)
|
text_embeddingB = self.sdh.get_text_embedding(prompt2)
|
||||||
|
list_fract_mixing = np.linspace(0, 1, nmb_max_branches_lowres - 1)
|
||||||
list_fract_mixing = np.linspace(0, 1, nmb_max_branches_lowres-1)
|
for i in range(nmb_max_branches_lowres - 1):
|
||||||
|
|
||||||
for i in range(nmb_max_branches_lowres-1):
|
|
||||||
print(f"Starting movie segment {i+1}/{nmb_max_branches_lowres-1}")
|
print(f"Starting movie segment {i+1}/{nmb_max_branches_lowres-1}")
|
||||||
|
|
||||||
self.text_embedding1 = interpolate_linear(text_embeddingA, text_embeddingB, list_fract_mixing[i])
|
self.text_embedding1 = interpolate_linear(text_embeddingA, text_embeddingB, list_fract_mixing[i])
|
||||||
self.text_embedding2 = interpolate_linear(text_embeddingA, text_embeddingB, 1-list_fract_mixing[i])
|
self.text_embedding2 = interpolate_linear(text_embeddingA, text_embeddingB, 1 - list_fract_mixing[i])
|
||||||
|
if i == 0:
|
||||||
if i==0:
|
|
||||||
recycle_img1 = False
|
recycle_img1 = False
|
||||||
else:
|
else:
|
||||||
self.swap_forward()
|
self.swap_forward()
|
||||||
recycle_img1 = True
|
recycle_img1 = True
|
||||||
|
|
||||||
self.set_image1(imgs_lowres[i])
|
self.set_image1(imgs_lowres[i])
|
||||||
self.set_image2(imgs_lowres[i+1])
|
self.set_image2(imgs_lowres[i + 1])
|
||||||
|
|
||||||
list_imgs = self.run_transition(
|
list_imgs = self.run_transition(
|
||||||
recycle_img1 = recycle_img1,
|
recycle_img1=recycle_img1,
|
||||||
recycle_img2 = False,
|
recycle_img2=False,
|
||||||
num_inference_steps = num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
depth_strength = depth_strength,
|
depth_strength=depth_strength,
|
||||||
nmb_max_branches = nmb_max_branches_highres,
|
nmb_max_branches=nmb_max_branches_highres)
|
||||||
)
|
|
||||||
|
|
||||||
list_imgs_interp = add_frames_linear_interp(list_imgs, fps, duration_single_segment)
|
list_imgs_interp = add_frames_linear_interp(list_imgs, fps, duration_single_segment)
|
||||||
|
|
||||||
# Save movie frame
|
# Save movie frame
|
||||||
for img in list_imgs_interp:
|
for img in list_imgs_interp:
|
||||||
ms.write_frame(img)
|
ms.write_frame(img)
|
||||||
|
|
||||||
ms.finalize()
|
ms.finalize()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_mixed_conditioning(self, fract_mixing):
|
def get_mixed_conditioning(self, fract_mixing):
|
||||||
if self.mode == 'standard':
|
if self.mode == 'standard':
|
||||||
|
@ -776,8 +721,7 @@ class LatentBlending():
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_text_embeddings(
|
def get_text_embeddings(
|
||||||
self,
|
self,
|
||||||
prompt: str
|
prompt: str):
|
||||||
):
|
|
||||||
r"""
|
r"""
|
||||||
Computes the text embeddings provided a string with a prompts.
|
Computes the text embeddings provided a string with a prompts.
|
||||||
Adapted from stable diffusion repo
|
Adapted from stable diffusion repo
|
||||||
|
@ -785,10 +729,8 @@ class LatentBlending():
|
||||||
prompt: str
|
prompt: str
|
||||||
ABC trending on artstation painted by Old Greg.
|
ABC trending on artstation painted by Old Greg.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.sdh.get_text_embedding(prompt)
|
return self.sdh.get_text_embedding(prompt)
|
||||||
|
|
||||||
|
|
||||||
def write_imgs_transition(self, dp_img):
|
def write_imgs_transition(self, dp_img):
|
||||||
r"""
|
r"""
|
||||||
Writes the transition images into the folder dp_img.
|
Writes the transition images into the folder dp_img.
|
||||||
|
@ -802,7 +744,6 @@ class LatentBlending():
|
||||||
for i, img in enumerate(imgs_transition):
|
for i, img in enumerate(imgs_transition):
|
||||||
img_leaf = Image.fromarray(img)
|
img_leaf = Image.fromarray(img)
|
||||||
img_leaf.save(os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg"))
|
img_leaf.save(os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg"))
|
||||||
|
|
||||||
fp_yml = os.path.join(dp_img, "lowres.yaml")
|
fp_yml = os.path.join(dp_img, "lowres.yaml")
|
||||||
self.save_statedict(fp_yml)
|
self.save_statedict(fp_yml)
|
||||||
|
|
||||||
|
@ -817,7 +758,6 @@ class LatentBlending():
|
||||||
duration of the movie in seonds
|
duration of the movie in seonds
|
||||||
fps: int
|
fps: int
|
||||||
fps of the movie
|
fps of the movie
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Let's get more cheap frames via linear interpolation (duration_transition*fps frames)
|
# Let's get more cheap frames via linear interpolation (duration_transition*fps frames)
|
||||||
|
@ -831,8 +771,6 @@ class LatentBlending():
|
||||||
ms.write_frame(img)
|
ms.write_frame(img)
|
||||||
ms.finalize()
|
ms.finalize()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def save_statedict(self, fp_yml):
|
def save_statedict(self, fp_yml):
|
||||||
# Dump everything relevant into yaml
|
# Dump everything relevant into yaml
|
||||||
imgs_transition = self.tree_final_imgs
|
imgs_transition = self.tree_final_imgs
|
||||||
|
@ -857,9 +795,8 @@ class LatentBlending():
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
state_dict[v] = getattr(self, v)
|
state_dict[v] = getattr(self, v)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
def randomize_seed(self):
|
def randomize_seed(self):
|
||||||
|
@ -892,7 +829,6 @@ class LatentBlending():
|
||||||
self.height = height
|
self.height = height
|
||||||
self.sdh.height = height
|
self.sdh.height = height
|
||||||
|
|
||||||
|
|
||||||
def swap_forward(self):
|
def swap_forward(self):
|
||||||
r"""
|
r"""
|
||||||
Moves over keyframe two -> keyframe one. Useful for making a sequence of transitions
|
Moves over keyframe two -> keyframe one. Useful for making a sequence of transitions
|
||||||
|
@ -900,15 +836,12 @@ class LatentBlending():
|
||||||
"""
|
"""
|
||||||
# Move over all latents
|
# Move over all latents
|
||||||
self.tree_latents[0] = self.tree_latents[-1]
|
self.tree_latents[0] = self.tree_latents[-1]
|
||||||
|
|
||||||
# Move over prompts and text embeddings
|
# Move over prompts and text embeddings
|
||||||
self.prompt1 = self.prompt2
|
self.prompt1 = self.prompt2
|
||||||
self.text_embedding1 = self.text_embedding2
|
self.text_embedding1 = self.text_embedding2
|
||||||
|
|
||||||
# Final cleanup for extra sanity
|
# Final cleanup for extra sanity
|
||||||
self.tree_final_imgs = []
|
self.tree_final_imgs = []
|
||||||
|
|
||||||
|
|
||||||
def get_lpips_similarity(self, imgA, imgB):
|
def get_lpips_similarity(self, imgA, imgB):
|
||||||
r"""
|
r"""
|
||||||
Computes the image similarity between two images imgA and imgB.
|
Computes the image similarity between two images imgA and imgB.
|
||||||
|
@ -916,328 +849,36 @@ class LatentBlending():
|
||||||
High values indicate low similarity.
|
High values indicate low similarity.
|
||||||
"""
|
"""
|
||||||
tensorA = torch.from_numpy(imgA).float().cuda(self.device)
|
tensorA = torch.from_numpy(imgA).float().cuda(self.device)
|
||||||
tensorA = 2*tensorA/255.0 - 1
|
tensorA = 2 * tensorA / 255.0 - 1
|
||||||
tensorA = tensorA.permute([2,0,1]).unsqueeze(0)
|
tensorA = tensorA.permute([2, 0, 1]).unsqueeze(0)
|
||||||
|
|
||||||
tensorB = torch.from_numpy(imgB).float().cuda(self.device)
|
tensorB = torch.from_numpy(imgB).float().cuda(self.device)
|
||||||
tensorB = 2*tensorB/255.0 - 1
|
tensorB = 2 * tensorB / 255.0 - 1
|
||||||
tensorB = tensorB.permute([2,0,1]).unsqueeze(0)
|
tensorB = tensorB.permute([2, 0, 1]).unsqueeze(0)
|
||||||
lploss = self.lpips(tensorA, tensorB)
|
lploss = self.lpips(tensorA, tensorB)
|
||||||
lploss = float(lploss[0][0][0][0])
|
lploss = float(lploss[0][0][0][0])
|
||||||
|
|
||||||
return lploss
|
return lploss
|
||||||
|
|
||||||
|
# Auxiliary functions
|
||||||
|
def get_closest_idx(
|
||||||
|
self,
|
||||||
|
fract_mixing: float):
|
||||||
|
r"""
|
||||||
|
Helper function to retrieve the parents for any given mixing.
|
||||||
|
Example: fract_mixing = 0.4 and self.tree_fracts = [0, 0.3, 0.6, 1.0]
|
||||||
|
Will return the two closest values here, i.e. [1, 2]
|
||||||
|
"""
|
||||||
|
|
||||||
# Auxiliary functions
|
pdist = fract_mixing - np.asarray(self.tree_fracts)
|
||||||
def get_closest_idx(
|
pdist_pos = pdist.copy()
|
||||||
fract_mixing: float,
|
pdist_pos[pdist_pos < 0] = np.inf
|
||||||
list_fract_mixing_prev: List[float],
|
b_parent1 = np.argmin(pdist_pos)
|
||||||
):
|
pdist_neg = -pdist.copy()
|
||||||
r"""
|
pdist_neg[pdist_neg <= 0] = np.inf
|
||||||
Helper function to retrieve the parents for any given mixing.
|
b_parent2 = np.argmin(pdist_neg)
|
||||||
Example: fract_mixing = 0.4 and list_fract_mixing_prev = [0, 0.3, 0.6, 1.0]
|
|
||||||
Will return the two closest values from list_fract_mixing_prev, i.e. [1, 2]
|
|
||||||
"""
|
|
||||||
|
|
||||||
pdist = fract_mixing - np.asarray(list_fract_mixing_prev)
|
if b_parent1 > b_parent2:
|
||||||
pdist_pos = pdist.copy()
|
tmp = b_parent2
|
||||||
pdist_pos[pdist_pos<0] = np.inf
|
b_parent2 = b_parent1
|
||||||
b_parent1 = np.argmin(pdist_pos)
|
b_parent1 = tmp
|
||||||
pdist_neg = -pdist.copy()
|
|
||||||
pdist_neg[pdist_neg<=0] = np.inf
|
|
||||||
b_parent2= np.argmin(pdist_neg)
|
|
||||||
|
|
||||||
if b_parent1 > b_parent2:
|
return b_parent1, b_parent2
|
||||||
tmp = b_parent2
|
|
||||||
b_parent2 = b_parent1
|
|
||||||
b_parent1 = tmp
|
|
||||||
|
|
||||||
return b_parent1, b_parent2
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def interpolate_spherical(p0, p1, fract_mixing: float):
|
|
||||||
r"""
|
|
||||||
Helper function to correctly mix two random variables using spherical interpolation.
|
|
||||||
See https://en.wikipedia.org/wiki/Slerp
|
|
||||||
The function will always cast up to float64 for sake of extra 4.
|
|
||||||
Args:
|
|
||||||
p0:
|
|
||||||
First tensor for interpolation
|
|
||||||
p1:
|
|
||||||
Second tensor for interpolation
|
|
||||||
fract_mixing: float
|
|
||||||
Mixing coefficient of interval [0, 1].
|
|
||||||
0 will return in p0
|
|
||||||
1 will return in p1
|
|
||||||
0.x will return a mix between both preserving angular velocity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if p0.dtype == torch.float16:
|
|
||||||
recast_to = 'fp16'
|
|
||||||
else:
|
|
||||||
recast_to = 'fp32'
|
|
||||||
|
|
||||||
p0 = p0.double()
|
|
||||||
p1 = p1.double()
|
|
||||||
norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
|
|
||||||
epsilon = 1e-7
|
|
||||||
dot = torch.sum(p0 * p1) / norm
|
|
||||||
dot = dot.clamp(-1+epsilon, 1-epsilon)
|
|
||||||
|
|
||||||
theta_0 = torch.arccos(dot)
|
|
||||||
sin_theta_0 = torch.sin(theta_0)
|
|
||||||
theta_t = theta_0 * fract_mixing
|
|
||||||
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
|
|
||||||
s1 = torch.sin(theta_t) / sin_theta_0
|
|
||||||
interp = p0*s0 + p1*s1
|
|
||||||
|
|
||||||
if recast_to == 'fp16':
|
|
||||||
interp = interp.half()
|
|
||||||
elif recast_to == 'fp32':
|
|
||||||
interp = interp.float()
|
|
||||||
|
|
||||||
return interp
|
|
||||||
|
|
||||||
|
|
||||||
def interpolate_linear(p0, p1, fract_mixing):
|
|
||||||
r"""
|
|
||||||
Helper function to mix two variables using standard linear interpolation.
|
|
||||||
Args:
|
|
||||||
p0:
|
|
||||||
First tensor / np.ndarray for interpolation
|
|
||||||
p1:
|
|
||||||
Second tensor / np.ndarray for interpolation
|
|
||||||
fract_mixing: float
|
|
||||||
Mixing coefficient of interval [0, 1].
|
|
||||||
0 will return in p0
|
|
||||||
1 will return in p1
|
|
||||||
0.x will return a linear mix between both.
|
|
||||||
"""
|
|
||||||
reconvert_uint8 = False
|
|
||||||
if type(p0) is np.ndarray and p0.dtype == 'uint8':
|
|
||||||
reconvert_uint8 = True
|
|
||||||
p0 = p0.astype(np.float64)
|
|
||||||
|
|
||||||
if type(p1) is np.ndarray and p1.dtype == 'uint8':
|
|
||||||
reconvert_uint8 = True
|
|
||||||
p1 = p1.astype(np.float64)
|
|
||||||
|
|
||||||
interp = (1-fract_mixing) * p0 + fract_mixing * p1
|
|
||||||
|
|
||||||
if reconvert_uint8:
|
|
||||||
interp = np.clip(interp, 0, 255).astype(np.uint8)
|
|
||||||
|
|
||||||
return interp
|
|
||||||
|
|
||||||
|
|
||||||
def add_frames_linear_interp(
|
|
||||||
list_imgs: List[np.ndarray],
|
|
||||||
fps_target: Union[float, int] = None,
|
|
||||||
duration_target: Union[float, int] = None,
|
|
||||||
nmb_frames_target: int=None,
|
|
||||||
):
|
|
||||||
r"""
|
|
||||||
Helper function to cheaply increase the number of frames given a list of images,
|
|
||||||
by virtue of standard linear interpolation.
|
|
||||||
The number of inserted frames will be automatically adjusted so that the total of number
|
|
||||||
of frames can be fixed precisely, using a random shuffling technique.
|
|
||||||
The function allows 1:1 comparisons between transitions as videos.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
list_imgs: List[np.ndarray)
|
|
||||||
List of images, between each image new frames will be inserted via linear interpolation.
|
|
||||||
fps_target:
|
|
||||||
OptionA: specify here the desired frames per second.
|
|
||||||
duration_target:
|
|
||||||
OptionA: specify here the desired duration of the transition in seconds.
|
|
||||||
nmb_frames_target:
|
|
||||||
OptionB: directly fix the total number of frames of the output.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Sanity
|
|
||||||
if nmb_frames_target is not None and fps_target is not None:
|
|
||||||
raise ValueError("You cannot specify both fps_target and nmb_frames_target")
|
|
||||||
if fps_target is None:
|
|
||||||
assert nmb_frames_target is not None, "Either specify nmb_frames_target or nmb_frames_target"
|
|
||||||
if nmb_frames_target is None:
|
|
||||||
assert fps_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
|
|
||||||
assert duration_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
|
|
||||||
nmb_frames_target = fps_target*duration_target
|
|
||||||
|
|
||||||
# Get number of frames that are missing
|
|
||||||
nmb_frames_diff = len(list_imgs)-1
|
|
||||||
nmb_frames_missing = nmb_frames_target - nmb_frames_diff - 1
|
|
||||||
|
|
||||||
if nmb_frames_missing < 1:
|
|
||||||
return list_imgs
|
|
||||||
|
|
||||||
list_imgs_float = [img.astype(np.float32) for img in list_imgs]
|
|
||||||
# Distribute missing frames, append nmb_frames_to_insert(i) frames for each frame
|
|
||||||
mean_nmb_frames_insert = nmb_frames_missing/nmb_frames_diff
|
|
||||||
constfact = np.floor(mean_nmb_frames_insert)
|
|
||||||
remainder_x = 1-(mean_nmb_frames_insert - constfact)
|
|
||||||
|
|
||||||
nmb_iter = 0
|
|
||||||
while True:
|
|
||||||
nmb_frames_to_insert = np.random.rand(nmb_frames_diff)
|
|
||||||
nmb_frames_to_insert[nmb_frames_to_insert<=remainder_x] = 0
|
|
||||||
nmb_frames_to_insert[nmb_frames_to_insert>remainder_x] = 1
|
|
||||||
nmb_frames_to_insert += constfact
|
|
||||||
if np.sum(nmb_frames_to_insert) == nmb_frames_missing:
|
|
||||||
break
|
|
||||||
nmb_iter += 1
|
|
||||||
if nmb_iter > 100000:
|
|
||||||
print("add_frames_linear_interp: issue with inserting the right number of frames")
|
|
||||||
break
|
|
||||||
|
|
||||||
nmb_frames_to_insert = nmb_frames_to_insert.astype(np.int32)
|
|
||||||
list_imgs_interp = []
|
|
||||||
for i in range(len(list_imgs_float)-1):#, desc="STAGE linear interp"):
|
|
||||||
img0 = list_imgs_float[i]
|
|
||||||
img1 = list_imgs_float[i+1]
|
|
||||||
list_imgs_interp.append(img0.astype(np.uint8))
|
|
||||||
list_fracts_linblend = np.linspace(0, 1, nmb_frames_to_insert[i]+2)[1:-1]
|
|
||||||
for fract_linblend in list_fracts_linblend:
|
|
||||||
img_blend = interpolate_linear(img0, img1, fract_linblend).astype(np.uint8)
|
|
||||||
list_imgs_interp.append(img_blend.astype(np.uint8))
|
|
||||||
|
|
||||||
if i==len(list_imgs_float)-2:
|
|
||||||
list_imgs_interp.append(img1.astype(np.uint8))
|
|
||||||
|
|
||||||
return list_imgs_interp
|
|
||||||
|
|
||||||
|
|
||||||
def get_spacing(nmb_points: int, scaling: float):
|
|
||||||
"""
|
|
||||||
Helper function for getting nonlinear spacing between 0 and 1, symmetric around 0.5
|
|
||||||
Args:
|
|
||||||
nmb_points: int
|
|
||||||
Number of points between [0, 1]
|
|
||||||
scaling: float
|
|
||||||
Higher values will return higher sampling density around 0.5
|
|
||||||
|
|
||||||
"""
|
|
||||||
if scaling < 1.7:
|
|
||||||
return np.linspace(0, 1, nmb_points)
|
|
||||||
nmb_points_per_side = nmb_points//2 + 1
|
|
||||||
if np.mod(nmb_points, 2) != 0: # uneven case
|
|
||||||
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)
|
|
||||||
right_side = 1-left_side[::-1][1:]
|
|
||||||
else:
|
|
||||||
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)[0:-1]
|
|
||||||
right_side = 1-left_side[::-1]
|
|
||||||
all_fracts = np.hstack([left_side, right_side])
|
|
||||||
return all_fracts
|
|
||||||
|
|
||||||
|
|
||||||
def get_time(resolution=None):
|
|
||||||
"""
|
|
||||||
Helper function returning an nicely formatted time string, e.g. 221117_1620
|
|
||||||
"""
|
|
||||||
if resolution==None:
|
|
||||||
resolution="second"
|
|
||||||
if resolution == "day":
|
|
||||||
t = time.strftime('%y%m%d', time.localtime())
|
|
||||||
elif resolution == "minute":
|
|
||||||
t = time.strftime('%y%m%d_%H%M', time.localtime())
|
|
||||||
elif resolution == "second":
|
|
||||||
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
|
|
||||||
elif resolution == "millisecond":
|
|
||||||
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
|
|
||||||
t += "_"
|
|
||||||
t += str("{:03d}".format(int(int(datetime.utcnow().strftime('%f'))/1000)))
|
|
||||||
else:
|
|
||||||
raise ValueError("bad resolution provided: %s" %resolution)
|
|
||||||
return t
|
|
||||||
|
|
||||||
def compare_dicts(a, b):
|
|
||||||
"""
|
|
||||||
Compares two dictionaries a and b and returns a dictionary c, with all
|
|
||||||
keys,values that have shared keys in a and b but same values in a and b.
|
|
||||||
The values of a and b are stacked together in the output.
|
|
||||||
Example:
|
|
||||||
a = {}; a['bobo'] = 4
|
|
||||||
b = {}; b['bobo'] = 5
|
|
||||||
c = dict_compare(a,b)
|
|
||||||
c = {"bobo",[4,5]}
|
|
||||||
"""
|
|
||||||
c = {}
|
|
||||||
for key in a.keys():
|
|
||||||
if key in b.keys():
|
|
||||||
val_a = a[key]
|
|
||||||
val_b = b[key]
|
|
||||||
if val_a != val_b:
|
|
||||||
c[key] = [val_a, val_b]
|
|
||||||
return c
|
|
||||||
|
|
||||||
def yml_load(fp_yml, print_fields=False):
|
|
||||||
"""
|
|
||||||
Helper function for loading yaml files
|
|
||||||
"""
|
|
||||||
with open(fp_yml) as f:
|
|
||||||
data = yaml.load(f, Loader=yaml.loader.SafeLoader)
|
|
||||||
dict_data = dict(data)
|
|
||||||
print("load: loaded {}".format(fp_yml))
|
|
||||||
return dict_data
|
|
||||||
|
|
||||||
def yml_save(fp_yml, dict_stuff):
|
|
||||||
"""
|
|
||||||
Helper function for saving yaml files
|
|
||||||
"""
|
|
||||||
with open(fp_yml, 'w') as f:
|
|
||||||
data = yaml.dump(dict_stuff, f, sort_keys=False, default_flow_style=False)
|
|
||||||
print("yml_save: saved {}".format(fp_yml))
|
|
||||||
|
|
||||||
|
|
||||||
#%% le main
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# xxxx
|
|
||||||
|
|
||||||
#%% First let us spawn a stable diffusion holder
|
|
||||||
device = "cuda"
|
|
||||||
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
|
|
||||||
|
|
||||||
sdh = StableDiffusionHolder(fp_ckpt)
|
|
||||||
|
|
||||||
xxx
|
|
||||||
|
|
||||||
|
|
||||||
#%% Next let's set up all parameters
|
|
||||||
depth_strength = 0.3 # Specifies how deep (in terms of diffusion iterations the first branching happens)
|
|
||||||
fixed_seeds = [697164, 430214]
|
|
||||||
|
|
||||||
prompt1 = "photo of a desert and a sky"
|
|
||||||
prompt2 = "photo of a tree with a lake"
|
|
||||||
|
|
||||||
duration_transition = 12 # In seconds
|
|
||||||
fps = 30
|
|
||||||
|
|
||||||
# Spawn latent blending
|
|
||||||
self = LatentBlending(sdh)
|
|
||||||
|
|
||||||
self.set_prompt1(prompt1)
|
|
||||||
self.set_prompt2(prompt2)
|
|
||||||
|
|
||||||
# Run latent blending
|
|
||||||
self.branch1_crossfeed_power = 0.3
|
|
||||||
self.branch1_crossfeed_range = 0.4
|
|
||||||
# self.run_transition(depth_strength=depth_strength, fixed_seeds=fixed_seeds)
|
|
||||||
self.seed1=21312
|
|
||||||
img1 =self.compute_latents1(True)
|
|
||||||
#%
|
|
||||||
self.seed2=1234121
|
|
||||||
self.branch1_crossfeed_power = 0.7
|
|
||||||
self.branch1_crossfeed_range = 0.3
|
|
||||||
self.branch1_crossfeed_decay = 0.3
|
|
||||||
img2 =self.compute_latents2(True)
|
|
||||||
# Image.fromarray(np.concatenate((img1, img2), axis=1))
|
|
||||||
|
|
||||||
#%%
|
|
||||||
t0 = time.time()
|
|
||||||
self.t_compute_max_allowed = 30
|
|
||||||
self.parental_crossfeed_range = 1.0
|
|
||||||
self.parental_crossfeed_power = 0.0
|
|
||||||
self.parental_crossfeed_power_decay = 1.0
|
|
||||||
imgs_transition = self.run_transition(recycle_img1=True, recycle_img2=True)
|
|
||||||
t1 = time.time()
|
|
||||||
print(f"took: {t1-t0}s")
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# Copyright 2022 Lunar Ring. All rights reserved.
|
# Copyright 2022 Lunar Ring. All rights reserved.
|
||||||
#
|
# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
|
||||||
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
# You may obtain a copy of the License at
|
# You may obtain a copy of the License at
|
||||||
|
@ -17,10 +18,9 @@ import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import cv2
|
import cv2
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import List
|
||||||
import ffmpeg # pip install ffmpeg-python. if error with broken pipe: conda update ffmpeg
|
import ffmpeg # pip install ffmpeg-python. if error with broken pipe: conda update ffmpeg
|
||||||
|
|
||||||
#%%
|
|
||||||
|
|
||||||
class MovieSaver():
|
class MovieSaver():
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -30,10 +30,9 @@ class MovieSaver():
|
||||||
shape_hw: List[int] = None,
|
shape_hw: List[int] = None,
|
||||||
crf: int = 24,
|
crf: int = 24,
|
||||||
codec: str = 'libx264',
|
codec: str = 'libx264',
|
||||||
preset: str ='fast',
|
preset: str = 'fast',
|
||||||
pix_fmt: str = 'yuv420p',
|
pix_fmt: str = 'yuv420p',
|
||||||
silent_ffmpeg: bool = True
|
silent_ffmpeg: bool = True):
|
||||||
):
|
|
||||||
r"""
|
r"""
|
||||||
Initializes movie saver class - a human friendly ffmpeg wrapper.
|
Initializes movie saver class - a human friendly ffmpeg wrapper.
|
||||||
After you init the class, you can dump numpy arrays x into moviesaver.write_frame(x).
|
After you init the class, you can dump numpy arrays x into moviesaver.write_frame(x).
|
||||||
|
@ -92,10 +91,8 @@ class MovieSaver():
|
||||||
self.shape_hw = shape_hw
|
self.shape_hw = shape_hw
|
||||||
self.initialize()
|
self.initialize()
|
||||||
|
|
||||||
|
|
||||||
print(f"MovieSaver initialized. fps={fps} crf={crf} pix_fmt={pix_fmt} codec={codec} preset={preset}")
|
print(f"MovieSaver initialized. fps={fps} crf={crf} pix_fmt={pix_fmt} codec={codec} preset={preset}")
|
||||||
|
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
args = (
|
args = (
|
||||||
ffmpeg
|
ffmpeg
|
||||||
|
@ -112,7 +109,6 @@ class MovieSaver():
|
||||||
self.shape_hw = tuple(self.shape_hw)
|
self.shape_hw = tuple(self.shape_hw)
|
||||||
print(f"Initialization done. Movie shape: {self.shape_hw}")
|
print(f"Initialization done. Movie shape: {self.shape_hw}")
|
||||||
|
|
||||||
|
|
||||||
def write_frame(self, out_frame: np.ndarray):
|
def write_frame(self, out_frame: np.ndarray):
|
||||||
r"""
|
r"""
|
||||||
Function to dump a numpy array as frame of a movie.
|
Function to dump a numpy array as frame of a movie.
|
||||||
|
@ -123,7 +119,6 @@ class MovieSaver():
|
||||||
Dim 1: x
|
Dim 1: x
|
||||||
Dim 2: RGB
|
Dim 2: RGB
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert out_frame.dtype == np.uint8, "Convert to np.uint8 before"
|
assert out_frame.dtype == np.uint8, "Convert to np.uint8 before"
|
||||||
assert len(out_frame.shape) == 3, "out_frame needs to be three dimensional, Y X C"
|
assert len(out_frame.shape) == 3, "out_frame needs to be three dimensional, Y X C"
|
||||||
assert out_frame.shape[2] == 3, f"need three color channels, but you provided {out_frame.shape[2]}."
|
assert out_frame.shape[2] == 3, f"need three color channels, but you provided {out_frame.shape[2]}."
|
||||||
|
@ -143,7 +138,6 @@ class MovieSaver():
|
||||||
|
|
||||||
self.nmb_frames += 1
|
self.nmb_frames += 1
|
||||||
|
|
||||||
|
|
||||||
def finalize(self):
|
def finalize(self):
|
||||||
r"""
|
r"""
|
||||||
Call this function to finalize the movie. If you forget to call it your movie will be garbage.
|
Call this function to finalize the movie. If you forget to call it your movie will be garbage.
|
||||||
|
@ -157,7 +151,6 @@ class MovieSaver():
|
||||||
print(f"Movie saved, {duration}s playtime, watch here: \n{self.fp_out}")
|
print(f"Movie saved, {duration}s playtime, watch here: \n{self.fp_out}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def concatenate_movies(fp_final: str, list_fp_movies: List[str]):
|
def concatenate_movies(fp_final: str, list_fp_movies: List[str]):
|
||||||
r"""
|
r"""
|
||||||
Concatenate multiple movie segments into one long movie, using ffmpeg.
|
Concatenate multiple movie segments into one long movie, using ffmpeg.
|
||||||
|
@ -189,7 +182,6 @@ def concatenate_movies(fp_final: str, list_fp_movies: List[str]):
|
||||||
fa.write("%s\n" % item)
|
fa.write("%s\n" % item)
|
||||||
|
|
||||||
cmd = f'ffmpeg -f concat -safe 0 -i {fp_list} -c copy {fp_final}'
|
cmd = f'ffmpeg -f concat -safe 0 -i {fp_list} -c copy {fp_final}'
|
||||||
dp_movie = os.path.split(fp_final)[0]
|
|
||||||
subprocess.call(cmd, shell=True)
|
subprocess.call(cmd, shell=True)
|
||||||
os.remove(fp_list)
|
os.remove(fp_list)
|
||||||
if os.path.isfile(fp_final):
|
if os.path.isfile(fp_final):
|
||||||
|
@ -200,11 +192,12 @@ class MovieReader():
|
||||||
r"""
|
r"""
|
||||||
Class to read in a movie.
|
Class to read in a movie.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, fp_movie):
|
def __init__(self, fp_movie):
|
||||||
self.video_player_object = cv2.VideoCapture(fp_movie)
|
self.video_player_object = cv2.VideoCapture(fp_movie)
|
||||||
self.nmb_frames = int(self.video_player_object.get(cv2.CAP_PROP_FRAME_COUNT))
|
self.nmb_frames = int(self.video_player_object.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
self.fps_movie = int(self.video_player_object.get(cv2.CAP_PROP_FPS))
|
self.fps_movie = int(self.video_player_object.get(cv2.CAP_PROP_FPS))
|
||||||
self.shape = [100,100,3]
|
self.shape = [100, 100, 3]
|
||||||
self.shape_is_set = False
|
self.shape_is_set = False
|
||||||
|
|
||||||
def get_next_frame(self):
|
def get_next_frame(self):
|
||||||
|
@ -217,19 +210,18 @@ class MovieReader():
|
||||||
else:
|
else:
|
||||||
return np.zeros(self.shape)
|
return np.zeros(self.shape)
|
||||||
|
|
||||||
#%%
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
fps=2
|
fps = 2
|
||||||
list_fp_movies = []
|
list_fp_movies = []
|
||||||
for k in range(4):
|
for k in range(4):
|
||||||
fp_movie = f"/tmp/my_random_movie_{k}.mp4"
|
fp_movie = f"/tmp/my_random_movie_{k}.mp4"
|
||||||
list_fp_movies.append(fp_movie)
|
list_fp_movies.append(fp_movie)
|
||||||
ms = MovieSaver(fp_movie, fps=fps)
|
ms = MovieSaver(fp_movie, fps=fps)
|
||||||
for fn in tqdm(range(30)):
|
for fn in tqdm(range(30)):
|
||||||
img = (np.random.rand(512, 1024, 3)*255).astype(np.uint8)
|
img = (np.random.rand(512, 1024, 3) * 255).astype(np.uint8)
|
||||||
ms.write_frame(img)
|
ms.write_frame(img)
|
||||||
ms.finalize()
|
ms.finalize()
|
||||||
|
|
||||||
fp_final = "/tmp/my_concatenated_movie.mp4"
|
fp_final = "/tmp/my_concatenated_movie.mp4"
|
||||||
concatenate_movies(fp_final, list_fp_movies)
|
concatenate_movies(fp_final, list_fp_movies)
|
||||||
|
|
||||||
|
|
|
@ -13,36 +13,25 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os, sys
|
import os
|
||||||
dp_git = "/home/lugo/git/"
|
|
||||||
sys.path.append(os.path.join(dp_git,'garden4'))
|
|
||||||
sys.path.append('util')
|
|
||||||
import torch
|
import torch
|
||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
|
torch.set_grad_enabled(False)
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
import time
|
|
||||||
import subprocess
|
|
||||||
import warnings
|
import warnings
|
||||||
import torch
|
import torch
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
# import matplotlib.pyplot as plt
|
|
||||||
import torch
|
import torch
|
||||||
from movie_util import MovieSaver
|
from typing import Optional
|
||||||
import datetime
|
|
||||||
from typing import Callable, List, Optional, Union
|
|
||||||
import inspect
|
|
||||||
from threading import Thread
|
|
||||||
torch.set_grad_enabled(False)
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from einops import repeat, rearrange
|
from einops import repeat, rearrange
|
||||||
#%%
|
from utils import interpolate_spherical
|
||||||
|
|
||||||
|
|
||||||
def pad_image(input_image):
|
def pad_image(input_image):
|
||||||
|
@ -53,41 +42,11 @@ def pad_image(input_image):
|
||||||
return im_padded
|
return im_padded
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def make_batch_inpaint(
|
|
||||||
image,
|
|
||||||
mask,
|
|
||||||
txt,
|
|
||||||
device,
|
|
||||||
num_samples=1):
|
|
||||||
image = np.array(image.convert("RGB"))
|
|
||||||
image = image[None].transpose(0, 3, 1, 2)
|
|
||||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
|
||||||
|
|
||||||
mask = np.array(mask.convert("L"))
|
|
||||||
mask = mask.astype(np.float32) / 255.0
|
|
||||||
mask = mask[None, None]
|
|
||||||
mask[mask < 0.5] = 0
|
|
||||||
mask[mask >= 0.5] = 1
|
|
||||||
mask = torch.from_numpy(mask)
|
|
||||||
|
|
||||||
masked_image = image * (mask < 0.5)
|
|
||||||
|
|
||||||
batch = {
|
|
||||||
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
|
|
||||||
"txt": num_samples * [txt],
|
|
||||||
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
|
|
||||||
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
|
|
||||||
}
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
def make_batch_superres(
|
def make_batch_superres(
|
||||||
image,
|
image,
|
||||||
txt,
|
txt,
|
||||||
device,
|
device,
|
||||||
num_samples=1,
|
num_samples=1):
|
||||||
):
|
|
||||||
image = np.array(image.convert("RGB"))
|
image = np.array(image.convert("RGB"))
|
||||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||||
batch = {
|
batch = {
|
||||||
|
@ -114,7 +73,7 @@ class StableDiffusionHolder:
|
||||||
height: Optional[int] = None,
|
height: Optional[int] = None,
|
||||||
width: Optional[int] = None,
|
width: Optional[int] = None,
|
||||||
device: str = None,
|
device: str = None,
|
||||||
precision: str='autocast',
|
precision: str = 'autocast',
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Initializes the stable diffusion holder, which contains the models and sampler.
|
Initializes the stable diffusion holder, which contains the models and sampler.
|
||||||
|
@ -137,7 +96,7 @@ class StableDiffusionHolder:
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
self.init_model(fp_ckpt, fp_config)
|
self.init_model(fp_ckpt, fp_config)
|
||||||
|
|
||||||
self.f = 8 #downsampling factor, most often 8 or 16",
|
self.f = 8 # downsampling factor, most often 8 or 16"
|
||||||
self.C = 4
|
self.C = 4
|
||||||
self.ddim_eta = 0
|
self.ddim_eta = 0
|
||||||
self.num_inference_steps = num_inference_steps
|
self.num_inference_steps = num_inference_steps
|
||||||
|
@ -150,13 +109,8 @@ class StableDiffusionHolder:
|
||||||
self.height = height
|
self.height = height
|
||||||
self.width = width
|
self.width = width
|
||||||
|
|
||||||
# Inpainting inits
|
|
||||||
self.mask_empty = Image.fromarray(255*np.ones([self.width, self.height], dtype=np.uint8))
|
|
||||||
self.image_empty = Image.fromarray(np.zeros([self.width, self.height, 3], dtype=np.uint8))
|
|
||||||
|
|
||||||
self.negative_prompt = [""]
|
self.negative_prompt = [""]
|
||||||
|
|
||||||
|
|
||||||
def init_model(self, fp_ckpt, fp_config):
|
def init_model(self, fp_ckpt, fp_config):
|
||||||
r"""Loads the models and sampler.
|
r"""Loads the models and sampler.
|
||||||
"""
|
"""
|
||||||
|
@ -169,13 +123,11 @@ class StableDiffusionHolder:
|
||||||
fn_ckpt = os.path.basename(fp_ckpt)
|
fn_ckpt = os.path.basename(fp_ckpt)
|
||||||
if 'depth' in fn_ckpt:
|
if 'depth' in fn_ckpt:
|
||||||
fp_config = 'configs/v2-midas-inference.yaml'
|
fp_config = 'configs/v2-midas-inference.yaml'
|
||||||
elif 'inpain' in fn_ckpt:
|
|
||||||
fp_config = 'configs/v2-inpainting-inference.yaml'
|
|
||||||
elif 'upscaler' in fn_ckpt:
|
elif 'upscaler' in fn_ckpt:
|
||||||
fp_config = 'configs/x4-upscaling.yaml'
|
fp_config = 'configs/x4-upscaling.yaml'
|
||||||
elif '512' in fn_ckpt:
|
elif '512' in fn_ckpt:
|
||||||
fp_config = 'configs/v2-inference.yaml'
|
fp_config = 'configs/v2-inference.yaml'
|
||||||
elif '768'in fn_ckpt:
|
elif '768' in fn_ckpt:
|
||||||
fp_config = 'configs/v2-inference-v.yaml'
|
fp_config = 'configs/v2-inference-v.yaml'
|
||||||
elif 'v1-5' in fn_ckpt:
|
elif 'v1-5' in fn_ckpt:
|
||||||
fp_config = 'configs/v1-inference.yaml'
|
fp_config = 'configs/v1-inference.yaml'
|
||||||
|
@ -186,7 +138,6 @@ class StableDiffusionHolder:
|
||||||
|
|
||||||
assert os.path.isfile(fp_config), f"Your config file does not exist: {fp_config}"
|
assert os.path.isfile(fp_config), f"Your config file does not exist: {fp_config}"
|
||||||
|
|
||||||
|
|
||||||
config = OmegaConf.load(fp_config)
|
config = OmegaConf.load(fp_config)
|
||||||
|
|
||||||
self.model = instantiate_from_config(config.model)
|
self.model = instantiate_from_config(config.model)
|
||||||
|
@ -195,7 +146,6 @@ class StableDiffusionHolder:
|
||||||
self.model = self.model.to(self.device)
|
self.model = self.model.to(self.device)
|
||||||
self.sampler = DDIMSampler(self.model)
|
self.sampler = DDIMSampler(self.model)
|
||||||
|
|
||||||
|
|
||||||
def init_auto_res(self):
|
def init_auto_res(self):
|
||||||
r"""Automatically set the resolution to the one used in training.
|
r"""Automatically set the resolution to the one used in training.
|
||||||
"""
|
"""
|
||||||
|
@ -218,7 +168,6 @@ class StableDiffusionHolder:
|
||||||
if len(self.negative_prompt) > 1:
|
if len(self.negative_prompt) > 1:
|
||||||
self.negative_prompt = [self.negative_prompt[0]]
|
self.negative_prompt = [self.negative_prompt[0]]
|
||||||
|
|
||||||
|
|
||||||
def get_text_embedding(self, prompt):
|
def get_text_embedding(self, prompt):
|
||||||
c = self.model.get_learned_conditioning(prompt)
|
c = self.model.get_learned_conditioning(prompt)
|
||||||
return c
|
return c
|
||||||
|
@ -228,7 +177,6 @@ class StableDiffusionHolder:
|
||||||
r"""
|
r"""
|
||||||
Initializes the conditioning for the x4 upscaling model.
|
Initializes the conditioning for the x4 upscaling model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
image = pad_image(image) # resize to integer multiple of 32
|
image = pad_image(image) # resize to integer multiple of 32
|
||||||
w, h = image.size
|
w, h = image.size
|
||||||
noise_level = torch.Tensor(1 * [noise_level]).to(self.sampler.model.device).long()
|
noise_level = torch.Tensor(1 * [noise_level]).to(self.sampler.model.device).long()
|
||||||
|
@ -240,7 +188,6 @@ class StableDiffusionHolder:
|
||||||
# uncond cond
|
# uncond cond
|
||||||
uc_cross = self.model.get_unconditional_conditioning(1, "")
|
uc_cross = self.model.get_unconditional_conditioning(1, "")
|
||||||
uc_full = {"c_concat": [x_augment], "c_crossattn": [uc_cross], "c_adm": noise_level}
|
uc_full = {"c_concat": [x_augment], "c_crossattn": [uc_cross], "c_adm": noise_level}
|
||||||
|
|
||||||
return cond, uc_full
|
return cond, uc_full
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -249,14 +196,12 @@ class StableDiffusionHolder:
|
||||||
text_embeddings: torch.FloatTensor,
|
text_embeddings: torch.FloatTensor,
|
||||||
latents_start: torch.FloatTensor,
|
latents_start: torch.FloatTensor,
|
||||||
idx_start: int = 0,
|
idx_start: int = 0,
|
||||||
list_latents_mixing = None,
|
list_latents_mixing=None,
|
||||||
mixing_coeffs = 0.0,
|
mixing_coeffs=0.0,
|
||||||
spatial_mask = None,
|
spatial_mask=None,
|
||||||
return_image: Optional[bool] = False,
|
return_image: Optional[bool] = False):
|
||||||
):
|
|
||||||
r"""
|
r"""
|
||||||
Diffusion standard version.
|
Diffusion standard version.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text_embeddings: torch.FloatTensor
|
text_embeddings: torch.FloatTensor
|
||||||
Text embeddings used for diffusion
|
Text embeddings used for diffusion
|
||||||
|
@ -270,12 +215,10 @@ class StableDiffusionHolder:
|
||||||
experimental feature for enforcing pixels from list_latents_mixing
|
experimental feature for enforcing pixels from list_latents_mixing
|
||||||
return_image: Optional[bool]
|
return_image: Optional[bool]
|
||||||
Optionally return image directly
|
Optionally return image directly
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Asserts
|
# Asserts
|
||||||
if type(mixing_coeffs) == float:
|
if type(mixing_coeffs) == float:
|
||||||
list_mixing_coeffs = self.num_inference_steps*[mixing_coeffs]
|
list_mixing_coeffs = self.num_inference_steps * [mixing_coeffs]
|
||||||
elif type(mixing_coeffs) == list:
|
elif type(mixing_coeffs) == list:
|
||||||
assert len(mixing_coeffs) == self.num_inference_steps
|
assert len(mixing_coeffs) == self.num_inference_steps
|
||||||
list_mixing_coeffs = mixing_coeffs
|
list_mixing_coeffs = mixing_coeffs
|
||||||
|
@ -285,26 +228,19 @@ class StableDiffusionHolder:
|
||||||
if np.sum(list_mixing_coeffs) > 0:
|
if np.sum(list_mixing_coeffs) > 0:
|
||||||
assert len(list_latents_mixing) == self.num_inference_steps
|
assert len(list_latents_mixing) == self.num_inference_steps
|
||||||
|
|
||||||
|
|
||||||
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
||||||
|
|
||||||
with precision_scope("cuda"):
|
with precision_scope("cuda"):
|
||||||
with self.model.ema_scope():
|
with self.model.ema_scope():
|
||||||
if self.guidance_scale != 1.0:
|
if self.guidance_scale != 1.0:
|
||||||
uc = self.model.get_learned_conditioning(self.negative_prompt)
|
uc = self.model.get_learned_conditioning(self.negative_prompt)
|
||||||
else:
|
else:
|
||||||
uc = None
|
uc = None
|
||||||
|
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps - 1, ddim_eta=self.ddim_eta, verbose=False)
|
||||||
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
|
|
||||||
|
|
||||||
latents = latents_start.clone()
|
latents = latents_start.clone()
|
||||||
|
|
||||||
timesteps = self.sampler.ddim_timesteps
|
timesteps = self.sampler.ddim_timesteps
|
||||||
|
|
||||||
time_range = np.flip(timesteps)
|
time_range = np.flip(timesteps)
|
||||||
total_steps = timesteps.shape[0]
|
total_steps = timesteps.shape[0]
|
||||||
|
# Collect latents
|
||||||
# collect latents
|
|
||||||
list_latents_out = []
|
list_latents_out = []
|
||||||
for i, step in enumerate(time_range):
|
for i, step in enumerate(time_range):
|
||||||
# Set the right starting latents
|
# Set the right starting latents
|
||||||
|
@ -313,34 +249,30 @@ class StableDiffusionHolder:
|
||||||
continue
|
continue
|
||||||
elif i == idx_start:
|
elif i == idx_start:
|
||||||
latents = latents_start.clone()
|
latents = latents_start.clone()
|
||||||
|
# Mix latents
|
||||||
# Mix the latents.
|
if i > 0 and list_mixing_coeffs[i] > 0:
|
||||||
if i > 0 and list_mixing_coeffs[i]>0:
|
latents_mixtarget = list_latents_mixing[i - 1].clone()
|
||||||
latents_mixtarget = list_latents_mixing[i-1].clone()
|
|
||||||
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
|
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
|
||||||
|
|
||||||
if spatial_mask is not None and list_latents_mixing is not None:
|
if spatial_mask is not None and list_latents_mixing is not None:
|
||||||
latents = interpolate_spherical(latents, list_latents_mixing[i-1], 1-spatial_mask)
|
latents = interpolate_spherical(latents, list_latents_mixing[i - 1], 1 - spatial_mask)
|
||||||
# latents[:,:,-15:,:] = latents_mixtarget[:,:,-15:,:]
|
|
||||||
|
|
||||||
index = total_steps - i - 1
|
index = total_steps - i - 1
|
||||||
ts = torch.full((1,), step, device=self.device, dtype=torch.long)
|
ts = torch.full((1,), step, device=self.device, dtype=torch.long)
|
||||||
outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False,
|
outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False,
|
||||||
quantize_denoised=False, temperature=1.0,
|
quantize_denoised=False, temperature=1.0,
|
||||||
noise_dropout=0.0, score_corrector=None,
|
noise_dropout=0.0, score_corrector=None,
|
||||||
corrector_kwargs=None,
|
corrector_kwargs=None,
|
||||||
unconditional_guidance_scale=self.guidance_scale,
|
unconditional_guidance_scale=self.guidance_scale,
|
||||||
unconditional_conditioning=uc,
|
unconditional_conditioning=uc,
|
||||||
dynamic_threshold=None)
|
dynamic_threshold=None)
|
||||||
latents, pred_x0 = outs
|
latents, pred_x0 = outs
|
||||||
list_latents_out.append(latents.clone())
|
list_latents_out.append(latents.clone())
|
||||||
|
|
||||||
if return_image:
|
if return_image:
|
||||||
return self.latent2image(latents)
|
return self.latent2image(latents)
|
||||||
else:
|
else:
|
||||||
return list_latents_out
|
return list_latents_out
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def run_diffusion_upscaling(
|
def run_diffusion_upscaling(
|
||||||
self,
|
self,
|
||||||
|
@ -348,17 +280,16 @@ class StableDiffusionHolder:
|
||||||
uc_full,
|
uc_full,
|
||||||
latents_start: torch.FloatTensor,
|
latents_start: torch.FloatTensor,
|
||||||
idx_start: int = -1,
|
idx_start: int = -1,
|
||||||
list_latents_mixing = None,
|
list_latents_mixing: list = None,
|
||||||
mixing_coeffs = 0.0,
|
mixing_coeffs: float = 0.0,
|
||||||
return_image: Optional[bool] = False
|
return_image: Optional[bool] = False):
|
||||||
):
|
|
||||||
r"""
|
r"""
|
||||||
Diffusion upscaling version.
|
Diffusion upscaling version.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Asserts
|
# Asserts
|
||||||
if type(mixing_coeffs) == float:
|
if type(mixing_coeffs) == float:
|
||||||
list_mixing_coeffs = self.num_inference_steps*[mixing_coeffs]
|
list_mixing_coeffs = self.num_inference_steps * [mixing_coeffs]
|
||||||
elif type(mixing_coeffs) == list:
|
elif type(mixing_coeffs) == list:
|
||||||
assert len(mixing_coeffs) == self.num_inference_steps
|
assert len(mixing_coeffs) == self.num_inference_steps
|
||||||
list_mixing_coeffs = mixing_coeffs
|
list_mixing_coeffs = mixing_coeffs
|
||||||
|
@ -369,27 +300,20 @@ class StableDiffusionHolder:
|
||||||
assert len(list_latents_mixing) == self.num_inference_steps
|
assert len(list_latents_mixing) == self.num_inference_steps
|
||||||
|
|
||||||
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
||||||
|
|
||||||
h = uc_full['c_concat'][0].shape[2]
|
h = uc_full['c_concat'][0].shape[2]
|
||||||
w = uc_full['c_concat'][0].shape[3]
|
w = uc_full['c_concat'][0].shape[3]
|
||||||
|
|
||||||
with precision_scope("cuda"):
|
with precision_scope("cuda"):
|
||||||
with self.model.ema_scope():
|
with self.model.ema_scope():
|
||||||
|
|
||||||
shape_latents = [self.model.channels, h, w]
|
shape_latents = [self.model.channels, h, w]
|
||||||
|
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps - 1, ddim_eta=self.ddim_eta, verbose=False)
|
||||||
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
|
|
||||||
C, H, W = shape_latents
|
C, H, W = shape_latents
|
||||||
size = (1, C, H, W)
|
size = (1, C, H, W)
|
||||||
b = size[0]
|
b = size[0]
|
||||||
|
|
||||||
latents = latents_start.clone()
|
latents = latents_start.clone()
|
||||||
|
|
||||||
timesteps = self.sampler.ddim_timesteps
|
timesteps = self.sampler.ddim_timesteps
|
||||||
|
|
||||||
time_range = np.flip(timesteps)
|
time_range = np.flip(timesteps)
|
||||||
total_steps = timesteps.shape[0]
|
total_steps = timesteps.shape[0]
|
||||||
|
|
||||||
# collect latents
|
# collect latents
|
||||||
list_latents_out = []
|
list_latents_out = []
|
||||||
for i, step in enumerate(time_range):
|
for i, step in enumerate(time_range):
|
||||||
|
@ -399,132 +323,20 @@ class StableDiffusionHolder:
|
||||||
continue
|
continue
|
||||||
elif i == idx_start:
|
elif i == idx_start:
|
||||||
latents = latents_start.clone()
|
latents = latents_start.clone()
|
||||||
|
|
||||||
# Mix the latents.
|
# Mix the latents.
|
||||||
if i > 0 and list_mixing_coeffs[i]>0:
|
if i > 0 and list_mixing_coeffs[i] > 0:
|
||||||
latents_mixtarget = list_latents_mixing[i-1].clone()
|
latents_mixtarget = list_latents_mixing[i - 1].clone()
|
||||||
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
|
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
|
||||||
|
|
||||||
# print(f"diffusion iter {i}")
|
# print(f"diffusion iter {i}")
|
||||||
index = total_steps - i - 1
|
index = total_steps - i - 1
|
||||||
ts = torch.full((b,), step, device=self.device, dtype=torch.long)
|
ts = torch.full((b,), step, device=self.device, dtype=torch.long)
|
||||||
outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
|
outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
|
||||||
quantize_denoised=False, temperature=1.0,
|
quantize_denoised=False, temperature=1.0,
|
||||||
noise_dropout=0.0, score_corrector=None,
|
noise_dropout=0.0, score_corrector=None,
|
||||||
corrector_kwargs=None,
|
corrector_kwargs=None,
|
||||||
unconditional_guidance_scale=self.guidance_scale,
|
unconditional_guidance_scale=self.guidance_scale,
|
||||||
unconditional_conditioning=uc_full,
|
unconditional_conditioning=uc_full,
|
||||||
dynamic_threshold=None)
|
dynamic_threshold=None)
|
||||||
latents, pred_x0 = outs
|
|
||||||
list_latents_out.append(latents.clone())
|
|
||||||
|
|
||||||
if return_image:
|
|
||||||
return self.latent2image(latents)
|
|
||||||
else:
|
|
||||||
return list_latents_out
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def run_diffusion_inpaint(
|
|
||||||
self,
|
|
||||||
text_embeddings: torch.FloatTensor,
|
|
||||||
latents_for_injection: torch.FloatTensor = None,
|
|
||||||
idx_start: int = -1,
|
|
||||||
idx_stop: int = -1,
|
|
||||||
return_image: Optional[bool] = False
|
|
||||||
):
|
|
||||||
r"""
|
|
||||||
Runs inpaint-based diffusion. Returns a list of latents that were computed.
|
|
||||||
Adaptations allow to supply
|
|
||||||
a) starting index for diffusion
|
|
||||||
b) stopping index for diffusion
|
|
||||||
c) latent representations that are injected at the starting index
|
|
||||||
Furthermore the intermittent latents are collected and returned.
|
|
||||||
|
|
||||||
Adapted from diffusers (https://github.com/huggingface/diffusers)
|
|
||||||
Args:
|
|
||||||
text_embeddings: torch.FloatTensor
|
|
||||||
Text embeddings used for diffusion
|
|
||||||
latents_for_injection: torch.FloatTensor
|
|
||||||
Latents that are used for injection
|
|
||||||
idx_start: int
|
|
||||||
Index of the diffusion process start and where the latents_for_injection are injected
|
|
||||||
idx_stop: int
|
|
||||||
Index of the diffusion process end.
|
|
||||||
return_image: Optional[bool]
|
|
||||||
Optionally return image directly
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
if latents_for_injection is None:
|
|
||||||
do_inject_latents = False
|
|
||||||
else:
|
|
||||||
do_inject_latents = True
|
|
||||||
|
|
||||||
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
|
||||||
generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
|
|
||||||
|
|
||||||
with precision_scope("cuda"):
|
|
||||||
with self.model.ema_scope():
|
|
||||||
|
|
||||||
batch = make_batch_inpaint(self.image_source, self.mask_image, txt="willbereplaced", device=self.device, num_samples=1)
|
|
||||||
c = text_embeddings
|
|
||||||
c_cat = list()
|
|
||||||
for ck in self.model.concat_keys:
|
|
||||||
cc = batch[ck].float()
|
|
||||||
if ck != self.model.masked_image_key:
|
|
||||||
bchw = [1, 4, self.height // 8, self.width // 8]
|
|
||||||
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
|
||||||
else:
|
|
||||||
cc = self.model.get_first_stage_encoding(self.model.encode_first_stage(cc))
|
|
||||||
c_cat.append(cc)
|
|
||||||
c_cat = torch.cat(c_cat, dim=1)
|
|
||||||
|
|
||||||
# cond
|
|
||||||
cond = {"c_concat": [c_cat], "c_crossattn": [c]}
|
|
||||||
|
|
||||||
# uncond cond
|
|
||||||
uc_cross = self.model.get_unconditional_conditioning(1, "")
|
|
||||||
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
|
||||||
|
|
||||||
shape_latents = [self.model.channels, self.height // 8, self.width // 8]
|
|
||||||
|
|
||||||
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=0., verbose=False)
|
|
||||||
# sampling
|
|
||||||
C, H, W = shape_latents
|
|
||||||
size = (1, C, H, W)
|
|
||||||
|
|
||||||
device = self.model.betas.device
|
|
||||||
b = size[0]
|
|
||||||
latents = torch.randn(size, generator=generator, device=device)
|
|
||||||
|
|
||||||
timesteps = self.sampler.ddim_timesteps
|
|
||||||
|
|
||||||
time_range = np.flip(timesteps)
|
|
||||||
total_steps = timesteps.shape[0]
|
|
||||||
|
|
||||||
# collect latents
|
|
||||||
list_latents_out = []
|
|
||||||
for i, step in enumerate(time_range):
|
|
||||||
if do_inject_latents:
|
|
||||||
# Inject latent at right place
|
|
||||||
if i < idx_start:
|
|
||||||
continue
|
|
||||||
elif i == idx_start:
|
|
||||||
latents = latents_for_injection.clone()
|
|
||||||
|
|
||||||
if i == idx_stop:
|
|
||||||
return list_latents_out
|
|
||||||
|
|
||||||
index = total_steps - i - 1
|
|
||||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
|
||||||
|
|
||||||
outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
|
|
||||||
quantize_denoised=False, temperature=1.0,
|
|
||||||
noise_dropout=0.0, score_corrector=None,
|
|
||||||
corrector_kwargs=None,
|
|
||||||
unconditional_guidance_scale=self.guidance_scale,
|
|
||||||
unconditional_conditioning=uc_full,
|
|
||||||
dynamic_threshold=None)
|
|
||||||
latents, pred_x0 = outs
|
latents, pred_x0 = outs
|
||||||
list_latents_out.append(latents.clone())
|
list_latents_out.append(latents.clone())
|
||||||
|
|
||||||
|
@ -536,8 +348,7 @@ class StableDiffusionHolder:
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def latent2image(
|
def latent2image(
|
||||||
self,
|
self,
|
||||||
latents: torch.FloatTensor
|
latents: torch.FloatTensor):
|
||||||
):
|
|
||||||
r"""
|
r"""
|
||||||
Returns an image provided a latent representation from diffusion.
|
Returns an image provided a latent representation from diffusion.
|
||||||
Args:
|
Args:
|
||||||
|
@ -546,85 +357,6 @@ class StableDiffusionHolder:
|
||||||
"""
|
"""
|
||||||
x_sample = self.model.decode_first_stage(latents)
|
x_sample = self.model.decode_first_stage(latents)
|
||||||
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
x_sample = 255 * x_sample[0,:,:].permute([1,2,0]).cpu().numpy()
|
x_sample = 255 * x_sample[0, :, :].permute([1, 2, 0]).cpu().numpy()
|
||||||
image = x_sample.astype(np.uint8)
|
image = x_sample.astype(np.uint8)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def interpolate_spherical(p0, p1, fract_mixing: float):
|
|
||||||
r"""
|
|
||||||
Helper function to correctly mix two random variables using spherical interpolation.
|
|
||||||
See https://en.wikipedia.org/wiki/Slerp
|
|
||||||
The function will always cast up to float64 for sake of extra 4.
|
|
||||||
Args:
|
|
||||||
p0:
|
|
||||||
First tensor for interpolation
|
|
||||||
p1:
|
|
||||||
Second tensor for interpolation
|
|
||||||
fract_mixing: float
|
|
||||||
Mixing coefficient of interval [0, 1].
|
|
||||||
0 will return in p0
|
|
||||||
1 will return in p1
|
|
||||||
0.x will return a mix between both preserving angular velocity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if p0.dtype == torch.float16:
|
|
||||||
recast_to = 'fp16'
|
|
||||||
else:
|
|
||||||
recast_to = 'fp32'
|
|
||||||
|
|
||||||
p0 = p0.double()
|
|
||||||
p1 = p1.double()
|
|
||||||
norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
|
|
||||||
epsilon = 1e-7
|
|
||||||
dot = torch.sum(p0 * p1) / norm
|
|
||||||
dot = dot.clamp(-1+epsilon, 1-epsilon)
|
|
||||||
|
|
||||||
theta_0 = torch.arccos(dot)
|
|
||||||
sin_theta_0 = torch.sin(theta_0)
|
|
||||||
theta_t = theta_0 * fract_mixing
|
|
||||||
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
|
|
||||||
s1 = torch.sin(theta_t) / sin_theta_0
|
|
||||||
interp = p0*s0 + p1*s1
|
|
||||||
|
|
||||||
if recast_to == 'fp16':
|
|
||||||
interp = interp.half()
|
|
||||||
elif recast_to == 'fp32':
|
|
||||||
interp = interp.float()
|
|
||||||
|
|
||||||
return interp
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
num_inference_steps = 20 # Number of diffusion interations
|
|
||||||
|
|
||||||
# fp_ckpt = "../stable_diffusion_models/ckpt/768-v-ema.ckpt"
|
|
||||||
# fp_config = '../stablediffusion/configs/stable-diffusion/v2-inference-v.yaml'
|
|
||||||
|
|
||||||
# fp_ckpt= "../stable_diffusion_models/ckpt/512-inpainting-ema.ckpt"
|
|
||||||
# fp_config = '../stablediffusion/configs//stable-diffusion/v2-inpainting-inference.yaml'
|
|
||||||
|
|
||||||
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
|
|
||||||
# fp_config = 'configs/v2-inference-v.yaml'
|
|
||||||
|
|
||||||
|
|
||||||
self = StableDiffusionHolder(fp_ckpt, num_inference_steps=num_inference_steps)
|
|
||||||
|
|
||||||
xxx
|
|
||||||
|
|
||||||
#%%
|
|
||||||
self.width = 1536
|
|
||||||
self.height = 768
|
|
||||||
prompt = "360 degree equirectangular, a huge rocky hill full of pianos and keyboards, musical instruments, cinematic, masterpiece 8 k, artstation"
|
|
||||||
self.set_negative_prompt("out of frame, faces, rendering, blurry")
|
|
||||||
te = self.get_text_embedding(prompt)
|
|
||||||
|
|
||||||
img = self.run_diffusion_standard(te, return_image=True)
|
|
||||||
Image.fromarray(img).show()
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,260 @@
|
||||||
|
# Copyright 2022 Lunar Ring. All rights reserved.
|
||||||
|
# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
import numpy as np
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
import datetime
|
||||||
|
from typing import List, Union
|
||||||
|
torch.set_grad_enabled(False)
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def interpolate_spherical(p0, p1, fract_mixing: float):
|
||||||
|
r"""
|
||||||
|
Helper function to correctly mix two random variables using spherical interpolation.
|
||||||
|
See https://en.wikipedia.org/wiki/Slerp
|
||||||
|
The function will always cast up to float64 for sake of extra 4.
|
||||||
|
Args:
|
||||||
|
p0:
|
||||||
|
First tensor for interpolation
|
||||||
|
p1:
|
||||||
|
Second tensor for interpolation
|
||||||
|
fract_mixing: float
|
||||||
|
Mixing coefficient of interval [0, 1].
|
||||||
|
0 will return in p0
|
||||||
|
1 will return in p1
|
||||||
|
0.x will return a mix between both preserving angular velocity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if p0.dtype == torch.float16:
|
||||||
|
recast_to = 'fp16'
|
||||||
|
else:
|
||||||
|
recast_to = 'fp32'
|
||||||
|
|
||||||
|
p0 = p0.double()
|
||||||
|
p1 = p1.double()
|
||||||
|
norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
|
||||||
|
epsilon = 1e-7
|
||||||
|
dot = torch.sum(p0 * p1) / norm
|
||||||
|
dot = dot.clamp(-1 + epsilon, 1 - epsilon)
|
||||||
|
|
||||||
|
theta_0 = torch.arccos(dot)
|
||||||
|
sin_theta_0 = torch.sin(theta_0)
|
||||||
|
theta_t = theta_0 * fract_mixing
|
||||||
|
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
|
||||||
|
s1 = torch.sin(theta_t) / sin_theta_0
|
||||||
|
interp = p0 * s0 + p1 * s1
|
||||||
|
|
||||||
|
if recast_to == 'fp16':
|
||||||
|
interp = interp.half()
|
||||||
|
elif recast_to == 'fp32':
|
||||||
|
interp = interp.float()
|
||||||
|
|
||||||
|
return interp
|
||||||
|
|
||||||
|
|
||||||
|
def interpolate_linear(p0, p1, fract_mixing):
|
||||||
|
r"""
|
||||||
|
Helper function to mix two variables using standard linear interpolation.
|
||||||
|
Args:
|
||||||
|
p0:
|
||||||
|
First tensor / np.ndarray for interpolation
|
||||||
|
p1:
|
||||||
|
Second tensor / np.ndarray for interpolation
|
||||||
|
fract_mixing: float
|
||||||
|
Mixing coefficient of interval [0, 1].
|
||||||
|
0 will return in p0
|
||||||
|
1 will return in p1
|
||||||
|
0.x will return a linear mix between both.
|
||||||
|
"""
|
||||||
|
reconvert_uint8 = False
|
||||||
|
if type(p0) is np.ndarray and p0.dtype == 'uint8':
|
||||||
|
reconvert_uint8 = True
|
||||||
|
p0 = p0.astype(np.float64)
|
||||||
|
|
||||||
|
if type(p1) is np.ndarray and p1.dtype == 'uint8':
|
||||||
|
reconvert_uint8 = True
|
||||||
|
p1 = p1.astype(np.float64)
|
||||||
|
|
||||||
|
interp = (1 - fract_mixing) * p0 + fract_mixing * p1
|
||||||
|
|
||||||
|
if reconvert_uint8:
|
||||||
|
interp = np.clip(interp, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
|
return interp
|
||||||
|
|
||||||
|
|
||||||
|
def add_frames_linear_interp(
|
||||||
|
list_imgs: List[np.ndarray],
|
||||||
|
fps_target: Union[float, int] = None,
|
||||||
|
duration_target: Union[float, int] = None,
|
||||||
|
nmb_frames_target: int = None):
|
||||||
|
r"""
|
||||||
|
Helper function to cheaply increase the number of frames given a list of images,
|
||||||
|
by virtue of standard linear interpolation.
|
||||||
|
The number of inserted frames will be automatically adjusted so that the total of number
|
||||||
|
of frames can be fixed precisely, using a random shuffling technique.
|
||||||
|
The function allows 1:1 comparisons between transitions as videos.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
list_imgs: List[np.ndarray)
|
||||||
|
List of images, between each image new frames will be inserted via linear interpolation.
|
||||||
|
fps_target:
|
||||||
|
OptionA: specify here the desired frames per second.
|
||||||
|
duration_target:
|
||||||
|
OptionA: specify here the desired duration of the transition in seconds.
|
||||||
|
nmb_frames_target:
|
||||||
|
OptionB: directly fix the total number of frames of the output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Sanity
|
||||||
|
if nmb_frames_target is not None and fps_target is not None:
|
||||||
|
raise ValueError("You cannot specify both fps_target and nmb_frames_target")
|
||||||
|
if fps_target is None:
|
||||||
|
assert nmb_frames_target is not None, "Either specify nmb_frames_target or nmb_frames_target"
|
||||||
|
if nmb_frames_target is None:
|
||||||
|
assert fps_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
|
||||||
|
assert duration_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
|
||||||
|
nmb_frames_target = fps_target * duration_target
|
||||||
|
|
||||||
|
# Get number of frames that are missing
|
||||||
|
nmb_frames_diff = len(list_imgs) - 1
|
||||||
|
nmb_frames_missing = nmb_frames_target - nmb_frames_diff - 1
|
||||||
|
|
||||||
|
if nmb_frames_missing < 1:
|
||||||
|
return list_imgs
|
||||||
|
|
||||||
|
list_imgs_float = [img.astype(np.float32) for img in list_imgs]
|
||||||
|
# Distribute missing frames, append nmb_frames_to_insert(i) frames for each frame
|
||||||
|
mean_nmb_frames_insert = nmb_frames_missing / nmb_frames_diff
|
||||||
|
constfact = np.floor(mean_nmb_frames_insert)
|
||||||
|
remainder_x = 1 - (mean_nmb_frames_insert - constfact)
|
||||||
|
nmb_iter = 0
|
||||||
|
while True:
|
||||||
|
nmb_frames_to_insert = np.random.rand(nmb_frames_diff)
|
||||||
|
nmb_frames_to_insert[nmb_frames_to_insert <= remainder_x] = 0
|
||||||
|
nmb_frames_to_insert[nmb_frames_to_insert > remainder_x] = 1
|
||||||
|
nmb_frames_to_insert += constfact
|
||||||
|
if np.sum(nmb_frames_to_insert) == nmb_frames_missing:
|
||||||
|
break
|
||||||
|
nmb_iter += 1
|
||||||
|
if nmb_iter > 100000:
|
||||||
|
print("add_frames_linear_interp: issue with inserting the right number of frames")
|
||||||
|
break
|
||||||
|
|
||||||
|
nmb_frames_to_insert = nmb_frames_to_insert.astype(np.int32)
|
||||||
|
list_imgs_interp = []
|
||||||
|
for i in range(len(list_imgs_float) - 1):
|
||||||
|
img0 = list_imgs_float[i]
|
||||||
|
img1 = list_imgs_float[i + 1]
|
||||||
|
list_imgs_interp.append(img0.astype(np.uint8))
|
||||||
|
list_fracts_linblend = np.linspace(0, 1, nmb_frames_to_insert[i] + 2)[1:-1]
|
||||||
|
for fract_linblend in list_fracts_linblend:
|
||||||
|
img_blend = interpolate_linear(img0, img1, fract_linblend).astype(np.uint8)
|
||||||
|
list_imgs_interp.append(img_blend.astype(np.uint8))
|
||||||
|
if i == len(list_imgs_float) - 2:
|
||||||
|
list_imgs_interp.append(img1.astype(np.uint8))
|
||||||
|
|
||||||
|
return list_imgs_interp
|
||||||
|
|
||||||
|
|
||||||
|
def get_spacing(nmb_points: int, scaling: float):
|
||||||
|
"""
|
||||||
|
Helper function for getting nonlinear spacing between 0 and 1, symmetric around 0.5
|
||||||
|
Args:
|
||||||
|
nmb_points: int
|
||||||
|
Number of points between [0, 1]
|
||||||
|
scaling: float
|
||||||
|
Higher values will return higher sampling density around 0.5
|
||||||
|
"""
|
||||||
|
if scaling < 1.7:
|
||||||
|
return np.linspace(0, 1, nmb_points)
|
||||||
|
nmb_points_per_side = nmb_points // 2 + 1
|
||||||
|
if np.mod(nmb_points, 2) != 0: # Uneven case
|
||||||
|
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)
|
||||||
|
right_side = 1 - left_side[::-1][1:]
|
||||||
|
else:
|
||||||
|
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)[0:-1]
|
||||||
|
right_side = 1 - left_side[::-1]
|
||||||
|
all_fracts = np.hstack([left_side, right_side])
|
||||||
|
return all_fracts
|
||||||
|
|
||||||
|
|
||||||
|
def get_time(resolution=None):
|
||||||
|
"""
|
||||||
|
Helper function returning an nicely formatted time string, e.g. 221117_1620
|
||||||
|
"""
|
||||||
|
if resolution is None:
|
||||||
|
resolution = "second"
|
||||||
|
if resolution == "day":
|
||||||
|
t = time.strftime('%y%m%d', time.localtime())
|
||||||
|
elif resolution == "minute":
|
||||||
|
t = time.strftime('%y%m%d_%H%M', time.localtime())
|
||||||
|
elif resolution == "second":
|
||||||
|
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
|
||||||
|
elif resolution == "millisecond":
|
||||||
|
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
|
||||||
|
t += "_"
|
||||||
|
t += str("{:03d}".format(int(int(datetime.utcnow().strftime('%f')) / 1000)))
|
||||||
|
else:
|
||||||
|
raise ValueError("bad resolution provided: %s" % resolution)
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def compare_dicts(a, b):
|
||||||
|
"""
|
||||||
|
Compares two dictionaries a and b and returns a dictionary c, with all
|
||||||
|
keys,values that have shared keys in a and b but same values in a and b.
|
||||||
|
The values of a and b are stacked together in the output.
|
||||||
|
Example:
|
||||||
|
a = {}; a['bobo'] = 4
|
||||||
|
b = {}; b['bobo'] = 5
|
||||||
|
c = dict_compare(a,b)
|
||||||
|
c = {"bobo",[4,5]}
|
||||||
|
"""
|
||||||
|
c = {}
|
||||||
|
for key in a.keys():
|
||||||
|
if key in b.keys():
|
||||||
|
val_a = a[key]
|
||||||
|
val_b = b[key]
|
||||||
|
if val_a != val_b:
|
||||||
|
c[key] = [val_a, val_b]
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
def yml_load(fp_yml, print_fields=False):
|
||||||
|
"""
|
||||||
|
Helper function for loading yaml files
|
||||||
|
"""
|
||||||
|
with open(fp_yml) as f:
|
||||||
|
data = yaml.load(f, Loader=yaml.loader.SafeLoader)
|
||||||
|
dict_data = dict(data)
|
||||||
|
print("load: loaded {}".format(fp_yml))
|
||||||
|
return dict_data
|
||||||
|
|
||||||
|
|
||||||
|
def yml_save(fp_yml, dict_stuff):
|
||||||
|
"""
|
||||||
|
Helper function for saving yaml files
|
||||||
|
"""
|
||||||
|
with open(fp_yml, 'w') as f:
|
||||||
|
yaml.dump(dict_stuff, f, sort_keys=False, default_flow_style=False)
|
||||||
|
print("yml_save: saved {}".format(fp_yml))
|
Loading…
Reference in New Issue