stabel2.0 for mulit trans

This commit is contained in:
lugo 2022-11-28 08:46:21 +01:00
parent 578385e411
commit 6f977ebb7e
2 changed files with 25 additions and 29 deletions

View File

@ -21,34 +21,23 @@ warnings.filterwarnings('ignore')
import warnings import warnings
import torch import torch
from tqdm.auto import tqdm from tqdm.auto import tqdm
from diffusers import StableDiffusionPipeline
from diffusers.schedulers import DDIMScheduler
from PIL import Image from PIL import Image
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch import torch
from movie_util import MovieSaver from movie_util import MovieSaver
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
from latent_blending import LatentBlending, add_frames_linear_interp from latent_blending import LatentBlending, add_frames_linear_interp
from stable_diffusion_holder import StableDiffusionHolder
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
#%% First let us spawn a diffusers pipe using DDIMScheduler #%% First let us spawn a stable diffusion holder
device = "cuda:0" device = "cuda:0"
model_path = "../stable_diffusion_models/stable-diffusion-v1-5" num_inference_steps = 20 # Number of diffusion interations
fp_ckpt = "../stable_diffusion_models/ckpt/768-v-ema.ckpt"
fp_config = '../stablediffusion/configs/stable-diffusion/v2-inference-v.yaml'
sdh = StableDiffusionHolder(fp_ckpt, fp_config, device, num_inference_steps=num_inference_steps)
scheduler = DDIMScheduler(beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False)
pipe = StableDiffusionPipeline.from_pretrained(
model_path,
revision="fp16",
torch_dtype=torch.float16,
scheduler=scheduler,
use_auth_token=True
)
pipe = pipe.to(device)
#%% MULTITRANS #%% MULTITRANS
@ -57,18 +46,21 @@ list_nmb_branches = [2, 10, 50, 100, 200] #
list_injection_strength = list(np.linspace(0.5, 0.95, 4)) # Branching structure: how deep is the blending list_injection_strength = list(np.linspace(0.5, 0.95, 4)) # Branching structure: how deep is the blending
list_injection_strength.insert(0, 0.0) list_injection_strength.insert(0, 0.0)
width = 512
height = 512
guidance_scale = 5 guidance_scale = 5
fps = 30 fps = 30
duration_single_trans = 20 duration_single_trans = 20
width = 512 width = 768
height = 512 height = 768
lb = LatentBlending(pipe, device, height, width, num_inference_steps, guidance_scale) lb = LatentBlending(sdh, num_inference_steps, guidance_scale)
# deepth_strength = 0.5
# num_inference_steps, list_injection_idx, list_nmb_branches = lb.get_branching('medium', deepth_strength, fps*duration_single_trans)
#list_nmb_branches = [2, 3, 10, 24] # Branching structure: how many branches
#list_injection_strength = [0.0, 0.6, 0.8, 0.9] #
list_prompts = [] list_prompts = []
list_prompts.append("surrealistic statue made of glitter and dirt, standing in a lake, atmospheric light, strange glow") list_prompts.append("surrealistic statue made of glitter and dirt, standing in a lake, atmospheric light, strange glow")
@ -82,13 +74,14 @@ list_prompts.append("statue of an ancient cybernetic messenger annoucing good ne
list_seeds = [234187386, 422209351, 241845736, 28652396, 783279867, 831049796, 234903931] list_seeds = [234187386, 422209351, 241845736, 28652396, 783279867, 831049796, 234903931]
fp_movie = "/home/lugo/tmp/latentblending/bubua.mp4" fp_movie = "movie_example3.mp4"
ms = MovieSaver(fp_movie, fps=fps) ms = MovieSaver(fp_movie, fps=fps)
lb.run_multi_transition( lb.run_multi_transition(
list_prompts, list_prompts,
list_seeds, list_seeds,
list_nmb_branches, list_nmb_branches,
# list_injection_idx=list_injection_idx,
list_injection_strength=list_injection_strength, list_injection_strength=list_injection_strength,
ms=ms, ms=ms,
fps=fps, fps=fps,

View File

@ -764,12 +764,15 @@ def get_branching(
idx_last_check +=1 idx_last_check +=1
list_injection_idx_clean = [int(l) for l in list_injection_idx_clean] list_injection_idx_clean = [int(l) for l in list_injection_idx_clean]
list_nmb_branches_clean = [int(l) for l in list_nmb_branches_clean] list_nmb_branches_clean = [int(l) for l in list_nmb_branches_clean]
list_injection_idx = list_injection_idx_clean
list_nmb_branches = list_nmb_branches_clean
print(f"num_inference_steps: {num_inference_steps}") print(f"num_inference_steps: {num_inference_steps}")
print(f"list_injection_idx: {list_injection_idx_clean}") print(f"list_injection_idx: {list_injection_idx}")
print(f"list_nmb_branches: {list_nmb_branches_clean}") print(f"list_nmb_branches: {list_nmb_branches}")
return num_inference_steps, list_injection_idx_clean, list_nmb_branches_clean return num_inference_steps, list_injection_idx, list_nmb_branches