multi user support for gradio (huggingface server)

This commit is contained in:
Johannes Stelzer 2023-02-21 19:46:16 +01:00
parent cc8baf6c72
commit 6f78b990e1
1 changed files with 94 additions and 34 deletions

View File

@ -33,7 +33,8 @@ import gradio as gr
import copy import copy
from dotenv import find_dotenv, load_dotenv from dotenv import find_dotenv, load_dotenv
import shutil import shutil
import random
import time
#%% #%%
@ -54,7 +55,7 @@ class BlendingFrontend():
self.init_save_dir() self.init_save_dir()
self.save_empty_image() self.save_empty_image()
self.share = True self.share = False
self.transition_can_be_computed = False self.transition_can_be_computed = False
self.depth_strength = 0.25 self.depth_strength = 0.25
self.seed1 = 420 self.seed1 = 420
@ -79,12 +80,12 @@ class BlendingFrontend():
self.current_timestamp = None self.current_timestamp = None
self.recycle_img1 = False self.recycle_img1 = False
self.recycle_img2 = False self.recycle_img2 = False
self.fp_img1 = None
self.fp_img2 = None
self.multi_idx_current = -1 self.multi_idx_current = -1
self.list_imgs_shown_last = 5*[self.fp_img_empty] self.list_imgs_shown_last = 5*[self.fp_img_empty]
self.list_all_segments = [] self.list_all_segments = []
self.dp_session = "" self.dp_session = ""
self.user_id = None
self.block_transition = False
def init_save_dir(self): def init_save_dir(self):
@ -106,9 +107,6 @@ class BlendingFrontend():
def randomize_seed1(self): def randomize_seed1(self):
# Dont randomize seed if we are in a multi concat mode. we don't want to change this one otherwise the movie breaks # Dont randomize seed if we are in a multi concat mode. we don't want to change this one otherwise the movie breaks
if len(self.list_all_segments) > 0:
seed = self.seed1
else:
seed = np.random.randint(0, 10000000) seed = np.random.randint(0, 10000000)
self.seed1 = int(seed) self.seed1 = int(seed)
print(f"randomize_seed1: new seed = {self.seed1}") print(f"randomize_seed1: new seed = {self.seed1}")
@ -147,47 +145,80 @@ class BlendingFrontend():
self.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')] self.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
self.depth_strength = list_ui_elem[list_ui_keys.index('depth_strength')] self.depth_strength = list_ui_elem[list_ui_keys.index('depth_strength')]
if len(list_ui_elem[list_ui_keys.index('user_id')]) > 1:
self.user_id = list_ui_elem[list_ui_keys.index('user_id')]
else:
# generate new user id
self.user_id = ''.join((random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ') for i in range(8)))
print(f"made new user_id: {self.user_id}")
def save_latents(self, fp_latents, list_latents):
list_latents_cpu = [l.cpu().numpy() for l in list_latents]
np.save(fp_latents, list_latents_cpu)
def load_latents(self, fp_latents):
list_latents_cpu = np.load(fp_latents)
list_latents = [torch.from_numpy(l).to(self.lb.device) for l in list_latents_cpu]
return list_latents
def compute_img1(self, *args): def compute_img1(self, *args):
list_ui_elem = args list_ui_elem = args
self.setup_lb(list_ui_elem) self.setup_lb(list_ui_elem)
self.fp_img1 = os.path.join(self.dp_imgs, f"img1_{get_time('second')}.jpg") fp_img1 = os.path.join(self.dp_imgs, f"img1_{self.user_id}")
img1 = Image.fromarray(self.lb.compute_latents1(return_image=True)) img1 = Image.fromarray(self.lb.compute_latents1(return_image=True))
img1.save(self.fp_img1) img1.save(fp_img1+".jpg")
self.save_latents(fp_img1+".npy", self.lb.tree_latents[0])
self.recycle_img1 = True self.recycle_img1 = True
self.recycle_img2 = False self.recycle_img2 = False
return [self.fp_img1, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty] # fixme save seeds. change filenames?
return [fp_img1+".jpg", self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
def compute_img2(self, *args): def compute_img2(self, *args):
if self.fp_img1 is None: # don't do anything if not os.path.isfile(os.path.join(self.dp_imgs, f"img1_{self.user_id}.jpg")): # don't do anything
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty] return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
list_ui_elem = args list_ui_elem = args
self.setup_lb(list_ui_elem) self.setup_lb(list_ui_elem)
self.fp_img2 = os.path.join(self.dp_imgs, f"img2_{get_time('second')}.jpg")
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
fp_img2 = os.path.join(self.dp_imgs, f"img2_{self.user_id}")
img2 = Image.fromarray(self.lb.compute_latents2(return_image=True)) img2 = Image.fromarray(self.lb.compute_latents2(return_image=True))
img2.save(self.fp_img2) img2.save(fp_img2+'.jpg')
self.save_latents(fp_img2+".npy", self.lb.tree_latents[-1])
self.recycle_img2 = True self.recycle_img2 = True
self.transition_can_be_computed = True self.transition_can_be_computed = True
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img2] # fixme save seeds. change filenames?
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2+".jpg", self.user_id]
def compute_transition(self, *args): def compute_transition(self, *args):
if not self.transition_can_be_computed: if not self.transition_can_be_computed:
list_return = [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty] list_return = [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
return list_return return list_return
list_ui_elem = args list_ui_elem = args
self.setup_lb(list_ui_elem) self.setup_lb(list_ui_elem)
print("STARTING TRANSITION...") print("STARTING TRANSITION...")
if self.use_debug:
list_imgs = [(255*np.random.rand(self.height,self.width,3)).astype(np.uint8) for l in range(5)]
list_imgs = [Image.fromarray(l) for l in list_imgs]
print("DONE! SENDING BACK RESULTS")
return list_imgs
fixed_seeds = [self.seed1, self.seed2] fixed_seeds = [self.seed1, self.seed2]
# Run Latent Blending # Run Latent Blending
# Check if another user is blocking this... otherwise everything will become mixed.
# t_now = time.time()
# if self.block_transition:
# while True:
# time.sleep(1)
# if not self.block_transition:
# break
# if time.time() - t_now > 1000:
# return
self.block_transition = True
# Inject loaded latents (other user interference)
self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
self.lb.tree_latents[-1] = self.load_latents(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"))
imgs_transition = self.lb.run_transition( imgs_transition = self.lb.run_transition(
recycle_img1=self.recycle_img1, recycle_img1=self.recycle_img1,
recycle_img2=self.recycle_img2, recycle_img2=self.recycle_img2,
@ -211,12 +242,12 @@ class BlendingFrontend():
fp_img = os.path.join(self.dp_imgs, f"img_preview_{i}_{self.current_timestamp}.jpg") fp_img = os.path.join(self.dp_imgs, f"img_preview_{i}_{self.current_timestamp}.jpg")
list_imgs_preview[i].save(fp_img) list_imgs_preview[i].save(fp_img)
self.list_fp_imgs_current.append(fp_img) self.list_fp_imgs_current.append(fp_img)
self.block_transition = False
# Insert cheap frames for the movie # Insert cheap frames for the movie
imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration_video, self.fps) imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration_video, self.fps)
# Save as movie # Save as movie
self.fp_movie = os.path.join(self.dp_movies, f"movie_{self.current_timestamp}.mp4") self.fp_movie = self.get_fp_video_last()
if os.path.isfile(self.fp_movie): if os.path.isfile(self.fp_movie):
os.remove(self.fp_movie) os.remove(self.fp_movie)
ms = MovieSaver(self.fp_movie, fps=self.fps) ms = MovieSaver(self.fp_movie, fps=self.fps)
@ -244,12 +275,17 @@ class BlendingFrontend():
self.list_all_segments.append(dp_segment) self.list_all_segments.append(dp_segment)
self.lb.write_imgs_transition(dp_segment) self.lb.write_imgs_transition(dp_segment)
shutil.copyfile(self.fp_movie, os.path.join(dp_segment, "movie.mp4"))
fp_movie_last = self.get_fp_video_last()
fp_movie_next = self.get_fp_video_next()
shutil.copyfile(fp_movie_last, fp_movie_next)
self.lb.swap_forward() self.lb.swap_forward()
fp_multi = self.multi_concat() fp_multi = self.multi_concat()
list_out = [fp_multi] list_out = [fp_multi]
list_out.extend([self.fp_img2])
list_out.extend([os.path.join(self.dp_imgs, f"img2_{self.user_id}.jpg")])
list_out.extend([self.fp_img_empty]*4) list_out.extend([self.fp_img_empty]*4)
list_out.append(gr.update(interactive=False, value=prompt2)) list_out.append(gr.update(interactive=False, value=prompt2))
list_out.append(gr.update(interactive=False, value=seed2)) list_out.append(gr.update(interactive=False, value=seed2))
@ -260,15 +296,36 @@ class BlendingFrontend():
def multi_concat(self): def multi_concat(self):
list_fp_movies = [] list_fp_movies = self.get_fp_video_all()
for dp_segment in self.list_all_segments:
list_fp_movies.append(os.path.join(dp_segment, "movie.mp4"))
# Concatenate movies and save # Concatenate movies and save
fp_final = os.path.join(self.dp_session, "movie.mp4") fp_final = os.path.join(self.dp_session, f"concat_{self.user_id}.mp4")
concatenate_movies(fp_final, list_fp_movies) concatenate_movies(fp_final, list_fp_movies)
return fp_final return fp_final
def get_fp_video_all(self):
list_all = os.listdir(self.dp_movies)
str_beg = f"movie_{self.user_id}_"
list_user = [l for l in list_all if str_beg in l]
list_user.sort()
list_user = [os.path.join(self.dp_movies, l) for l in list_user]
return list_user
def get_fp_video_next(self):
list_videos = self.get_fp_video_all()
if len(list_videos) == 0:
idx_next = 0
else:
idx_next = len(list_videos)
fp_video_next = os.path.join(self.dp_movies, f"movie_{self.user_id}_{str(idx_next).zfill(3)}.mp4")
return fp_video_next
def get_fp_video_last(self):
fp_video_last = os.path.join(self.dp_movies, f"last_{self.user_id}.mp4")
return fp_video_last
def get_state_dict(self): def get_state_dict(self):
state_dict = {} state_dict = {}
grab_vars = ['prompt1', 'prompt2', 'seed1', 'seed2', 'height', 'width', grab_vars = ['prompt1', 'prompt2', 'seed1', 'seed2', 'height', 'width',
@ -378,6 +435,8 @@ if __name__ == "__main__":
""" """
) )
with gr.Row():
user_id = gr.Textbox(label="user id", interactive=False)
# Collect all UI elemts in list to easily pass as inputs in gradio # Collect all UI elemts in list to easily pass as inputs in gradio
dict_ui_elem = {} dict_ui_elem = {}
@ -404,6 +463,7 @@ if __name__ == "__main__":
dict_ui_elem["parental_crossfeed_range"] = parental_crossfeed_range dict_ui_elem["parental_crossfeed_range"] = parental_crossfeed_range
dict_ui_elem["parental_crossfeed_power"] = parental_crossfeed_power dict_ui_elem["parental_crossfeed_power"] = parental_crossfeed_power
dict_ui_elem["parental_crossfeed_power_decay"] = parental_crossfeed_power_decay dict_ui_elem["parental_crossfeed_power_decay"] = parental_crossfeed_power_decay
dict_ui_elem["user_id"] = user_id
# Convert to list, as gradio doesn't seem to accept dicts # Convert to list, as gradio doesn't seem to accept dicts
list_ui_elem = [] list_ui_elem = []
@ -415,8 +475,8 @@ if __name__ == "__main__":
b_newseed1.click(bf.randomize_seed1, outputs=seed1) b_newseed1.click(bf.randomize_seed1, outputs=seed1)
b_newseed2.click(bf.randomize_seed2, outputs=seed2) b_newseed2.click(bf.randomize_seed2, outputs=seed2)
b_compute1.click(bf.compute_img1, inputs=list_ui_elem, outputs=[img1, img2, img3, img4, img5]) b_compute1.click(bf.compute_img1, inputs=list_ui_elem, outputs=[img1, img2, img3, img4, img5, user_id])
b_compute2.click(bf.compute_img2, inputs=list_ui_elem, outputs=[img2, img3, img4, img5]) b_compute2.click(bf.compute_img2, inputs=list_ui_elem, outputs=[img2, img3, img4, img5, user_id])
b_compute_transition.click(bf.compute_transition, b_compute_transition.click(bf.compute_transition,
inputs=list_ui_elem, inputs=list_ui_elem,
outputs=[img2, img3, img4, vid_single]) outputs=[img2, img3, img4, vid_single])