diff --git a/gradio_ui.py b/gradio_ui.py index 7aa874b..33116b2 100644 --- a/gradio_ui.py +++ b/gradio_ui.py @@ -56,9 +56,7 @@ class BlendingFrontend(): self.prompt1 = "" self.prompt2 = "" self.negative_prompt = "" - self.list_settings = [] self.state_current = {} - self.showing_current = True self.branch1_influence = 0.3 self.branch1_max_depth_influence = 0.6 self.branch1_influence_decay = 0.3 @@ -70,12 +68,10 @@ class BlendingFrontend(): self.fps = 30 self.duration_video = 10 self.t_compute_max_allowed = 10 - self.dict_multi_trans = {} - self.dict_multi_trans_include = {} - self.multi_trans_currently_shown = [] self.list_fp_imgs_current = [] self.current_timestamp = None - self.nmb_trans_stack = 8 + self.recycle_img1 = False + self.recycle_img2 = False if not self.use_debug: self.lb.sdh.num_inference_steps = self.num_inference_steps @@ -125,7 +121,7 @@ class BlendingFrontend(): self.lb.set_negative_prompt(list_ui_elem[list_ui_keys.index('negative_prompt')]) self.lb.guidance_scale = list_ui_elem[list_ui_keys.index('guidance_scale')] self.lb.guidance_scale_mid_damper = list_ui_elem[list_ui_keys.index('guidance_scale_mid_damper')] - self.lb.t_compute_max_allowed = list_ui_elem[list_ui_keys.index('duration_compute')] + self.t_compute_max_allowed = list_ui_elem[list_ui_keys.index('duration_compute')] self.lb.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')] self.lb.sdh.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')] self.duration_video = list_ui_elem[list_ui_keys.index('duration_video')] @@ -148,6 +144,8 @@ class BlendingFrontend(): img1 = Image.fromarray(self.lb.compute_latents1(return_image=True)) img1.save(fp_img1) self.save_empty_image() + self.recycle_img1 = True + self.recycle_img2 = False return [fp_img1, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty] def compute_img2(self, *args): @@ -156,9 +154,19 @@ class BlendingFrontend(): fp_img2 = os.path.join(self.dp_out, f"img2_{get_time('second')}.jpg") img2 = Image.fromarray(self.lb.compute_latents2(return_image=True)) img2.save(fp_img2) + self.recycle_img2 = True return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2] def compute_transition(self, *args): + + if not self.recycle_img1: + print("compute first image before transition") + return + if not self.recycle_img2: + print("compute last image before transition") + return + + list_ui_elem = args self.setup_lb(list_ui_elem) print("STARTING DIFFUSION!") @@ -172,10 +180,11 @@ class BlendingFrontend(): # Run Latent Blending imgs_transition = self.lb.run_transition( - recycle_img1=True, - recycle_img2=True, + recycle_img1=self.recycle_img1, + recycle_img2=self.recycle_img2, num_inference_steps=self.num_inference_steps, depth_strength=self.depth_strength, + t_compute_max_allowed=self.t_compute_max_allowed, fixed_seeds=fixed_seeds ) print(f"Latent Blending pass finished. Resulted in {len(imgs_transition)} images") @@ -222,9 +231,8 @@ class BlendingFrontend(): def stack_forward(self, prompt2, seed2): # Save preview images, prompts and seeds into dictionary for stacking - self.dict_multi_trans[self.current_timestamp] = generate_list_output(self.prompt1, self.prompt2, self.seed1, self.seed2, self.list_fp_imgs_current) - self.dict_multi_trans_include[self.current_timestamp] = True - + dp_out = os.path.join(self.dp_out, get_time('second')) + self.lb.write_imgs_transition(dp_out) self.lb.swap_forward() list_out = [self.list_fp_imgs_current[-1]] list_out.extend([self.fp_img_empty]*4) @@ -232,19 +240,14 @@ class BlendingFrontend(): list_out.append(seed2) list_out.append("") list_out.append(np.random.randint(0, 10000000)) - - list_out_multi_tab = self.update_trans_stacks() - - list_out.extend(list_out_multi_tab) - # self.nmb_trans_stack = len(self.dict_multi_trans_include) return list_out + def stack_movie(self): # collect all that are in... list_fp_movies = [] - for timestamp in self.multi_trans_currently_shown: - if timestamp is not None: - list_fp_movies.append(self.get_fp_movie(timestamp)) + + list_fp_movies.append(self.get_fp_movie(timestamp)) fp_stacked = self.get_fp_movie(get_time('second'), True) concatenate_movies(fp_stacked, list_fp_movies) @@ -261,44 +264,7 @@ class BlendingFrontend(): state_dict[v] = getattr(self, v) return state_dict - - def update_trans_stacks(self): - print("Updating transition stack...") - - self.multi_trans_currently_shown = [] - list_output = [] - # Figure out which transitions should be shown - for timestamp in self.dict_multi_trans_include.keys(): - if len(self.multi_trans_currently_shown) >= self.nmb_trans_stack: - continue - - if self.dict_multi_trans_include[timestamp]: - last_timestamp_vals = self.dict_multi_trans[timestamp] - list_output.extend(self.dict_multi_trans[timestamp]) - self.multi_trans_currently_shown.append(timestamp) - print(f"including timestamp: {timestamp}") - - # Fill with empty images if below nmb_trans_stack - nmb_empty_missing = self.nmb_trans_stack - len(self.multi_trans_currently_shown) - for i in range(nmb_empty_missing): - list_output.extend([gr.update(visible=False)]*len(last_timestamp_vals)) - self.multi_trans_currently_shown.append(None) - - return list_output - - - def remove_trans(self, idx_row): - idx_row = int(idx_row) - # do removal... - if idx_row < len(self.multi_trans_currently_shown): - timestamp = self.multi_trans_currently_shown[idx_row] - if timestamp in self.dict_multi_trans_include.keys(): - self.dict_multi_trans_include[timestamp] = False - print(f"remove_trans called: {timestamp}") - else: - print(f"remove_trans called: idx_row too large {idx_row}") - - return self.update_trans_stacks() + def get_img_rand(): return (255*np.random.rand(self.height,self.width,3)).astype(np.uint8) @@ -388,6 +354,9 @@ if __name__ == "__main__": with gr.Row(): vid_transition = gr.Video() + + with gr.Row(): + b_stackforward = gr.Button('multi-movie start next segment (move last image -> first image)') # Collect all UI elemts in list to easily pass as inputs dict_ui_elem["prompt1"] = prompt1 @@ -430,4 +399,7 @@ if __name__ == "__main__": inputs=list_ui_elem, outputs=[img2, img3, img4, vid_transition]) + b_stackforward.click(self.stack_forward, + inputs=[prompt2, seed2], + outputs=[img1, img2, img3, img4, img5, prompt1, seed1, prompt2]) demo.launch(share=self.share, inbrowser=True, inline=False)