diff --git a/gradio_ui.py b/gradio_ui.py index d16d798..b6074b0 100644 --- a/gradio_ui.py +++ b/gradio_ui.py @@ -33,7 +33,8 @@ import gradio as gr import copy from dotenv import find_dotenv, load_dotenv import shutil - +import random +import time #%% @@ -54,7 +55,7 @@ class BlendingFrontend(): self.init_save_dir() self.save_empty_image() - self.share = True + self.share = False self.transition_can_be_computed = False self.depth_strength = 0.25 self.seed1 = 420 @@ -79,12 +80,12 @@ class BlendingFrontend(): self.current_timestamp = None self.recycle_img1 = False self.recycle_img2 = False - self.fp_img1 = None - self.fp_img2 = None self.multi_idx_current = -1 self.list_imgs_shown_last = 5*[self.fp_img_empty] self.list_all_segments = [] self.dp_session = "" + self.user_id = None + self.block_transition = False def init_save_dir(self): @@ -106,10 +107,7 @@ class BlendingFrontend(): def randomize_seed1(self): # Dont randomize seed if we are in a multi concat mode. we don't want to change this one otherwise the movie breaks - if len(self.list_all_segments) > 0: - seed = self.seed1 - else: - seed = np.random.randint(0, 10000000) + seed = np.random.randint(0, 10000000) self.seed1 = int(seed) print(f"randomize_seed1: new seed = {self.seed1}") return seed @@ -147,47 +145,80 @@ class BlendingFrontend(): self.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')] self.depth_strength = list_ui_elem[list_ui_keys.index('depth_strength')] + if len(list_ui_elem[list_ui_keys.index('user_id')]) > 1: + self.user_id = list_ui_elem[list_ui_keys.index('user_id')] + else: + # generate new user id + self.user_id = ''.join((random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ') for i in range(8))) + print(f"made new user_id: {self.user_id}") + + def save_latents(self, fp_latents, list_latents): + list_latents_cpu = [l.cpu().numpy() for l in list_latents] + np.save(fp_latents, list_latents_cpu) + + + def load_latents(self, fp_latents): + list_latents_cpu = np.load(fp_latents) + list_latents = [torch.from_numpy(l).to(self.lb.device) for l in list_latents_cpu] + return list_latents + def compute_img1(self, *args): list_ui_elem = args self.setup_lb(list_ui_elem) - self.fp_img1 = os.path.join(self.dp_imgs, f"img1_{get_time('second')}.jpg") + fp_img1 = os.path.join(self.dp_imgs, f"img1_{self.user_id}") img1 = Image.fromarray(self.lb.compute_latents1(return_image=True)) - img1.save(self.fp_img1) + img1.save(fp_img1+".jpg") + self.save_latents(fp_img1+".npy", self.lb.tree_latents[0]) + self.recycle_img1 = True self.recycle_img2 = False - return [self.fp_img1, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty] + # fixme save seeds. change filenames? + return [fp_img1+".jpg", self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id] def compute_img2(self, *args): - if self.fp_img1 is None: # don't do anything - return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty] + if not os.path.isfile(os.path.join(self.dp_imgs, f"img1_{self.user_id}.jpg")): # don't do anything + return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id] list_ui_elem = args self.setup_lb(list_ui_elem) - self.fp_img2 = os.path.join(self.dp_imgs, f"img2_{get_time('second')}.jpg") + + self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy")) + fp_img2 = os.path.join(self.dp_imgs, f"img2_{self.user_id}") img2 = Image.fromarray(self.lb.compute_latents2(return_image=True)) - img2.save(self.fp_img2) + img2.save(fp_img2+'.jpg') + self.save_latents(fp_img2+".npy", self.lb.tree_latents[-1]) self.recycle_img2 = True self.transition_can_be_computed = True - return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img2] + # fixme save seeds. change filenames? + return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2+".jpg", self.user_id] + def compute_transition(self, *args): - if not self.transition_can_be_computed: - list_return = [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty] + list_return = [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id] return list_return list_ui_elem = args self.setup_lb(list_ui_elem) print("STARTING TRANSITION...") - if self.use_debug: - list_imgs = [(255*np.random.rand(self.height,self.width,3)).astype(np.uint8) for l in range(5)] - list_imgs = [Image.fromarray(l) for l in list_imgs] - print("DONE! SENDING BACK RESULTS") - return list_imgs fixed_seeds = [self.seed1, self.seed2] # Run Latent Blending + # Check if another user is blocking this... otherwise everything will become mixed. + # t_now = time.time() + # if self.block_transition: + # while True: + # time.sleep(1) + # if not self.block_transition: + # break + # if time.time() - t_now > 1000: + # return + + self.block_transition = True + # Inject loaded latents (other user interference) + self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy")) + self.lb.tree_latents[-1] = self.load_latents(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy")) imgs_transition = self.lb.run_transition( recycle_img1=self.recycle_img1, recycle_img2=self.recycle_img2, @@ -211,12 +242,12 @@ class BlendingFrontend(): fp_img = os.path.join(self.dp_imgs, f"img_preview_{i}_{self.current_timestamp}.jpg") list_imgs_preview[i].save(fp_img) self.list_fp_imgs_current.append(fp_img) - + self.block_transition = False # Insert cheap frames for the movie imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration_video, self.fps) # Save as movie - self.fp_movie = os.path.join(self.dp_movies, f"movie_{self.current_timestamp}.mp4") + self.fp_movie = self.get_fp_video_last() if os.path.isfile(self.fp_movie): os.remove(self.fp_movie) ms = MovieSaver(self.fp_movie, fps=self.fps) @@ -244,12 +275,17 @@ class BlendingFrontend(): self.list_all_segments.append(dp_segment) self.lb.write_imgs_transition(dp_segment) - shutil.copyfile(self.fp_movie, os.path.join(dp_segment, "movie.mp4")) + + fp_movie_last = self.get_fp_video_last() + fp_movie_next = self.get_fp_video_next() + + shutil.copyfile(fp_movie_last, fp_movie_next) self.lb.swap_forward() fp_multi = self.multi_concat() list_out = [fp_multi] - list_out.extend([self.fp_img2]) + + list_out.extend([os.path.join(self.dp_imgs, f"img2_{self.user_id}.jpg")]) list_out.extend([self.fp_img_empty]*4) list_out.append(gr.update(interactive=False, value=prompt2)) list_out.append(gr.update(interactive=False, value=seed2)) @@ -260,15 +296,36 @@ class BlendingFrontend(): def multi_concat(self): - list_fp_movies = [] - for dp_segment in self.list_all_segments: - list_fp_movies.append(os.path.join(dp_segment, "movie.mp4")) - + list_fp_movies = self.get_fp_video_all() # Concatenate movies and save - fp_final = os.path.join(self.dp_session, "movie.mp4") + fp_final = os.path.join(self.dp_session, f"concat_{self.user_id}.mp4") concatenate_movies(fp_final, list_fp_movies) return fp_final + + def get_fp_video_all(self): + list_all = os.listdir(self.dp_movies) + str_beg = f"movie_{self.user_id}_" + list_user = [l for l in list_all if str_beg in l] + list_user.sort() + list_user = [os.path.join(self.dp_movies, l) for l in list_user] + return list_user + + + def get_fp_video_next(self): + list_videos = self.get_fp_video_all() + if len(list_videos) == 0: + idx_next = 0 + else: + idx_next = len(list_videos) + fp_video_next = os.path.join(self.dp_movies, f"movie_{self.user_id}_{str(idx_next).zfill(3)}.mp4") + return fp_video_next + + def get_fp_video_last(self): + fp_video_last = os.path.join(self.dp_movies, f"last_{self.user_id}.mp4") + return fp_video_last + + def get_state_dict(self): state_dict = {} grab_vars = ['prompt1', 'prompt2', 'seed1', 'seed2', 'height', 'width', @@ -378,6 +435,8 @@ if __name__ == "__main__": """ ) + with gr.Row(): + user_id = gr.Textbox(label="user id", interactive=False) # Collect all UI elemts in list to easily pass as inputs in gradio dict_ui_elem = {} @@ -404,6 +463,7 @@ if __name__ == "__main__": dict_ui_elem["parental_crossfeed_range"] = parental_crossfeed_range dict_ui_elem["parental_crossfeed_power"] = parental_crossfeed_power dict_ui_elem["parental_crossfeed_power_decay"] = parental_crossfeed_power_decay + dict_ui_elem["user_id"] = user_id # Convert to list, as gradio doesn't seem to accept dicts list_ui_elem = [] @@ -415,8 +475,8 @@ if __name__ == "__main__": b_newseed1.click(bf.randomize_seed1, outputs=seed1) b_newseed2.click(bf.randomize_seed2, outputs=seed2) - b_compute1.click(bf.compute_img1, inputs=list_ui_elem, outputs=[img1, img2, img3, img4, img5]) - b_compute2.click(bf.compute_img2, inputs=list_ui_elem, outputs=[img2, img3, img4, img5]) + b_compute1.click(bf.compute_img1, inputs=list_ui_elem, outputs=[img1, img2, img3, img4, img5, user_id]) + b_compute2.click(bf.compute_img2, inputs=list_ui_elem, outputs=[img2, img3, img4, img5, user_id]) b_compute_transition.click(bf.compute_transition, inputs=list_ui_elem, outputs=[img2, img3, img4, vid_single])