From 20147a17c6f9781aa55843a3d7da2a1a72f660f9 Mon Sep 17 00:00:00 2001 From: Johannes Stelzer Date: Sun, 15 Jan 2023 16:52:42 +0100 Subject: [PATCH] simplification and prompt bugfix --- gradio_ui.py | 208 ++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 149 insertions(+), 59 deletions(-) diff --git a/gradio_ui.py b/gradio_ui.py index 29c5635..a56c418 100644 --- a/gradio_ui.py +++ b/gradio_ui.py @@ -24,7 +24,7 @@ import torch from tqdm.auto import tqdm from PIL import Image import torch -from movie_util import MovieSaver +from movie_util import MovieSaver, concatenate_movies from typing import Callable, List, Optional, Union from latent_blending import get_time, yml_save, LatentBlending, add_frames_linear_interp, compare_dicts from stable_diffusion_holder import StableDiffusionHolder @@ -33,12 +33,6 @@ import gradio as gr import copy -""" -try this: - button variant 'primary' for main call-to-action, 'secondary' for a more subdued style - gr.Column(scale=1, min_width=600): - -""" #%% @@ -62,17 +56,19 @@ class BlendingFrontend(): self.prompt2 = "" self.negative_prompt = "" self.list_settings = [] - self.state_prev = {} self.state_current = {} self.showing_current = True self.branch1_influence = 0.02 - self.imgs_show_last = [] - self.imgs_show_current = [] self.nmb_branches_final = 9 self.nmb_imgs_show = 5 # don't change self.fps = 30 self.duration = 10 - self.max_size_imgs = 200 # gradio otherwise mega slow + self.dict_multi_trans = {} + self.dict_multi_trans_include = {} + self.multi_trans_currently_shown = [] + self.list_fp_imgs_current = [] + self.current_timestamp = None + self.nmb_trans_stack = 8 if not self.use_debug: self.lb.sdh.num_inference_steps = self.num_inference_steps @@ -82,6 +78,10 @@ class BlendingFrontend(): self.height = 768 self.width = 768 + # make dummy image + self.fp_img_empty = 'empty.jpg' + Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5) + def change_depth_strength(self, value): self.depth_strength = value print(f"changed depth_strength to {value}") @@ -127,14 +127,6 @@ class BlendingFrontend(): self.fps = value print(f"changed fps to {value}") - def change_prompt1(self, value): - self.prompt1 = value - # print(f"changed prompt1 to {value}") - - def change_prompt2(self, value): - self.prompt2 = value - # print(f"changed prompt2 to {value}") - def change_negative_prompt(self, value): self.negative_prompt = value @@ -157,30 +149,25 @@ class BlendingFrontend(): return seed - def run(self): + def compute_transition(self, prompt1, prompt2): + self.prompt1 = prompt1 + self.prompt2 = prompt2 print("STARTING DIFFUSION!") - self.state_prev = self.state_current.copy() self.state_current = self.get_state_dict() - # Copy last iteration - self.imgs_show_last = copy.deepcopy(self.imgs_show_current) - if self.use_debug: list_imgs = [(255*np.random.rand(self.height,self.width,3)).astype(np.uint8) for l in range(5)] list_imgs = [Image.fromarray(l) for l in list_imgs] - list_imgs = self.downscale_imgs(list_imgs) - self.imgs_show_current = copy.deepcopy(list_imgs) print("DONE! SENDING BACK RESULTS") return list_imgs + # Collect latent blending variables self.lb.set_width(self.width) self.lb.set_height(self.height) - self.lb.autosetup_branching( depth_strength = self.depth_strength, num_inference_steps = self.num_inference_steps, nmb_branches_final = self.nmb_branches_final, nmb_mindist = 3) - self.lb.set_prompt1(self.prompt1) self.lb.set_prompt2(self.prompt2) self.lb.set_negative_prompt(self.negative_prompt) @@ -189,31 +176,32 @@ class BlendingFrontend(): self.lb.guidance_scale_mid_damper = self.guidance_scale_mid_damper self.lb.mid_compression_scaler = self.mid_compression_scaler self.lb.branch1_influence = self.branch1_influence - fixed_seeds = [self.seed1, self.seed2] + + # Run Latent Blending imgs_transition = self.lb.run_transition(fixed_seeds=fixed_seeds) + print(f"Latent Blending pass finished. Resulted in {len(imgs_transition)} images") - print(f"DONE DIFFUSION! Resulted in {len(imgs_transition)} images") - + # Subselect the preview images (hard fixed to self.nmb_imgs_show=5) assert np.mod((self.nmb_branches_final-self.nmb_imgs_show)/4, 1)==0, 'self.nmb_branches_final illegal value!' idx_list = np.linspace(0, self.nmb_branches_final-1, self.nmb_imgs_show).astype(np.int32) list_imgs_preview = [] for j in idx_list: list_imgs_preview.append(Image.fromarray(imgs_transition[j])) - # Save as jpgs on disk so we are not sending umcompressed data around - timestamp = get_time('second') - list_fp_imgs = [] + # Save the preview imgs as jpgs on disk so we are not sending umcompressed data around + self.current_timestamp = get_time('second') + self.list_fp_imgs_current = [] for i in range(len(list_imgs_preview)): - fp_img = f"img_preview_{i}_{timestamp}.jpg" + fp_img = f"img_preview_{i}_{self.current_timestamp}.jpg" list_imgs_preview[i].save(fp_img) - list_fp_imgs.append(fp_img) + self.list_fp_imgs_current.append(fp_img) - # Save the movie as well + # Insert cheap frames for the movie imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration, self.fps) # Save as movie - fp_movie = f"movie_{timestamp}.mp4" + fp_movie = self.get_fp_movie(self.current_timestamp) if os.path.isfile(fp_movie): os.remove(fp_movie) ms = MovieSaver(fp_movie, fps=self.fps) @@ -221,11 +209,49 @@ class BlendingFrontend(): ms.write_frame(img) ms.finalize() print("DONE SAVING MOVIE! SENDING BACK...") - list_return = list_fp_imgs + [fp_movie] - return list_return - - + # Assemble Output, updating the preview images and le movie + list_return = self.list_fp_imgs_current + [fp_movie] + return list_return + + def get_fp_movie(self, timestamp, is_stacked=False): + if not is_stacked: + return f"movie_{timestamp}.mp4" + else: + return f"movie_stacked_{timestamp}.mp4" + + + def stack_forward(self, prompt2, seed2): + # Save preview images, prompts and seeds into dictionary for stacking + self.dict_multi_trans[self.current_timestamp] = generate_list_output(self.prompt1, self.prompt2, self.seed1, self.seed2, self.list_fp_imgs_current) + self.dict_multi_trans_include[self.current_timestamp] = True + + self.lb.swap_forward() + list_out = [self.list_fp_imgs_current[-1]] + list_out.extend([self.fp_img_empty]*4) + list_out.append(prompt2) + list_out.append(seed2) + list_out.append("") + list_out.append(np.random.randint(0, 10000000)) + + list_out_multi_tab = self.update_trans_stacks() + + list_out.extend(list_out_multi_tab) + # self.nmb_trans_stack = len(self.dict_multi_trans_include) + return list_out + + def stack_movie(self): + # collect all that are in... + list_fp_movies = [] + for timestamp in self.multi_trans_currently_shown: + if timestamp is not None: + list_fp_movies.append(self.get_fp_movie(timestamp)) + + fp_stacked = self.get_fp_movie(get_time('second'), True) + concatenate_movies(fp_stacked, list_fp_movies) + return fp_stacked + + def get_state_dict(self): state_dict = {} grab_vars = ['prompt1', 'prompt2', 'seed1', 'seed2', 'height', 'width', @@ -235,17 +261,77 @@ class BlendingFrontend(): for v in grab_vars: state_dict[v] = getattr(self, v) return state_dict + + + def update_trans_stacks(self): + print("Updating transition stack...") + self.multi_trans_currently_shown = [] + list_output = [] + # Figure out which transitions should be shown + for timestamp in self.dict_multi_trans_include.keys(): + if len(self.multi_trans_currently_shown) >= self.nmb_trans_stack: + continue + + if self.dict_multi_trans_include[timestamp]: + last_timestamp_vals = self.dict_multi_trans[timestamp] + list_output.extend(self.dict_multi_trans[timestamp]) + self.multi_trans_currently_shown.append(timestamp) + print(f"including timestamp: {timestamp}") + + # Fill with empty images if below nmb_trans_stack + nmb_empty_missing = self.nmb_trans_stack - len(self.multi_trans_currently_shown) + for i in range(nmb_empty_missing): + list_output.extend([gr.update(visible=False)]*len(last_timestamp_vals)) + self.multi_trans_currently_shown.append(None) + + return list_output + + + def remove_trans(self, idx_row): + idx_row = int(idx_row) + # do removal... + if idx_row < len(self.multi_trans_currently_shown): + timestamp = self.multi_trans_currently_shown[idx_row] + if timestamp in self.dict_multi_trans_include.keys(): + self.dict_multi_trans_include[timestamp] = False + print(f"remove_trans called: {timestamp}") + else: + print(f"remove_trans called: idx_row too large {idx_row}") + + return self.update_trans_stacks() + +def get_img_rand(): + return (255*np.random.rand(self.height,self.width,3)).astype(np.uint8) + +def generate_list_output( + prompt1, + prompt2, + seed1, + seed2, + list_fp_imgs, + ): + list_output = [] + list_output.append(prompt1) + list_output.append(prompt2) + list_output.append(seed1) + list_output.append(seed2) + for fp_img in list_fp_imgs: + list_output.append(fp_img) + + return list_output + + if __name__ == "__main__": - fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt" - # fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt" - # sdh = StableDiffusionHolder(fp_ckpt) - self = BlendingFrontend(None) + # fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt" + fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt" + sdh = StableDiffusionHolder(fp_ckpt) + + self = BlendingFrontend(sdh) # Yes this is possible in python and yes it is an awesome trick with gr.Blocks() as demo: - with gr.Row(): prompt1 = gr.Textbox(label="prompt 1") prompt2 = gr.Textbox(label="prompt 2") @@ -265,16 +351,15 @@ if __name__ == "__main__": depth_strength = gr.Slider(0.01, 0.99, self.depth_strength, step=0.01, label='depth_strength', interactive=True) duration = gr.Slider(0.1, 30, self.duration, step=0.1, label='video duration', interactive=True) guidance_scale_mid_damper = gr.Slider(0.01, 2.0, self.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True) - + with gr.Row(): - b_run = gr.Button('COMPUTE!', variant='primary') seed1 = gr.Number(42, label="seed 1", interactive=True) + b_newseed1 = gr.Button("randomize seed 1", variant='secondary') seed2 = gr.Number(420, label="seed 2", interactive=True) - - with gr.Column(): - b_newseed1 = gr.Button("randomize \nseed 1", variant='secondary') - b_newseed2 = gr.Button("randomize \nseed 2", variant='secondary') - + b_newseed2 = gr.Button("randomize seed 2", variant='secondary') + with gr.Row(): + b_compute_transition = gr.Button('compute transition', variant='primary') + with gr.Row(): img1 = gr.Image(label="1/5") img2 = gr.Image(label="2/5") @@ -283,7 +368,7 @@ if __name__ == "__main__": img5 = gr.Image(label="5/5") with gr.Row(): - vid = gr.Video() + vid_transition = gr.Video() # Bind the on-change methods depth_strength.change(fn=self.change_depth_strength, inputs=depth_strength) @@ -295,8 +380,6 @@ if __name__ == "__main__": height.change(fn=self.change_height, inputs=height) width.change(fn=self.change_width, inputs=width) - prompt1.change(fn=self.change_prompt1, inputs=prompt1) - prompt2.change(fn=self.change_prompt2, inputs=prompt2) negative_prompt.change(fn=self.change_negative_prompt, inputs=negative_prompt) seed1.change(fn=self.change_seed1, inputs=seed1) seed2.change(fn=self.change_seed2, inputs=seed2) @@ -305,6 +388,13 @@ if __name__ == "__main__": b_newseed1.click(self.randomize_seed1, outputs=seed1) b_newseed2.click(self.randomize_seed2, outputs=seed2) - b_run.click(self.run, outputs=[img1, img2, img3, img4, img5, vid]) - + # b_stackforward.click(self.stack_forward, + # inputs=[prompt2, seed2], + # outputs=[img1, img2, img3, img4, img5, prompt1, seed1, prompt2]) + b_compute_transition.click(self.compute_transition, + inputs=[prompt1, prompt2], + outputs=[img1, img2, img3, img4, img5, vid_transition]) + + + demo.launch(share=self.share, inbrowser=True, inline=False)