This commit is contained in:
Johannes Stelzer 2023-02-22 10:15:03 +01:00
parent 3ed876e0ee
commit 297bb9abe6
9 changed files with 906 additions and 1322 deletions

View File

@ -13,30 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os, sys
import torch
torch.backends.cudnn.benchmark = False
import numpy as np
torch.set_grad_enabled(False)
import warnings
warnings.filterwarnings('ignore')
import warnings
import torch
from tqdm.auto import tqdm
from PIL import Image
# import matplotlib.pyplot as plt
import torch
from movie_util import MovieSaver
from typing import Callable, List, Optional, Union
from latent_blending import LatentBlending, add_frames_linear_interp
from latent_blending import LatentBlending
from stable_diffusion_holder import StableDiffusionHolder
torch.set_grad_enabled(False)
from huggingface_hub import hf_hub_download
#%% First let us spawn a stable diffusion holder
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
# %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
# fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.ckpt")
sdh = StableDiffusionHolder(fp_ckpt)
#%% Next let's set up all parameters
# %% Next let's set up all parameters
depth_strength = 0.65 # Specifies how deep (in terms of diffusion iterations the first branching happens)
t_compute_max_allowed = 15 # Determines the quality of the transition in terms of compute time you grant it
fixed_seeds = [69731932, 504430820]
@ -54,10 +46,9 @@ lb.set_prompt2(prompt2)
# Run latent blending
lb.run_transition(
depth_strength = depth_strength,
t_compute_max_allowed = t_compute_max_allowed,
fixed_seeds = fixed_seeds
)
depth_strength=depth_strength,
t_compute_max_allowed=t_compute_max_allowed,
fixed_seeds=fixed_seeds)
# Save movie
lb.write_movie_transition(fp_movie, duration_transition)

View File

@ -13,33 +13,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os, sys
import torch
torch.backends.cudnn.benchmark = False
import numpy as np
torch.set_grad_enabled(False)
import warnings
warnings.filterwarnings('ignore')
import warnings
import torch
from tqdm.auto import tqdm
from PIL import Image
import torch
from movie_util import MovieSaver, concatenate_movies
from typing import Callable, List, Optional, Union
from latent_blending import LatentBlending, add_frames_linear_interp
from latent_blending import LatentBlending
from stable_diffusion_holder import StableDiffusionHolder
torch.set_grad_enabled(False)
from movie_util import concatenate_movies
from huggingface_hub import hf_hub_download
#%% First let us spawn a stable diffusion holder
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
# fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
# %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
# fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.ckpt")
sdh = StableDiffusionHolder(fp_ckpt)
#%% Let's setup the multi transition
# %% Let's setup the multi transition
fps = 30
duration_single_trans = 6
depth_strength = 0.55 #Specifies how deep (in terms of diffusion iterations the first branching happens)
depth_strength = 0.55 # Specifies how deep (in terms of diffusion iterations the first branching happens)
# Specify a list of prompts below
list_prompts = []
@ -56,28 +49,25 @@ t_compute_max_allowed = 12 # per segment
fp_movie = 'movie_example2.mp4'
lb = LatentBlending(sdh)
list_movie_parts = [] #
for i in range(len(list_prompts)-1):
list_movie_parts = []
for i in range(len(list_prompts) - 1):
# For a multi transition we can save some computation time and recycle the latents
if i==0:
if i == 0:
lb.set_prompt1(list_prompts[i])
lb.set_prompt2(list_prompts[i+1])
lb.set_prompt2(list_prompts[i + 1])
recycle_img1 = False
else:
lb.swap_forward()
lb.set_prompt2(list_prompts[i+1])
lb.set_prompt2(list_prompts[i + 1])
recycle_img1 = True
fp_movie_part = f"tmp_part_{str(i).zfill(3)}.mp4"
fixed_seeds = list_seeds[i:i+2]
fixed_seeds = list_seeds[i:i + 2]
# Run latent blending
lb.run_transition(
depth_strength = depth_strength,
t_compute_max_allowed = t_compute_max_allowed,
fixed_seeds = fixed_seeds
)
depth_strength=depth_strength,
t_compute_max_allowed=t_compute_max_allowed,
fixed_seeds=fixed_seeds)
# Save movie
lb.write_movie_transition(fp_movie_part, duration_single_trans)

View File

@ -13,25 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os, sys
import torch
torch.backends.cudnn.benchmark = False
import numpy as np
torch.set_grad_enabled(False)
import warnings
warnings.filterwarnings('ignore')
import warnings
import torch
from tqdm.auto import tqdm
from PIL import Image
# import matplotlib.pyplot as plt
import torch
from movie_util import MovieSaver
from typing import Callable, List, Optional, Union
from latent_blending import LatentBlending, add_frames_linear_interp
from latent_blending import LatentBlending
from stable_diffusion_holder import StableDiffusionHolder
torch.set_grad_enabled(False)
from huggingface_hub import hf_hub_download
#%% Define vars for low-resoltion pass
# %% Define vars for low-resoltion pass
prompt1 = "photo of mount vesuvius erupting a terrifying pyroclastic ash cloud"
prompt2 = "photo of a inside a building full of ash, fire, death, destruction, explosions"
fixed_seeds = [5054613, 1168652]
@ -41,21 +33,18 @@ height = 384
num_inference_steps_lores = 40
nmb_max_branches_lores = 10
depth_strength_lores = 0.5
fp_ckpt_lores = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
fp_ckpt_lores = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
#%% Define vars for high-resoltion pass
fp_ckpt_hires = "../stable_diffusion_models/ckpt/x4-upscaler-ema.ckpt"
# %% Define vars for high-resoltion pass
fp_ckpt_hires = hf_hub_download(repo_id="stabilityai/stable-diffusion-x4-upscaler", filename="x4-upscaler-ema.ckpt")
depth_strength_hires = 0.65
num_inference_steps_hires = 100
nmb_branches_final_hires = 6
dp_imgs = "tmp_transition" # folder for results and intermediate steps
dp_imgs = "tmp_transition" # Folder for results and intermediate steps
#%% Run low-res pass
# %% Run low-res pass
sdh = StableDiffusionHolder(fp_ckpt_lores)
#%%
lb = LatentBlending(sdh)
lb.set_prompt1(prompt1)
lb.set_prompt2(prompt2)
@ -64,14 +53,13 @@ lb.set_height(height)
# Run latent blending
lb.run_transition(
depth_strength = depth_strength_lores,
nmb_max_branches = nmb_max_branches_lores,
fixed_seeds = fixed_seeds
)
depth_strength=depth_strength_lores,
nmb_max_branches=nmb_max_branches_lores,
fixed_seeds=fixed_seeds)
lb.write_imgs_transition(dp_imgs)
#%% Run high-res pass
# %% Run high-res pass
sdh = StableDiffusionHolder(fp_ckpt_hires)
lb = LatentBlending(sdh)
lb.run_upscaling(dp_imgs, depth_strength_hires, num_inference_steps_hires, nmb_branches_final_hires)

View File

@ -13,25 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os, sys
import os
import torch
torch.backends.cudnn.benchmark = False
import numpy as np
torch.set_grad_enabled(False)
import warnings
warnings.filterwarnings('ignore')
import warnings
import torch
from tqdm.auto import tqdm
from PIL import Image
# import matplotlib.pyplot as plt
import torch
from movie_util import MovieSaver, concatenate_movies
from typing import Callable, List, Optional, Union
from latent_blending import LatentBlending, add_frames_linear_interp
from latent_blending import LatentBlending
from stable_diffusion_holder import StableDiffusionHolder
torch.set_grad_enabled(False)
from movie_util import concatenate_movies
from huggingface_hub import hf_hub_download
#%% Define vars for low-resoltion pass
# %% Define vars for low-resoltion pass
list_prompts = []
list_prompts.append("surrealistic statue made of glitter and dirt, standing in a lake, atmospheric light, strange glow")
list_prompts.append("statue of a mix between a tree and human, made of marble, incredibly detailed")
@ -50,56 +44,54 @@ num_inference_steps_lores = 40
nmb_max_branches_lores = 10
depth_strength_lores = 0.5
fp_ckpt_lores = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
fp_ckpt_lores = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
#%% Define vars for high-resoltion pass
fp_ckpt_hires = "../stable_diffusion_models/ckpt/x4-upscaler-ema.ckpt"
# %% Define vars for high-resoltion pass
fp_ckpt_hires = hf_hub_download(repo_id="stabilityai/stable-diffusion-x4-upscaler", filename="x4-upscaler-ema.ckpt")
depth_strength_hires = 0.65
num_inference_steps_hires = 100
nmb_branches_final_hires = 6
#%% Run low-res pass
# %% Run low-res pass
sdh = StableDiffusionHolder(fp_ckpt_lores)
t_compute_max_allowed = 12 # per segment
t_compute_max_allowed = 12 # Per segment
lb = LatentBlending(sdh)
list_movie_dirs = [] #
for i in range(len(list_prompts)-1):
list_movie_dirs = []
for i in range(len(list_prompts) - 1):
# For a multi transition we can save some computation time and recycle the latents
if i==0:
if i == 0:
lb.set_prompt1(list_prompts[i])
lb.set_prompt2(list_prompts[i+1])
lb.set_prompt2(list_prompts[i + 1])
recycle_img1 = False
else:
lb.swap_forward()
lb.set_prompt2(list_prompts[i+1])
lb.set_prompt2(list_prompts[i + 1])
recycle_img1 = True
dp_movie_part = f"tmp_part_{str(i).zfill(3)}"
fp_movie_part = os.path.join(dp_movie_part, "movie_lowres.mp4")
os.makedirs(dp_movie_part, exist_ok=True)
fixed_seeds = list_seeds[i:i+2]
fixed_seeds = list_seeds[i:i + 2]
# Run latent blending
lb.run_transition(
depth_strength = depth_strength_lores,
nmb_max_branches = nmb_max_branches_lores,
fixed_seeds = fixed_seeds
)
depth_strength=depth_strength_lores,
nmb_max_branches=nmb_max_branches_lores,
fixed_seeds=fixed_seeds)
# Save movie and images (needed for upscaling!)
lb.write_movie_transition(fp_movie_part, duration_single_trans)
lb.write_imgs_transition(dp_movie_part)
list_movie_dirs.append(dp_movie_part)
#%% Run high-res pass on each segment
# %% Run high-res pass on each segment
sdh = StableDiffusionHolder(fp_ckpt_hires)
lb = LatentBlending(sdh)
for dp_part in list_movie_dirs:
lb.run_upscaling(dp_part, depth_strength_hires, num_inference_steps_hires, nmb_branches_final_hires)
#%% concatenate into one long movie
# %% concatenate into one long movie
list_fp_movies = []
for dp_part in list_movie_dirs:
fp_movie = os.path.join(dp_part, "movie_highres.mp4")

View File

