multi user support for gradio (huggingface server)
This commit is contained in:
parent
cc8baf6c72
commit
6f78b990e1
126
gradio_ui.py
126
gradio_ui.py
|
@ -33,7 +33,8 @@ import gradio as gr
|
|||
import copy
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
import shutil
|
||||
|
||||
import random
|
||||
import time
|
||||
|
||||
|
||||
#%%
|
||||
|
@ -54,7 +55,7 @@ class BlendingFrontend():
|
|||
|
||||
self.init_save_dir()
|
||||
self.save_empty_image()
|
||||
self.share = True
|
||||
self.share = False
|
||||
self.transition_can_be_computed = False
|
||||
self.depth_strength = 0.25
|
||||
self.seed1 = 420
|
||||
|
@ -79,12 +80,12 @@ class BlendingFrontend():
|
|||
self.current_timestamp = None
|
||||
self.recycle_img1 = False
|
||||
self.recycle_img2 = False
|
||||
self.fp_img1 = None
|
||||
self.fp_img2 = None
|
||||
self.multi_idx_current = -1
|
||||
self.list_imgs_shown_last = 5*[self.fp_img_empty]
|
||||
self.list_all_segments = []
|
||||
self.dp_session = ""
|
||||
self.user_id = None
|
||||
self.block_transition = False
|
||||
|
||||
|
||||
def init_save_dir(self):
|
||||
|
@ -106,9 +107,6 @@ class BlendingFrontend():
|
|||
|
||||
def randomize_seed1(self):
|
||||
# Dont randomize seed if we are in a multi concat mode. we don't want to change this one otherwise the movie breaks
|
||||
if len(self.list_all_segments) > 0:
|
||||
seed = self.seed1
|
||||
else:
|
||||
seed = np.random.randint(0, 10000000)
|
||||
self.seed1 = int(seed)
|
||||
print(f"randomize_seed1: new seed = {self.seed1}")
|
||||
|
@ -147,47 +145,80 @@ class BlendingFrontend():
|
|||
self.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
|
||||
self.depth_strength = list_ui_elem[list_ui_keys.index('depth_strength')]
|
||||
|
||||
if len(list_ui_elem[list_ui_keys.index('user_id')]) > 1:
|
||||
self.user_id = list_ui_elem[list_ui_keys.index('user_id')]
|
||||
else:
|
||||
# generate new user id
|
||||
self.user_id = ''.join((random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ') for i in range(8)))
|
||||
print(f"made new user_id: {self.user_id}")
|
||||
|
||||
def save_latents(self, fp_latents, list_latents):
|
||||
list_latents_cpu = [l.cpu().numpy() for l in list_latents]
|
||||
np.save(fp_latents, list_latents_cpu)
|
||||
|
||||
|
||||
def load_latents(self, fp_latents):
|
||||
list_latents_cpu = np.load(fp_latents)
|
||||
list_latents = [torch.from_numpy(l).to(self.lb.device) for l in list_latents_cpu]
|
||||
return list_latents
|
||||
|
||||
|
||||
def compute_img1(self, *args):
|
||||
list_ui_elem = args
|
||||
self.setup_lb(list_ui_elem)
|
||||
self.fp_img1 = os.path.join(self.dp_imgs, f"img1_{get_time('second')}.jpg")
|
||||
fp_img1 = os.path.join(self.dp_imgs, f"img1_{self.user_id}")
|
||||
img1 = Image.fromarray(self.lb.compute_latents1(return_image=True))
|
||||
img1.save(self.fp_img1)
|
||||
img1.save(fp_img1+".jpg")
|
||||
self.save_latents(fp_img1+".npy", self.lb.tree_latents[0])
|
||||
|
||||
self.recycle_img1 = True
|
||||
self.recycle_img2 = False
|
||||
return [self.fp_img1, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty]
|
||||
# fixme save seeds. change filenames?
|
||||
return [fp_img1+".jpg", self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
|
||||
|
||||
def compute_img2(self, *args):
|
||||
if self.fp_img1 is None: # don't do anything
|
||||
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty]
|
||||
if not os.path.isfile(os.path.join(self.dp_imgs, f"img1_{self.user_id}.jpg")): # don't do anything
|
||||
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
|
||||
list_ui_elem = args
|
||||
self.setup_lb(list_ui_elem)
|
||||
self.fp_img2 = os.path.join(self.dp_imgs, f"img2_{get_time('second')}.jpg")
|
||||
|
||||
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
|
||||
fp_img2 = os.path.join(self.dp_imgs, f"img2_{self.user_id}")
|
||||
img2 = Image.fromarray(self.lb.compute_latents2(return_image=True))
|
||||
img2.save(self.fp_img2)
|
||||
img2.save(fp_img2+'.jpg')
|
||||
self.save_latents(fp_img2+".npy", self.lb.tree_latents[-1])
|
||||
self.recycle_img2 = True
|
||||
self.transition_can_be_computed = True
|
||||
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img2]
|
||||
# fixme save seeds. change filenames?
|
||||
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2+".jpg", self.user_id]
|
||||
|
||||
|
||||
def compute_transition(self, *args):
|
||||
|
||||
if not self.transition_can_be_computed:
|
||||
list_return = [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty]
|
||||
list_return = [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
|
||||
return list_return
|
||||
|
||||
list_ui_elem = args
|
||||
self.setup_lb(list_ui_elem)
|
||||
print("STARTING TRANSITION...")
|
||||
if self.use_debug:
|
||||
list_imgs = [(255*np.random.rand(self.height,self.width,3)).astype(np.uint8) for l in range(5)]
|
||||
list_imgs = [Image.fromarray(l) for l in list_imgs]
|
||||
print("DONE! SENDING BACK RESULTS")
|
||||
return list_imgs
|
||||
|
||||
fixed_seeds = [self.seed1, self.seed2]
|
||||
|
||||
# Run Latent Blending
|
||||
# Check if another user is blocking this... otherwise everything will become mixed.
|
||||
# t_now = time.time()
|
||||
# if self.block_transition:
|
||||
# while True:
|
||||
# time.sleep(1)
|
||||
# if not self.block_transition:
|
||||
# break
|
||||
# if time.time() - t_now > 1000:
|
||||
# return
|
||||
|
||||
self.block_transition = True
|
||||
# Inject loaded latents (other user interference)
|
||||
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
|
||||
self.lb.tree_latents[-1] = self.load_latents(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"))
|
||||
imgs_transition = self.lb.run_transition(
|
||||
recycle_img1=self.recycle_img1,
|
||||
recycle_img2=self.recycle_img2,
|
||||
|
@ -211,12 +242,12 @@ class BlendingFrontend():
|
|||
fp_img = os.path.join(self.dp_imgs, f"img_preview_{i}_{self.current_timestamp}.jpg")
|
||||
list_imgs_preview[i].save(fp_img)
|
||||
self.list_fp_imgs_current.append(fp_img)
|
||||
|
||||
self.block_transition = False
|
||||
# Insert cheap frames for the movie
|
||||
imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration_video, self.fps)
|
||||
|
||||
# Save as movie
|
||||
self.fp_movie = os.path.join(self.dp_movies, f"movie_{self.current_timestamp}.mp4")
|
||||
self.fp_movie = self.get_fp_video_last()
|
||||
if os.path.isfile(self.fp_movie):
|
||||
os.remove(self.fp_movie)
|
||||
ms = MovieSaver(self.fp_movie, fps=self.fps)
|
||||
|
@ -244,12 +275,17 @@ class BlendingFrontend():
|
|||
|
||||
self.list_all_segments.append(dp_segment)
|
||||
self.lb.write_imgs_transition(dp_segment)
|
||||
shutil.copyfile(self.fp_movie, os.path.join(dp_segment, "movie.mp4"))
|
||||
|
||||
fp_movie_last = self.get_fp_video_last()
|
||||
fp_movie_next = self.get_fp_video_next()
|
||||
|
||||
shutil.copyfile(fp_movie_last, fp_movie_next)
|
||||
|
||||
self.lb.swap_forward()
|
||||
fp_multi = self.multi_concat()
|
||||
list_out = [fp_multi]
|
||||
list_out.extend([self.fp_img2])
|
||||
|
||||
list_out.extend([os.path.join(self.dp_imgs, f"img2_{self.user_id}.jpg")])
|
||||
list_out.extend([self.fp_img_empty]*4)
|
||||
list_out.append(gr.update(interactive=False, value=prompt2))
|
||||
list_out.append(gr.update(interactive=False, value=seed2))
|
||||
|
@ -260,15 +296,36 @@ class BlendingFrontend():
|
|||
|
||||
|
||||
def multi_concat(self):
|
||||
list_fp_movies = []
|
||||
for dp_segment in self.list_all_segments:
|
||||
list_fp_movies.append(os.path.join(dp_segment, "movie.mp4"))
|
||||
|
||||
list_fp_movies = self.get_fp_video_all()
|
||||
# Concatenate movies and save
|
||||
fp_final = os.path.join(self.dp_session, "movie.mp4")
|
||||
fp_final = os.path.join(self.dp_session, f"concat_{self.user_id}.mp4")
|
||||
concatenate_movies(fp_final, list_fp_movies)
|
||||
return fp_final
|
||||
|
||||
|
||||
def get_fp_video_all(self):
|
||||
list_all = os.listdir(self.dp_movies)
|
||||
str_beg = f"movie_{self.user_id}_"
|
||||
list_user = [l for l in list_all if str_beg in l]
|
||||
list_user.sort()
|
||||
list_user = [os.path.join(self.dp_movies, l) for l in list_user]
|
||||
return list_user
|
||||
|
||||
|
||||
def get_fp_video_next(self):
|
||||
list_videos = self.get_fp_video_all()
|
||||
if len(list_videos) == 0:
|
||||
idx_next = 0
|
||||
else:
|
||||
idx_next = len(list_videos)
|
||||
fp_video_next = os.path.join(self.dp_movies, f"movie_{self.user_id}_{str(idx_next).zfill(3)}.mp4")
|
||||
return fp_video_next
|
||||
|
||||
def get_fp_video_last(self):
|
||||
fp_video_last = os.path.join(self.dp_movies, f"last_{self.user_id}.mp4")
|
||||
return fp_video_last
|
||||
|
||||
|
||||
def get_state_dict(self):
|
||||
state_dict = {}
|
||||
grab_vars = ['prompt1', 'prompt2', 'seed1', 'seed2', 'height', 'width',
|
||||
|
@ -378,6 +435,8 @@ if __name__ == "__main__":
|
|||
"""
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
user_id = gr.Textbox(label="user id", interactive=False)
|
||||
|
||||
# Collect all UI elemts in list to easily pass as inputs in gradio
|
||||
dict_ui_elem = {}
|
||||
|
@ -404,6 +463,7 @@ if __name__ == "__main__":
|
|||
dict_ui_elem["parental_crossfeed_range"] = parental_crossfeed_range
|
||||
dict_ui_elem["parental_crossfeed_power"] = parental_crossfeed_power
|
||||
dict_ui_elem["parental_crossfeed_power_decay"] = parental_crossfeed_power_decay
|
||||
dict_ui_elem["user_id"] = user_id
|
||||
|
||||
# Convert to list, as gradio doesn't seem to accept dicts
|
||||
list_ui_elem = []
|
||||
|
@ -415,8 +475,8 @@ if __name__ == "__main__":
|
|||
|
||||
b_newseed1.click(bf.randomize_seed1, outputs=seed1)
|
||||
b_newseed2.click(bf.randomize_seed2, outputs=seed2)
|
||||
b_compute1.click(bf.compute_img1, inputs=list_ui_elem, outputs=[img1, img2, img3, img4, img5])
|
||||
b_compute2.click(bf.compute_img2, inputs=list_ui_elem, outputs=[img2, img3, img4, img5])
|
||||
b_compute1.click(bf.compute_img1, inputs=list_ui_elem, outputs=[img1, img2, img3, img4, img5, user_id])
|
||||
b_compute2.click(bf.compute_img2, inputs=list_ui_elem, outputs=[img2, img3, img4, img5, user_id])
|
||||
b_compute_transition.click(bf.compute_transition,
|
||||
inputs=list_ui_elem,
|
||||
outputs=[img2, img3, img4, vid_single])
|
||||
|
|
Loading…
Reference in New Issue