From 3c6015782f7134fcea1dd60f2605364b918c8f7c Mon Sep 17 00:00:00 2001 From: Johannes Stelzer Date: Sat, 18 Feb 2023 07:56:30 +0100 Subject: [PATCH] small fixes --- gradio_ui.py | 424 +++++++++++++++++++++++-------------- latent_blending.py | 391 +++++++++++++++------------------- stable_diffusion_holder.py | 183 ++++++++-------- 3 files changed, 531 insertions(+), 467 deletions(-) diff --git a/gradio_ui.py b/gradio_ui.py index 33116b2..200abe5 100644 --- a/gradio_ui.py +++ b/gradio_ui.py @@ -32,23 +32,35 @@ torch.set_grad_enabled(False) import gradio as gr import copy from dotenv import find_dotenv, load_dotenv +import shutil +""" +never hit compute trans -> multi movie add fail + +""" #%% class BlendingFrontend(): def __init__(self, sdh=None): + self.num_inference_steps = 30 if sdh is None: self.use_debug = True + self.height = 768 + self.width = 768 else: self.use_debug = False self.lb = LatentBlending(sdh) - - self.share = True - self.num_inference_steps = 30 + self.lb.sdh.num_inference_steps = self.num_inference_steps + self.height = self.lb.sdh.height + self.width = self.lb.sdh.width + + self.init_save_dir() + self.save_empty_image() + self.share = False self.depth_strength = 0.25 - self.seed1 = 42 + self.seed1 = 420 self.seed2 = 420 self.guidance_scale = 4.0 self.guidance_scale_mid_damper = 0.5 @@ -72,16 +84,13 @@ class BlendingFrontend(): self.current_timestamp = None self.recycle_img1 = False self.recycle_img2 = False + self.fp_img1 = None + self.fp_img2 = None + self.multi_idx_current = -1 + self.multi_list_concat = [] + self.list_imgs_shown_last = 5*[self.fp_img_empty] + self.nmb_trans_stack = 6 - if not self.use_debug: - self.lb.sdh.num_inference_steps = self.num_inference_steps - self.height = self.lb.sdh.height - self.width = self.lb.sdh.width - else: - self.height = 768 - self.width = 768 - - self.init_save_dir() def init_save_dir(self): @@ -89,12 +98,18 @@ class BlendingFrontend(): try: self.dp_out = os.getenv("dp_out") except Exception as e: + print(f"did not find .env file. using local folder. {e}") self.dp_out = "" + self.dp_imgs = os.path.join(self.dp_out, "imgs") + os.makedirs(self.dp_imgs, exist_ok=True) + self.dp_movies = os.path.join(self.dp_out, "movies") + os.makedirs(self.dp_movies, exist_ok=True) + # make dummy image def save_empty_image(self): - self.fp_img_empty = os.path.join(self.dp_out, 'empty.jpg') + self.fp_img_empty = os.path.join(self.dp_imgs, 'empty.jpg') Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5) @@ -140,22 +155,21 @@ class BlendingFrontend(): def compute_img1(self, *args): list_ui_elem = args self.setup_lb(list_ui_elem) - fp_img1 = os.path.join(self.dp_out, f"img1_{get_time('second')}.jpg") + self.fp_img1 = os.path.join(self.dp_imgs, f"img1_{get_time('second')}.jpg") img1 = Image.fromarray(self.lb.compute_latents1(return_image=True)) - img1.save(fp_img1) - self.save_empty_image() + img1.save(self.fp_img1) self.recycle_img1 = True self.recycle_img2 = False - return [fp_img1, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty] + return [self.fp_img1, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty] def compute_img2(self, *args): list_ui_elem = args self.setup_lb(list_ui_elem) - fp_img2 = os.path.join(self.dp_out, f"img2_{get_time('second')}.jpg") + self.fp_img2 = os.path.join(self.dp_imgs, f"img2_{get_time('second')}.jpg") img2 = Image.fromarray(self.lb.compute_latents2(return_image=True)) - img2.save(fp_img2) + img2.save(self.fp_img2) self.recycle_img2 = True - return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2] + return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img2] def compute_transition(self, *args): @@ -199,7 +213,7 @@ class BlendingFrontend(): self.current_timestamp = get_time('second') self.list_fp_imgs_current = [] for i in range(len(list_imgs_preview)): - fp_img = f"img_preview_{i}_{self.current_timestamp}.jpg" + fp_img = os.path.join(self.dp_imgs, f"img_preview_{i}_{self.current_timestamp}.jpg") list_imgs_preview[i].save(fp_img) self.list_fp_imgs_current.append(fp_img) @@ -207,51 +221,38 @@ class BlendingFrontend(): imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration_video, self.fps) # Save as movie - fp_movie = self.get_fp_movie(self.current_timestamp) - if os.path.isfile(fp_movie): - os.remove(fp_movie) - ms = MovieSaver(fp_movie, fps=self.fps) + self.fp_movie = os.path.join(self.dp_movies, f"movie_{self.current_timestamp}.mp4") + if os.path.isfile(self.fp_movie): + os.remove(self.fp_movie) + ms = MovieSaver(self.fp_movie, fps=self.fps) for img in tqdm(imgs_transition_ext): ms.write_frame(img) ms.finalize() print("DONE SAVING MOVIE! SENDING BACK...") # Assemble Output, updating the preview images and le movie - list_return = self.list_fp_imgs_current + [fp_movie] + list_return = self.list_fp_imgs_current + [self.fp_movie] return list_return - def get_fp_movie(self, timestamp, is_stacked=False): - if not is_stacked: - fn = f"movie_{timestamp}.mp4" - else: - fn = f"movie_stacked_{timestamp}.mp4" - fp = os.path.join(self.dp_out, fn) - return fp - def stack_forward(self, prompt2, seed2): # Save preview images, prompts and seeds into dictionary for stacking - dp_out = os.path.join(self.dp_out, get_time('second')) - self.lb.write_imgs_transition(dp_out) + # self.list_imgs_shown_last = self.get_multi_trans_imgs_preview(f"lowres_{self.current_timestamp}")[0:5] + timestamp_section = get_time('second') + self.lb.write_imgs_transition(os.path.join(self.dp_out, f"lowres_{timestamp_section}")) + self.lb.write_imgs_transition(os.path.join(self.dp_out, "lowres_current")) + shutil.copyfile(self.fp_movie, os.path.join(self.dp_out, f"lowres_{timestamp_section}", "movie.mp4")) + self.lb.swap_forward() - list_out = [self.list_fp_imgs_current[-1]] + list_out = [self.fp_img2] list_out.extend([self.fp_img_empty]*4) list_out.append(prompt2) list_out.append(seed2) list_out.append("") list_out.append(np.random.randint(0, 10000000)) + return list_out - - - def stack_movie(self): - # collect all that are in... - list_fp_movies = [] - - list_fp_movies.append(self.get_fp_movie(timestamp)) - - fp_stacked = self.get_fp_movie(get_time('second'), True) - concatenate_movies(fp_stacked, list_fp_movies) - return fp_stacked + def get_state_dict(self): @@ -264,7 +265,90 @@ class BlendingFrontend(): state_dict[v] = getattr(self, v) return state_dict - + + def get_list_all_stacked(self): + list_all = os.listdir(os.path.join(self.dp_out)) + list_all = [l for l in list_all if l[:8]=="lowres_2"] + list_all.sort() + return list_all + + def multi_trans_show_older(self): + list_all = self.get_list_all_stacked() + if self.multi_idx_current == -1: + self.multi_idx_current = len(list_all) - 1 + else: + self.multi_idx_current -= 1 + + if self.multi_idx_current < 0: + self.multi_idx_current = 0 + dn = list_all[self.multi_idx_current] + return self.get_multi_trans_imgs_preview(dn) + + def multi_trans_show_newer(self): + list_all = self.get_list_all_stacked() + if self.multi_idx_current == -1: + self.multi_idx_current = len(list_all) - 1 + else: + self.multi_idx_current += 1 + + if self.multi_idx_current >= len(list_all): + self.multi_idx_current = len(list_all) - 1 + dn = list_all[self.multi_idx_current] + return self.get_multi_trans_imgs_preview(dn) + + def get_multi_trans_imgs_preview(self, dn): + dp_show = os.path.join(self.dp_out, dn) + list_imgs_transition = os.listdir(dp_show) + list_imgs_transition = [l for l in list_imgs_transition if l[:11]=="lowres_img_"] + list_imgs_transition.sort() + + idx_img_prev = np.round(np.linspace(0, len(list_imgs_transition)-1, 5)).astype(np.int32) + list_imgs_preview = [] + for j in idx_img_prev: + list_imgs_preview.append(os.path.join(dp_show, list_imgs_transition[j])) + + list_out = list_imgs_preview + list_out.append(dn[7:]) + + return list_out + + def multi_append(self): + list_all = self.get_list_all_stacked() + dn = list_all[self.multi_idx_current] + self.multi_list_concat.append(dn) + list_short = [dn[7:] for dn in self.multi_list_concat] + str_out = "\n".join(list_short) + return str_out + + def multi_reset(self): + self.multi_list_concat = [] + str_out = "" + return str_out + + def multi_concat(self): + # Make new output directory + dp_multi = os.path.join(self.dp_out, f"multi_{get_time('second')}") + os.makedirs(dp_multi, exist_ok=False) + + # Copy all low-res folders (prepending multi001_xxxx), however leave out the movie.mp4 + # also collect all movie.mp4 + list_fp_movies = [] + for i, dn in enumerate(self.multi_list_concat): + dp_source = os.path.join(self.dp_out, dn) + dp_sequence = os.path.join(dp_multi, f"{str(i).zfill(3)}_{dn}") + os.makedirs(dp_sequence, exist_ok=False) + list_source = os.listdir(dp_source) + list_source = [l for l in list_source if not l.endswith(".mp4")] + for fn in list_source: + shutil.copyfile(os.path.join(dp_source, fn), os.path.join(dp_sequence, fn)) + list_fp_movies.append(os.path.join(dp_source, "movie.mp4")) + + # Concatenate movies and save + fp_final = os.path.join(dp_multi, "movie.mp4") + concatenate_movies(fp_final, list_fp_movies) + return fp_final + + def get_img_rand(): return (255*np.random.rand(self.height,self.width,3)).astype(np.uint8) @@ -286,120 +370,150 @@ def generate_list_output( return list_output - if __name__ == "__main__": - # fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt" - fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt" - sdh = StableDiffusionHolder(fp_ckpt) - - self = BlendingFrontend(sdh) # Yes this is possible in python and yes it is an awesome trick + fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt" + # fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt" + self = BlendingFrontend(StableDiffusionHolder(fp_ckpt)) # Yes this is possible in python and yes it is an awesome trick # self = BlendingFrontend(None) # Yes this is possible in python and yes it is an awesome trick dict_ui_elem = {} with gr.Blocks() as demo: - with gr.Row(): - prompt1 = gr.Textbox(label="prompt 1") - prompt2 = gr.Textbox(label="prompt 2") - - with gr.Row(): - duration_compute = gr.Slider(5, 45, self.t_compute_max_allowed, step=1, label='compute budget for transition (seconds)', interactive=True) - duration_video = gr.Slider(0.1, 30, self.duration_video, step=0.1, label='result video duration (seconds)', interactive=True) - height = gr.Slider(256, 2048, self.height, step=128, label='height', interactive=True) - width = gr.Slider(256, 2048, self.width, step=128, label='width', interactive=True) + with gr.Tab("Single Transition"): + with gr.Row(): + prompt1 = gr.Textbox(label="prompt 1") + prompt2 = gr.Textbox(label="prompt 2") - with gr.Accordion("Advanced Settings (click to expand)", open=False): - - with gr.Accordion("Diffusion settings", open=True): - with gr.Row(): - num_inference_steps = gr.Slider(5, 100, self.num_inference_steps, step=1, label='num_inference_steps', interactive=True) - guidance_scale = gr.Slider(1, 25, self.guidance_scale, step=0.1, label='guidance_scale', interactive=True) - negative_prompt = gr.Textbox(label="negative prompt") - - with gr.Accordion("Seeds control", open=True): - with gr.Row(): - seed1 = gr.Number(420, label="seed 1", interactive=True) - b_newseed1 = gr.Button("randomize seed 1", variant='secondary') - seed2 = gr.Number(420, label="seed 2", interactive=True) - b_newseed2 = gr.Button("randomize seed 2", variant='secondary') - - with gr.Accordion("Crossfeeding for last image", open=True): - with gr.Row(): - branch1_influence = gr.Slider(0.0, 1.0, self.branch1_influence, step=0.01, label='crossfeed power', interactive=True) - branch1_max_depth_influence = gr.Slider(0.0, 1.0, self.branch1_max_depth_influence, step=0.01, label='crossfeed range', interactive=True) - branch1_influence_decay = gr.Slider(0.0, 1.0, self.branch1_influence_decay, step=0.01, label='crossfeed decay', interactive=True) - - with gr.Accordion("Transition settings", open=True): - with gr.Row(): - depth_strength = gr.Slider(0.01, 0.99, self.depth_strength, step=0.01, label='depth_strength', interactive=True) - guidance_scale_mid_damper = gr.Slider(0.01, 2.0, self.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True) - parental_influence = gr.Slider(0.0, 1.0, self.parental_influence, step=0.01, label='parental power', interactive=True) - parental_max_depth_influence = gr.Slider(0.0, 1.0, self.parental_max_depth_influence, step=0.01, label='parental range', interactive=True) - parental_influence_decay = gr.Slider(0.0, 1.0, self.parental_influence_decay, step=0.01, label='parental decay', interactive=True) - + with gr.Row(): + duration_compute = gr.Slider(5, 45, self.t_compute_max_allowed, step=1, label='compute budget for transition (seconds)', interactive=True) + duration_video = gr.Slider(0.1, 30, self.duration_video, step=0.1, label='result video duration (seconds)', interactive=True) + height = gr.Slider(256, 2048, self.height, step=128, label='height', interactive=True) + width = gr.Slider(256, 2048, self.width, step=128, label='width', interactive=True) - with gr.Row(): - b_compute1 = gr.Button('compute first image', variant='primary') - b_compute_transition = gr.Button('compute transition', variant='primary') - b_compute2 = gr.Button('compute last image', variant='primary') - - with gr.Row(): - img1 = gr.Image(label="1/5") - img2 = gr.Image(label="2/5") - img3 = gr.Image(label="3/5") - img4 = gr.Image(label="4/5") - img5 = gr.Image(label="5/5") - - with gr.Row(): - vid_transition = gr.Video() + with gr.Accordion("Advanced Settings (click to expand)", open=False): + + with gr.Accordion("Diffusion settings", open=True): + with gr.Row(): + num_inference_steps = gr.Slider(5, 100, self.num_inference_steps, step=1, label='num_inference_steps', interactive=True) + guidance_scale = gr.Slider(1, 25, self.guidance_scale, step=0.1, label='guidance_scale', interactive=True) + negative_prompt = gr.Textbox(label="negative prompt") + + with gr.Accordion("Seeds control", open=True): + with gr.Row(): + seed1 = gr.Number(self.seed1, label="seed 1", interactive=True) + b_newseed1 = gr.Button("randomize seed 1", variant='secondary') + seed2 = gr.Number(self.seed2, label="seed 2", interactive=True) + b_newseed2 = gr.Button("randomize seed 2", variant='secondary') + + with gr.Accordion("Crossfeeding for last image", open=True): + with gr.Row(): + branch1_influence = gr.Slider(0.0, 1.0, self.branch1_influence, step=0.01, label='crossfeed power', interactive=True) + branch1_max_depth_influence = gr.Slider(0.0, 1.0, self.branch1_max_depth_influence, step=0.01, label='crossfeed range', interactive=True) + branch1_influence_decay = gr.Slider(0.0, 1.0, self.branch1_influence_decay, step=0.01, label='crossfeed decay', interactive=True) + + with gr.Accordion("Transition settings", open=True): + with gr.Row(): + depth_strength = gr.Slider(0.01, 0.99, self.depth_strength, step=0.01, label='depth_strength', interactive=True) + guidance_scale_mid_damper = gr.Slider(0.01, 2.0, self.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True) + parental_influence = gr.Slider(0.0, 1.0, self.parental_influence, step=0.01, label='parental power', interactive=True) + parental_max_depth_influence = gr.Slider(0.0, 1.0, self.parental_max_depth_influence, step=0.01, label='parental range', interactive=True) + parental_influence_decay = gr.Slider(0.0, 1.0, self.parental_influence_decay, step=0.01, label='parental decay', interactive=True) + + + with gr.Row(): + b_compute1 = gr.Button('compute first image', variant='primary') + b_compute_transition = gr.Button('compute transition', variant='primary') + b_compute2 = gr.Button('compute last image', variant='primary') + + with gr.Row(): + img1 = gr.Image(label="1/5") + img2 = gr.Image(label="2/5") + img3 = gr.Image(label="3/5") + img4 = gr.Image(label="4/5") + img5 = gr.Image(label="5/5") + + with gr.Row(): + vid_transition = gr.Video() + + with gr.Row(): + b_stackforward = gr.Button('multi-movie start next segment (move last image -> first image)') + + # Collect all UI elemts in list to easily pass as inputs + dict_ui_elem["prompt1"] = prompt1 + dict_ui_elem["negative_prompt"] = negative_prompt + dict_ui_elem["prompt2"] = prompt2 + + dict_ui_elem["duration_compute"] = duration_compute + dict_ui_elem["duration_video"] = duration_video + dict_ui_elem["height"] = height + dict_ui_elem["width"] = width + + dict_ui_elem["depth_strength"] = depth_strength + dict_ui_elem["branch1_influence"] = branch1_influence + dict_ui_elem["branch1_max_depth_influence"] = branch1_max_depth_influence + dict_ui_elem["branch1_influence_decay"] = branch1_influence_decay + + dict_ui_elem["num_inference_steps"] = num_inference_steps + dict_ui_elem["guidance_scale"] = guidance_scale + dict_ui_elem["guidance_scale_mid_damper"] = guidance_scale_mid_damper + dict_ui_elem["seed1"] = seed1 + dict_ui_elem["seed2"] = seed2 + + dict_ui_elem["parental_max_depth_influence"] = parental_max_depth_influence + dict_ui_elem["parental_influence"] = parental_influence + dict_ui_elem["parental_influence_decay"] = parental_influence_decay + + # Convert to list, as gradio doesn't seem to accept dicts + list_ui_elem = [] + list_ui_keys = [] + for k in dict_ui_elem.keys(): + list_ui_elem.append(dict_ui_elem[k]) + list_ui_keys.append(k) + self.list_ui_keys = list_ui_keys + + b_newseed1.click(self.randomize_seed1, outputs=seed1) + b_newseed2.click(self.randomize_seed2, outputs=seed2) + b_compute1.click(self.compute_img1, inputs=list_ui_elem, outputs=[img1, img2, img3, img4, img5]) + b_compute2.click(self.compute_img2, inputs=list_ui_elem, outputs=[img2, img3, img4, img5]) + b_compute_transition.click(self.compute_transition, + inputs=list_ui_elem, + outputs=[img2, img3, img4, vid_transition]) + + b_stackforward.click(self.stack_forward, + inputs=[prompt2, seed2], + outputs=[img1, img2, img3, img4, img5, prompt1, seed1, prompt2]) + + with gr.Tab("Multi Transition"): + with gr.Row(): + multi_img1_prev = gr.Image(value=self.list_imgs_shown_last[0], label="1/5") + multi_img2_prev = gr.Image(value=self.list_imgs_shown_last[1], label="2/5") + multi_img3_prev = gr.Image(value=self.list_imgs_shown_last[2], label="3/5") + multi_img4_prev = gr.Image(value=self.list_imgs_shown_last[3], label="4/5") + multi_img5_prev = gr.Image(value=self.list_imgs_shown_last[4], label="5/5") + + with gr.Row(): + b_older = gr.Button("show older") + b_newer = gr.Button("show newer") + text_timestamp = gr.Textbox(label="created", interactive=False) + b_append = gr.Button("append this transition") + + with gr.Row(): + text_all_timestamps = gr.Textbox(label="movie list", interactive=False) + with gr.Row(): + b_reset = gr.Button("reset") + b_concat = gr.Button("merge together", variant='primary') + + with gr.Row(): + vid_multi = gr.Video() + + + b_older.click(self.multi_trans_show_older, inputs=[], outputs=[multi_img1_prev, multi_img2_prev, multi_img3_prev, multi_img4_prev, multi_img5_prev, text_timestamp]) + b_newer.click(self.multi_trans_show_newer, inputs=[], outputs=[multi_img1_prev, multi_img2_prev, multi_img3_prev, multi_img4_prev, multi_img5_prev, text_timestamp]) + b_append.click(self.multi_append, inputs=[], outputs=[text_all_timestamps]) + b_reset.click(self.multi_reset, inputs=[], outputs=[text_all_timestamps]) + b_concat.click(self.multi_concat, inputs=[], outputs=[vid_multi]) + - with gr.Row(): - b_stackforward = gr.Button('multi-movie start next segment (move last image -> first image)') - - # Collect all UI elemts in list to easily pass as inputs - dict_ui_elem["prompt1"] = prompt1 - dict_ui_elem["negative_prompt"] = negative_prompt - dict_ui_elem["prompt2"] = prompt2 - - dict_ui_elem["duration_compute"] = duration_compute - dict_ui_elem["duration_video"] = duration_video - dict_ui_elem["height"] = height - dict_ui_elem["width"] = width - - dict_ui_elem["depth_strength"] = depth_strength - dict_ui_elem["branch1_influence"] = branch1_influence - dict_ui_elem["branch1_max_depth_influence"] = branch1_max_depth_influence - dict_ui_elem["branch1_influence_decay"] = branch1_influence_decay - - dict_ui_elem["num_inference_steps"] = num_inference_steps - dict_ui_elem["guidance_scale"] = guidance_scale - dict_ui_elem["guidance_scale_mid_damper"] = guidance_scale_mid_damper - dict_ui_elem["seed1"] = seed1 - dict_ui_elem["seed2"] = seed2 - - dict_ui_elem["parental_max_depth_influence"] = parental_max_depth_influence - dict_ui_elem["parental_influence"] = parental_influence - dict_ui_elem["parental_influence_decay"] = parental_influence_decay - - # Convert to list, as gradio doesn't seem to accept dicts - list_ui_elem = [] - list_ui_keys = [] - for k in dict_ui_elem.keys(): - list_ui_elem.append(dict_ui_elem[k]) - list_ui_keys.append(k) - self.list_ui_keys = list_ui_keys - - b_newseed1.click(self.randomize_seed1, outputs=seed1) - b_newseed2.click(self.randomize_seed2, outputs=seed2) - b_compute1.click(self.compute_img1, inputs=list_ui_elem, outputs=[img1, img2, img3, img4, img5]) - b_compute2.click(self.compute_img2, inputs=list_ui_elem, outputs=[img2, img3, img4, img5]) - b_compute_transition.click(self.compute_transition, - inputs=list_ui_elem, - outputs=[img2, img3, img4, vid_transition]) - - b_stackforward.click(self.stack_forward, - inputs=[prompt2, seed2], - outputs=[img1, img2, img3, img4, img5, prompt1, seed1, prompt2]) demo.launch(share=self.share, inbrowser=True, inline=False) diff --git a/latent_blending.py b/latent_blending.py index ef100b1..ca32f15 100644 --- a/latent_blending.py +++ b/latent_blending.py @@ -231,36 +231,36 @@ class LatentBlending(): if quality == 'lowest': num_inference_steps = 12 - nmb_branches_final = 5 + nmb_max_branches = 5 elif quality == 'low': num_inference_steps = 15 - nmb_branches_final = nmb_frames//16 + nmb_max_branches = nmb_frames//16 elif quality == 'medium': num_inference_steps = 30 - nmb_branches_final = nmb_frames//8 + nmb_max_branches = nmb_frames//8 elif quality == 'high': num_inference_steps = 60 - nmb_branches_final = nmb_frames//4 + nmb_max_branches = nmb_frames//4 elif quality == 'ultra': num_inference_steps = 100 - nmb_branches_final = nmb_frames//2 + nmb_max_branches = nmb_frames//2 elif quality == 'upscaling_step1': num_inference_steps = 40 - nmb_branches_final = 12 + nmb_max_branches = 12 elif quality == 'upscaling_step2': num_inference_steps = 100 - nmb_branches_final = 6 + nmb_max_branches = 6 else: raise ValueError(f"quality = '{quality}' not supported") - self.autosetup_branching(depth_strength, num_inference_steps, nmb_branches_final) + self.autosetup_branching(depth_strength, num_inference_steps, nmb_max_branches) def autosetup_branching( self, depth_strength: float = 0.65, num_inference_steps: int = 30, - nmb_branches_final: int = 20, + nmb_max_branches: int = 20, nmb_mindist: int = 3, ): r""" @@ -273,7 +273,7 @@ class LatentBlending(): more shallow values will go into alpha-blendy land. num_inference_steps: int Number of diffusion steps. Higher values will take more compute time. - nmb_branches_final (int): The number of diffusion-generated images + nmb_max_branches (int): The number of diffusion-generated images at the end of the inference. nmb_mindist (int): The minimum number of diffusion steps between two injections. @@ -285,7 +285,7 @@ class LatentBlending(): list_injection_idx = [0] list_injection_idx.extend(np.linspace(idx_injection_first, idx_injection_last, nmb_injections).astype(int)) - list_nmb_branches = np.round(np.logspace(np.log10(2), np.log10(nmb_branches_final), nmb_injections+1)).astype(int) + list_nmb_branches = np.round(np.logspace(np.log10(2), np.log10(nmb_max_branches), nmb_injections+1)).astype(int) # Cleanup. There should be at least nmb_mindist diffusion steps between each injection and list_nmb_branches increases list_nmb_branches_clean = [list_nmb_branches[0]] @@ -294,7 +294,7 @@ class LatentBlending(): if idx_injection - list_injection_idx_clean[-1] >= nmb_mindist and nmb_branches > list_nmb_branches_clean[-1]: list_nmb_branches_clean.append(nmb_branches) list_injection_idx_clean.append(idx_injection) - list_nmb_branches_clean[-1] = nmb_branches_final + list_nmb_branches_clean[-1] = nmb_max_branches list_injection_idx_clean = [int(l) for l in list_injection_idx_clean] list_nmb_branches_clean = [int(l) for l in list_nmb_branches_clean] @@ -394,8 +394,36 @@ class LatentBlending(): recycle_img2: Optional[bool] = False, num_inference_steps: Optional[int] = 30, depth_strength: Optional[float] = 0.3, + t_compute_max_allowed: Optional[float] = None, + nmb_max_branches: Optional[int] = None, fixed_seeds: Optional[List[int]] = None, ): + r""" + Function for computing transitions. + Returns a list of transition images using spherical latent blending. + Args: + recycle_img1: Optional[bool]: + Don't recompute the latents for the first keyframe (purely prompt1). Saves compute. + recycle_img2: Optional[bool]: + Don't recompute the latents for the second keyframe (purely prompt2). Saves compute. + num_inference_steps: + Number of diffusion steps. Higher values will take more compute time. + depth_strength: + Determines how deep the first injection will happen. + Deeper injections will cause (unwanted) formation of new structures, + more shallow values will go into alpha-blendy land. + t_compute_max_allowed: + Either provide t_compute_max_allowed or nmb_max_branches. + The maximum time allowed for computation. Higher values give better results but take longer. + nmb_max_branches: int + Either provide t_compute_max_allowed or nmb_max_branches. The maximum number of branches to be computed. Higher values give better + results. Use this if you want to have controllable results independent + of your computer. + fixed_seeds: Optional[List[int)]: + You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2). + Otherwise random seeds will be taken. + + """ # Sanity checks first assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before' @@ -412,7 +440,8 @@ class LatentBlending(): self.seed2 = fixed_seeds[1] # Ensure correct num_inference_steps in holder - self.sdh.num_inference_steps = self.num_inference_steps + self.num_inference_steps = num_inference_steps + self.sdh.num_inference_steps = num_inference_steps # Compute / Recycle first image if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps: @@ -433,11 +462,61 @@ class LatentBlending(): self.tree_idx_injection = [0, 0] # Set up branching scheme (dependent on provided compute time) + list_idx_injection, list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches) + + # Run iteratively, starting with the longest trajectory. + # Always inserting new branches where they are needed most according to image similarity + for s_idx in tqdm(range(len(list_idx_injection))): + nmb_stems = list_nmb_stems[s_idx] + idx_injection = list_idx_injection[s_idx] + + for i in range(nmb_stems): + fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection) + self.set_guidance_mid_dampening(fract_mixing) + list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection) + self.insert_into_tree(fract_mixing, idx_injection, list_latents) + # print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}") + + return self.tree_final_imgs + + + def get_time_based_branching(self, depth_strength, t_compute_max_allowed=None, nmb_max_branches=None): + r""" + Sets up the branching scheme dependent on the time that is granted for compute. + The scheme uses an estimation derived from the first image's computation speed. + Either provide t_compute_max_allowed or nmb_max_branches + Args: + depth_strength: + Determines how deep the first injection will happen. + Deeper injections will cause (unwanted) formation of new structures, + more shallow values will go into alpha-blendy land. + t_compute_max_allowed: float + The maximum time allowed for computation. Higher values give better results + but take longer. Use this if you want to fix your waiting time for the results. + nmb_max_branches: int + The maximum number of branches to be computed. Higher values give better + results. Use this if you want to have controllable results independent + of your computer. + """ idx_injection_base = int(round(self.num_inference_steps*depth_strength)) list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps-1, 3) list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32) t_compute = 0 - while t_compute < self.t_compute_max_allowed: + + if nmb_max_branches is None: + assert t_compute_max_allowed is not None, "Either specify t_compute_max_allowed or nmb_max_branches" + stop_criterion = "t_compute_max_allowed" + elif t_compute_max_allowed is None: + assert nmb_max_branches is not None, "Either specify t_compute_max_allowed or nmb_max_branches" + stop_criterion = "nmb_max_branches" + nmb_max_branches -= 2 # discounting the outer frames + else: + raise ValueError("Either specify t_compute_max_allowed or nmb_max_branches") + + stop_criterion_reached = False + is_first_iteration = True + + while not stop_criterion_reached: list_compute_steps = self.num_inference_steps - list_idx_injection list_compute_steps *= list_nmb_stems t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15*np.sum(list_nmb_stems) @@ -449,23 +528,21 @@ class LatentBlending(): break if not increase_done: list_nmb_stems[-1] += 1 - # print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}") - # Run iteratively, always inserting new branches where they are needed most - for s_idx in tqdm(range(len(list_idx_injection))): - nmb_stems = list_nmb_stems[s_idx] - idx_injection = list_idx_injection[s_idx] - - for i in range(nmb_stems): - fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection) - self.set_guidance_mid_dampening(fract_mixing) - # print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}") - list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection) - self.insert_into_tree(fract_mixing, idx_injection, list_latents) - - return self.tree_final_imgs + if stop_criterion == "t_compute_max_allowed" and t_compute > t_compute_max_allowed: + stop_criterion_reached = True + # FIXME: also undersample here... but how... maybe drop them iteratively? + elif stop_criterion == "nmb_max_branches" and np.sum(list_nmb_stems) >= nmb_max_branches: + stop_criterion_reached = True + if is_first_iteration: + # Need to undersample. + list_idx_injection = np.linspace(list_idx_injection[0], list_idx_injection[-1], nmb_max_branches).astype(np.int32) + list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32) + else: + is_first_iteration = False - + # print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}") + return list_idx_injection, list_nmb_stems def get_mixing_parameters(self, idx_injection): r""" @@ -581,7 +658,7 @@ class LatentBlending(): whether to return an image or the list of latents """ print("starting compute_latents1") - list_conditionings = [self.text_embedding1] + list_conditionings = self.get_mixed_conditioning(0) t0 = time.time() latents_start = self.get_noise(self.seed1) list_latents1 = self.run_diffusion( @@ -604,7 +681,8 @@ class LatentBlending(): return_image: bool whether to return an image or the list of latents """ - list_conditionings = [self.text_embedding2] + print("starting compute_latents2") + list_conditionings = self.get_mixed_conditioning(1) latents_start = self.get_noise(self.seed2) # Influence from branch1 if self.branch1_influence > 0.0: @@ -630,178 +708,24 @@ class LatentBlending(): return list_latents2 def get_noise(self, seed): - generator = torch.Generator(device=self.sdh.device).manual_seed(int(seed)) - shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f] - C, H, W = shape_latents - return torch.randn((1, C, H, W), generator=generator, device=self.sdh.device) - - - def run_transition_legacy( - self, - recycle_img1: Optional[bool] = False, - recycle_img2: Optional[bool] = False, - fixed_seeds: Optional[List[int]] = None, - premature_stop: Optional[int] = np.inf, - ): r""" - Old legacy function for computing transitions. - Returns a list of transition images using spherical latent blending. + Helper function to get noise given seed. Args: - recycle_img1: Optional[bool]: - Don't recompute the latents for the first keyframe (purely prompt1). Saves compute. - recycle_img2: Optional[bool]: - Don't recompute the latents for the second keyframe (purely prompt2). Saves compute. - fixed_seeds: Optional[List[int)]: - You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2). - Otherwise random seeds will be taken. - premature_stop: Optional[int]: - Stop the computation after premature_stop frames have been computed in the transition + seed: int """ - # Sanity checks first - assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before' - assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before' - assert self.list_injection_idx is not None, 'Set the branching structure before, by calling autosetup_branching or setup_branching' + generator = torch.Generator(device=self.sdh.device).manual_seed(int(seed)) + if self.mode == 'standard': + shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f] + C, H, W = shape_latents + elif self.mode == 'upscale': + w = self.image1_lowres.size[0] + h = self.image1_lowres.size[1] + shape_latents = [self.sdh.model.channels, h, w] + C, H, W = shape_latents - if fixed_seeds is not None: - if fixed_seeds == 'randomize': - fixed_seeds = list(np.random.randint(0, 1000000, 2).astype(np.int32)) - else: - assert len(fixed_seeds)==2, "Supply a list with len = 2" - - self.seed1 = fixed_seeds[0] - self.seed2 = fixed_seeds[1] - - # Process interruption variable - self.stop_diffusion = False - - # Ensure correct num_inference_steps in holder - self.sdh.num_inference_steps = self.num_inference_steps - - # Make a backup for future reference - self.list_nmb_branches_prev = self.list_nmb_branches[:] - self.list_injection_idx_prev = self.list_injection_idx[:] - - # Split the first block if there is branch1 crossfeeding - if self.branch1_influence > 0.0 and not self.branch1_insertion_completed: - assert self.list_nmb_branches[0]==2, 'branch1 influnce currently requires the self.list_nmb_branches[0] = 0' - self.branch1_influence = np.clip(self.branch1_influence, 0, 1) - self.branch1_max_depth_influence = np.clip(self.branch1_max_depth_influence, 0, 1) - self.list_nmb_branches.insert(1, 2) - idx_crossfeed = int(round(self.list_injection_idx[1]*self.branch1_max_depth_influence)) - self.list_injection_idx_ext.insert(1, idx_crossfeed) - self.tree_fracts.insert(1, self.tree_fracts[0]) - self.tree_status.insert(1, self.tree_status[0]) - self.tree_latents.insert(1, self.tree_latents[0]) - self.branch1_insertion_completed = True - - - - # Pre-define entire branching tree structures - self.tree_final_imgs = [None]*self.list_nmb_branches[-1] - nmb_blocks_time = len(self.list_injection_idx_ext)-1 - if not recycle_img1 and not recycle_img2: - self.init_tree_struct() - else: - self.tree_final_imgs = [None]*self.list_nmb_branches[-1] - for t_block in range(nmb_blocks_time): - nmb_branches = self.list_nmb_branches[t_block] - for idx_branch in range(nmb_branches): - self.tree_status[t_block][idx_branch] = 'untouched' - if recycle_img1: - self.tree_status[t_block][0] = 'computed' - self.tree_final_imgs[0] = self.sdh.latent2image(self.tree_latents[-1][0][-1]) - self.tree_final_imgs_timing[0] = 0 - if recycle_img2: - self.tree_status[t_block][-1] = 'computed' - self.tree_final_imgs[-1] = self.sdh.latent2image(self.tree_latents[-1][-1][-1]) - self.tree_final_imgs_timing[-1] = 0 - - # setup compute order: goal: try to get last branch computed asap. - # first compute the right keyframe. needs to be there in any case - list_compute = [] - list_local_stem = [] - for t_block in range(nmb_blocks_time - 1, -1, -1): - if self.tree_status[t_block][0] == 'untouched': - self.tree_status[t_block][0] = 'prefetched' - list_local_stem.append([t_block, 0]) - list_compute.extend(list_local_stem[::-1]) - - # setup compute order: start from last leafs (the final transition images) and work way down. what parents do they need? - for idx_leaf in range(1, self.list_nmb_branches[-1]): - list_local_stem = [] - t_block = nmb_blocks_time - 1 - t_block_prev = t_block - 1 - self.tree_status[t_block][idx_leaf] = 'prefetched' - list_local_stem.append([t_block, idx_leaf]) - idx_leaf_deep = idx_leaf - - for t_block in range(nmb_blocks_time-1, 0, -1): - t_block_prev = t_block - 1 - fract_mixing = self.tree_fracts[t_block][idx_leaf_deep] - list_fract_mixing_prev = self.tree_fracts[t_block_prev] - b_parent1, b_parent2 = get_closest_idx(fract_mixing, list_fract_mixing_prev) - assert self.tree_status[t_block_prev][b_parent1] != 'untouched', 'Branch destruction??? This should never happen!' - if self.tree_status[t_block_prev][b_parent2] == 'untouched': - self.tree_status[t_block_prev][b_parent2] = 'prefetched' - list_local_stem.append([t_block_prev, b_parent2]) - idx_leaf_deep = b_parent2 - list_compute.extend(list_local_stem[::-1]) - - # Diffusion computations start here - time_start = time.time() - for t_block, idx_branch in tqdm(list_compute, desc="computing transition", smoothing=0.01): - if self.stop_diffusion: - print("run_transition: process interrupted") - return self.tree_final_imgs - if idx_branch > premature_stop: - print(f"run_transition: premature_stop criterion reached. returning tree with {premature_stop} branches") - return self.tree_final_imgs - - # print(f"computing t_block {t_block} idx_branch {idx_branch}") - idx_stop = self.list_injection_idx_ext[t_block+1] - fract_mixing = self.tree_fracts[t_block][idx_branch] - - list_conditionings = self.get_mixed_conditioning(fract_mixing) - self.set_guidance_mid_dampening(fract_mixing) - # print(f"fract_mixing {fract_mixing} guid {self.sdh.guidance_scale}") - if t_block == 0: - if fixed_seeds is not None: - if idx_branch == 0: - self.set_seed(fixed_seeds[0]) - elif idx_branch == self.list_nmb_branches[0] -1: - self.set_seed(fixed_seeds[1]) - - list_latents = self.run_diffusion(list_conditionings, idx_stop=idx_stop) - - # Inject latents from first branch for very first block - if idx_branch==1 and self.branch1_influence > 0: - fract_base_influence = np.clip(self.branch1_influence, 0, 1) - for i in range(len(list_latents)): - list_latents[i] = interpolate_spherical(list_latents[i], self.tree_latents[0][0][i], fract_base_influence) - else: - # find parents latents - b_parent1, b_parent2 = get_closest_idx(fract_mixing, self.tree_fracts[t_block-1]) - latents1 = self.tree_latents[t_block-1][b_parent1][-1] - if fract_mixing == 0: - latents2 = latents1 - else: - latents2 = self.tree_latents[t_block-1][b_parent2][-1] - idx_start = self.list_injection_idx_ext[t_block] - fract_mixing_parental = (fract_mixing - self.tree_fracts[t_block-1][b_parent1]) / (self.tree_fracts[t_block-1][b_parent2] - self.tree_fracts[t_block-1][b_parent1]) - latents_for_injection = interpolate_spherical(latents1, latents2, fract_mixing_parental) - list_latents = self.run_diffusion(list_conditionings, latents_for_injection, idx_start=idx_start, idx_stop=idx_stop) - - self.tree_latents[t_block][idx_branch] = list_latents - self.tree_status[t_block][idx_branch] = 'computed' - - # Convert latents to image directly for the last t_block - if t_block == nmb_blocks_time-1: - self.tree_final_imgs[idx_branch] = self.sdh.latent2image(list_latents[-1]) - self.tree_final_imgs_timing[idx_branch] = time.time() - time_start - - return self.tree_final_imgs - + return torch.randn((1, C, H, W), generator=generator, device=self.sdh.device) + def run_multi_transition( self, @@ -906,24 +830,31 @@ class LatentBlending(): return_image = return_image, ) + elif self.mode == 'upscale': + cond = list_conditionings[0] + uc_full = list_conditionings[1] + return self.sdh.run_diffusion_upscaling( + cond, + uc_full, + latents_start=latents_start, + idx_start=idx_start, + list_latents_mixing = list_latents_mixing, + mixing_coeffs = mixing_coeffs, + return_image=return_image) + # elif self.mode == 'inpaint': # text_embeddings = list_conditionings[0] # assert self.sdh.image_source is not None, "image_source is None. Please run init_inpainting first." # assert self.sdh.mask_image is not None, "image_source is None. Please run init_inpainting first." # return self.sdh.run_diffusion_inpaint(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image) - - # # FIXME LONG LINE and bad args - # elif self.mode == 'upscale': - # cond = list_conditionings[0] - # uc_full = list_conditionings[1] - # return self.sdh.run_diffusion_upscaling(cond, uc_full, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image) + # FIXME. new transition engine def run_upscaling_step1( self, dp_img: str, depth_strength: float = 0.65, num_inference_steps: int = 30, - nmb_branches_final: int = 10, + nmb_max_branches: int = 10, fixed_seeds: Optional[List[int]] = None, ): r""" @@ -932,6 +863,7 @@ class LatentBlending(): dp_img: Path to directory where the low-res images and yaml will be saved to. This directory cannot exist and will be created here. + FIXME quality: str Determines how many diffusion steps are being made + how many branches in total. We suggest to leave it with upscaling_step1 which has 10 final branches. @@ -951,7 +883,6 @@ class LatentBlending(): fixed_seeds = list(np.random.randint(0, 1000000, 2).astype(np.int32)) # Run latent blending - self.autosetup_branching(depth_strength, num_inference_steps, nmb_branches_final) imgs_transition = self.run_transition(fixed_seeds=fixed_seeds) self.write_imgs_transition(dp_img, imgs_transition) @@ -962,13 +893,14 @@ class LatentBlending(): self, dp_img: str, depth_strength: float = 0.65, - num_inference_steps: int = 30, - nmb_branches_final: int = 10, + num_inference_steps: int = 100, + nmb_max_branches_highres: int = 5, + nmb_max_branches_lowres: int = 6, fixed_seeds: Optional[List[int]] = None, ): fp_yml = os.path.join(dp_img, "lowres.yaml") - fp_movie = os.path.join(dp_img, "movie.mp4") + fp_movie = os.path.join(dp_img, "movie_highres.mp4") fps = 24 ms = MovieSaver(fp_movie, fps=fps) assert os.path.isfile(fp_yml), "lowres.yaml does not exist. did you forget run_upscaling_step1?" @@ -978,8 +910,9 @@ class LatentBlending(): nmb_images_lowres = dict_stuff['nmb_images'] prompt1 = dict_stuff['prompt1'] prompt2 = dict_stuff['prompt2'] + idx_img_lowres = np.round(np.linspace(0, nmb_images_lowres-1, nmb_max_branches_lowres)).astype(np.int32) imgs_lowres = [] - for i in range(nmb_images_lowres): + for i in idx_img_lowres: fp_img_lowres = os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg") assert os.path.isfile(fp_img_lowres), f"{fp_img_lowres} does not exist. did you forget run_upscaling_step1?" imgs_lowres.append(Image.open(fp_img_lowres)) @@ -989,13 +922,12 @@ class LatentBlending(): text_embeddingA = self.sdh.get_text_embedding(prompt1) text_embeddingB = self.sdh.get_text_embedding(prompt2) - self.autosetup_branching(depth_strength, num_inference_steps, nmb_branches_final) - + #FIXME: have a total length for the whole video section duration_single_trans = 3 - list_fract_mixing = np.linspace(0, 1, nmb_images_lowres-1) + list_fract_mixing = np.linspace(0, 1, nmb_max_branches_lowres-1) - for i in range(nmb_images_lowres-1): - print(f"Starting movie segment {i+1}/{nmb_images_lowres-1}") + for i in range(nmb_max_branches_lowres-1): + print(f"Starting movie segment {i+1}/{nmb_max_branches_lowres-1}") self.text_embedding1 = interpolate_linear(text_embeddingA, text_embeddingB, list_fract_mixing[i]) self.text_embedding2 = interpolate_linear(text_embeddingA, text_embeddingB, 1-list_fract_mixing[i]) @@ -1008,7 +940,15 @@ class LatentBlending(): self.set_image1(imgs_lowres[i]) self.set_image2(imgs_lowres[i+1]) - list_imgs = self.run_transition(recycle_img1=recycle_img1) + + list_imgs = self.run_transition( + recycle_img1 = recycle_img1, + recycle_img2 = False, + num_inference_steps = num_inference_steps, + depth_strength = depth_strength, + nmb_max_branches = nmb_max_branches_highres, + ) + list_imgs_interp = add_frames_linear_interp(list_imgs, fps, duration_single_trans) # Save movie frame @@ -1075,11 +1015,12 @@ class LatentBlending(): return self.sdh.get_text_embedding(prompt) - def write_imgs_transition(self, dp_img, imgs_transition): + def write_imgs_transition(self, dp_img): r""" Writes the transition images into the folder dp_img. """ - os.makedirs(dp_img) + imgs_transition = self.tree_final_imgs + os.makedirs(dp_img, exist_ok=True) for i, img in enumerate(imgs_transition): img_leaf = Image.fromarray(img) img_leaf.save(os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg")) @@ -1090,6 +1031,7 @@ class LatentBlending(): def save_statedict(self, fp_yml): # Dump everything relevant into yaml + imgs_transition = self.tree_final_imgs state_dict = self.get_state_dict() state_dict['nmb_images'] = len(imgs_transition) yml_save(fp_yml, state_dict) @@ -1098,7 +1040,9 @@ class LatentBlending(): state_dict = {} grab_vars = ['prompt1', 'prompt2', 'seed1', 'seed2', 'height', 'width', 'num_inference_steps', 'depth_strength', 'guidance_scale', - 'guidance_scale_mid_damper', 'mid_compression_scaler', 'negative_prompt'] + 'guidance_scale_mid_damper', 'mid_compression_scaler', 'negative_prompt', + 'branch1_influence', 'branch1_max_depth_influence', 'branch1_influence_decay' + 'parental_influence', 'parental_max_depth_influence', 'parental_influence_decay'] for v in grab_vars: if hasattr(self, v): if v == 'seed1' or v == 'seed2': @@ -1107,9 +1051,11 @@ class LatentBlending(): state_dict[v] = float(getattr(self, v)) else: - state_dict[v] = getattr(self, v) + try: + state_dict[v] = getattr(self, v) + except Exception as e: + pass - return state_dict def randomize_seed(self): @@ -1163,8 +1109,7 @@ class LatentBlending(): as in run_multi_transition() """ # Move over all latents - for t_block in range(len(self.tree_latents)): - self.tree_latents[t_block][0] = self.tree_latents[t_block][-1] + self.tree_latents[0] = self.tree_latents[-1] # Move over prompts and text embeddings self.prompt1 = self.prompt2 diff --git a/stable_diffusion_holder.py b/stable_diffusion_holder.py index 235e319..8bcf28e 100644 --- a/stable_diffusion_holder.py +++ b/stable_diffusion_holder.py @@ -285,8 +285,7 @@ class StableDiffusionHolder: return_image: Optional[bool] = False, ): r""" - Wrapper function for run_diffusion_standard and run_diffusion_inpaint. - Depending on the mode, the correct one will be executed. + Diffusion standard version. Args: text_embeddings: torch.FloatTensor @@ -363,6 +362,99 @@ class StableDiffusionHolder: return self.latent2image(latents) else: return list_latents_out + + + @torch.no_grad() + def run_diffusion_upscaling( + self, + cond, + uc_full, + latents_start: torch.FloatTensor, + idx_start: int = -1, + list_latents_mixing = None, + mixing_coeffs = 0.0, + return_image: Optional[bool] = False + ): + r""" + Diffusion upscaling version. + # FIXME + Args: + ?? + latents_for_injection: torch.FloatTensor + Latents that are used for injection + idx_start: int + Index of the diffusion process start and where the latents_for_injection are injected + return_image: Optional[bool] + Optionally return image directly + """ + + # Asserts + if type(mixing_coeffs) == float: + list_mixing_coeffs = self.num_inference_steps*[mixing_coeffs] + elif type(mixing_coeffs) == list: + assert len(mixing_coeffs) == self.num_inference_steps + list_mixing_coeffs = mixing_coeffs + else: + raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps") + + if np.sum(list_mixing_coeffs) > 0: + assert len(list_latents_mixing) == self.num_inference_steps + + precision_scope = autocast if self.precision == "autocast" else nullcontext + generator = torch.Generator(device=self.device).manual_seed(int(self.seed)) + + h = uc_full['c_concat'][0].shape[2] + w = uc_full['c_concat'][0].shape[3] + + with precision_scope("cuda"): + with self.model.ema_scope(): + + shape_latents = [self.model.channels, h, w] + + self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False) + C, H, W = shape_latents + size = (1, C, H, W) + b = size[0] + + latents = latents_start.clone() + + timesteps = self.sampler.ddim_timesteps + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + + # collect latents + list_latents_out = [] + for i, step in enumerate(time_range): + # Set the right starting latents + if i < idx_start: + list_latents_out.append(None) + continue + elif i == idx_start: + latents = latents_start.clone() + + # Mix the latents. + if i > 0 and list_mixing_coeffs[i]>0: + latents_mixtarget = list_latents_mixing[i-1].clone() + latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i]) + + # print(f"diffusion iter {i}") + index = total_steps - i - 1 + ts = torch.full((b,), step, device=self.device, dtype=torch.long) + outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False, + quantize_denoised=False, temperature=1.0, + noise_dropout=0.0, score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=self.guidance_scale, + unconditional_conditioning=uc_full, + dynamic_threshold=None) + latents, pred_x0 = outs + list_latents_out.append(latents.clone()) + + if return_image: + return self.latent2image(latents) + else: + return list_latents_out @torch.no_grad() def run_diffusion_inpaint( @@ -473,93 +565,6 @@ class StableDiffusionHolder: return self.latent2image(latents) else: return list_latents_out - - @torch.no_grad() - def run_diffusion_upscaling( - self, - cond, - uc_full, - latents_for_injection: torch.FloatTensor = None, - idx_start: int = -1, - idx_stop: int = -1, - return_image: Optional[bool] = False - ): - r""" - Wrapper function for run_diffusion_standard and run_diffusion_inpaint. - Depending on the mode, the correct one will be executed. - - Args: - ?? - latents_for_injection: torch.FloatTensor - Latents that are used for injection - idx_start: int - Index of the diffusion process start and where the latents_for_injection are injected - idx_stop: int - Index of the diffusion process end. - return_image: Optional[bool] - Optionally return image directly - """ - - - if latents_for_injection is None: - do_inject_latents = False - else: - do_inject_latents = True - - precision_scope = autocast if self.precision == "autocast" else nullcontext - generator = torch.Generator(device=self.device).manual_seed(int(self.seed)) - - h = uc_full['c_concat'][0].shape[2] - w = uc_full['c_concat'][0].shape[3] - - with precision_scope("cuda"): - with self.model.ema_scope(): - - - shape_latents = [self.model.channels, h, w] - - self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False) - C, H, W = shape_latents - size = (1, C, H, W) - b = size[0] - - latents = torch.randn(size, generator=generator, device=self.device) - - timesteps = self.sampler.ddim_timesteps - - time_range = np.flip(timesteps) - total_steps = timesteps.shape[0] - - # collect latents - list_latents_out = [] - for i, step in enumerate(time_range): - if do_inject_latents: - # Inject latent at right place - if i < idx_start: - continue - elif i == idx_start: - latents = latents_for_injection.clone() - - if i == idx_stop: - return list_latents_out - - # print(f"diffusion iter {i}") - index = total_steps - i - 1 - ts = torch.full((b,), step, device=self.device, dtype=torch.long) - outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False, - quantize_denoised=False, temperature=1.0, - noise_dropout=0.0, score_corrector=None, - corrector_kwargs=None, - unconditional_guidance_scale=self.guidance_scale, - unconditional_conditioning=uc_full, - dynamic_threshold=None) - latents, pred_x0 = outs - list_latents_out.append(latents.clone()) - - if return_image: - return self.latent2image(latents) - else: - return list_latents_out @torch.no_grad() def latent2image(