@ -13,82 +13,89 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os, sys
import os
import torch
torch.backends.cudnn.benchmark = False
torch.set_grad_enabled(False)
import numpy as np
import warnings
warnings.filterwarnings('ignore')
import warnings
import torch
from tqdm.auto import tqdm
from PIL import Image
import torch
from movie_util import MovieSaver, concatenate_movies
from typing import Callable, List, Optional, Union
from latent_blending import get_time, yml_save, LatentBlending, add_frames_linear_interp, compare_dicts
from latent_blending import LatentBlending
from stable_diffusion_holder import StableDiffusionHolder
torch.set_grad_enabled(False)
import gradio as gr
import copy
from dotenv import find_dotenv, load_dotenv
import shutil
import random
import time
from utils import get_time, add_frames_linear_interp
from huggingface_hub import hf_hub_download
#%%
class BlendingFrontend():
def __init__(self, sdh=None):
self.num_inference_steps = 30
if sdh is None:
self.use_debug = True
self.height = 768
self.width = 768
else:
self.use_debug = False
self.lb = LatentBlending(sdh)
self.lb.sdh.num_inference_steps = self.num_inference_steps
self.height = self.lb.sdh.height
self.width = self.lb.sdh.width
def __init__(
self,
sdh,
share=False):
r"""
Gradio Helper Class to collect UI data and start latent blending.
Args:
sdh:
StableDiffusionHolder
share: bool
Set true to get a shareable gradio link (e.g. for running a remote server)
"""
self.share = share
self.init_save_dir()
self.save_empty_image()
self.share = False
self.transition_can_be_computed = False
# UI Defaults
self.num_inference_steps = 30
self.depth_strength = 0.25
self.seed1 = 420
self.seed2 = 420
self.guidance_scale = 4.0
self.guidance_scale_mid_damper = 0.5
self.mid_compression_scaler = 1.2
self.prompt1 = ""
self.prompt2 = ""
self.negative_prompt = ""
self.state_current = {}
self.fps = 30
self.duration_video = 8
self.t_compute_max_allowed = 10
self.lb = LatentBlending(sdh)
self.lb.sdh.num_inference_steps = self.num_inference_steps
self.init_parameters_from_lb()
self.init_save_dir()
# Vars
self.list_fp_imgs_current = []
self.recycle_img1 = False
self.recycle_img2 = False
self.list_all_segments = []
self.dp_session = ""
self.user_id = None
def init_parameters_from_lb(self):
r"""
Automatically init parameters from latentblending instance
"""
self.height = self.lb.sdh.height
self.width = self.lb.sdh.width
self.guidance_scale = self.lb.guidance_scale
self.guidance_scale_mid_damper = self.lb.guidance_scale_mid_damper
self.mid_compression_scaler = self.lb.mid_compression_scaler
self.branch1_crossfeed_power = self.lb.branch1_crossfeed_power
self.branch1_crossfeed_range = self.lb.branch1_crossfeed_range
self.branch1_crossfeed_decay = self.lb.branch1_crossfeed_decay
self.parental_crossfeed_power = self.lb.parental_crossfeed_power
self.parental_crossfeed_range = self.lb.parental_crossfeed_range
self.parental_crossfeed_power_decay = self.lb.parental_crossfeed_power_decay
self.fps = 30
self.duration_video = 10
self.t_compute_max_allowed = 10
self.list_fp_imgs_current = []
self.current_timestamp = None
self.recycle_img1 = False
self.recycle_img2 = False
self.multi_idx_current = -1
self.list_imgs_shown_last = 5*[self.fp_img_empty]
self.list_all_segments = []
self.dp_session = ""
self.user_id = None
self.block_transition = False
def init_save_dir(self):
r"""
Initializes the directory where stuff is being saved.
You can specify this directory in a ".env" file in your latentblending root, setting
DIR_OUT='/path/to/saving'
"""
load_dotenv(find_dotenv(), verbose=False)
self.dp_out = os.getenv("DIR_OUT")
if self.dp_out is None:
@ -97,124 +104,125 @@ class BlendingFrontend():
os.makedirs(self.dp_imgs, exist_ok=True)
self.dp_movies = os.path.join(self.dp_out, "movies")
os.makedirs(self.dp_movies, exist_ok=True)
self.save_empty_image()
# make dummy image
def save_empty_image(self):
r"""
Saves an empty/black dummy image.
"""
self.fp_img_empty = os.path.join(self.dp_imgs, 'empty.jpg')
Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5)
def randomize_seed1(self):
# Dont randomize seed if we are in a multi concat mode. we don't want to change this one otherwise the movie breaks
r"""
Randomizes the first seed
"""
seed = np.random.randint(0, 10000000)
self.seed1 = int(seed)
print(f"randomize_seed1: new seed = {self.seed1}")
return seed
def randomize_seed2(self):
r"""
Randomizes the second seed
"""
seed = np.random.randint(0, 10000000)
self.seed2 = int(seed)
print(f"randomize_seed2: new seed = {self.seed2}")
return seed
def setup_lb(self, list_ui_elem):
def setup_lb(self, list_ui_vals):
r"""
Sets all parameters from the UI. Since gradio does not support to pass dictionaries,
we have to instead pass keys (list_ui_keys, global) and values (list_ui_vals)
"""
# Collect latent blending variables
self.state_current = self.get_state_dict()
self.lb.set_width(list_ui_elem[list_ui_keys.index('width')])
self.lb.set_height(list_ui_elem[list_ui_keys.index('height')])
self.lb.set_prompt1(list_ui_elem[list_ui_keys.index('prompt1')])
self.lb.set_prompt2(list_ui_elem[list_ui_keys.index('prompt2')])
self.lb.set_negative_prompt(list_ui_elem[list_ui_keys.index('negative_prompt')])
self.lb.guidance_scale = list_ui_elem[list_ui_keys.index('guidance_scale')]
self.lb.guidance_scale_mid_damper = list_ui_elem[list_ui_keys.index('guidance_scale_mid_damper')]
self.t_compute_max_allowed = list_ui_elem[list_ui_keys.index('duration_compute')]
self.lb.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
self.lb.sdh.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
self.duration_video = list_ui_elem[list_ui_keys.index('duration_video')]
self.lb.seed1 = list_ui_elem[list_ui_keys.index('seed1')] #seed
self.lb.seed2 = list_ui_elem[list_ui_keys.index('seed2')]
self.lb.set_width(list_ui_vals[list_ui_keys.index('width')])
self.lb.set_height(list_ui_vals[list_ui_keys.index('height')])
self.lb.set_prompt1(list_ui_vals[list_ui_keys.index('prompt1')])
self.lb.set_prompt2(list_ui_vals[list_ui_keys.index('prompt2')])
self.lb.set_negative_prompt(list_ui_vals[list_ui_keys.index('negative_prompt')])
self.lb.guidance_scale = list_ui_vals[list_ui_keys.index('guidance_scale')]
self.lb.guidance_scale_mid_damper = list_ui_vals[list_ui_keys.index('guidance_scale_mid_damper')]
self.t_compute_max_allowed = list_ui_vals[list_ui_keys.index('duration_compute')]
self.lb.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
self.lb.sdh.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
self.duration_video = list_ui_vals[list_ui_keys.index('duration_video')]
self.lb.seed1 = list_ui_vals[list_ui_keys.index('seed1')]
self.lb.seed2 = list_ui_vals[list_ui_keys.index('seed2')]
self.lb.branch1_crossfeed_power = list_ui_vals[list_ui_keys.index('branch1_crossfeed_power')]
self.lb.branch1_crossfeed_range = list_ui_vals[list_ui_keys.index('branch1_crossfeed_range')]
self.lb.branch1_crossfeed_decay = list_ui_vals[list_ui_keys.index('branch1_crossfeed_decay')]
self.lb.parental_crossfeed_power = list_ui_vals[list_ui_keys.index('parental_crossfeed_power')]
self.lb.parental_crossfeed_range = list_ui_vals[list_ui_keys.index('parental_crossfeed_range')]
self.lb.parental_crossfeed_power_decay = list_ui_vals[list_ui_keys.index('parental_crossfeed_power_decay')]
self.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
self.depth_strength = list_ui_vals[list_ui_keys.index('depth_strength')]
self.lb.branch1_crossfeed_power = list_ui_elem[list_ui_keys.index('branch1_crossfeed_power')]
self.lb.branch1_crossfeed_range = list_ui_elem[list_ui_keys.index('branch1_crossfeed_range')]
self.lb.branch1_crossfeed_decay = list_ui_elem[list_ui_keys.index('branch1_crossfeed_decay')]
self.lb.parental_crossfeed_power = list_ui_elem[list_ui_keys.index('parental_crossfeed_power')]
self.lb.parental_crossfeed_range = list_ui_elem[list_ui_keys.index('parental_crossfeed_range')]
self.lb.parental_crossfeed_power_decay = list_ui_elem[list_ui_keys.index('parental_crossfeed_power_decay')]
self.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
self.depth_strength = list_ui_elem[list_ui_keys.index('depth_strength')]
if len(list_ui_elem[list_ui_keys.index('user_id')]) > 1:
self.user_id = list_ui_elem[list_ui_keys.index('user_id')]
if len(list_ui_vals[list_ui_keys.index('user_id')]) > 1:
self.user_id = list_ui_vals[list_ui_keys.index('user_id')]
else:
# generate new user id
self.user_id = ''.join((random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ') for i in range(8)))
print(f"made new user_id: {self.user_id}")
print(f"made new user_id: {self.user_id} at {get_time('second')}")
def save_latents(self, fp_latents, list_latents):
r"""
Saves a latent trajectory on disk, in npy format.
"""
list_latents_cpu = [l.cpu().numpy() for l in list_latents]
np.save(fp_latents, list_latents_cpu)
def load_latents(self, fp_latents):
r"""
Loads a latent trajectory from disk, converts to torch tensor.
"""
list_latents_cpu = np.load(fp_latents)
list_latents = [torch.from_numpy(l).to(self.lb.device) for l in list_latents_cpu]
return list_latents
def compute_img1(self, *args):
list_ui_elem = args
self.setup_lb(list_ui_elem)
r"""
Computes the first transition image and returns it for display.
Sets all other transition images and last image to empty (as they are obsolete with this operation)
"""
list_ui_vals = args
self.setup_lb(list_ui_vals)
fp_img1 = os.path.join(self.dp_imgs, f"img1_{self.user_id}")
img1 = Image.fromarray(self.lb.compute_latents1(return_image=True))
img1.save(fp_img1+".jpg")
self.save_latents(fp_img1+".npy", self.lb.tree_latents[0])
img1.save(fp_img1 + ".jpg")
self.save_latents(fp_img1 + ".npy", self.lb.tree_latents[0])
self.recycle_img1 = True
self.recycle_img2 = False
# fixme save seeds. change filenames?
return [fp_img1+".jpg", self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
return [fp_img1 + ".jpg", self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
def compute_img2(self, *args):
r"""
Computes the last transition image and returns it for display.
Sets all other transition images to empty (as they are obsolete with this operation)
"""
if not os.path.isfile(os.path.join(self.dp_imgs, f"img1_{self.user_id}.jpg")): # don't do anything
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
list_ui_elem = args
self.setup_lb(list_ui_elem)
list_ui_vals = args
self.setup_lb(list_ui_vals)
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
fp_img2 = os.path.join(self.dp_imgs, f"img2_{self.user_id}")
img2 = Image.fromarray(self.lb.compute_latents2(return_image=True))
img2.save(fp_img2+'.jpg')
self.save_latents(fp_img2+".npy", self.lb.tree_latents[-1])
img2.save(fp_img2 + '.jpg')
self.save_latents(fp_img2 + ".npy", self.lb.tree_latents[-1])
self.recycle_img2 = True
self.transition_can_be_computed = True
# fixme save seeds. change filenames?
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2+".jpg", self.user_id]
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2 + ".jpg", self.user_id]
def compute_transition(self, *args):
if not self.transition_can_be_computed:
list_return = [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
return list_return
list_ui_elem = args
self.setup_lb(list_ui_elem)
r"""
Computes transition images and movie.
"""
list_ui_vals = args
self.setup_lb(list_ui_vals)
print("STARTING TRANSITION...")
fixed_seeds = [self.seed1, self.seed2]
# Run Latent Blending
# Check if another user is blocking this... otherwise everything will become mixed.
# t_now = time.time()
# if self.block_transition:
# while True:
# time.sleep(1)
# if not self.block_transition:
# break
# if time.time() - t_now > 1000:
# return
self.block_transition = True
# Inject loaded latents (other user interference)
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
self.lb.tree_latents[-1] = self.load_latents(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"))
@ -224,24 +232,23 @@ class BlendingFrontend():
num_inference_steps=self.num_inference_steps,
depth_strength=self.depth_strength,
t_compute_max_allowed=self.t_compute_max_allowed,
fixed_seeds=fixed_seeds
)
print(f"Latent Blending pass finished. Resulted in {len(imgs_transition)} images")
fixed_seeds=fixed_seeds)
print(f"Latent Blending pass finished ({get_time('second')}). Resulted in {len(imgs_transition)} images")
# Subselect three preview images
idx_img_prev = np.round(np.linspace(0, len(imgs_transition)-1, 5)[1:-1]).astype(np.int32)
idx_img_prev = np.round(np.linspace(0, len(imgs_transition) - 1, 5)[1:-1]).astype(np.int32)
list_imgs_preview = []
for j in idx_img_prev:
list_imgs_preview.append(Image.fromarray(imgs_transition[j]))
# Save the preview imgs as jpgs on disk so we are not sending umcompressed data around
self.current_timestamp = get_time('second')
current_timestamp = get_time('second')
self.list_fp_imgs_current = []
for i in range(len(list_imgs_preview)):
fp_img = os.path.join(self.dp_imgs, f"img_preview_{i}_{self.current_timestamp}.jpg")
fp_img = os.path.join(self.dp_imgs, f"img_preview_{i}_{current_timestamp}.jpg")
list_imgs_preview[i].save(fp_img)
self.list_fp_imgs_current.append(fp_img)
self.block_transition = False
# Insert cheap frames for the movie
imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration_video, self.fps)
@ -259,16 +266,17 @@ class BlendingFrontend():
list_return = self.list_fp_imgs_current + [self.fp_movie]
return list_return
def stack_forward(self, prompt2, seed2):
r"""
Allows to generate multi-segment movies. Sets last image -> first image with all
relevant parameters.
"""
# Save preview images, prompts and seeds into dictionary for stacking
if len(self.list_all_segments) == 0:
timestamp_session = get_time('second')
self.dp_session = os.path.join(self.dp_out, f"session_{timestamp_session}")
os.makedirs(self.dp_session)
self.transition_can_be_computed = False
idx_segment = len(self.list_all_segments)
dp_segment = os.path.join(self.dp_session, f"segment_{str(idx_segment).zfill(3)}")
@ -285,13 +293,11 @@ class BlendingFrontend():
self.lb.swap_forward()
shutil.copyfile(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"), os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
fp_multi = self.multi_concat()
list_out = [fp_multi]
list_out.extend([os.path.join(self.dp_imgs, f"img2_{self.user_id}.jpg")])
list_out.extend([self.fp_img_empty]*4)
list_out.extend([self.fp_img_empty] * 4)
list_out.append(gr.update(interactive=False, value=prompt2))
list_out.append(gr.update(interactive=False, value=seed2))
list_out.append("")
@ -299,16 +305,20 @@ class BlendingFrontend():
print(f"stack_forward: fp_multi {fp_multi}")
return list_out
def multi_concat(self):
r"""
Concatentates all stacked segments into one long movie.
"""
list_fp_movies = self.get_fp_video_all()
# Concatenate movies and save
fp_final = os.path.join(self.dp_session, f"concat_{self.user_id}.mp4")
concatenate_movies(fp_final, list_fp_movies)
return fp_final
def get_fp_video_all(self):
r"""
Collects all stacked movie segments.
"""
list_all = os.listdir(self.dp_movies)
str_beg = f"movie_{self.user_id}_"
list_user = [l for l in list_all if str_beg in l]
@ -316,8 +326,10 @@ class BlendingFrontend():
list_user = [os.path.join(self.dp_movies, l) for l in list_user]
return list_user
def get_fp_video_next(self):
r"""
Gets the filepath of the next movie segment.
"""
list_videos = self.get_fp_video_all()
if len(list_videos) == 0:
idx_next = 0
@ -327,26 +339,16 @@ class BlendingFrontend():
return fp_video_next
def get_fp_video_last(self):
r"""
Gets the current video that was saved.
"""
fp_video_last = os.path.join(self.dp_movies, f"last_{self.user_id}.mp4")
return fp_video_last
def get_state_dict(self):
state_dict = {}
grab_vars = ['prompt1', 'prompt2', 'seed1', 'seed2', 'height', 'width',
'num_inference_steps', 'depth_strength', 'guidance_scale',
'guidance_scale_mid_damper', 'mid_compression_scaler']
for v in grab_vars:
state_dict[v] = getattr(self, v)
return state_dict
if __name__ == "__main__":
# fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
# fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.ckpt")
bf = BlendingFrontend(StableDiffusionHolder(fp_ckpt))
# self = BlendingFrontend(None)
@ -391,7 +393,6 @@ if __name__ == "__main__":
depth_strength = gr.Slider(0.01, 0.99, bf.depth_strength, step=0.01, label='depth_strength', interactive=True)
guidance_scale_mid_damper = gr.Slider(0.01, 2.0, bf.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
with gr.Row():
b_compute1 = gr.Button('compute first image', variant='primary')
b_compute_transition = gr.Button('compute transition', variant='primary')
@ -405,11 +406,10 @@ if __name__ == "__main__":
img5 = gr.Image(label="5/5")
with gr.Row():
vid_single = gr.Video(label="single trans")
vid_multi = gr.Video(label="multi trans")
vid_single = gr.Video(label="current single trans")
vid_multi = gr.Video(label="concatented multi trans")
with gr.Row():
# b_restart = gr.Button("RESTART EVERYTHING")
b_stackforward = gr.Button('append last movie segment (left) to multi movie (right)', variant='primary')
with gr.Row():
@ -437,8 +437,7 @@ if __name__ == "__main__":
- parental_crossfeed_power_decay: Similar to branch1_crossfeed_decay, however applied for the images withinin the transition.
- depth_strength: Determines when the blending process will begin in terms of diffusion steps. Low values more inventive but can cause motion.
- guidance_scale_mid_damper: Decreases the guidance scale in the middle of a transition.
"""
)
""")
with gr.Row():
user_id = gr.Textbox(label="user id", interactive=False)
@ -471,24 +470,23 @@ if __name__ == "__main__":
dict_ui_elem["user_id"] = user_id
# Convert to list, as gradio doesn't seem to accept dicts
list_ui_elem = []
list_ui_vals = []
list_ui_keys = []
for k in dict_ui_elem.keys():
list_ui_elem.append(dict_ui_elem[k])
list_ui_vals.append(dict_ui_elem[k])
list_ui_keys.append(k)
bf.list_ui_keys = list_ui_keys
b_newseed1.click(bf.randomize_seed1, outputs=seed1)
b_newseed2.click(bf.randomize_seed2, outputs=seed2)
b_compute1.click(bf.compute_img1, inputs=list_ui_elem, outputs=[img1, img2, img3, img4, img5, user_id])
b_compute2.click(bf.compute_img2, inputs=list_ui_elem, outputs=[img2, img3, img4, img5, user_id])
b_compute1.click(bf.compute_img1, inputs=list_ui_vals, outputs=[img1, img2, img3, img4, img5, user_id])
b_compute2.click(bf.compute_img2, inputs=list_ui_vals, outputs=[img2, img3, img4, img5, user_id])
b_compute_transition.click(bf.compute_transition,
inputs=list_ui_elem,
inputs=list_ui_vals,
outputs=[img2, img3, img4, vid_single])
b_stackforward.click(bf.stack_forward,
inputs=[prompt2, seed2],
outputs=[vid_multi, img1, img2, img3, img4, img5, prompt1, seed1, prompt2])
demo.launch(share=bf.share, inbrowser=True, inline=False)

View File

@ -13,41 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os, sys
import os
import torch
torch.backends.cudnn.benchmark = False
torch.set_grad_enabled(False)
import numpy as np
import warnings
warnings.filterwarnings('ignore')
import time
import subprocess
import warnings
from tqdm.auto import tqdm
from PIL import Image
# import matplotlib.pyplot as plt
from movie_util import MovieSaver
import datetime
from typing import Callable, List, Optional, Union
import inspect
from threading import Thread
torch.set_grad_enabled(False)
from contextlib import nullcontext
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config
from typing import List, Optional
from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentInpaintDiffusion
from stable_diffusion_holder import StableDiffusionHolder
import yaml
import lpips
#%%
from utils import interpolate_spherical, interpolate_linear, add_frames_linear_interp, yml_load, yml_save
class LatentBlending():
def __init__(
self,
sdh: None,
guidance_scale: float = 4,
guidance_scale_mid_damper: float = 0.5,
mid_compression_scaler: float = 1.2,
):
mid_compression_scaler: float = 1.2):
r"""
Initializes the latent blending class.
Args:
@ -64,9 +54,10 @@ class LatentBlending():
Increases the sampling density in the middle (where most changes happen). Higher value
imply more values in the middle. However the inflection point can occur outside the middle,
thus high values can give rough transitions. Values around 2 should be fine.
"""
assert guidance_scale_mid_damper>0 and guidance_scale_mid_damper<=1.0, f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}"
assert guidance_scale_mid_damper > 0 \
and guidance_scale_mid_damper <= 1.0, \
f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}"
self.sdh = sdh
self.device = self.sdh.device
@ -115,10 +106,8 @@ class LatentBlending():
self.multi_transition_img_last = None
self.dt_per_diff = 0
self.spatial_mask = None
self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
def init_mode(self):
r"""
Sets the operational mode. Currently supported are standard, inpainting and x4 upscaling.
@ -151,13 +140,12 @@ class LatentBlending():
Tunes the guidance scale down as a linear function of fract_mixing,
towards 0.5 the minimum will be reached.
"""
mid_factor = 1 - np.abs(fract_mixing - 0.5)/ 0.5
max_guidance_reduction = self.guidance_scale_base * (1-self.guidance_scale_mid_damper) - 1
guidance_scale_effective = self.guidance_scale_base - max_guidance_reduction*mid_factor
mid_factor = 1 - np.abs(fract_mixing - 0.5) / 0.5
max_guidance_reduction = self.guidance_scale_base * (1 - self.guidance_scale_mid_damper) - 1
guidance_scale_effective = self.guidance_scale_base - max_guidance_reduction * mid_factor
self.guidance_scale = guidance_scale_effective
self.sdh.guidance_scale = guidance_scale_effective
def set_branch1_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay):
r"""
Sets the crossfeed parameters for the first branch to the last branch.
@ -173,7 +161,6 @@ class LatentBlending():
self.branch1_crossfeed_range = np.clip(crossfeed_range, 0, 1)
self.branch1_crossfeed_decay = np.clip(crossfeed_decay, 0, 1)
def set_parental_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay):
r"""
Sets the crossfeed parameters for all transition images (within the first and last branch).
@ -189,7 +176,6 @@ class LatentBlending():
self.parental_crossfeed_range = np.clip(crossfeed_range, 0, 1)
self.parental_crossfeed_power_decay = np.clip(crossfeed_decay, 0, 1)
def set_prompt1(self, prompt: str):
r"""
Sets the first prompt (for the first keyframe) including text embeddings.
@ -201,7 +187,6 @@ class LatentBlending():
self.prompt1 = prompt
self.text_embedding1 = self.get_text_embeddings(self.prompt1)
def set_prompt2(self, prompt: str):
r"""
Sets the second prompt (for the second keyframe) including text embeddings.
@ -237,8 +222,7 @@ class LatentBlending():
depth_strength: Optional[float] = 0.3,
t_compute_max_allowed: Optional[float] = None,
nmb_max_branches: Optional[int] = None,
fixed_seeds: Optional[List[int]] = None,
):
fixed_seeds: Optional[List[int]] = None):
r"""
Function for computing transitions.
Returns a list of transition images using spherical latent blending.
@ -263,7 +247,6 @@ class LatentBlending():
fixed_seeds: Optional[List[int)]:
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
Otherwise random seeds will be taken.
"""
# Sanity checks first
@ -275,7 +258,7 @@ class LatentBlending():
if fixed_seeds == 'randomize':
fixed_seeds = list(np.random.randint(0, 1000000, 2).astype(np.int32))
else:
assert len(fixed_seeds)==2, "Supply a list with len = 2"
assert len(fixed_seeds) == 2, "Supply a list with len = 2"
self.seed1 = fixed_seeds[0]
self.seed2 = fixed_seeds[1]
@ -323,7 +306,6 @@ class LatentBlending():
return self.tree_final_imgs
def compute_latents1(self, return_image=False):
r"""
Runs a diffusion trajectory for the first image
@ -337,11 +319,10 @@ class LatentBlending():
latents_start = self.get_noise(self.seed1)
list_latents1 = self.run_diffusion(
list_conditionings,
latents_start = latents_start,
idx_start = 0
)
latents_start=latents_start,
idx_start=0)
t1 = time.time()
self.dt_per_diff = (t1-t0) / self.num_inference_steps
self.dt_per_diff = (t1 - t0) / self.num_inference_steps
self.tree_latents[0] = list_latents1
if return_image:
return self.sdh.latent2image(list_latents1[-1])
@ -361,17 +342,16 @@ class LatentBlending():
# Influence from branch1
if self.branch1_crossfeed_power > 0.0:
# Set up the mixing_coeffs
idx_mixing_stop = int(round(self.num_inference_steps*self.branch1_crossfeed_range))
mixing_coeffs = list(np.linspace(self.branch1_crossfeed_power, self.branch1_crossfeed_power*self.branch1_crossfeed_decay, idx_mixing_stop))
mixing_coeffs.extend((self.num_inference_steps-idx_mixing_stop)*[0])
idx_mixing_stop = int(round(self.num_inference_steps * self.branch1_crossfeed_range))
mixing_coeffs = list(np.linspace(self.branch1_crossfeed_power, self.branch1_crossfeed_power * self.branch1_crossfeed_decay, idx_mixing_stop))
mixing_coeffs.extend((self.num_inference_steps - idx_mixing_stop) * [0])
list_latents_mixing = self.tree_latents[0]
list_latents2 = self.run_diffusion(
list_conditionings,
latents_start = latents_start,
idx_start = 0,
list_latents_mixing = list_latents_mixing,
mixing_coeffs = mixing_coeffs
)
latents_start=latents_start,
idx_start=0,
list_latents_mixing=list_latents_mixing,
mixing_coeffs=mixing_coeffs)
else:
list_latents2 = self.run_diffusion(list_conditionings, latents_start)
self.tree_latents[-1] = list_latents2
@ -381,7 +361,6 @@ class LatentBlending():
else:
return list_latents2
def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection):
r"""
Runs a diffusion trajectory, using the latents from the respective parents
@ -409,22 +388,19 @@ class LatentBlending():
latents_parental = interpolate_spherical(latents_p1, latents_p2, fract_mixing_parental)
list_latents_parental_mix.append(latents_parental)
idx_mixing_stop = int(round(self.num_inference_steps*self.parental_crossfeed_range))
mixing_coeffs = idx_injection*[self.parental_crossfeed_power]
idx_mixing_stop = int(round(self.num_inference_steps * self.parental_crossfeed_range))
mixing_coeffs = idx_injection * [self.parental_crossfeed_power]
nmb_mixing = idx_mixing_stop - idx_injection
if nmb_mixing > 0:
mixing_coeffs.extend(list(np.linspace(self.parental_crossfeed_power, self.parental_crossfeed_power*self.parental_crossfeed_power_decay, nmb_mixing)))
mixing_coeffs.extend((self.num_inference_steps-len(mixing_coeffs))*[0])
latents_start = list_latents_parental_mix[idx_injection-1]
mixing_coeffs.extend(list(np.linspace(self.parental_crossfeed_power, self.parental_crossfeed_power * self.parental_crossfeed_power_decay, nmb_mixing)))
mixing_coeffs.extend((self.num_inference_steps - len(mixing_coeffs)) * [0])
latents_start = list_latents_parental_mix[idx_injection - 1]
list_latents = self.run_diffusion(
list_conditionings,
latents_start = latents_start,
idx_start = idx_injection,
list_latents_mixing = list_latents_parental_mix,
mixing_coeffs = mixing_coeffs
)
latents_start=latents_start,
idx_start=idx_injection,
list_latents_mixing=list_latents_parental_mix,
mixing_coeffs=mixing_coeffs)
return list_latents
def get_time_based_branching(self, depth_strength, t_compute_max_allowed=None, nmb_max_branches=None):
@ -445,8 +421,8 @@ class LatentBlending():
results. Use this if you want to have controllable results independent
of your computer.
"""
idx_injection_base = int(round(self.num_inference_steps*depth_strength))
list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps-1, 3)
idx_injection_base = int(round(self.num_inference_steps * depth_strength))
list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps - 1, 3)
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
t_compute = 0
@ -456,20 +432,18 @@ class LatentBlending():
elif t_compute_max_allowed is None:
assert nmb_max_branches is not None, "Either specify t_compute_max_allowed or nmb_max_branches"
stop_criterion = "nmb_max_branches"
nmb_max_branches -= 2 # discounting the outer frames
nmb_max_branches -= 2 # Discounting the outer frames
else:
raise ValueError("Either specify t_compute_max_allowed or nmb_max_branches")
stop_criterion_reached = False
is_first_iteration = True
while not stop_criterion_reached:
list_compute_steps = self.num_inference_steps - list_idx_injection
list_compute_steps *= list_nmb_stems
t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15*np.sum(list_nmb_stems)
t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15 * np.sum(list_nmb_stems)
increase_done = False
for s_idx in range(len(list_nmb_stems)-1):
if list_nmb_stems[s_idx+1] / list_nmb_stems[s_idx] >= 2:
for s_idx in range(len(list_nmb_stems) - 1):
if list_nmb_stems[s_idx + 1] / list_nmb_stems[s_idx] >= 2:
list_nmb_stems[s_idx] += 1
increase_done = True
break
@ -501,10 +475,10 @@ class LatentBlending():
"""
# get_lpips_similarity
similarities = []
for i in range(len(self.tree_final_imgs)-1):
similarities.append(self.get_lpips_similarity(self.tree_final_imgs[i], self.tree_final_imgs[i+1]))
for i in range(len(self.tree_final_imgs) - 1):
similarities.append(self.get_lpips_similarity(self.tree_final_imgs[i], self.tree_final_imgs[i + 1]))
b_closest1 = np.argmax(similarities)
b_closest2 = b_closest1+1
b_closest2 = b_closest1 + 1
fract_closest1 = self.tree_fracts[b_closest1]
fract_closest2 = self.tree_fracts[b_closest2]
@ -515,23 +489,15 @@ class LatentBlending():
break
else:
b_parent1 -= 1
b_parent2 = b_closest2
while True:
if self.tree_idx_injection[b_parent2] < idx_injection:
break
else:
b_parent2 += 1
# print(f"\n\nb_closest: {b_closest1} {b_closest2} fract_closest1 {fract_closest1} fract_closest2 {fract_closest2}")
# print(f"b_parent: {b_parent1} {b_parent2}")
# print(f"similarities {similarities}")
# print(f"idx_injection {idx_injection} tree_idx_injection {self.tree_idx_injection}")
fract_mixing = (fract_closest1 + fract_closest2) /2
fract_mixing = (fract_closest1 + fract_closest2) / 2
return fract_mixing, b_parent1, b_parent2
def insert_into_tree(self, fract_mixing, idx_injection, list_latents):
r"""
Inserts all necessary parameters into the trajectory tree.
@ -543,12 +509,11 @@ class LatentBlending():
list_latents: list
list of the latents to be inserted
"""
b_parent1, b_parent2 = get_closest_idx(fract_mixing, self.tree_fracts)
self.tree_latents.insert(b_parent1+1, list_latents)
self.tree_final_imgs.insert(b_parent1+1, self.sdh.latent2image(list_latents[-1]))
self.tree_fracts.insert(b_parent1+1, fract_mixing)
self.tree_idx_injection.insert(b_parent1+1, idx_injection)
b_parent1, b_parent2 = self.get_closest_idx(fract_mixing)
self.tree_latents.insert(b_parent1 + 1, list_latents)
self.tree_final_imgs.insert(b_parent1 + 1, self.sdh.latent2image(list_latents[-1]))
self.tree_fracts.insert(b_parent1 + 1, fract_mixing)
self.tree_idx_injection.insert(b_parent1 + 1, idx_injection)
def get_spatial_mask_template(self):
r"""
@ -565,9 +530,7 @@ class LatentBlending():
Args:
img_mask:
mask image [0,1]. You can get a template using get_spatial_mask_template
"""
shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
C, H, W = shape_latents
img_mask = np.asarray(img_mask)
@ -577,18 +540,15 @@ class LatentBlending():
assert img_mask.shape[1] == W, f"Your mask needs to be of dimension {H} x {W}"
spatial_mask = torch.from_numpy(img_mask).to(device=self.device)
spatial_mask = torch.unsqueeze(spatial_mask, 0)
spatial_mask = spatial_mask.repeat((C,1,1))
spatial_mask = spatial_mask.repeat((C, 1, 1))
spatial_mask = torch.unsqueeze(spatial_mask, 0)
self.spatial_mask = spatial_mask
def get_noise(self, seed):
r"""
Helper function to get noise given seed.
Args:
seed: int
"""
generator = torch.Generator(device=self.sdh.device).manual_seed(int(seed))
if self.mode == 'standard':
@ -599,21 +559,17 @@ class LatentBlending():
h = self.image1_lowres.size[1]
shape_latents = [self.sdh.model.channels, h, w]
C, H, W = shape_latents
return torch.randn((1, C, H, W), generator=generator, device=self.sdh.device)
@torch.no_grad()
def run_diffusion(
self,
list_conditionings,
latents_start: torch.FloatTensor = None,
idx_start: int = 0,
list_latents_mixing = None,
mixing_coeffs = 0.0,
return_image: Optional[bool] = False
):
list_latents_mixing=None,
mixing_coeffs=0.0,
return_image: Optional[bool] = False):
r"""
Wrapper function for diffusion runners.
Depending on the mode, the correct one will be executed.
@ -640,14 +596,13 @@ class LatentBlending():
if self.mode == 'standard':
text_embeddings = list_conditionings[0]
return self.sdh.run_diffusion_standard(
text_embeddings = text_embeddings,
latents_start = latents_start,
idx_start = idx_start,
list_latents_mixing = list_latents_mixing,
mixing_coeffs = mixing_coeffs,
spatial_mask = self.spatial_mask,
return_image = return_image,
)
text_embeddings=text_embeddings,
latents_start=latents_start,
idx_start=idx_start,
list_latents_mixing=list_latents_mixing,
mixing_coeffs=mixing_coeffs,
spatial_mask=self.spatial_mask,
return_image=return_image)
elif self.mode == 'upscale':
cond = list_conditionings[0]
@ -657,11 +612,10 @@ class LatentBlending():
uc_full,
latents_start=latents_start,
idx_start=idx_start,
list_latents_mixing = list_latents_mixing,
mixing_coeffs = mixing_coeffs,
list_latents_mixing=list_latents_mixing,
mixing_coeffs=mixing_coeffs,
return_image=return_image)
def run_upscaling(
self,
dp_img: str,
@ -669,9 +623,9 @@ class LatentBlending():
num_inference_steps: int = 100,
nmb_max_branches_highres: int = 5,
nmb_max_branches_lowres: int = 6,
duration_single_segment = 3,
fixed_seeds: Optional[List[int]] = None,
):
duration_single_segment=3,
fps=24,
fixed_seeds: Optional[List[int]] = None):
r"""
Runs upscaling with the x4 model. Requires that you run a transition before with a low-res model and save the results using write_imgs_transition.
@ -692,13 +646,14 @@ class LatentBlending():
Setting this number lower (e.g. 6) will decrease the compute time but not affect the results too much.
duration_single_segment: float
The duration of each high-res movie segment. You will have nmb_max_branches_lowres-1 segments in total.
fps: float
frames per second of movie
fixed_seeds: Optional[List[int)]:
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
Otherwise random seeds will be taken.
"""
fp_yml = os.path.join(dp_img, "lowres.yaml")
fp_movie = os.path.join(dp_img, "movie_highres.mp4")
fps = 24
ms = MovieSaver(fp_movie, fps=fps)
assert os.path.isfile(fp_yml), "lowres.yaml does not exist. did you forget run_upscaling_step1?"
dict_stuff = yml_load(fp_yml)
@ -707,53 +662,43 @@ class LatentBlending():
nmb_images_lowres = dict_stuff['nmb_images']
prompt1 = dict_stuff['prompt1']
prompt2 = dict_stuff['prompt2']
idx_img_lowres = np.round(np.linspace(0, nmb_images_lowres-1, nmb_max_branches_lowres)).astype(np.int32)
idx_img_lowres = np.round(np.linspace(0, nmb_images_lowres - 1, nmb_max_branches_lowres)).astype(np.int32)
imgs_lowres = []
for i in idx_img_lowres:
fp_img_lowres = os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg")
assert os.path.isfile(fp_img_lowres), f"{fp_img_lowres} does not exist. did you forget run_upscaling_step1?"
imgs_lowres.append(Image.open(fp_img_lowres))
# set up upscaling
text_embeddingA = self.sdh.get_text_embedding(prompt1)
text_embeddingB = self.sdh.get_text_embedding(prompt2)
list_fract_mixing = np.linspace(0, 1, nmb_max_branches_lowres-1)
for i in range(nmb_max_branches_lowres-1):
list_fract_mixing = np.linspace(0, 1, nmb_max_branches_lowres - 1)
for i in range(nmb_max_branches_lowres - 1):
print(f"Starting movie segment {i+1}/{nmb_max_branches_lowres-1}")
self.text_embedding1 = interpolate_linear(text_embeddingA, text_embeddingB, list_fract_mixing[i])
self.text_embedding2 = interpolate_linear(text_embeddingA, text_embeddingB, 1-list_fract_mixing[i])
if i==0:
self.text_embedding2 = interpolate_linear(text_embeddingA, text_embeddingB, 1 - list_fract_mixing[i])
if i == 0:
recycle_img1 = False
else:
self.swap_forward()
recycle_img1 = True
self.set_image1(imgs_lowres[i])
self.set_image2(imgs_lowres[i+1])
self.set_image2(imgs_lowres[i + 1])
list_imgs = self.run_transition(
recycle_img1 = recycle_img1,
recycle_img2 = False,
num_inference_steps = num_inference_steps,
depth_strength = depth_strength,
nmb_max_branches = nmb_max_branches_highres,
)
recycle_img1=recycle_img1,
recycle_img2=False,
num_inference_steps=num_inference_steps,
depth_strength=depth_strength,
nmb_max_branches=nmb_max_branches_highres)
list_imgs_interp = add_frames_linear_interp(list_imgs, fps, duration_single_segment)
# Save movie frame
for img in list_imgs_interp:
ms.write_frame(img)
ms.finalize()
@torch.no_grad()
def get_mixed_conditioning(self, fract_mixing):
if self.mode == 'standard':
@ -776,8 +721,7 @@ class LatentBlending():
@torch.no_grad()
def get_text_embeddings(
self,
prompt: str
):
prompt: str):
r"""
Computes the text embeddings provided a string with a prompts.
Adapted from stable diffusion repo
@ -785,10 +729,8 @@ class LatentBlending():
prompt: str
ABC trending on artstation painted by Old Greg.
"""
return self.sdh.get_text_embedding(prompt)
def write_imgs_transition(self, dp_img):
r"""
Writes the transition images into the folder dp_img.
@ -802,7 +744,6 @@ class LatentBlending():
for i, img in enumerate(imgs_transition):
img_leaf = Image.fromarray(img)
img_leaf.save(os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg"))
fp_yml = os.path.join(dp_img, "lowres.yaml")
self.save_statedict(fp_yml)
@ -817,7 +758,6 @@ class LatentBlending():
duration of the movie in seonds
fps: int
fps of the movie
"""
# Let's get more cheap frames via linear interpolation (duration_transition*fps frames)
@ -831,8 +771,6 @@ class LatentBlending():
ms.write_frame(img)
ms.finalize()
def save_statedict(self, fp_yml):
# Dump everything relevant into yaml
imgs_transition = self.tree_final_imgs
@ -857,9 +795,8 @@ class LatentBlending():
else:
try:
state_dict[v] = getattr(self, v)
except Exception as e:
except Exception:
pass
return state_dict
def randomize_seed(self):
@ -892,7 +829,6 @@ class LatentBlending():
self.height = height
self.sdh.height = height
def swap_forward(self):
r"""
Moves over keyframe two -> keyframe one. Useful for making a sequence of transitions
@ -900,15 +836,12 @@ class LatentBlending():
"""
# Move over all latents
self.tree_latents[0] = self.tree_latents[-1]
# Move over prompts and text embeddings
self.prompt1 = self.prompt2
self.text_embedding1 = self.text_embedding2
# Final cleanup for extra sanity
self.tree_final_imgs = []
def get_lpips_similarity(self, imgA, imgB):
r"""
Computes the image similarity between two images imgA and imgB.
@ -916,36 +849,32 @@ class LatentBlending():
High values indicate low similarity.
"""
tensorA = torch.from_numpy(imgA).float().cuda(self.device)
tensorA = 2*tensorA/255.0 - 1
tensorA = tensorA.permute([2,0,1]).unsqueeze(0)
tensorA = 2 * tensorA / 255.0 - 1
tensorA = tensorA.permute([2, 0, 1]).unsqueeze(0)
tensorB = torch.from_numpy(imgB).float().cuda(self.device)
tensorB = 2*tensorB/255.0 - 1
tensorB = tensorB.permute([2,0,1]).unsqueeze(0)
tensorB = 2 * tensorB / 255.0 - 1
tensorB = tensorB.permute([2, 0, 1]).unsqueeze(0)
lploss = self.lpips(tensorA, tensorB)
lploss = float(lploss[0][0][0][0])
return lploss
# Auxiliary functions
def get_closest_idx(
fract_mixing: float,
list_fract_mixing_prev: List[float],
):
# Auxiliary functions
def get_closest_idx(
self,
fract_mixing: float):
r"""
Helper function to retrieve the parents for any given mixing.
Example: fract_mixing = 0.4 and list_fract_mixing_prev = [0, 0.3, 0.6, 1.0]
Will return the two closest values from list_fract_mixing_prev, i.e. [1, 2]
Example: fract_mixing = 0.4 and self.tree_fracts = [0, 0.3, 0.6, 1.0]
Will return the two closest values here, i.e. [1, 2]
"""
pdist = fract_mixing - np.asarray(list_fract_mixing_prev)
pdist = fract_mixing - np.asarray(self.tree_fracts)
pdist_pos = pdist.copy()
pdist_pos[pdist_pos<0] = np.inf
pdist_pos[pdist_pos < 0] = np.inf
b_parent1 = np.argmin(pdist_pos)
pdist_neg = -pdist.copy()
pdist_neg[pdist_neg<=0] = np.inf
b_parent2= np.argmin(pdist_neg)
pdist_neg[pdist_neg <= 0] = np.inf
b_parent2 = np.argmin(pdist_neg)
if b_parent1 > b_parent2:
tmp = b_parent2
@ -953,291 +882,3 @@ def get_closest_idx(
b_parent1 = tmp
return b_parent1, b_parent2
@torch.no_grad()
def interpolate_spherical(p0, p1, fract_mixing: float):
r"""
Helper function to correctly mix two random variables using spherical interpolation.
See https://en.wikipedia.org/wiki/Slerp
The function will always cast up to float64 for sake of extra 4.
Args:
p0:
First tensor for interpolation
p1:
Second tensor for interpolation
fract_mixing: float
Mixing coefficient of interval [0, 1].
0 will return in p0
1 will return in p1
0.x will return a mix between both preserving angular velocity.
"""
if p0.dtype == torch.float16:
recast_to = 'fp16'
else:
recast_to = 'fp32'
p0 = p0.double()
p1 = p1.double()
norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
epsilon = 1e-7
dot = torch.sum(p0 * p1) / norm
dot = dot.clamp(-1+epsilon, 1-epsilon)
theta_0 = torch.arccos(dot)
sin_theta_0 = torch.sin(theta_0)
theta_t = theta_0 * fract_mixing
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
s1 = torch.sin(theta_t) / sin_theta_0
interp = p0*s0 + p1*s1
if recast_to == 'fp16':
interp = interp.half()
elif recast_to == 'fp32':
interp = interp.float()
return interp
def interpolate_linear(p0, p1, fract_mixing):
r"""
Helper function to mix two variables using standard linear interpolation.
Args:
p0:
First tensor / np.ndarray for interpolation
p1:
Second tensor / np.ndarray for interpolation
fract_mixing: float
Mixing coefficient of interval [0, 1].
0 will return in p0
1 will return in p1
0.x will return a linear mix between both.
"""
reconvert_uint8 = False
if type(p0) is np.ndarray and p0.dtype == 'uint8':
reconvert_uint8 = True
p0 = p0.astype(np.float64)
if type(p1) is np.ndarray and p1.dtype == 'uint8':
reconvert_uint8 = True
p1 = p1.astype(np.float64)
interp = (1-fract_mixing) * p0 + fract_mixing * p1
if reconvert_uint8:
interp = np.clip(interp, 0, 255).astype(np.uint8)
return interp
def add_frames_linear_interp(
list_imgs: List[np.ndarray],
fps_target: Union[float, int] = None,
duration_target: Union[float, int] = None,
nmb_frames_target: int=None,
):
r"""
Helper function to cheaply increase the number of frames given a list of images,
by virtue of standard linear interpolation.
The number of inserted frames will be automatically adjusted so that the total of number
of frames can be fixed precisely, using a random shuffling technique.
The function allows 1:1 comparisons between transitions as videos.
Args:
list_imgs: List[np.ndarray)
List of images, between each image new frames will be inserted via linear interpolation.
fps_target:
OptionA: specify here the desired frames per second.
duration_target:
OptionA: specify here the desired duration of the transition in seconds.
nmb_frames_target:
OptionB: directly fix the total number of frames of the output.
"""
# Sanity
if nmb_frames_target is not None and fps_target is not None:
raise ValueError("You cannot specify both fps_target and nmb_frames_target")
if fps_target is None:
assert nmb_frames_target is not None, "Either specify nmb_frames_target or nmb_frames_target"
if nmb_frames_target is None:
assert fps_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
assert duration_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
nmb_frames_target = fps_target*duration_target
# Get number of frames that are missing
nmb_frames_diff = len(list_imgs)-1
nmb_frames_missing = nmb_frames_target - nmb_frames_diff - 1
if nmb_frames_missing < 1:
return list_imgs
list_imgs_float = [img.astype(np.float32) for img in list_imgs]
# Distribute missing frames, append nmb_frames_to_insert(i) frames for each frame
mean_nmb_frames_insert = nmb_frames_missing/nmb_frames_diff
constfact = np.floor(mean_nmb_frames_insert)
remainder_x = 1-(mean_nmb_frames_insert - constfact)
nmb_iter = 0
while True:
nmb_frames_to_insert = np.random.rand(nmb_frames_diff)
nmb_frames_to_insert[nmb_frames_to_insert<=remainder_x] = 0
nmb_frames_to_insert[nmb_frames_to_insert>remainder_x] = 1
nmb_frames_to_insert += constfact
if np.sum(nmb_frames_to_insert) == nmb_frames_missing:
break
nmb_iter += 1
if nmb_iter > 100000:
print("add_frames_linear_interp: issue with inserting the right number of frames")
break
nmb_frames_to_insert = nmb_frames_to_insert.astype(np.int32)
list_imgs_interp = []
for i in range(len(list_imgs_float)-1):#, desc="STAGE linear interp"):
img0 = list_imgs_float[i]
img1 = list_imgs_float[i+1]
list_imgs_interp.append(img0.astype(np.uint8))
list_fracts_linblend = np.linspace(0, 1, nmb_frames_to_insert[i]+2)[1:-1]
for fract_linblend in list_fracts_linblend:
img_blend = interpolate_linear(img0, img1, fract_linblend).astype(np.uint8)
list_imgs_interp.append(img_blend.astype(np.uint8))
if i==len(list_imgs_float)-2:
list_imgs_interp.append(img1.astype(np.uint8))
return list_imgs_interp
def get_spacing(nmb_points: int, scaling: float):
"""
Helper function for getting nonlinear spacing between 0 and 1, symmetric around 0.5
Args:
nmb_points: int
Number of points between [0, 1]
scaling: float
Higher values will return higher sampling density around 0.5
"""
if scaling < 1.7:
return np.linspace(0, 1, nmb_points)
nmb_points_per_side = nmb_points//2 + 1
if np.mod(nmb_points, 2) != 0: # uneven case
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)
right_side = 1-left_side[::-1][1:]
else:
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)[0:-1]
right_side = 1-left_side[::-1]
all_fracts = np.hstack([left_side, right_side])
return all_fracts
def get_time(resolution=None):
"""
Helper function returning an nicely formatted time string, e.g. 221117_1620
"""
if resolution==None:
resolution="second"
if resolution == "day":
t = time.strftime('%y%m%d', time.localtime())
elif resolution == "minute":
t = time.strftime('%y%m%d_%H%M', time.localtime())
elif resolution == "second":
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
elif resolution == "millisecond":
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
t += "_"
t += str("{:03d}".format(int(int(datetime.utcnow().strftime('%f'))/1000)))
else:
raise ValueError("bad resolution provided: %s" %resolution)
return t
def compare_dicts(a, b):
"""
Compares two dictionaries a and b and returns a dictionary c, with all
keys,values that have shared keys in a and b but same values in a and b.
The values of a and b are stacked together in the output.
Example:
a = {}; a['bobo'] = 4
b = {}; b['bobo'] = 5
c = dict_compare(a,b)
c = {"bobo",[4,5]}
"""
c = {}
for key in a.keys():
if key in b.keys():
val_a = a[key]
val_b = b[key]
if val_a != val_b:
c[key] = [val_a, val_b]
return c
def yml_load(fp_yml, print_fields=False):
"""
Helper function for loading yaml files
"""
with open(fp_yml) as f:
data = yaml.load(f, Loader=yaml.loader.SafeLoader)
dict_data = dict(data)
print("load: loaded {}".format(fp_yml))
return dict_data
def yml_save(fp_yml, dict_stuff):
"""
Helper function for saving yaml files
"""
with open(fp_yml, 'w') as f:
data = yaml.dump(dict_stuff, f, sort_keys=False, default_flow_style=False)
print("yml_save: saved {}".format(fp_yml))
#%% le main
if __name__ == "__main__":
# xxxx
#%% First let us spawn a stable diffusion holder
device = "cuda"
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
sdh = StableDiffusionHolder(fp_ckpt)
xxx
#%% Next let's set up all parameters
depth_strength = 0.3 # Specifies how deep (in terms of diffusion iterations the first branching happens)
fixed_seeds = [697164, 430214]
prompt1 = "photo of a desert and a sky"
prompt2 = "photo of a tree with a lake"
duration_transition = 12 # In seconds
fps = 30
# Spawn latent blending
self = LatentBlending(sdh)
self.set_prompt1(prompt1)
self.set_prompt2(prompt2)
# Run latent blending
self.branch1_crossfeed_power = 0.3
self.branch1_crossfeed_range = 0.4
# self.run_transition(depth_strength=depth_strength, fixed_seeds=fixed_seeds)
self.seed1=21312
img1 =self.compute_latents1(True)
#%
self.seed2=1234121
self.branch1_crossfeed_power = 0.7
self.branch1_crossfeed_range = 0.3
self.branch1_crossfeed_decay = 0.3
img2 =self.compute_latents2(True)
# Image.fromarray(np.concatenate((img1, img2), axis=1))
#%%
t0 = time.time()
self.t_compute_max_allowed = 30
self.parental_crossfeed_range = 1.0
self.parental_crossfeed_power = 0.0
self.parental_crossfeed_power_decay = 1.0
imgs_transition = self.run_transition(recycle_img1=True, recycle_img2=True)
t1 = time.time()
print(f"took: {t1-t0}s")

View File

@ -1,5 +1,6 @@
# Copyright 2022 Lunar Ring. All rights reserved.
#
# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
@ -17,10 +18,9 @@ import os
import numpy as np
from tqdm import tqdm
import cv2
from typing import Callable, List, Optional, Union
from typing import List
import ffmpeg # pip install ffmpeg-python. if error with broken pipe: conda update ffmpeg
#%%
class MovieSaver():
def __init__(
@ -30,10 +30,9 @@ class MovieSaver():
shape_hw: List[int] = None,
crf: int = 24,
codec: str = 'libx264',
preset: str ='fast',
preset: str = 'fast',
pix_fmt: str = 'yuv420p',
silent_ffmpeg: bool = True
):
silent_ffmpeg: bool = True):
r"""
Initializes movie saver class - a human friendly ffmpeg wrapper.
After you init the class, you can dump numpy arrays x into moviesaver.write_frame(x).
@ -92,10 +91,8 @@ class MovieSaver():
self.shape_hw = shape_hw
self.initialize()
print(f"MovieSaver initialized. fps={fps} crf={crf} pix_fmt={pix_fmt} codec={codec} preset={preset}")
def initialize(self):
args = (
ffmpeg
@ -112,7 +109,6 @@ class MovieSaver():
self.shape_hw = tuple(self.shape_hw)
print(f"Initialization done. Movie shape: {self.shape_hw}")
def write_frame(self, out_frame: np.ndarray):
r"""
Function to dump a numpy array as frame of a movie.
@ -123,7 +119,6 @@ class MovieSaver():
Dim 1: x
Dim 2: RGB
"""
assert out_frame.dtype == np.uint8, "Convert to np.uint8 before"
assert len(out_frame.shape) == 3, "out_frame needs to be three dimensional, Y X C"
assert out_frame.shape[2] == 3, f"need three color channels, but you provided {out_frame.shape[2]}."
@ -143,7 +138,6 @@ class MovieSaver():
self.nmb_frames += 1
def finalize(self):
r"""
Call this function to finalize the movie. If you forget to call it your movie will be garbage.
@ -157,7 +151,6 @@ class MovieSaver():
print(f"Movie saved, {duration}s playtime, watch here: \n{self.fp_out}")
def concatenate_movies(fp_final: str, list_fp_movies: List[str]):
r"""
Concatenate multiple movie segments into one long movie, using ffmpeg.
@ -189,7 +182,6 @@ def concatenate_movies(fp_final: str, list_fp_movies: List[str]):
fa.write("%s\n" % item)
cmd = f'ffmpeg -f concat -safe 0 -i {fp_list} -c copy {fp_final}'
dp_movie = os.path.split(fp_final)[0]
subprocess.call(cmd, shell=True)
os.remove(fp_list)
if os.path.isfile(fp_final):
@ -200,11 +192,12 @@ class MovieReader():
r"""
Class to read in a movie.
"""
def __init__(self, fp_movie):
self.video_player_object = cv2.VideoCapture(fp_movie)
self.nmb_frames = int(self.video_player_object.get(cv2.CAP_PROP_FRAME_COUNT))
self.fps_movie = int(self.video_player_object.get(cv2.CAP_PROP_FPS))
self.shape = [100,100,3]
self.shape = [100, 100, 3]
self.shape_is_set = False
def get_next_frame(self):
@ -217,19 +210,18 @@ class MovieReader():
else:
return np.zeros(self.shape)
#%%
if __name__ == "__main__":
fps=2
fps = 2
list_fp_movies = []
for k in range(4):
fp_movie = f"/tmp/my_random_movie_{k}.mp4"
list_fp_movies.append(fp_movie)
ms = MovieSaver(fp_movie, fps=fps)
for fn in tqdm(range(30)):
img = (np.random.rand(512, 1024, 3)*255).astype(np.uint8)
img = (np.random.rand(512, 1024, 3) * 255).astype(np.uint8)
ms.write_frame(img)
ms.finalize()
fp_final = "/tmp/my_concatenated_movie.mp4"
concatenate_movies(fp_final, list_fp_movies)

View File

@ -13,36 +13,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os, sys
dp_git = "/home/lugo/git/"
sys.path.append(os.path.join(dp_git,'garden4'))
sys.path.append('util')
import os
import torch
torch.backends.cudnn.benchmark = False
torch.set_grad_enabled(False)
import numpy as np
import warnings
warnings.filterwarnings('ignore')
import time
import subprocess
import warnings
import torch
from tqdm.auto import tqdm
from PIL import Image
# import matplotlib.pyplot as plt
import torch
from movie_util import MovieSaver
import datetime
from typing import Callable, List, Optional, Union
import inspect
from threading import Thread
torch.set_grad_enabled(False)
from typing import Optional
from omegaconf import OmegaConf
from torch import autocast
from contextlib import nullcontext
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from einops import repeat, rearrange
#%%
from utils import interpolate_spherical
def pad_image(input_image):
@ -53,41 +42,11 @@ def pad_image(input_image):
return im_padded
def make_batch_inpaint(
image,
mask,
txt,
device,
num_samples=1):
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
batch = {
"image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
"txt": num_samples * [txt],
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
"masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
}
return batch
def make_batch_superres(
image,
txt,
device,
num_samples=1,
):
num_samples=1):
image = np.array(image.convert("RGB"))
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
batch = {
@ -114,7 +73,7 @@ class StableDiffusionHolder:
height: Optional[int] = None,
width: Optional[int] = None,
device: str = None,
precision: str='autocast',
precision: str = 'autocast',
):
r"""
Initializes the stable diffusion holder, which contains the models and sampler.
@ -137,7 +96,7 @@ class StableDiffusionHolder:
self.precision = precision
self.init_model(fp_ckpt, fp_config)
self.f = 8 #downsampling factor, most often 8 or 16",
self.f = 8 # downsampling factor, most often 8 or 16"
self.C = 4
self.ddim_eta = 0
self.num_inference_steps = num_inference_steps
@ -150,13 +109,8 @@ class StableDiffusionHolder:
self.height = height
self.width = width
# Inpainting inits
self.mask_empty = Image.fromarray(255*np.ones([self.width, self.height], dtype=np.uint8))
self.image_empty = Image.fromarray(np.zeros([self.width, self.height, 3], dtype=np.uint8))
self.negative_prompt = [""]
def init_model(self, fp_ckpt, fp_config):
r"""Loads the models and sampler.
"""
@ -169,13 +123,11 @@ class StableDiffusionHolder:
fn_ckpt = os.path.basename(fp_ckpt)
if 'depth' in fn_ckpt:
fp_config = 'configs/v2-midas-inference.yaml'
elif 'inpain' in fn_ckpt:
fp_config = 'configs/v2-inpainting-inference.yaml'
elif 'upscaler' in fn_ckpt:
fp_config = 'configs/x4-upscaling.yaml'
elif '512' in fn_ckpt:
fp_config = 'configs/v2-inference.yaml'
elif '768'in fn_ckpt:
elif '768' in fn_ckpt:
fp_config = 'configs/v2-inference-v.yaml'
elif 'v1-5' in fn_ckpt:
fp_config = 'configs/v1-inference.yaml'
@ -186,7 +138,6 @@ class StableDiffusionHolder:
assert os.path.isfile(fp_config), f"Your config file does not exist: {fp_config}"
config = OmegaConf.load(fp_config)
self.model = instantiate_from_config(config.model)
@ -195,7 +146,6 @@ class StableDiffusionHolder:
self.model = self.model.to(self.device)
self.sampler = DDIMSampler(self.model)
def init_auto_res(self):
r"""Automatically set the resolution to the one used in training.
"""
@ -218,7 +168,6 @@ class StableDiffusionHolder:
if len(self.negative_prompt) > 1:
self.negative_prompt = [self.negative_prompt[0]]
def get_text_embedding(self, prompt):
c = self.model.get_learned_conditioning(prompt)
return c
@ -228,7 +177,6 @@ class StableDiffusionHolder:
r"""
Initializes the conditioning for the x4 upscaling model.
"""
image = pad_image(image) # resize to integer multiple of 32
w, h = image.size
noise_level = torch.Tensor(1 * [noise_level]).to(self.sampler.model.device).long()
@ -240,7 +188,6 @@ class StableDiffusionHolder:
# uncond cond
uc_cross = self.model.get_unconditional_conditioning(1, "")
uc_full = {"c_concat": [x_augment], "c_crossattn": [uc_cross], "c_adm": noise_level}
return cond, uc_full
@torch.no_grad()
@ -249,14 +196,12 @@ class StableDiffusionHolder:
text_embeddings: torch.FloatTensor,
latents_start: torch.FloatTensor,
idx_start: int = 0,
list_latents_mixing = None,
mixing_coeffs = 0.0,
spatial_mask = None,
return_image: Optional[bool] = False,
):
list_latents_mixing=None,
mixing_coeffs=0.0,
spatial_mask=None,
return_image: Optional[bool] = False):
r"""
Diffusion standard version.
Args:
text_embeddings: torch.FloatTensor
Text embeddings used for diffusion
@ -270,12 +215,10 @@ class StableDiffusionHolder:
experimental feature for enforcing pixels from list_latents_mixing
return_image: Optional[bool]
Optionally return image directly
"""
# Asserts
if type(mixing_coeffs) == float:
list_mixing_coeffs = self.num_inference_steps*[mixing_coeffs]
list_mixing_coeffs = self.num_inference_steps * [mixing_coeffs]
elif type(mixing_coeffs) == list:
assert len(mixing_coeffs) == self.num_inference_steps
list_mixing_coeffs = mixing_coeffs
@ -285,26 +228,19 @@ class StableDiffusionHolder:
if np.sum(list_mixing_coeffs) > 0:
assert len(list_latents_mixing) == self.num_inference_steps
precision_scope = autocast if self.precision == "autocast" else nullcontext
with precision_scope("cuda"):
with self.model.ema_scope():
if self.guidance_scale != 1.0:
uc = self.model.get_learned_conditioning(self.negative_prompt)
else:
uc = None
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps - 1, ddim_eta=self.ddim_eta, verbose=False)
latents = latents_start.clone()
timesteps = self.sampler.ddim_timesteps
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
# collect latents
# Collect latents
list_latents_out = []
for i, step in enumerate(time_range):
# Set the right starting latents
@ -313,15 +249,13 @@ class StableDiffusionHolder:
continue
elif i == idx_start:
latents = latents_start.clone()
# Mix the latents.
if i > 0 and list_mixing_coeffs[i]>0:
latents_mixtarget = list_latents_mixing[i-1].clone()
# Mix latents
if i > 0 and list_mixing_coeffs[i] > 0:
latents_mixtarget = list_latents_mixing[i - 1].clone()
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
if spatial_mask is not None and list_latents_mixing is not None:
latents = interpolate_spherical(latents, list_latents_mixing[i-1], 1-spatial_mask)
# latents[:,:,-15:,:] = latents_mixtarget[:,:,-15:,:]
latents = interpolate_spherical(latents, list_latents_mixing[i - 1], 1 - spatial_mask)
index = total_steps - i - 1
ts = torch.full((1,), step, device=self.device, dtype=torch.long)
@ -334,13 +268,11 @@ class StableDiffusionHolder:
dynamic_threshold=None)
latents, pred_x0 = outs
list_latents_out.append(latents.clone())
if return_image:
return self.latent2image(latents)
else:
return list_latents_out
@torch.no_grad()
def run_diffusion_upscaling(
self,
@ -348,17 +280,16 @@ class StableDiffusionHolder:
uc_full,
latents_start: torch.FloatTensor,
idx_start: int = -1,
list_latents_mixing = None,
mixing_coeffs = 0.0,
return_image: Optional[bool] = False
):
list_latents_mixing: list = None,
mixing_coeffs: float = 0.0,
return_image: Optional[bool] = False):
r"""
Diffusion upscaling version.
"""
# Asserts
if type(mixing_coeffs) == float:
list_mixing_coeffs = self.num_inference_steps*[mixing_coeffs]
list_mixing_coeffs = self.num_inference_steps * [mixing_coeffs]
elif type(mixing_coeffs) == list:
assert len(mixing_coeffs) == self.num_inference_steps
list_mixing_coeffs = mixing_coeffs
@ -369,27 +300,20 @@ class StableDiffusionHolder:
assert len(list_latents_mixing) == self.num_inference_steps
precision_scope = autocast if self.precision == "autocast" else nullcontext
h = uc_full['c_concat'][0].shape[2]
w = uc_full['c_concat'][0].shape[3]
with precision_scope("cuda"):
with self.model.ema_scope():
shape_latents = [self.model.channels, h, w]
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps - 1, ddim_eta=self.ddim_eta, verbose=False)
C, H, W = shape_latents
size = (1, C, H, W)
b = size[0]
latents = latents_start.clone()
timesteps = self.sampler.ddim_timesteps
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
# collect latents
list_latents_out = []
for i, step in enumerate(time_range):
@ -399,12 +323,10 @@ class StableDiffusionHolder:
continue
elif i == idx_start:
latents = latents_start.clone()
# Mix the latents.
if i > 0 and list_mixing_coeffs[i]>0:
latents_mixtarget = list_latents_mixing[i-1].clone()
if i > 0 and list_mixing_coeffs[i] > 0:
latents_mixtarget = list_latents_mixing[i - 1].clone()
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
# print(f"diffusion iter {i}")
index = total_steps - i - 1
ts = torch.full((b,), step, device=self.device, dtype=torch.long)
@ -423,121 +345,10 @@ class StableDiffusionHolder:
else:
return list_latents_out
@torch.no_grad()
def run_diffusion_inpaint(
self,
text_embeddings: torch.FloatTensor,
latents_for_injection: torch.FloatTensor = None,
idx_start: int = -1,
idx_stop: int = -1,
return_image: Optional[bool] = False
):
r"""
Runs inpaint-based diffusion. Returns a list of latents that were computed.
Adaptations allow to supply
a) starting index for diffusion
b) stopping index for diffusion
c) latent representations that are injected at the starting index
Furthermore the intermittent latents are collected and returned.
Adapted from diffusers (https://github.com/huggingface/diffusers)
Args:
text_embeddings: torch.FloatTensor
Text embeddings used for diffusion
latents_for_injection: torch.FloatTensor
Latents that are used for injection
idx_start: int
Index of the diffusion process start and where the latents_for_injection are injected
idx_stop: int
Index of the diffusion process end.
return_image: Optional[bool]
Optionally return image directly
"""
if latents_for_injection is None:
do_inject_latents = False
else:
do_inject_latents = True
precision_scope = autocast if self.precision == "autocast" else nullcontext
generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
with precision_scope("cuda"):
with self.model.ema_scope():
batch = make_batch_inpaint(self.image_source, self.mask_image, txt="willbereplaced", device=self.device, num_samples=1)
c = text_embeddings
c_cat = list()
for ck in self.model.concat_keys:
cc = batch[ck].float()
if ck != self.model.masked_image_key:
bchw = [1, 4, self.height // 8, self.width // 8]
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else:
cc = self.model.get_first_stage_encoding(self.model.encode_first_stage(cc))
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
# cond
cond = {"c_concat": [c_cat], "c_crossattn": [c]}
# uncond cond
uc_cross = self.model.get_unconditional_conditioning(1, "")
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
shape_latents = [self.model.channels, self.height // 8, self.width // 8]
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=0., verbose=False)
# sampling
C, H, W = shape_latents
size = (1, C, H, W)
device = self.model.betas.device
b = size[0]
latents = torch.randn(size, generator=generator, device=device)
timesteps = self.sampler.ddim_timesteps
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
# collect latents
list_latents_out = []
for i, step in enumerate(time_range):
if do_inject_latents:
# Inject latent at right place
if i < idx_start:
continue
elif i == idx_start:
latents = latents_for_injection.clone()
if i == idx_stop:
return list_latents_out
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
quantize_denoised=False, temperature=1.0,
noise_dropout=0.0, score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=self.guidance_scale,
unconditional_conditioning=uc_full,
dynamic_threshold=None)
latents, pred_x0 = outs
list_latents_out.append(latents.clone())
if return_image:
return self.latent2image(latents)
else:
return list_latents_out
@torch.no_grad()
def latent2image(
self,
latents: torch.FloatTensor
):
latents: torch.FloatTensor):
r"""
Returns an image provided a latent representation from diffusion.
Args:
@ -546,85 +357,6 @@ class StableDiffusionHolder:
"""
x_sample = self.model.decode_first_stage(latents)
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255 * x_sample[0,:,:].permute([1,2,0]).cpu().numpy()
x_sample = 255 * x_sample[0, :, :].permute([1, 2, 0]).cpu().numpy()
image = x_sample.astype(np.uint8)
return image
@torch.no_grad()
def interpolate_spherical(p0, p1, fract_mixing: float):
r"""
Helper function to correctly mix two random variables using spherical interpolation.
See https://en.wikipedia.org/wiki/Slerp
The function will always cast up to float64 for sake of extra 4.
Args:
p0:
First tensor for interpolation
p1:
Second tensor for interpolation
fract_mixing: float
Mixing coefficient of interval [0, 1].
0 will return in p0
1 will return in p1
0.x will return a mix between both preserving angular velocity.
"""
if p0.dtype == torch.float16:
recast_to = 'fp16'
else:
recast_to = 'fp32'
p0 = p0.double()
p1 = p1.double()
norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
epsilon = 1e-7
dot = torch.sum(p0 * p1) / norm
dot = dot.clamp(-1+epsilon, 1-epsilon)
theta_0 = torch.arccos(dot)
sin_theta_0 = torch.sin(theta_0)
theta_t = theta_0 * fract_mixing
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
s1 = torch.sin(theta_t) / sin_theta_0
interp = p0*s0 + p1*s1
if recast_to == 'fp16':
interp = interp.half()
elif recast_to == 'fp32':
interp = interp.float()
return interp
if __name__ == "__main__":
num_inference_steps = 20 # Number of diffusion interations
# fp_ckpt = "../stable_diffusion_models/ckpt/768-v-ema.ckpt"
# fp_config = '../stablediffusion/configs/stable-diffusion/v2-inference-v.yaml'
# fp_ckpt= "../stable_diffusion_models/ckpt/512-inpainting-ema.ckpt"
# fp_config = '../stablediffusion/configs//stable-diffusion/v2-inpainting-inference.yaml'
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
# fp_config = 'configs/v2-inference-v.yaml'
self = StableDiffusionHolder(fp_ckpt, num_inference_steps=num_inference_steps)
xxx
#%%
self.width = 1536
self.height = 768
prompt = "360 degree equirectangular, a huge rocky hill full of pianos and keyboards, musical instruments, cinematic, masterpiece 8 k, artstation"
self.set_negative_prompt("out of frame, faces, rendering, blurry")
te = self.get_text_embedding(prompt)
img = self.run_diffusion_standard(te, return_image=True)
Image.fromarray(img).show()

260
utils.py Normal file
View File

@ -0,0 +1,260 @@
# Copyright 2022 Lunar Ring. All rights reserved.
# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
torch.backends.cudnn.benchmark = False
import numpy as np
import warnings
warnings.filterwarnings('ignore')
import time
import warnings
import datetime
from typing import List, Union
torch.set_grad_enabled(False)
import yaml
@torch.no_grad()
def interpolate_spherical(p0, p1, fract_mixing: float):
r"""
Helper function to correctly mix two random variables using spherical interpolation.
See https://en.wikipedia.org/wiki/Slerp
The function will always cast up to float64 for sake of extra 4.
Args:
p0:
First tensor for interpolation
p1:
Second tensor for interpolation
fract_mixing: float
Mixing coefficient of interval [0, 1].
0 will return in p0
1 will return in p1
0.x will return a mix between both preserving angular velocity.
"""
if p0.dtype == torch.float16:
recast_to = 'fp16'
else:
recast_to = 'fp32'
p0 = p0.double()
p1 = p1.double()
norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
epsilon = 1e-7
dot = torch.sum(p0 * p1) / norm
dot = dot.clamp(-1 + epsilon, 1 - epsilon)
theta_0 = torch.arccos(dot)
sin_theta_0 = torch.sin(theta_0)
theta_t = theta_0 * fract_mixing
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
s1 = torch.sin(theta_t) / sin_theta_0
interp = p0 * s0 + p1 * s1
if recast_to == 'fp16':
interp = interp.half()
elif recast_to == 'fp32':
interp = interp.float()
return interp
def interpolate_linear(p0, p1, fract_mixing):
r"""
Helper function to mix two variables using standard linear interpolation.
Args:
p0:
First tensor / np.ndarray for interpolation
p1:
Second tensor / np.ndarray for interpolation
fract_mixing: float
Mixing coefficient of interval [0, 1].
0 will return in p0
1 will return in p1
0.x will return a linear mix between both.
"""
reconvert_uint8 = False
if type(p0) is np.ndarray and p0.dtype == 'uint8':
reconvert_uint8 = True
p0 = p0.astype(np.float64)
if type(p1) is np.ndarray and p1.dtype == 'uint8':
reconvert_uint8 = True
p1 = p1.astype(np.float64)
interp = (1 - fract_mixing) * p0 + fract_mixing * p1
if reconvert_uint8:
interp = np.clip(interp, 0, 255).astype(np.uint8)
return interp
def add_frames_linear_interp(
list_imgs: List[np.ndarray],
fps_target: Union[float, int] = None,
duration_target: Union[float, int] = None,
nmb_frames_target: int = None):
r"""
Helper function to cheaply increase the number of frames given a list of images,
by virtue of standard linear interpolation.
The number of inserted frames will be automatically adjusted so that the total of number
of frames can be fixed precisely, using a random shuffling technique.
The function allows 1:1 comparisons between transitions as videos.
Args:
list_imgs: List[np.ndarray)
List of images, between each image new frames will be inserted via linear interpolation.
fps_target:
OptionA: specify here the desired frames per second.
duration_target:
OptionA: specify here the desired duration of the transition in seconds.
nmb_frames_target:
OptionB: directly fix the total number of frames of the output.
"""
# Sanity
if nmb_frames_target is not None and fps_target is not None:
raise ValueError("You cannot specify both fps_target and nmb_frames_target")
if fps_target is None:
assert nmb_frames_target is not None, "Either specify nmb_frames_target or nmb_frames_target"
if nmb_frames_target is None:
assert fps_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
assert duration_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
nmb_frames_target = fps_target * duration_target
# Get number of frames that are missing
nmb_frames_diff = len(list_imgs) - 1
nmb_frames_missing = nmb_frames_target - nmb_frames_diff - 1
if nmb_frames_missing < 1:
return list_imgs
list_imgs_float = [img.astype(np.float32) for img in list_imgs]
# Distribute missing frames, append nmb_frames_to_insert(i) frames for each frame
mean_nmb_frames_insert = nmb_frames_missing / nmb_frames_diff
constfact = np.floor(mean_nmb_frames_insert)
remainder_x = 1 - (mean_nmb_frames_insert - constfact)
nmb_iter = 0
while True:
nmb_frames_to_insert = np.random.rand(nmb_frames_diff)
nmb_frames_to_insert[nmb_frames_to_insert <= remainder_x] = 0
nmb_frames_to_insert[nmb_frames_to_insert > remainder_x] = 1
nmb_frames_to_insert += constfact
if np.sum(nmb_frames_to_insert) == nmb_frames_missing:
break
nmb_iter += 1
if nmb_iter > 100000:
print("add_frames_linear_interp: issue with inserting the right number of frames")
break
nmb_frames_to_insert = nmb_frames_to_insert.astype(np.int32)
list_imgs_interp = []
for i in range(len(list_imgs_float) - 1):
img0 = list_imgs_float[i]
img1 = list_imgs_float[i + 1]
list_imgs_interp.append(img0.astype(np.uint8))
list_fracts_linblend = np.linspace(0, 1, nmb_frames_to_insert[i] + 2)[1:-1]
for fract_linblend in list_fracts_linblend:
img_blend = interpolate_linear(img0, img1, fract_linblend).astype(np.uint8)
list_imgs_interp.append(img_blend.astype(np.uint8))
if i == len(list_imgs_float) - 2:
list_imgs_interp.append(img1.astype(np.uint8))
return list_imgs_interp
def get_spacing(nmb_points: int, scaling: float):
"""
Helper function for getting nonlinear spacing between 0 and 1, symmetric around 0.5
Args:
nmb_points: int
Number of points between [0, 1]
scaling: float
Higher values will return higher sampling density around 0.5
"""
if scaling < 1.7:
return np.linspace(0, 1, nmb_points)
nmb_points_per_side = nmb_points // 2 + 1
if np.mod(nmb_points, 2) != 0: # Uneven case
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)
right_side = 1 - left_side[::-1][1:]
else:
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)[0:-1]
right_side = 1 - left_side[::-1]
all_fracts = np.hstack([left_side, right_side])
return all_fracts
def get_time(resolution=None):
"""
Helper function returning an nicely formatted time string, e.g. 221117_1620
"""
if resolution is None:
resolution = "second"
if resolution == "day":
t = time.strftime('%y%m%d', time.localtime())
elif resolution == "minute":
t = time.strftime('%y%m%d_%H%M', time.localtime())
elif resolution == "second":
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
elif resolution == "millisecond":
t = time.strftime('%y%m%d_%H%M%S', time.localtime())
t += "_"
t += str("{:03d}".format(int(int(datetime.utcnow().strftime('%f')) / 1000)))
else:
raise ValueError("bad resolution provided: %s" % resolution)
return t
def compare_dicts(a, b):
"""
Compares two dictionaries a and b and returns a dictionary c, with all
keys,values that have shared keys in a and b but same values in a and b.
The values of a and b are stacked together in the output.
Example:
a = {}; a['bobo'] = 4
b = {}; b['bobo'] = 5
c = dict_compare(a,b)
c = {"bobo",[4,5]}
"""
c = {}
for key in a.keys():
if key in b.keys():
val_a = a[key]
val_b = b[key]
if val_a != val_b:
c[key] = [val_a, val_b]
return c
def yml_load(fp_yml, print_fields=False):
"""
Helper function for loading yaml files
"""
with open(fp_yml) as f:
data = yaml.load(f, Loader=yaml.loader.SafeLoader)
dict_data = dict(data)
print("load: loaded {}".format(fp_yml))
return dict_data
def yml_save(fp_yml, dict_stuff):
"""
Helper function for saving yaml files
"""
with open(fp_yml, 'w') as f:
yaml.dump(dict_stuff, f, sort_keys=False, default_flow_style=False)
print("yml_save: saved {}".format(fp_yml))