small fixes
This commit is contained in:
parent
bb573b2f9e
commit
3c6015782f
386
gradio_ui.py
386
gradio_ui.py
|
@ -32,23 +32,35 @@ torch.set_grad_enabled(False)
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import copy
|
import copy
|
||||||
from dotenv import find_dotenv, load_dotenv
|
from dotenv import find_dotenv, load_dotenv
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
"""
|
||||||
|
never hit compute trans -> multi movie add fail
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
#%%
|
#%%
|
||||||
|
|
||||||
class BlendingFrontend():
|
class BlendingFrontend():
|
||||||
def __init__(self, sdh=None):
|
def __init__(self, sdh=None):
|
||||||
|
self.num_inference_steps = 30
|
||||||
if sdh is None:
|
if sdh is None:
|
||||||
self.use_debug = True
|
self.use_debug = True
|
||||||
|
self.height = 768
|
||||||
|
self.width = 768
|
||||||
else:
|
else:
|
||||||
self.use_debug = False
|
self.use_debug = False
|
||||||
self.lb = LatentBlending(sdh)
|
self.lb = LatentBlending(sdh)
|
||||||
|
self.lb.sdh.num_inference_steps = self.num_inference_steps
|
||||||
|
self.height = self.lb.sdh.height
|
||||||
|
self.width = self.lb.sdh.width
|
||||||
|
|
||||||
self.share = True
|
self.init_save_dir()
|
||||||
self.num_inference_steps = 30
|
self.save_empty_image()
|
||||||
|
self.share = False
|
||||||
self.depth_strength = 0.25
|
self.depth_strength = 0.25
|
||||||
self.seed1 = 42
|
self.seed1 = 420
|
||||||
self.seed2 = 420
|
self.seed2 = 420
|
||||||
self.guidance_scale = 4.0
|
self.guidance_scale = 4.0
|
||||||
self.guidance_scale_mid_damper = 0.5
|
self.guidance_scale_mid_damper = 0.5
|
||||||
|
@ -72,16 +84,13 @@ class BlendingFrontend():
|
||||||
self.current_timestamp = None
|
self.current_timestamp = None
|
||||||
self.recycle_img1 = False
|
self.recycle_img1 = False
|
||||||
self.recycle_img2 = False
|
self.recycle_img2 = False
|
||||||
|
self.fp_img1 = None
|
||||||
|
self.fp_img2 = None
|
||||||
|
self.multi_idx_current = -1
|
||||||
|
self.multi_list_concat = []
|
||||||
|
self.list_imgs_shown_last = 5*[self.fp_img_empty]
|
||||||
|
self.nmb_trans_stack = 6
|
||||||
|
|
||||||
if not self.use_debug:
|
|
||||||
self.lb.sdh.num_inference_steps = self.num_inference_steps
|
|
||||||
self.height = self.lb.sdh.height
|
|
||||||
self.width = self.lb.sdh.width
|
|
||||||
else:
|
|
||||||
self.height = 768
|
|
||||||
self.width = 768
|
|
||||||
|
|
||||||
self.init_save_dir()
|
|
||||||
|
|
||||||
|
|
||||||
def init_save_dir(self):
|
def init_save_dir(self):
|
||||||
|
@ -89,12 +98,18 @@ class BlendingFrontend():
|
||||||
try:
|
try:
|
||||||
self.dp_out = os.getenv("dp_out")
|
self.dp_out = os.getenv("dp_out")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"did not find .env file. using local folder. {e}")
|
||||||
self.dp_out = ""
|
self.dp_out = ""
|
||||||
|
self.dp_imgs = os.path.join(self.dp_out, "imgs")
|
||||||
|
os.makedirs(self.dp_imgs, exist_ok=True)
|
||||||
|
self.dp_movies = os.path.join(self.dp_out, "movies")
|
||||||
|
os.makedirs(self.dp_movies, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# make dummy image
|
# make dummy image
|
||||||
def save_empty_image(self):
|
def save_empty_image(self):
|
||||||
self.fp_img_empty = os.path.join(self.dp_out, 'empty.jpg')
|
self.fp_img_empty = os.path.join(self.dp_imgs, 'empty.jpg')
|
||||||
Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5)
|
Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5)
|
||||||
|
|
||||||
|
|
||||||
|
@ -140,22 +155,21 @@ class BlendingFrontend():
|
||||||
def compute_img1(self, *args):
|
def compute_img1(self, *args):
|
||||||
list_ui_elem = args
|
list_ui_elem = args
|
||||||
self.setup_lb(list_ui_elem)
|
self.setup_lb(list_ui_elem)
|
||||||
fp_img1 = os.path.join(self.dp_out, f"img1_{get_time('second')}.jpg")
|
self.fp_img1 = os.path.join(self.dp_imgs, f"img1_{get_time('second')}.jpg")
|
||||||
img1 = Image.fromarray(self.lb.compute_latents1(return_image=True))
|
img1 = Image.fromarray(self.lb.compute_latents1(return_image=True))
|
||||||
img1.save(fp_img1)
|
img1.save(self.fp_img1)
|
||||||
self.save_empty_image()
|
|
||||||
self.recycle_img1 = True
|
self.recycle_img1 = True
|
||||||
self.recycle_img2 = False
|
self.recycle_img2 = False
|
||||||
return [fp_img1, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty]
|
return [self.fp_img1, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty]
|
||||||
|
|
||||||
def compute_img2(self, *args):
|
def compute_img2(self, *args):
|
||||||
list_ui_elem = args
|
list_ui_elem = args
|
||||||
self.setup_lb(list_ui_elem)
|
self.setup_lb(list_ui_elem)
|
||||||
fp_img2 = os.path.join(self.dp_out, f"img2_{get_time('second')}.jpg")
|
self.fp_img2 = os.path.join(self.dp_imgs, f"img2_{get_time('second')}.jpg")
|
||||||
img2 = Image.fromarray(self.lb.compute_latents2(return_image=True))
|
img2 = Image.fromarray(self.lb.compute_latents2(return_image=True))
|
||||||
img2.save(fp_img2)
|
img2.save(self.fp_img2)
|
||||||
self.recycle_img2 = True
|
self.recycle_img2 = True
|
||||||
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2]
|
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img2]
|
||||||
|
|
||||||
def compute_transition(self, *args):
|
def compute_transition(self, *args):
|
||||||
|
|
||||||
|
@ -199,7 +213,7 @@ class BlendingFrontend():
|
||||||
self.current_timestamp = get_time('second')
|
self.current_timestamp = get_time('second')
|
||||||
self.list_fp_imgs_current = []
|
self.list_fp_imgs_current = []
|
||||||
for i in range(len(list_imgs_preview)):
|
for i in range(len(list_imgs_preview)):
|
||||||
fp_img = f"img_preview_{i}_{self.current_timestamp}.jpg"
|
fp_img = os.path.join(self.dp_imgs, f"img_preview_{i}_{self.current_timestamp}.jpg")
|
||||||
list_imgs_preview[i].save(fp_img)
|
list_imgs_preview[i].save(fp_img)
|
||||||
self.list_fp_imgs_current.append(fp_img)
|
self.list_fp_imgs_current.append(fp_img)
|
||||||
|
|
||||||
|
@ -207,52 +221,39 @@ class BlendingFrontend():
|
||||||
imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration_video, self.fps)
|
imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration_video, self.fps)
|
||||||
|
|
||||||
# Save as movie
|
# Save as movie
|
||||||
fp_movie = self.get_fp_movie(self.current_timestamp)
|
self.fp_movie = os.path.join(self.dp_movies, f"movie_{self.current_timestamp}.mp4")
|
||||||
if os.path.isfile(fp_movie):
|
if os.path.isfile(self.fp_movie):
|
||||||
os.remove(fp_movie)
|
os.remove(self.fp_movie)
|
||||||
ms = MovieSaver(fp_movie, fps=self.fps)
|
ms = MovieSaver(self.fp_movie, fps=self.fps)
|
||||||
for img in tqdm(imgs_transition_ext):
|
for img in tqdm(imgs_transition_ext):
|
||||||
ms.write_frame(img)
|
ms.write_frame(img)
|
||||||
ms.finalize()
|
ms.finalize()
|
||||||
print("DONE SAVING MOVIE! SENDING BACK...")
|
print("DONE SAVING MOVIE! SENDING BACK...")
|
||||||
|
|
||||||
# Assemble Output, updating the preview images and le movie
|
# Assemble Output, updating the preview images and le movie
|
||||||
list_return = self.list_fp_imgs_current + [fp_movie]
|
list_return = self.list_fp_imgs_current + [self.fp_movie]
|
||||||
return list_return
|
return list_return
|
||||||
|
|
||||||
def get_fp_movie(self, timestamp, is_stacked=False):
|
|
||||||
if not is_stacked:
|
|
||||||
fn = f"movie_{timestamp}.mp4"
|
|
||||||
else:
|
|
||||||
fn = f"movie_stacked_{timestamp}.mp4"
|
|
||||||
fp = os.path.join(self.dp_out, fn)
|
|
||||||
return fp
|
|
||||||
|
|
||||||
|
|
||||||
def stack_forward(self, prompt2, seed2):
|
def stack_forward(self, prompt2, seed2):
|
||||||
# Save preview images, prompts and seeds into dictionary for stacking
|
# Save preview images, prompts and seeds into dictionary for stacking
|
||||||
dp_out = os.path.join(self.dp_out, get_time('second'))
|
# self.list_imgs_shown_last = self.get_multi_trans_imgs_preview(f"lowres_{self.current_timestamp}")[0:5]
|
||||||
self.lb.write_imgs_transition(dp_out)
|
timestamp_section = get_time('second')
|
||||||
|
self.lb.write_imgs_transition(os.path.join(self.dp_out, f"lowres_{timestamp_section}"))
|
||||||
|
self.lb.write_imgs_transition(os.path.join(self.dp_out, "lowres_current"))
|
||||||
|
shutil.copyfile(self.fp_movie, os.path.join(self.dp_out, f"lowres_{timestamp_section}", "movie.mp4"))
|
||||||
|
|
||||||
self.lb.swap_forward()
|
self.lb.swap_forward()
|
||||||
list_out = [self.list_fp_imgs_current[-1]]
|
list_out = [self.fp_img2]
|
||||||
list_out.extend([self.fp_img_empty]*4)
|
list_out.extend([self.fp_img_empty]*4)
|
||||||
list_out.append(prompt2)
|
list_out.append(prompt2)
|
||||||
list_out.append(seed2)
|
list_out.append(seed2)
|
||||||
list_out.append("")
|
list_out.append("")
|
||||||
list_out.append(np.random.randint(0, 10000000))
|
list_out.append(np.random.randint(0, 10000000))
|
||||||
|
|
||||||
return list_out
|
return list_out
|
||||||
|
|
||||||
|
|
||||||
def stack_movie(self):
|
|
||||||
# collect all that are in...
|
|
||||||
list_fp_movies = []
|
|
||||||
|
|
||||||
list_fp_movies.append(self.get_fp_movie(timestamp))
|
|
||||||
|
|
||||||
fp_stacked = self.get_fp_movie(get_time('second'), True)
|
|
||||||
concatenate_movies(fp_stacked, list_fp_movies)
|
|
||||||
return fp_stacked
|
|
||||||
|
|
||||||
|
|
||||||
def get_state_dict(self):
|
def get_state_dict(self):
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
|
@ -265,6 +266,89 @@ class BlendingFrontend():
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_list_all_stacked(self):
|
||||||
|
list_all = os.listdir(os.path.join(self.dp_out))
|
||||||
|
list_all = [l for l in list_all if l[:8]=="lowres_2"]
|
||||||
|
list_all.sort()
|
||||||
|
return list_all
|
||||||
|
|
||||||
|
def multi_trans_show_older(self):
|
||||||
|
list_all = self.get_list_all_stacked()
|
||||||
|
if self.multi_idx_current == -1:
|
||||||
|
self.multi_idx_current = len(list_all) - 1
|
||||||
|
else:
|
||||||
|
self.multi_idx_current -= 1
|
||||||
|
|
||||||
|
if self.multi_idx_current < 0:
|
||||||
|
self.multi_idx_current = 0
|
||||||
|
dn = list_all[self.multi_idx_current]
|
||||||
|
return self.get_multi_trans_imgs_preview(dn)
|
||||||
|
|
||||||
|
def multi_trans_show_newer(self):
|
||||||
|
list_all = self.get_list_all_stacked()
|
||||||
|
if self.multi_idx_current == -1:
|
||||||
|
self.multi_idx_current = len(list_all) - 1
|
||||||
|
else:
|
||||||
|
self.multi_idx_current += 1
|
||||||
|
|
||||||
|
if self.multi_idx_current >= len(list_all):
|
||||||
|
self.multi_idx_current = len(list_all) - 1
|
||||||
|
dn = list_all[self.multi_idx_current]
|
||||||
|
return self.get_multi_trans_imgs_preview(dn)
|
||||||
|
|
||||||
|
def get_multi_trans_imgs_preview(self, dn):
|
||||||
|
dp_show = os.path.join(self.dp_out, dn)
|
||||||
|
list_imgs_transition = os.listdir(dp_show)
|
||||||
|
list_imgs_transition = [l for l in list_imgs_transition if l[:11]=="lowres_img_"]
|
||||||
|
list_imgs_transition.sort()
|
||||||
|
|
||||||
|
idx_img_prev = np.round(np.linspace(0, len(list_imgs_transition)-1, 5)).astype(np.int32)
|
||||||
|
list_imgs_preview = []
|
||||||
|
for j in idx_img_prev:
|
||||||
|
list_imgs_preview.append(os.path.join(dp_show, list_imgs_transition[j]))
|
||||||
|
|
||||||
|
list_out = list_imgs_preview
|
||||||
|
list_out.append(dn[7:])
|
||||||
|
|
||||||
|
return list_out
|
||||||
|
|
||||||
|
def multi_append(self):
|
||||||
|
list_all = self.get_list_all_stacked()
|
||||||
|
dn = list_all[self.multi_idx_current]
|
||||||
|
self.multi_list_concat.append(dn)
|
||||||
|
list_short = [dn[7:] for dn in self.multi_list_concat]
|
||||||
|
str_out = "\n".join(list_short)
|
||||||
|
return str_out
|
||||||
|
|
||||||
|
def multi_reset(self):
|
||||||
|
self.multi_list_concat = []
|
||||||
|
str_out = ""
|
||||||
|
return str_out
|
||||||
|
|
||||||
|
def multi_concat(self):
|
||||||
|
# Make new output directory
|
||||||
|
dp_multi = os.path.join(self.dp_out, f"multi_{get_time('second')}")
|
||||||
|
os.makedirs(dp_multi, exist_ok=False)
|
||||||
|
|
||||||
|
# Copy all low-res folders (prepending multi001_xxxx), however leave out the movie.mp4
|
||||||
|
# also collect all movie.mp4
|
||||||
|
list_fp_movies = []
|
||||||
|
for i, dn in enumerate(self.multi_list_concat):
|
||||||
|
dp_source = os.path.join(self.dp_out, dn)
|
||||||
|
dp_sequence = os.path.join(dp_multi, f"{str(i).zfill(3)}_{dn}")
|
||||||
|
os.makedirs(dp_sequence, exist_ok=False)
|
||||||
|
list_source = os.listdir(dp_source)
|
||||||
|
list_source = [l for l in list_source if not l.endswith(".mp4")]
|
||||||
|
for fn in list_source:
|
||||||
|
shutil.copyfile(os.path.join(dp_source, fn), os.path.join(dp_sequence, fn))
|
||||||
|
list_fp_movies.append(os.path.join(dp_source, "movie.mp4"))
|
||||||
|
|
||||||
|
# Concatenate movies and save
|
||||||
|
fp_final = os.path.join(dp_multi, "movie.mp4")
|
||||||
|
concatenate_movies(fp_final, list_fp_movies)
|
||||||
|
return fp_final
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_img_rand():
|
def get_img_rand():
|
||||||
return (255*np.random.rand(self.height,self.width,3)).astype(np.uint8)
|
return (255*np.random.rand(self.height,self.width,3)).astype(np.uint8)
|
||||||
|
@ -287,119 +371,149 @@ def generate_list_output(
|
||||||
return list_output
|
return list_output
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
# fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
|
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
|
||||||
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
|
# fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
|
||||||
sdh = StableDiffusionHolder(fp_ckpt)
|
self = BlendingFrontend(StableDiffusionHolder(fp_ckpt)) # Yes this is possible in python and yes it is an awesome trick
|
||||||
|
|
||||||
self = BlendingFrontend(sdh) # Yes this is possible in python and yes it is an awesome trick
|
|
||||||
# self = BlendingFrontend(None) # Yes this is possible in python and yes it is an awesome trick
|
# self = BlendingFrontend(None) # Yes this is possible in python and yes it is an awesome trick
|
||||||
|
|
||||||
dict_ui_elem = {}
|
dict_ui_elem = {}
|
||||||
|
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
with gr.Row():
|
with gr.Tab("Single Transition"):
|
||||||
prompt1 = gr.Textbox(label="prompt 1")
|
with gr.Row():
|
||||||
prompt2 = gr.Textbox(label="prompt 2")
|
prompt1 = gr.Textbox(label="prompt 1")
|
||||||
|
prompt2 = gr.Textbox(label="prompt 2")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
duration_compute = gr.Slider(5, 45, self.t_compute_max_allowed, step=1, label='compute budget for transition (seconds)', interactive=True)
|
duration_compute = gr.Slider(5, 45, self.t_compute_max_allowed, step=1, label='compute budget for transition (seconds)', interactive=True)
|
||||||
duration_video = gr.Slider(0.1, 30, self.duration_video, step=0.1, label='result video duration (seconds)', interactive=True)
|
duration_video = gr.Slider(0.1, 30, self.duration_video, step=0.1, label='result video duration (seconds)', interactive=True)
|
||||||
height = gr.Slider(256, 2048, self.height, step=128, label='height', interactive=True)
|
height = gr.Slider(256, 2048, self.height, step=128, label='height', interactive=True)
|
||||||
width = gr.Slider(256, 2048, self.width, step=128, label='width', interactive=True)
|
width = gr.Slider(256, 2048, self.width, step=128, label='width', interactive=True)
|
||||||
|
|
||||||
with gr.Accordion("Advanced Settings (click to expand)", open=False):
|
with gr.Accordion("Advanced Settings (click to expand)", open=False):
|
||||||
|
|
||||||
with gr.Accordion("Diffusion settings", open=True):
|
with gr.Accordion("Diffusion settings", open=True):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
num_inference_steps = gr.Slider(5, 100, self.num_inference_steps, step=1, label='num_inference_steps', interactive=True)
|
num_inference_steps = gr.Slider(5, 100, self.num_inference_steps, step=1, label='num_inference_steps', interactive=True)
|
||||||
guidance_scale = gr.Slider(1, 25, self.guidance_scale, step=0.1, label='guidance_scale', interactive=True)
|
guidance_scale = gr.Slider(1, 25, self.guidance_scale, step=0.1, label='guidance_scale', interactive=True)
|
||||||
negative_prompt = gr.Textbox(label="negative prompt")
|
negative_prompt = gr.Textbox(label="negative prompt")
|
||||||
|
|
||||||
with gr.Accordion("Seeds control", open=True):
|
with gr.Accordion("Seeds control", open=True):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
seed1 = gr.Number(420, label="seed 1", interactive=True)
|
seed1 = gr.Number(self.seed1, label="seed 1", interactive=True)
|
||||||
b_newseed1 = gr.Button("randomize seed 1", variant='secondary')
|
b_newseed1 = gr.Button("randomize seed 1", variant='secondary')
|
||||||
seed2 = gr.Number(420, label="seed 2", interactive=True)
|
seed2 = gr.Number(self.seed2, label="seed 2", interactive=True)
|
||||||
b_newseed2 = gr.Button("randomize seed 2", variant='secondary')
|
b_newseed2 = gr.Button("randomize seed 2", variant='secondary')
|
||||||
|
|
||||||
with gr.Accordion("Crossfeeding for last image", open=True):
|
with gr.Accordion("Crossfeeding for last image", open=True):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
branch1_influence = gr.Slider(0.0, 1.0, self.branch1_influence, step=0.01, label='crossfeed power', interactive=True)
|
branch1_influence = gr.Slider(0.0, 1.0, self.branch1_influence, step=0.01, label='crossfeed power', interactive=True)
|
||||||
branch1_max_depth_influence = gr.Slider(0.0, 1.0, self.branch1_max_depth_influence, step=0.01, label='crossfeed range', interactive=True)
|
branch1_max_depth_influence = gr.Slider(0.0, 1.0, self.branch1_max_depth_influence, step=0.01, label='crossfeed range', interactive=True)
|
||||||
branch1_influence_decay = gr.Slider(0.0, 1.0, self.branch1_influence_decay, step=0.01, label='crossfeed decay', interactive=True)
|
branch1_influence_decay = gr.Slider(0.0, 1.0, self.branch1_influence_decay, step=0.01, label='crossfeed decay', interactive=True)
|
||||||
|
|
||||||
with gr.Accordion("Transition settings", open=True):
|
with gr.Accordion("Transition settings", open=True):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
depth_strength = gr.Slider(0.01, 0.99, self.depth_strength, step=0.01, label='depth_strength', interactive=True)
|
depth_strength = gr.Slider(0.01, 0.99, self.depth_strength, step=0.01, label='depth_strength', interactive=True)
|
||||||
guidance_scale_mid_damper = gr.Slider(0.01, 2.0, self.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
|
guidance_scale_mid_damper = gr.Slider(0.01, 2.0, self.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
|
||||||
parental_influence = gr.Slider(0.0, 1.0, self.parental_influence, step=0.01, label='parental power', interactive=True)
|
parental_influence = gr.Slider(0.0, 1.0, self.parental_influence, step=0.01, label='parental power', interactive=True)
|
||||||
parental_max_depth_influence = gr.Slider(0.0, 1.0, self.parental_max_depth_influence, step=0.01, label='parental range', interactive=True)
|
parental_max_depth_influence = gr.Slider(0.0, 1.0, self.parental_max_depth_influence, step=0.01, label='parental range', interactive=True)
|
||||||
parental_influence_decay = gr.Slider(0.0, 1.0, self.parental_influence_decay, step=0.01, label='parental decay', interactive=True)
|
parental_influence_decay = gr.Slider(0.0, 1.0, self.parental_influence_decay, step=0.01, label='parental decay', interactive=True)
|
||||||
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
b_compute1 = gr.Button('compute first image', variant='primary')
|
b_compute1 = gr.Button('compute first image', variant='primary')
|
||||||
b_compute_transition = gr.Button('compute transition', variant='primary')
|
b_compute_transition = gr.Button('compute transition', variant='primary')
|
||||||
b_compute2 = gr.Button('compute last image', variant='primary')
|
b_compute2 = gr.Button('compute last image', variant='primary')
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
img1 = gr.Image(label="1/5")
|
img1 = gr.Image(label="1/5")
|
||||||
img2 = gr.Image(label="2/5")
|
img2 = gr.Image(label="2/5")
|
||||||
img3 = gr.Image(label="3/5")
|
img3 = gr.Image(label="3/5")
|
||||||
img4 = gr.Image(label="4/5")
|
img4 = gr.Image(label="4/5")
|
||||||
img5 = gr.Image(label="5/5")
|
img5 = gr.Image(label="5/5")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
vid_transition = gr.Video()
|
vid_transition = gr.Video()
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
b_stackforward = gr.Button('multi-movie start next segment (move last image -> first image)')
|
b_stackforward = gr.Button('multi-movie start next segment (move last image -> first image)')
|
||||||
|
|
||||||
# Collect all UI elemts in list to easily pass as inputs
|
# Collect all UI elemts in list to easily pass as inputs
|
||||||
dict_ui_elem["prompt1"] = prompt1
|
dict_ui_elem["prompt1"] = prompt1
|
||||||
dict_ui_elem["negative_prompt"] = negative_prompt
|
dict_ui_elem["negative_prompt"] = negative_prompt
|
||||||
dict_ui_elem["prompt2"] = prompt2
|
dict_ui_elem["prompt2"] = prompt2
|
||||||
|
|
||||||
dict_ui_elem["duration_compute"] = duration_compute
|
dict_ui_elem["duration_compute"] = duration_compute
|
||||||
dict_ui_elem["duration_video"] = duration_video
|
dict_ui_elem["duration_video"] = duration_video
|
||||||
dict_ui_elem["height"] = height
|
dict_ui_elem["height"] = height
|
||||||
dict_ui_elem["width"] = width
|
dict_ui_elem["width"] = width
|
||||||
|
|
||||||
dict_ui_elem["depth_strength"] = depth_strength
|
dict_ui_elem["depth_strength"] = depth_strength
|
||||||
dict_ui_elem["branch1_influence"] = branch1_influence
|
dict_ui_elem["branch1_influence"] = branch1_influence
|
||||||
dict_ui_elem["branch1_max_depth_influence"] = branch1_max_depth_influence
|
dict_ui_elem["branch1_max_depth_influence"] = branch1_max_depth_influence
|
||||||
dict_ui_elem["branch1_influence_decay"] = branch1_influence_decay
|
dict_ui_elem["branch1_influence_decay"] = branch1_influence_decay
|
||||||
|
|
||||||
dict_ui_elem["num_inference_steps"] = num_inference_steps
|
dict_ui_elem["num_inference_steps"] = num_inference_steps
|
||||||
dict_ui_elem["guidance_scale"] = guidance_scale
|
dict_ui_elem["guidance_scale"] = guidance_scale
|
||||||
dict_ui_elem["guidance_scale_mid_damper"] = guidance_scale_mid_damper
|
dict_ui_elem["guidance_scale_mid_damper"] = guidance_scale_mid_damper
|
||||||
dict_ui_elem["seed1"] = seed1
|
dict_ui_elem["seed1"] = seed1
|
||||||
dict_ui_elem["seed2"] = seed2
|
dict_ui_elem["seed2"] = seed2
|
||||||
|
|
||||||
dict_ui_elem["parental_max_depth_influence"] = parental_max_depth_influence
|
dict_ui_elem["parental_max_depth_influence"] = parental_max_depth_influence
|
||||||
dict_ui_elem["parental_influence"] = parental_influence
|
dict_ui_elem["parental_influence"] = parental_influence
|
||||||
dict_ui_elem["parental_influence_decay"] = parental_influence_decay
|
dict_ui_elem["parental_influence_decay"] = parental_influence_decay
|
||||||
|
|
||||||
# Convert to list, as gradio doesn't seem to accept dicts
|
# Convert to list, as gradio doesn't seem to accept dicts
|
||||||
list_ui_elem = []
|
list_ui_elem = []
|
||||||
list_ui_keys = []
|
list_ui_keys = []
|
||||||
for k in dict_ui_elem.keys():
|
for k in dict_ui_elem.keys():
|
||||||
list_ui_elem.append(dict_ui_elem[k])
|
list_ui_elem.append(dict_ui_elem[k])
|
||||||
list_ui_keys.append(k)
|
list_ui_keys.append(k)
|
||||||
self.list_ui_keys = list_ui_keys
|
self.list_ui_keys = list_ui_keys
|
||||||
|
|
||||||
|
b_newseed1.click(self.randomize_seed1, outputs=seed1)
|
||||||
|
b_newseed2.click(self.randomize_seed2, outputs=seed2)
|
||||||
|
b_compute1.click(self.compute_img1, inputs=list_ui_elem, outputs=[img1, img2, img3, img4, img5])
|
||||||
|
b_compute2.click(self.compute_img2, inputs=list_ui_elem, outputs=[img2, img3, img4, img5])
|
||||||
|
b_compute_transition.click(self.compute_transition,
|
||||||
|
inputs=list_ui_elem,
|
||||||
|
outputs=[img2, img3, img4, vid_transition])
|
||||||
|
|
||||||
|
b_stackforward.click(self.stack_forward,
|
||||||
|
inputs=[prompt2, seed2],
|
||||||
|
outputs=[img1, img2, img3, img4, img5, prompt1, seed1, prompt2])
|
||||||
|
|
||||||
|
with gr.Tab("Multi Transition"):
|
||||||
|
with gr.Row():
|
||||||
|
multi_img1_prev = gr.Image(value=self.list_imgs_shown_last[0], label="1/5")
|
||||||
|
multi_img2_prev = gr.Image(value=self.list_imgs_shown_last[1], label="2/5")
|
||||||
|
multi_img3_prev = gr.Image(value=self.list_imgs_shown_last[2], label="3/5")
|
||||||
|
multi_img4_prev = gr.Image(value=self.list_imgs_shown_last[3], label="4/5")
|
||||||
|
multi_img5_prev = gr.Image(value=self.list_imgs_shown_last[4], label="5/5")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
b_older = gr.Button("show older")
|
||||||
|
b_newer = gr.Button("show newer")
|
||||||
|
text_timestamp = gr.Textbox(label="created", interactive=False)
|
||||||
|
b_append = gr.Button("append this transition")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
text_all_timestamps = gr.Textbox(label="movie list", interactive=False)
|
||||||
|
with gr.Row():
|
||||||
|
b_reset = gr.Button("reset")
|
||||||
|
b_concat = gr.Button("merge together", variant='primary')
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
vid_multi = gr.Video()
|
||||||
|
|
||||||
|
|
||||||
|
b_older.click(self.multi_trans_show_older, inputs=[], outputs=[multi_img1_prev, multi_img2_prev, multi_img3_prev, multi_img4_prev, multi_img5_prev, text_timestamp])
|
||||||
|
b_newer.click(self.multi_trans_show_newer, inputs=[], outputs=[multi_img1_prev, multi_img2_prev, multi_img3_prev, multi_img4_prev, multi_img5_prev, text_timestamp])
|
||||||
|
b_append.click(self.multi_append, inputs=[], outputs=[text_all_timestamps])
|
||||||
|
b_reset.click(self.multi_reset, inputs=[], outputs=[text_all_timestamps])
|
||||||
|
b_concat.click(self.multi_concat, inputs=[], outputs=[vid_multi])
|
||||||
|
|
||||||
b_newseed1.click(self.randomize_seed1, outputs=seed1)
|
|
||||||
b_newseed2.click(self.randomize_seed2, outputs=seed2)
|
|
||||||
b_compute1.click(self.compute_img1, inputs=list_ui_elem, outputs=[img1, img2, img3, img4, img5])
|
|
||||||
b_compute2.click(self.compute_img2, inputs=list_ui_elem, outputs=[img2, img3, img4, img5])
|
|
||||||
b_compute_transition.click(self.compute_transition,
|
|
||||||
inputs=list_ui_elem,
|
|
||||||
outputs=[img2, img3, img4, vid_transition])
|
|
||||||
|
|
||||||
b_stackforward.click(self.stack_forward,
|
|
||||||
inputs=[prompt2, seed2],
|
|
||||||
outputs=[img1, img2, img3, img4, img5, prompt1, seed1, prompt2])
|
|
||||||
demo.launch(share=self.share, inbrowser=True, inline=False)
|
demo.launch(share=self.share, inbrowser=True, inline=False)
|
||||||
|
|
|
@ -231,36 +231,36 @@ class LatentBlending():
|
||||||
|
|
||||||
if quality == 'lowest':
|
if quality == 'lowest':
|
||||||
num_inference_steps = 12
|
num_inference_steps = 12
|
||||||
nmb_branches_final = 5
|
nmb_max_branches = 5
|
||||||
elif quality == 'low':
|
elif quality == 'low':
|
||||||
num_inference_steps = 15
|
num_inference_steps = 15
|
||||||
nmb_branches_final = nmb_frames//16
|
nmb_max_branches = nmb_frames//16
|
||||||
elif quality == 'medium':
|
elif quality == 'medium':
|
||||||
num_inference_steps = 30
|
num_inference_steps = 30
|
||||||
nmb_branches_final = nmb_frames//8
|
nmb_max_branches = nmb_frames//8
|
||||||
elif quality == 'high':
|
elif quality == 'high':
|
||||||
num_inference_steps = 60
|
num_inference_steps = 60
|
||||||
nmb_branches_final = nmb_frames//4
|
nmb_max_branches = nmb_frames//4
|
||||||
elif quality == 'ultra':
|
elif quality == 'ultra':
|
||||||
num_inference_steps = 100
|
num_inference_steps = 100
|
||||||
nmb_branches_final = nmb_frames//2
|
nmb_max_branches = nmb_frames//2
|
||||||
elif quality == 'upscaling_step1':
|
elif quality == 'upscaling_step1':
|
||||||
num_inference_steps = 40
|
num_inference_steps = 40
|
||||||
nmb_branches_final = 12
|
nmb_max_branches = 12
|
||||||
elif quality == 'upscaling_step2':
|
elif quality == 'upscaling_step2':
|
||||||
num_inference_steps = 100
|
num_inference_steps = 100
|
||||||
nmb_branches_final = 6
|
nmb_max_branches = 6
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"quality = '{quality}' not supported")
|
raise ValueError(f"quality = '{quality}' not supported")
|
||||||
|
|
||||||
self.autosetup_branching(depth_strength, num_inference_steps, nmb_branches_final)
|
self.autosetup_branching(depth_strength, num_inference_steps, nmb_max_branches)
|
||||||
|
|
||||||
|
|
||||||
def autosetup_branching(
|
def autosetup_branching(
|
||||||
self,
|
self,
|
||||||
depth_strength: float = 0.65,
|
depth_strength: float = 0.65,
|
||||||
num_inference_steps: int = 30,
|
num_inference_steps: int = 30,
|
||||||
nmb_branches_final: int = 20,
|
nmb_max_branches: int = 20,
|
||||||
nmb_mindist: int = 3,
|
nmb_mindist: int = 3,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
|
@ -273,7 +273,7 @@ class LatentBlending():
|
||||||
more shallow values will go into alpha-blendy land.
|
more shallow values will go into alpha-blendy land.
|
||||||
num_inference_steps: int
|
num_inference_steps: int
|
||||||
Number of diffusion steps. Higher values will take more compute time.
|
Number of diffusion steps. Higher values will take more compute time.
|
||||||
nmb_branches_final (int): The number of diffusion-generated images
|
nmb_max_branches (int): The number of diffusion-generated images
|
||||||
at the end of the inference.
|
at the end of the inference.
|
||||||
nmb_mindist (int): The minimum number of diffusion steps
|
nmb_mindist (int): The minimum number of diffusion steps
|
||||||
between two injections.
|
between two injections.
|
||||||
|
@ -285,7 +285,7 @@ class LatentBlending():
|
||||||
|
|
||||||
list_injection_idx = [0]
|
list_injection_idx = [0]
|
||||||
list_injection_idx.extend(np.linspace(idx_injection_first, idx_injection_last, nmb_injections).astype(int))
|
list_injection_idx.extend(np.linspace(idx_injection_first, idx_injection_last, nmb_injections).astype(int))
|
||||||
list_nmb_branches = np.round(np.logspace(np.log10(2), np.log10(nmb_branches_final), nmb_injections+1)).astype(int)
|
list_nmb_branches = np.round(np.logspace(np.log10(2), np.log10(nmb_max_branches), nmb_injections+1)).astype(int)
|
||||||
|
|
||||||
# Cleanup. There should be at least nmb_mindist diffusion steps between each injection and list_nmb_branches increases
|
# Cleanup. There should be at least nmb_mindist diffusion steps between each injection and list_nmb_branches increases
|
||||||
list_nmb_branches_clean = [list_nmb_branches[0]]
|
list_nmb_branches_clean = [list_nmb_branches[0]]
|
||||||
|
@ -294,7 +294,7 @@ class LatentBlending():
|
||||||
if idx_injection - list_injection_idx_clean[-1] >= nmb_mindist and nmb_branches > list_nmb_branches_clean[-1]:
|
if idx_injection - list_injection_idx_clean[-1] >= nmb_mindist and nmb_branches > list_nmb_branches_clean[-1]:
|
||||||
list_nmb_branches_clean.append(nmb_branches)
|
list_nmb_branches_clean.append(nmb_branches)
|
||||||
list_injection_idx_clean.append(idx_injection)
|
list_injection_idx_clean.append(idx_injection)
|
||||||
list_nmb_branches_clean[-1] = nmb_branches_final
|
list_nmb_branches_clean[-1] = nmb_max_branches
|
||||||
|
|
||||||
list_injection_idx_clean = [int(l) for l in list_injection_idx_clean]
|
list_injection_idx_clean = [int(l) for l in list_injection_idx_clean]
|
||||||
list_nmb_branches_clean = [int(l) for l in list_nmb_branches_clean]
|
list_nmb_branches_clean = [int(l) for l in list_nmb_branches_clean]
|
||||||
|
@ -394,8 +394,36 @@ class LatentBlending():
|
||||||
recycle_img2: Optional[bool] = False,
|
recycle_img2: Optional[bool] = False,
|
||||||
num_inference_steps: Optional[int] = 30,
|
num_inference_steps: Optional[int] = 30,
|
||||||
depth_strength: Optional[float] = 0.3,
|
depth_strength: Optional[float] = 0.3,
|
||||||
|
t_compute_max_allowed: Optional[float] = None,
|
||||||
|
nmb_max_branches: Optional[int] = None,
|
||||||
fixed_seeds: Optional[List[int]] = None,
|
fixed_seeds: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
|
r"""
|
||||||
|
Function for computing transitions.
|
||||||
|
Returns a list of transition images using spherical latent blending.
|
||||||
|
Args:
|
||||||
|
recycle_img1: Optional[bool]:
|
||||||
|
Don't recompute the latents for the first keyframe (purely prompt1). Saves compute.
|
||||||
|
recycle_img2: Optional[bool]:
|
||||||
|
Don't recompute the latents for the second keyframe (purely prompt2). Saves compute.
|
||||||
|
num_inference_steps:
|
||||||
|
Number of diffusion steps. Higher values will take more compute time.
|
||||||
|
depth_strength:
|
||||||
|
Determines how deep the first injection will happen.
|
||||||
|
Deeper injections will cause (unwanted) formation of new structures,
|
||||||
|
more shallow values will go into alpha-blendy land.
|
||||||
|
t_compute_max_allowed:
|
||||||
|
Either provide t_compute_max_allowed or nmb_max_branches.
|
||||||
|
The maximum time allowed for computation. Higher values give better results but take longer.
|
||||||
|
nmb_max_branches: int
|
||||||
|
Either provide t_compute_max_allowed or nmb_max_branches. The maximum number of branches to be computed. Higher values give better
|
||||||
|
results. Use this if you want to have controllable results independent
|
||||||
|
of your computer.
|
||||||
|
fixed_seeds: Optional[List[int)]:
|
||||||
|
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
|
||||||
|
Otherwise random seeds will be taken.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
# Sanity checks first
|
# Sanity checks first
|
||||||
assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before'
|
assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before'
|
||||||
|
@ -412,7 +440,8 @@ class LatentBlending():
|
||||||
self.seed2 = fixed_seeds[1]
|
self.seed2 = fixed_seeds[1]
|
||||||
|
|
||||||
# Ensure correct num_inference_steps in holder
|
# Ensure correct num_inference_steps in holder
|
||||||
self.sdh.num_inference_steps = self.num_inference_steps
|
self.num_inference_steps = num_inference_steps
|
||||||
|
self.sdh.num_inference_steps = num_inference_steps
|
||||||
|
|
||||||
# Compute / Recycle first image
|
# Compute / Recycle first image
|
||||||
if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps:
|
if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps:
|
||||||
|
@ -433,11 +462,61 @@ class LatentBlending():
|
||||||
self.tree_idx_injection = [0, 0]
|
self.tree_idx_injection = [0, 0]
|
||||||
|
|
||||||
# Set up branching scheme (dependent on provided compute time)
|
# Set up branching scheme (dependent on provided compute time)
|
||||||
|
list_idx_injection, list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches)
|
||||||
|
|
||||||
|
# Run iteratively, starting with the longest trajectory.
|
||||||
|
# Always inserting new branches where they are needed most according to image similarity
|
||||||
|
for s_idx in tqdm(range(len(list_idx_injection))):
|
||||||
|
nmb_stems = list_nmb_stems[s_idx]
|
||||||
|
idx_injection = list_idx_injection[s_idx]
|
||||||
|
|
||||||
|
for i in range(nmb_stems):
|
||||||
|
fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection)
|
||||||
|
self.set_guidance_mid_dampening(fract_mixing)
|
||||||
|
list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection)
|
||||||
|
self.insert_into_tree(fract_mixing, idx_injection, list_latents)
|
||||||
|
# print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}")
|
||||||
|
|
||||||
|
return self.tree_final_imgs
|
||||||
|
|
||||||
|
|
||||||
|
def get_time_based_branching(self, depth_strength, t_compute_max_allowed=None, nmb_max_branches=None):
|
||||||
|
r"""
|
||||||
|
Sets up the branching scheme dependent on the time that is granted for compute.
|
||||||
|
The scheme uses an estimation derived from the first image's computation speed.
|
||||||
|
Either provide t_compute_max_allowed or nmb_max_branches
|
||||||
|
Args:
|
||||||
|
depth_strength:
|
||||||
|
Determines how deep the first injection will happen.
|
||||||
|
Deeper injections will cause (unwanted) formation of new structures,
|
||||||
|
more shallow values will go into alpha-blendy land.
|
||||||
|
t_compute_max_allowed: float
|
||||||
|
The maximum time allowed for computation. Higher values give better results
|
||||||
|
but take longer. Use this if you want to fix your waiting time for the results.
|
||||||
|
nmb_max_branches: int
|
||||||
|
The maximum number of branches to be computed. Higher values give better
|
||||||
|
results. Use this if you want to have controllable results independent
|
||||||
|
of your computer.
|
||||||
|
"""
|
||||||
idx_injection_base = int(round(self.num_inference_steps*depth_strength))
|
idx_injection_base = int(round(self.num_inference_steps*depth_strength))
|
||||||
list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps-1, 3)
|
list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps-1, 3)
|
||||||
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
|
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
|
||||||
t_compute = 0
|
t_compute = 0
|
||||||
while t_compute < self.t_compute_max_allowed:
|
|
||||||
|
if nmb_max_branches is None:
|
||||||
|
assert t_compute_max_allowed is not None, "Either specify t_compute_max_allowed or nmb_max_branches"
|
||||||
|
stop_criterion = "t_compute_max_allowed"
|
||||||
|
elif t_compute_max_allowed is None:
|
||||||
|
assert nmb_max_branches is not None, "Either specify t_compute_max_allowed or nmb_max_branches"
|
||||||
|
stop_criterion = "nmb_max_branches"
|
||||||
|
nmb_max_branches -= 2 # discounting the outer frames
|
||||||
|
else:
|
||||||
|
raise ValueError("Either specify t_compute_max_allowed or nmb_max_branches")
|
||||||
|
|
||||||
|
stop_criterion_reached = False
|
||||||
|
is_first_iteration = True
|
||||||
|
|
||||||
|
while not stop_criterion_reached:
|
||||||
list_compute_steps = self.num_inference_steps - list_idx_injection
|
list_compute_steps = self.num_inference_steps - list_idx_injection
|
||||||
list_compute_steps *= list_nmb_stems
|
list_compute_steps *= list_nmb_stems
|
||||||
t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15*np.sum(list_nmb_stems)
|
t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15*np.sum(list_nmb_stems)
|
||||||
|
@ -449,23 +528,21 @@ class LatentBlending():
|
||||||
break
|
break
|
||||||
if not increase_done:
|
if not increase_done:
|
||||||
list_nmb_stems[-1] += 1
|
list_nmb_stems[-1] += 1
|
||||||
|
|
||||||
|
if stop_criterion == "t_compute_max_allowed" and t_compute > t_compute_max_allowed:
|
||||||
|
stop_criterion_reached = True
|
||||||
|
# FIXME: also undersample here... but how... maybe drop them iteratively?
|
||||||
|
elif stop_criterion == "nmb_max_branches" and np.sum(list_nmb_stems) >= nmb_max_branches:
|
||||||
|
stop_criterion_reached = True
|
||||||
|
if is_first_iteration:
|
||||||
|
# Need to undersample.
|
||||||
|
list_idx_injection = np.linspace(list_idx_injection[0], list_idx_injection[-1], nmb_max_branches).astype(np.int32)
|
||||||
|
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
|
||||||
|
else:
|
||||||
|
is_first_iteration = False
|
||||||
|
|
||||||
# print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}")
|
# print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}")
|
||||||
|
return list_idx_injection, list_nmb_stems
|
||||||
# Run iteratively, always inserting new branches where they are needed most
|
|
||||||
for s_idx in tqdm(range(len(list_idx_injection))):
|
|
||||||
nmb_stems = list_nmb_stems[s_idx]
|
|
||||||
idx_injection = list_idx_injection[s_idx]
|
|
||||||
|
|
||||||
for i in range(nmb_stems):
|
|
||||||
fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection)
|
|
||||||
self.set_guidance_mid_dampening(fract_mixing)
|
|
||||||
# print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}")
|
|
||||||
list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection)
|
|
||||||
self.insert_into_tree(fract_mixing, idx_injection, list_latents)
|
|
||||||
|
|
||||||
return self.tree_final_imgs
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_mixing_parameters(self, idx_injection):
|
def get_mixing_parameters(self, idx_injection):
|
||||||
r"""
|
r"""
|
||||||
|
@ -581,7 +658,7 @@ class LatentBlending():
|
||||||
whether to return an image or the list of latents
|
whether to return an image or the list of latents
|
||||||
"""
|
"""
|
||||||
print("starting compute_latents1")
|
print("starting compute_latents1")
|
||||||
list_conditionings = [self.text_embedding1]
|
list_conditionings = self.get_mixed_conditioning(0)
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
latents_start = self.get_noise(self.seed1)
|
latents_start = self.get_noise(self.seed1)
|
||||||
list_latents1 = self.run_diffusion(
|
list_latents1 = self.run_diffusion(
|
||||||
|
@ -604,7 +681,8 @@ class LatentBlending():
|
||||||
return_image: bool
|
return_image: bool
|
||||||
whether to return an image or the list of latents
|
whether to return an image or the list of latents
|
||||||
"""
|
"""
|
||||||
list_conditionings = [self.text_embedding2]
|
print("starting compute_latents2")
|
||||||
|
list_conditionings = self.get_mixed_conditioning(1)
|
||||||
latents_start = self.get_noise(self.seed2)
|
latents_start = self.get_noise(self.seed2)
|
||||||
# Influence from branch1
|
# Influence from branch1
|
||||||
if self.branch1_influence > 0.0:
|
if self.branch1_influence > 0.0:
|
||||||
|
@ -630,177 +708,23 @@ class LatentBlending():
|
||||||
return list_latents2
|
return list_latents2
|
||||||
|
|
||||||
def get_noise(self, seed):
|
def get_noise(self, seed):
|
||||||
generator = torch.Generator(device=self.sdh.device).manual_seed(int(seed))
|
|
||||||
shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
|
|
||||||
C, H, W = shape_latents
|
|
||||||
return torch.randn((1, C, H, W), generator=generator, device=self.sdh.device)
|
|
||||||
|
|
||||||
|
|
||||||
def run_transition_legacy(
|
|
||||||
self,
|
|
||||||
recycle_img1: Optional[bool] = False,
|
|
||||||
recycle_img2: Optional[bool] = False,
|
|
||||||
fixed_seeds: Optional[List[int]] = None,
|
|
||||||
premature_stop: Optional[int] = np.inf,
|
|
||||||
):
|
|
||||||
r"""
|
r"""
|
||||||
Old legacy function for computing transitions.
|
Helper function to get noise given seed.
|
||||||
Returns a list of transition images using spherical latent blending.
|
|
||||||
Args:
|
Args:
|
||||||
recycle_img1: Optional[bool]:
|
seed: int
|
||||||
Don't recompute the latents for the first keyframe (purely prompt1). Saves compute.
|
|
||||||
recycle_img2: Optional[bool]:
|
|
||||||
Don't recompute the latents for the second keyframe (purely prompt2). Saves compute.
|
|
||||||
fixed_seeds: Optional[List[int)]:
|
|
||||||
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
|
|
||||||
Otherwise random seeds will be taken.
|
|
||||||
premature_stop: Optional[int]:
|
|
||||||
Stop the computation after premature_stop frames have been computed in the transition
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Sanity checks first
|
generator = torch.Generator(device=self.sdh.device).manual_seed(int(seed))
|
||||||
assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before'
|
if self.mode == 'standard':
|
||||||
assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before'
|
shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
|
||||||
assert self.list_injection_idx is not None, 'Set the branching structure before, by calling autosetup_branching or setup_branching'
|
C, H, W = shape_latents
|
||||||
|
elif self.mode == 'upscale':
|
||||||
|
w = self.image1_lowres.size[0]
|
||||||
|
h = self.image1_lowres.size[1]
|
||||||
|
shape_latents = [self.sdh.model.channels, h, w]
|
||||||
|
C, H, W = shape_latents
|
||||||
|
|
||||||
if fixed_seeds is not None:
|
return torch.randn((1, C, H, W), generator=generator, device=self.sdh.device)
|
||||||
if fixed_seeds == 'randomize':
|
|
||||||
fixed_seeds = list(np.random.randint(0, 1000000, 2).astype(np.int32))
|
|
||||||
else:
|
|
||||||
assert len(fixed_seeds)==2, "Supply a list with len = 2"
|
|
||||||
|
|
||||||
self.seed1 = fixed_seeds[0]
|
|
||||||
self.seed2 = fixed_seeds[1]
|
|
||||||
|
|
||||||
# Process interruption variable
|
|
||||||
self.stop_diffusion = False
|
|
||||||
|
|
||||||
# Ensure correct num_inference_steps in holder
|
|
||||||
self.sdh.num_inference_steps = self.num_inference_steps
|
|
||||||
|
|
||||||
# Make a backup for future reference
|
|
||||||
self.list_nmb_branches_prev = self.list_nmb_branches[:]
|
|
||||||
self.list_injection_idx_prev = self.list_injection_idx[:]
|
|
||||||
|
|
||||||
# Split the first block if there is branch1 crossfeeding
|
|
||||||
if self.branch1_influence > 0.0 and not self.branch1_insertion_completed:
|
|
||||||
assert self.list_nmb_branches[0]==2, 'branch1 influnce currently requires the self.list_nmb_branches[0] = 0'
|
|
||||||
self.branch1_influence = np.clip(self.branch1_influence, 0, 1)
|
|
||||||
self.branch1_max_depth_influence = np.clip(self.branch1_max_depth_influence, 0, 1)
|
|
||||||
self.list_nmb_branches.insert(1, 2)
|
|
||||||
idx_crossfeed = int(round(self.list_injection_idx[1]*self.branch1_max_depth_influence))
|
|
||||||
self.list_injection_idx_ext.insert(1, idx_crossfeed)
|
|
||||||
self.tree_fracts.insert(1, self.tree_fracts[0])
|
|
||||||
self.tree_status.insert(1, self.tree_status[0])
|
|
||||||
self.tree_latents.insert(1, self.tree_latents[0])
|
|
||||||
self.branch1_insertion_completed = True
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Pre-define entire branching tree structures
|
|
||||||
self.tree_final_imgs = [None]*self.list_nmb_branches[-1]
|
|
||||||
nmb_blocks_time = len(self.list_injection_idx_ext)-1
|
|
||||||
if not recycle_img1 and not recycle_img2:
|
|
||||||
self.init_tree_struct()
|
|
||||||
else:
|
|
||||||
self.tree_final_imgs = [None]*self.list_nmb_branches[-1]
|
|
||||||
for t_block in range(nmb_blocks_time):
|
|
||||||
nmb_branches = self.list_nmb_branches[t_block]
|
|
||||||
for idx_branch in range(nmb_branches):
|
|
||||||
self.tree_status[t_block][idx_branch] = 'untouched'
|
|
||||||
if recycle_img1:
|
|
||||||
self.tree_status[t_block][0] = 'computed'
|
|
||||||
self.tree_final_imgs[0] = self.sdh.latent2image(self.tree_latents[-1][0][-1])
|
|
||||||
self.tree_final_imgs_timing[0] = 0
|
|
||||||
if recycle_img2:
|
|
||||||
self.tree_status[t_block][-1] = 'computed'
|
|
||||||
self.tree_final_imgs[-1] = self.sdh.latent2image(self.tree_latents[-1][-1][-1])
|
|
||||||
self.tree_final_imgs_timing[-1] = 0
|
|
||||||
|
|
||||||
# setup compute order: goal: try to get last branch computed asap.
|
|
||||||
# first compute the right keyframe. needs to be there in any case
|
|
||||||
list_compute = []
|
|
||||||
list_local_stem = []
|
|
||||||
for t_block in range(nmb_blocks_time - 1, -1, -1):
|
|
||||||
if self.tree_status[t_block][0] == 'untouched':
|
|
||||||
self.tree_status[t_block][0] = 'prefetched'
|
|
||||||
list_local_stem.append([t_block, 0])
|
|
||||||
list_compute.extend(list_local_stem[::-1])
|
|
||||||
|
|
||||||
# setup compute order: start from last leafs (the final transition images) and work way down. what parents do they need?
|
|
||||||
for idx_leaf in range(1, self.list_nmb_branches[-1]):
|
|
||||||
list_local_stem = []
|
|
||||||
t_block = nmb_blocks_time - 1
|
|
||||||
t_block_prev = t_block - 1
|
|
||||||
self.tree_status[t_block][idx_leaf] = 'prefetched'
|
|
||||||
list_local_stem.append([t_block, idx_leaf])
|
|
||||||
idx_leaf_deep = idx_leaf
|
|
||||||
|
|
||||||
for t_block in range(nmb_blocks_time-1, 0, -1):
|
|
||||||
t_block_prev = t_block - 1
|
|
||||||
fract_mixing = self.tree_fracts[t_block][idx_leaf_deep]
|
|
||||||
list_fract_mixing_prev = self.tree_fracts[t_block_prev]
|
|
||||||
b_parent1, b_parent2 = get_closest_idx(fract_mixing, list_fract_mixing_prev)
|
|
||||||
assert self.tree_status[t_block_prev][b_parent1] != 'untouched', 'Branch destruction??? This should never happen!'
|
|
||||||
if self.tree_status[t_block_prev][b_parent2] == 'untouched':
|
|
||||||
self.tree_status[t_block_prev][b_parent2] = 'prefetched'
|
|
||||||
list_local_stem.append([t_block_prev, b_parent2])
|
|
||||||
idx_leaf_deep = b_parent2
|
|
||||||
list_compute.extend(list_local_stem[::-1])
|
|
||||||
|
|
||||||
# Diffusion computations start here
|
|
||||||
time_start = time.time()
|
|
||||||
for t_block, idx_branch in tqdm(list_compute, desc="computing transition", smoothing=0.01):
|
|
||||||
if self.stop_diffusion:
|
|
||||||
print("run_transition: process interrupted")
|
|
||||||
return self.tree_final_imgs
|
|
||||||
if idx_branch > premature_stop:
|
|
||||||
print(f"run_transition: premature_stop criterion reached. returning tree with {premature_stop} branches")
|
|
||||||
return self.tree_final_imgs
|
|
||||||
|
|
||||||
# print(f"computing t_block {t_block} idx_branch {idx_branch}")
|
|
||||||
idx_stop = self.list_injection_idx_ext[t_block+1]
|
|
||||||
fract_mixing = self.tree_fracts[t_block][idx_branch]
|
|
||||||
|
|
||||||
list_conditionings = self.get_mixed_conditioning(fract_mixing)
|
|
||||||
self.set_guidance_mid_dampening(fract_mixing)
|
|
||||||
# print(f"fract_mixing {fract_mixing} guid {self.sdh.guidance_scale}")
|
|
||||||
if t_block == 0:
|
|
||||||
if fixed_seeds is not None:
|
|
||||||
if idx_branch == 0:
|
|
||||||
self.set_seed(fixed_seeds[0])
|
|
||||||
elif idx_branch == self.list_nmb_branches[0] -1:
|
|
||||||
self.set_seed(fixed_seeds[1])
|
|
||||||
|
|
||||||
list_latents = self.run_diffusion(list_conditionings, idx_stop=idx_stop)
|
|
||||||
|
|
||||||
# Inject latents from first branch for very first block
|
|
||||||
if idx_branch==1 and self.branch1_influence > 0:
|
|
||||||
fract_base_influence = np.clip(self.branch1_influence, 0, 1)
|
|
||||||
for i in range(len(list_latents)):
|
|
||||||
list_latents[i] = interpolate_spherical(list_latents[i], self.tree_latents[0][0][i], fract_base_influence)
|
|
||||||
else:
|
|
||||||
# find parents latents
|
|
||||||
b_parent1, b_parent2 = get_closest_idx(fract_mixing, self.tree_fracts[t_block-1])
|
|
||||||
latents1 = self.tree_latents[t_block-1][b_parent1][-1]
|
|
||||||
if fract_mixing == 0:
|
|
||||||
latents2 = latents1
|
|
||||||
else:
|
|
||||||
latents2 = self.tree_latents[t_block-1][b_parent2][-1]
|
|
||||||
idx_start = self.list_injection_idx_ext[t_block]
|
|
||||||
fract_mixing_parental = (fract_mixing - self.tree_fracts[t_block-1][b_parent1]) / (self.tree_fracts[t_block-1][b_parent2] - self.tree_fracts[t_block-1][b_parent1])
|
|
||||||
latents_for_injection = interpolate_spherical(latents1, latents2, fract_mixing_parental)
|
|
||||||
list_latents = self.run_diffusion(list_conditionings, latents_for_injection, idx_start=idx_start, idx_stop=idx_stop)
|
|
||||||
|
|
||||||
self.tree_latents[t_block][idx_branch] = list_latents
|
|
||||||
self.tree_status[t_block][idx_branch] = 'computed'
|
|
||||||
|
|
||||||
# Convert latents to image directly for the last t_block
|
|
||||||
if t_block == nmb_blocks_time-1:
|
|
||||||
self.tree_final_imgs[idx_branch] = self.sdh.latent2image(list_latents[-1])
|
|
||||||
self.tree_final_imgs_timing[idx_branch] = time.time() - time_start
|
|
||||||
|
|
||||||
return self.tree_final_imgs
|
|
||||||
|
|
||||||
|
|
||||||
def run_multi_transition(
|
def run_multi_transition(
|
||||||
|
@ -906,24 +830,31 @@ class LatentBlending():
|
||||||
return_image = return_image,
|
return_image = return_image,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif self.mode == 'upscale':
|
||||||
|
cond = list_conditionings[0]
|
||||||
|
uc_full = list_conditionings[1]
|
||||||
|
return self.sdh.run_diffusion_upscaling(
|
||||||
|
cond,
|
||||||
|
uc_full,
|
||||||
|
latents_start=latents_start,
|
||||||
|
idx_start=idx_start,
|
||||||
|
list_latents_mixing = list_latents_mixing,
|
||||||
|
mixing_coeffs = mixing_coeffs,
|
||||||
|
return_image=return_image)
|
||||||
|
|
||||||
# elif self.mode == 'inpaint':
|
# elif self.mode == 'inpaint':
|
||||||
# text_embeddings = list_conditionings[0]
|
# text_embeddings = list_conditionings[0]
|
||||||
# assert self.sdh.image_source is not None, "image_source is None. Please run init_inpainting first."
|
# assert self.sdh.image_source is not None, "image_source is None. Please run init_inpainting first."
|
||||||
# assert self.sdh.mask_image is not None, "image_source is None. Please run init_inpainting first."
|
# assert self.sdh.mask_image is not None, "image_source is None. Please run init_inpainting first."
|
||||||
# return self.sdh.run_diffusion_inpaint(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
|
# return self.sdh.run_diffusion_inpaint(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
|
||||||
|
|
||||||
# # FIXME LONG LINE and bad args
|
# FIXME. new transition engine
|
||||||
# elif self.mode == 'upscale':
|
|
||||||
# cond = list_conditionings[0]
|
|
||||||
# uc_full = list_conditionings[1]
|
|
||||||
# return self.sdh.run_diffusion_upscaling(cond, uc_full, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
|
|
||||||
|
|
||||||
def run_upscaling_step1(
|
def run_upscaling_step1(
|
||||||
self,
|
self,
|
||||||
dp_img: str,
|
dp_img: str,
|
||||||
depth_strength: float = 0.65,
|
depth_strength: float = 0.65,
|
||||||
num_inference_steps: int = 30,
|
num_inference_steps: int = 30,
|
||||||
nmb_branches_final: int = 10,
|
nmb_max_branches: int = 10,
|
||||||
fixed_seeds: Optional[List[int]] = None,
|
fixed_seeds: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
|
@ -932,6 +863,7 @@ class LatentBlending():
|
||||||
dp_img:
|
dp_img:
|
||||||
Path to directory where the low-res images and yaml will be saved to.
|
Path to directory where the low-res images and yaml will be saved to.
|
||||||
This directory cannot exist and will be created here.
|
This directory cannot exist and will be created here.
|
||||||
|
FIXME
|
||||||
quality: str
|
quality: str
|
||||||
Determines how many diffusion steps are being made + how many branches in total.
|
Determines how many diffusion steps are being made + how many branches in total.
|
||||||
We suggest to leave it with upscaling_step1 which has 10 final branches.
|
We suggest to leave it with upscaling_step1 which has 10 final branches.
|
||||||
|
@ -951,7 +883,6 @@ class LatentBlending():
|
||||||
fixed_seeds = list(np.random.randint(0, 1000000, 2).astype(np.int32))
|
fixed_seeds = list(np.random.randint(0, 1000000, 2).astype(np.int32))
|
||||||
|
|
||||||
# Run latent blending
|
# Run latent blending
|
||||||
self.autosetup_branching(depth_strength, num_inference_steps, nmb_branches_final)
|
|
||||||
imgs_transition = self.run_transition(fixed_seeds=fixed_seeds)
|
imgs_transition = self.run_transition(fixed_seeds=fixed_seeds)
|
||||||
self.write_imgs_transition(dp_img, imgs_transition)
|
self.write_imgs_transition(dp_img, imgs_transition)
|
||||||
|
|
||||||
|
@ -962,13 +893,14 @@ class LatentBlending():
|
||||||
self,
|
self,
|
||||||
dp_img: str,
|
dp_img: str,
|
||||||
depth_strength: float = 0.65,
|
depth_strength: float = 0.65,
|
||||||
num_inference_steps: int = 30,
|
num_inference_steps: int = 100,
|
||||||
nmb_branches_final: int = 10,
|
nmb_max_branches_highres: int = 5,
|
||||||
|
nmb_max_branches_lowres: int = 6,
|
||||||
fixed_seeds: Optional[List[int]] = None,
|
fixed_seeds: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
fp_yml = os.path.join(dp_img, "lowres.yaml")
|
fp_yml = os.path.join(dp_img, "lowres.yaml")
|
||||||
fp_movie = os.path.join(dp_img, "movie.mp4")
|
fp_movie = os.path.join(dp_img, "movie_highres.mp4")
|
||||||
fps = 24
|
fps = 24
|
||||||
ms = MovieSaver(fp_movie, fps=fps)
|
ms = MovieSaver(fp_movie, fps=fps)
|
||||||
assert os.path.isfile(fp_yml), "lowres.yaml does not exist. did you forget run_upscaling_step1?"
|
assert os.path.isfile(fp_yml), "lowres.yaml does not exist. did you forget run_upscaling_step1?"
|
||||||
|
@ -978,8 +910,9 @@ class LatentBlending():
|
||||||
nmb_images_lowres = dict_stuff['nmb_images']
|
nmb_images_lowres = dict_stuff['nmb_images']
|
||||||
prompt1 = dict_stuff['prompt1']
|
prompt1 = dict_stuff['prompt1']
|
||||||
prompt2 = dict_stuff['prompt2']
|
prompt2 = dict_stuff['prompt2']
|
||||||
|
idx_img_lowres = np.round(np.linspace(0, nmb_images_lowres-1, nmb_max_branches_lowres)).astype(np.int32)
|
||||||
imgs_lowres = []
|
imgs_lowres = []
|
||||||
for i in range(nmb_images_lowres):
|
for i in idx_img_lowres:
|
||||||
fp_img_lowres = os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg")
|
fp_img_lowres = os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg")
|
||||||
assert os.path.isfile(fp_img_lowres), f"{fp_img_lowres} does not exist. did you forget run_upscaling_step1?"
|
assert os.path.isfile(fp_img_lowres), f"{fp_img_lowres} does not exist. did you forget run_upscaling_step1?"
|
||||||
imgs_lowres.append(Image.open(fp_img_lowres))
|
imgs_lowres.append(Image.open(fp_img_lowres))
|
||||||
|
@ -989,13 +922,12 @@ class LatentBlending():
|
||||||
text_embeddingA = self.sdh.get_text_embedding(prompt1)
|
text_embeddingA = self.sdh.get_text_embedding(prompt1)
|
||||||
text_embeddingB = self.sdh.get_text_embedding(prompt2)
|
text_embeddingB = self.sdh.get_text_embedding(prompt2)
|
||||||
|
|
||||||
self.autosetup_branching(depth_strength, num_inference_steps, nmb_branches_final)
|
#FIXME: have a total length for the whole video section
|
||||||
|
|
||||||
duration_single_trans = 3
|
duration_single_trans = 3
|
||||||
list_fract_mixing = np.linspace(0, 1, nmb_images_lowres-1)
|
list_fract_mixing = np.linspace(0, 1, nmb_max_branches_lowres-1)
|
||||||
|
|
||||||
for i in range(nmb_images_lowres-1):
|
for i in range(nmb_max_branches_lowres-1):
|
||||||
print(f"Starting movie segment {i+1}/{nmb_images_lowres-1}")
|
print(f"Starting movie segment {i+1}/{nmb_max_branches_lowres-1}")
|
||||||
|
|
||||||
self.text_embedding1 = interpolate_linear(text_embeddingA, text_embeddingB, list_fract_mixing[i])
|
self.text_embedding1 = interpolate_linear(text_embeddingA, text_embeddingB, list_fract_mixing[i])
|
||||||
self.text_embedding2 = interpolate_linear(text_embeddingA, text_embeddingB, 1-list_fract_mixing[i])
|
self.text_embedding2 = interpolate_linear(text_embeddingA, text_embeddingB, 1-list_fract_mixing[i])
|
||||||
|
@ -1008,7 +940,15 @@ class LatentBlending():
|
||||||
|
|
||||||
self.set_image1(imgs_lowres[i])
|
self.set_image1(imgs_lowres[i])
|
||||||
self.set_image2(imgs_lowres[i+1])
|
self.set_image2(imgs_lowres[i+1])
|
||||||
list_imgs = self.run_transition(recycle_img1=recycle_img1)
|
|
||||||
|
list_imgs = self.run_transition(
|
||||||
|
recycle_img1 = recycle_img1,
|
||||||
|
recycle_img2 = False,
|
||||||
|
num_inference_steps = num_inference_steps,
|
||||||
|
depth_strength = depth_strength,
|
||||||
|
nmb_max_branches = nmb_max_branches_highres,
|
||||||
|
)
|
||||||
|
|
||||||
list_imgs_interp = add_frames_linear_interp(list_imgs, fps, duration_single_trans)
|
list_imgs_interp = add_frames_linear_interp(list_imgs, fps, duration_single_trans)
|
||||||
|
|
||||||
# Save movie frame
|
# Save movie frame
|
||||||
|
@ -1075,11 +1015,12 @@ class LatentBlending():
|
||||||
return self.sdh.get_text_embedding(prompt)
|
return self.sdh.get_text_embedding(prompt)
|
||||||
|
|
||||||
|
|
||||||
def write_imgs_transition(self, dp_img, imgs_transition):
|
def write_imgs_transition(self, dp_img):
|
||||||
r"""
|
r"""
|
||||||
Writes the transition images into the folder dp_img.
|
Writes the transition images into the folder dp_img.
|
||||||
"""
|
"""
|
||||||
os.makedirs(dp_img)
|
imgs_transition = self.tree_final_imgs
|
||||||
|
os.makedirs(dp_img, exist_ok=True)
|
||||||
for i, img in enumerate(imgs_transition):
|
for i, img in enumerate(imgs_transition):
|
||||||
img_leaf = Image.fromarray(img)
|
img_leaf = Image.fromarray(img)
|
||||||
img_leaf.save(os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg"))
|
img_leaf.save(os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg"))
|
||||||
|
@ -1090,6 +1031,7 @@ class LatentBlending():
|
||||||
|
|
||||||
def save_statedict(self, fp_yml):
|
def save_statedict(self, fp_yml):
|
||||||
# Dump everything relevant into yaml
|
# Dump everything relevant into yaml
|
||||||
|
imgs_transition = self.tree_final_imgs
|
||||||
state_dict = self.get_state_dict()
|
state_dict = self.get_state_dict()
|
||||||
state_dict['nmb_images'] = len(imgs_transition)
|
state_dict['nmb_images'] = len(imgs_transition)
|
||||||
yml_save(fp_yml, state_dict)
|
yml_save(fp_yml, state_dict)
|
||||||
|
@ -1098,7 +1040,9 @@ class LatentBlending():
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
grab_vars = ['prompt1', 'prompt2', 'seed1', 'seed2', 'height', 'width',
|
grab_vars = ['prompt1', 'prompt2', 'seed1', 'seed2', 'height', 'width',
|
||||||
'num_inference_steps', 'depth_strength', 'guidance_scale',
|
'num_inference_steps', 'depth_strength', 'guidance_scale',
|
||||||
'guidance_scale_mid_damper', 'mid_compression_scaler', 'negative_prompt']
|
'guidance_scale_mid_damper', 'mid_compression_scaler', 'negative_prompt',
|
||||||
|
'branch1_influence', 'branch1_max_depth_influence', 'branch1_influence_decay'
|
||||||
|
'parental_influence', 'parental_max_depth_influence', 'parental_influence_decay']
|
||||||
for v in grab_vars:
|
for v in grab_vars:
|
||||||
if hasattr(self, v):
|
if hasattr(self, v):
|
||||||
if v == 'seed1' or v == 'seed2':
|
if v == 'seed1' or v == 'seed2':
|
||||||
|
@ -1107,8 +1051,10 @@ class LatentBlending():
|
||||||
state_dict[v] = float(getattr(self, v))
|
state_dict[v] = float(getattr(self, v))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
state_dict[v] = getattr(self, v)
|
try:
|
||||||
|
state_dict[v] = getattr(self, v)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
@ -1163,8 +1109,7 @@ class LatentBlending():
|
||||||
as in run_multi_transition()
|
as in run_multi_transition()
|
||||||
"""
|
"""
|
||||||
# Move over all latents
|
# Move over all latents
|
||||||
for t_block in range(len(self.tree_latents)):
|
self.tree_latents[0] = self.tree_latents[-1]
|
||||||
self.tree_latents[t_block][0] = self.tree_latents[t_block][-1]
|
|
||||||
|
|
||||||
# Move over prompts and text embeddings
|
# Move over prompts and text embeddings
|
||||||
self.prompt1 = self.prompt2
|
self.prompt1 = self.prompt2
|
||||||
|
|
|
@ -285,8 +285,7 @@ class StableDiffusionHolder:
|
||||||
return_image: Optional[bool] = False,
|
return_image: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Wrapper function for run_diffusion_standard and run_diffusion_inpaint.
|
Diffusion standard version.
|
||||||
Depending on the mode, the correct one will be executed.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text_embeddings: torch.FloatTensor
|
text_embeddings: torch.FloatTensor
|
||||||
|
@ -364,6 +363,99 @@ class StableDiffusionHolder:
|
||||||
else:
|
else:
|
||||||
return list_latents_out
|
return list_latents_out
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def run_diffusion_upscaling(
|
||||||
|
self,
|
||||||
|
cond,
|
||||||
|
uc_full,
|
||||||
|
latents_start: torch.FloatTensor,
|
||||||
|
idx_start: int = -1,
|
||||||
|
list_latents_mixing = None,
|
||||||
|
mixing_coeffs = 0.0,
|
||||||
|
return_image: Optional[bool] = False
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Diffusion upscaling version.
|
||||||
|
# FIXME
|
||||||
|
Args:
|
||||||
|
??
|
||||||
|
latents_for_injection: torch.FloatTensor
|
||||||
|
Latents that are used for injection
|
||||||
|
idx_start: int
|
||||||
|
Index of the diffusion process start and where the latents_for_injection are injected
|
||||||
|
return_image: Optional[bool]
|
||||||
|
Optionally return image directly
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Asserts
|
||||||
|
if type(mixing_coeffs) == float:
|
||||||
|
list_mixing_coeffs = self.num_inference_steps*[mixing_coeffs]
|
||||||
|
elif type(mixing_coeffs) == list:
|
||||||
|
assert len(mixing_coeffs) == self.num_inference_steps
|
||||||
|
list_mixing_coeffs = mixing_coeffs
|
||||||
|
else:
|
||||||
|
raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps")
|
||||||
|
|
||||||
|
if np.sum(list_mixing_coeffs) > 0:
|
||||||
|
assert len(list_latents_mixing) == self.num_inference_steps
|
||||||
|
|
||||||
|
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
||||||
|
generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
|
||||||
|
|
||||||
|
h = uc_full['c_concat'][0].shape[2]
|
||||||
|
w = uc_full['c_concat'][0].shape[3]
|
||||||
|
|
||||||
|
with precision_scope("cuda"):
|
||||||
|
with self.model.ema_scope():
|
||||||
|
|
||||||
|
shape_latents = [self.model.channels, h, w]
|
||||||
|
|
||||||
|
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
|
||||||
|
C, H, W = shape_latents
|
||||||
|
size = (1, C, H, W)
|
||||||
|
b = size[0]
|
||||||
|
|
||||||
|
latents = latents_start.clone()
|
||||||
|
|
||||||
|
timesteps = self.sampler.ddim_timesteps
|
||||||
|
|
||||||
|
time_range = np.flip(timesteps)
|
||||||
|
total_steps = timesteps.shape[0]
|
||||||
|
|
||||||
|
# collect latents
|
||||||
|
list_latents_out = []
|
||||||
|
for i, step in enumerate(time_range):
|
||||||
|
# Set the right starting latents
|
||||||
|
if i < idx_start:
|
||||||
|
list_latents_out.append(None)
|
||||||
|
continue
|
||||||
|
elif i == idx_start:
|
||||||
|
latents = latents_start.clone()
|
||||||
|
|
||||||
|
# Mix the latents.
|
||||||
|
if i > 0 and list_mixing_coeffs[i]>0:
|
||||||
|
latents_mixtarget = list_latents_mixing[i-1].clone()
|
||||||
|
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
|
||||||
|
|
||||||
|
# print(f"diffusion iter {i}")
|
||||||
|
index = total_steps - i - 1
|
||||||
|
ts = torch.full((b,), step, device=self.device, dtype=torch.long)
|
||||||
|
outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
|
||||||
|
quantize_denoised=False, temperature=1.0,
|
||||||
|
noise_dropout=0.0, score_corrector=None,
|
||||||
|
corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=self.guidance_scale,
|
||||||
|
unconditional_conditioning=uc_full,
|
||||||
|
dynamic_threshold=None)
|
||||||
|
latents, pred_x0 = outs
|
||||||
|
list_latents_out.append(latents.clone())
|
||||||
|
|
||||||
|
if return_image:
|
||||||
|
return self.latent2image(latents)
|
||||||
|
else:
|
||||||
|
return list_latents_out
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def run_diffusion_inpaint(
|
def run_diffusion_inpaint(
|
||||||
self,
|
self,
|
||||||
|
@ -474,93 +566,6 @@ class StableDiffusionHolder:
|
||||||
else:
|
else:
|
||||||
return list_latents_out
|
return list_latents_out
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def run_diffusion_upscaling(
|
|
||||||
self,
|
|
||||||
cond,
|
|
||||||
uc_full,
|
|
||||||
latents_for_injection: torch.FloatTensor = None,
|
|
||||||
idx_start: int = -1,
|
|
||||||
idx_stop: int = -1,
|
|
||||||
return_image: Optional[bool] = False
|
|
||||||
):
|
|
||||||
r"""
|
|
||||||
Wrapper function for run_diffusion_standard and run_diffusion_inpaint.
|
|
||||||
Depending on the mode, the correct one will be executed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
??
|
|
||||||
latents_for_injection: torch.FloatTensor
|
|
||||||
Latents that are used for injection
|
|
||||||
idx_start: int
|
|
||||||
Index of the diffusion process start and where the latents_for_injection are injected
|
|
||||||
idx_stop: int
|
|
||||||
Index of the diffusion process end.
|
|
||||||
return_image: Optional[bool]
|
|
||||||
Optionally return image directly
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
if latents_for_injection is None:
|
|
||||||
do_inject_latents = False
|
|
||||||
else:
|
|
||||||
do_inject_latents = True
|
|
||||||
|
|
||||||
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
|
||||||
generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
|
|
||||||
|
|
||||||
h = uc_full['c_concat'][0].shape[2]
|
|
||||||
w = uc_full['c_concat'][0].shape[3]
|
|
||||||
|
|
||||||
with precision_scope("cuda"):
|
|
||||||
with self.model.ema_scope():
|
|
||||||
|
|
||||||
|
|
||||||
shape_latents = [self.model.channels, h, w]
|
|
||||||
|
|
||||||
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
|
|
||||||
C, H, W = shape_latents
|
|
||||||
size = (1, C, H, W)
|
|
||||||
b = size[0]
|
|
||||||
|
|
||||||
latents = torch.randn(size, generator=generator, device=self.device)
|
|
||||||
|
|
||||||
timesteps = self.sampler.ddim_timesteps
|
|
||||||
|
|
||||||
time_range = np.flip(timesteps)
|
|
||||||
total_steps = timesteps.shape[0]
|
|
||||||
|
|
||||||
# collect latents
|
|
||||||
list_latents_out = []
|
|
||||||
for i, step in enumerate(time_range):
|
|
||||||
if do_inject_latents:
|
|
||||||
# Inject latent at right place
|
|
||||||
if i < idx_start:
|
|
||||||
continue
|
|
||||||
elif i == idx_start:
|
|
||||||
latents = latents_for_injection.clone()
|
|
||||||
|
|
||||||
if i == idx_stop:
|
|
||||||
return list_latents_out
|
|
||||||
|
|
||||||
# print(f"diffusion iter {i}")
|
|
||||||
index = total_steps - i - 1
|
|
||||||
ts = torch.full((b,), step, device=self.device, dtype=torch.long)
|
|
||||||
outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
|
|
||||||
quantize_denoised=False, temperature=1.0,
|
|
||||||
noise_dropout=0.0, score_corrector=None,
|
|
||||||
corrector_kwargs=None,
|
|
||||||
unconditional_guidance_scale=self.guidance_scale,
|
|
||||||
unconditional_conditioning=uc_full,
|
|
||||||
dynamic_threshold=None)
|
|
||||||
latents, pred_x0 = outs
|
|
||||||
list_latents_out.append(latents.clone())
|
|
||||||
|
|
||||||
if return_image:
|
|
||||||
return self.latent2image(latents)
|
|
||||||
else:
|
|
||||||
return list_latents_out
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def latent2image(
|
def latent2image(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in New Issue