diff --git a/gradio_ui.py b/gradio_ui.py index ad6e03f..7aa874b 100644 --- a/gradio_ui.py +++ b/gradio_ui.py @@ -31,15 +31,9 @@ from stable_diffusion_holder import StableDiffusionHolder torch.set_grad_enabled(False) import gradio as gr import copy +from dotenv import find_dotenv, load_dotenv -""" -TODOS: - - clean parameter handling - - three buttons: diffuse A, diffuse B, make transition - - collapse for easy mode - - transition quality in terms of render time -""" #%% @@ -65,13 +59,17 @@ class BlendingFrontend(): self.list_settings = [] self.state_current = {} self.showing_current = True - self.branch1_influence = 0.1 - self.branch1_mixing_depth = 0.3 + self.branch1_influence = 0.3 + self.branch1_max_depth_influence = 0.6 + self.branch1_influence_decay = 0.3 + self.parental_influence = 0.1 + self.parental_max_depth_influence = 1.0 + self.parental_influence_decay = 1.0 self.nmb_branches_final = 9 self.nmb_imgs_show = 5 # don't change self.fps = 30 - self.duration_video = 15 - self.t_compute_max_allowed = 15 + self.duration_video = 10 + self.t_compute_max_allowed = 10 self.dict_multi_trans = {} self.dict_multi_trans_include = {} self.multi_trans_currently_shown = [] @@ -86,10 +84,21 @@ class BlendingFrontend(): else: self.height = 768 self.width = 768 + + self.init_save_dir() + + + def init_save_dir(self): + load_dotenv(find_dotenv(), verbose=False) + try: + self.dp_out = os.getenv("dp_out") + except Exception as e: + self.dp_out = "" + # make dummy image def save_empty_image(self): - self.fp_img_empty = 'empty.jpg' + self.fp_img_empty = os.path.join(self.dp_out, 'empty.jpg') Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5) @@ -116,8 +125,6 @@ class BlendingFrontend(): self.lb.set_negative_prompt(list_ui_elem[list_ui_keys.index('negative_prompt')]) self.lb.guidance_scale = list_ui_elem[list_ui_keys.index('guidance_scale')] self.lb.guidance_scale_mid_damper = list_ui_elem[list_ui_keys.index('guidance_scale_mid_damper')] - self.lb.branch1_influence = list_ui_elem[list_ui_keys.index('branch1_influence')] - self.lb.branch1_mixing_depth = list_ui_elem[list_ui_keys.index('branch1_mixing_depth')] self.lb.t_compute_max_allowed = list_ui_elem[list_ui_keys.index('duration_compute')] self.lb.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')] self.lb.sdh.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')] @@ -125,11 +132,19 @@ class BlendingFrontend(): self.lb.seed1 = list_ui_elem[list_ui_keys.index('seed1')] self.lb.seed2 = list_ui_elem[list_ui_keys.index('seed2')] + self.lb.branch1_influence = list_ui_elem[list_ui_keys.index('branch1_influence')] + self.lb.branch1_max_depth_influence = list_ui_elem[list_ui_keys.index('branch1_max_depth_influence')] + self.lb.branch1_influence_decay = list_ui_elem[list_ui_keys.index('branch1_influence_decay')] + self.lb.parental_influence = list_ui_elem[list_ui_keys.index('parental_influence')] + self.lb.parental_max_depth_influence = list_ui_elem[list_ui_keys.index('parental_max_depth_influence')] + self.lb.parental_influence_decay = list_ui_elem[list_ui_keys.index('parental_influence_decay')] + + def compute_img1(self, *args): list_ui_elem = args self.setup_lb(list_ui_elem) - fp_img1 = f"img1_{get_time('second')}.jpg" + fp_img1 = os.path.join(self.dp_out, f"img1_{get_time('second')}.jpg") img1 = Image.fromarray(self.lb.compute_latents1(return_image=True)) img1.save(fp_img1) self.save_empty_image() @@ -138,7 +153,7 @@ class BlendingFrontend(): def compute_img2(self, *args): list_ui_elem = args self.setup_lb(list_ui_elem) - fp_img2 = f"img2_{get_time('second')}.jpg" + fp_img2 = os.path.join(self.dp_out, f"img2_{get_time('second')}.jpg") img2 = Image.fromarray(self.lb.compute_latents2(return_image=True)) img2.save(fp_img2) return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2] @@ -198,9 +213,11 @@ class BlendingFrontend(): def get_fp_movie(self, timestamp, is_stacked=False): if not is_stacked: - return f"movie_{timestamp}.mp4" + fn = f"movie_{timestamp}.mp4" else: - return f"movie_stacked_{timestamp}.mp4" + fn = f"movie_stacked_{timestamp}.mp4" + fp = os.path.join(self.dp_out, fn) + return fp def stack_forward(self, prompt2, seed2): @@ -319,32 +336,43 @@ if __name__ == "__main__": with gr.Blocks() as demo: with gr.Row(): prompt1 = gr.Textbox(label="prompt 1") - negative_prompt = gr.Textbox(label="negative prompt") prompt2 = gr.Textbox(label="prompt 2") with gr.Row(): - duration_compute = gr.Slider(10, 40, self.duration_video, step=1, label='compute budget for transition (seconds)', interactive=True) + duration_compute = gr.Slider(5, 45, self.t_compute_max_allowed, step=1, label='compute budget for transition (seconds)', interactive=True) duration_video = gr.Slider(0.1, 30, self.duration_video, step=0.1, label='result video duration (seconds)', interactive=True) height = gr.Slider(256, 2048, self.height, step=128, label='height', interactive=True) width = gr.Slider(256, 2048, self.width, step=128, label='width', interactive=True) with gr.Accordion("Advanced Settings (click to expand)", open=False): - with gr.Row(): - depth_strength = gr.Slider(0.01, 0.99, self.depth_strength, step=0.01, label='depth_strength', interactive=True) - branch1_influence = gr.Slider(0.0, 1.0, self.branch1_influence, step=0.01, label='branch1_influence', interactive=True) - branch1_mixing_depth = gr.Slider(0.0, 1.0, self.branch1_mixing_depth, step=0.01, label='branch1_mixing_depth', interactive=True) + with gr.Accordion("Diffusion settings", open=True): + with gr.Row(): + num_inference_steps = gr.Slider(5, 100, self.num_inference_steps, step=1, label='num_inference_steps', interactive=True) + guidance_scale = gr.Slider(1, 25, self.guidance_scale, step=0.1, label='guidance_scale', interactive=True) + negative_prompt = gr.Textbox(label="negative prompt") + + with gr.Accordion("Seeds control", open=True): + with gr.Row(): + seed1 = gr.Number(420, label="seed 1", interactive=True) + b_newseed1 = gr.Button("randomize seed 1", variant='secondary') + seed2 = gr.Number(420, label="seed 2", interactive=True) + b_newseed2 = gr.Button("randomize seed 2", variant='secondary') + + with gr.Accordion("Crossfeeding for last image", open=True): + with gr.Row(): + branch1_influence = gr.Slider(0.0, 1.0, self.branch1_influence, step=0.01, label='crossfeed power', interactive=True) + branch1_max_depth_influence = gr.Slider(0.0, 1.0, self.branch1_max_depth_influence, step=0.01, label='crossfeed range', interactive=True) + branch1_influence_decay = gr.Slider(0.0, 1.0, self.branch1_influence_decay, step=0.01, label='crossfeed decay', interactive=True) - with gr.Row(): - num_inference_steps = gr.Slider(5, 100, self.num_inference_steps, step=1, label='num_inference_steps', interactive=True) - guidance_scale = gr.Slider(1, 25, self.guidance_scale, step=0.1, label='guidance_scale', interactive=True) - guidance_scale_mid_damper = gr.Slider(0.01, 2.0, self.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True) + with gr.Accordion("Transition settings", open=True): + with gr.Row(): + depth_strength = gr.Slider(0.01, 0.99, self.depth_strength, step=0.01, label='depth_strength', interactive=True) + guidance_scale_mid_damper = gr.Slider(0.01, 2.0, self.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True) + parental_influence = gr.Slider(0.0, 1.0, self.parental_influence, step=0.01, label='parental power', interactive=True) + parental_max_depth_influence = gr.Slider(0.0, 1.0, self.parental_max_depth_influence, step=0.01, label='parental range', interactive=True) + parental_influence_decay = gr.Slider(0.0, 1.0, self.parental_influence_decay, step=0.01, label='parental decay', interactive=True) - with gr.Row(): - seed1 = gr.Number(420, label="seed 1", interactive=True) - b_newseed1 = gr.Button("randomize seed 1", variant='secondary') - seed2 = gr.Number(420, label="seed 2", interactive=True) - b_newseed2 = gr.Button("randomize seed 2", variant='secondary') with gr.Row(): b_compute1 = gr.Button('compute first image', variant='primary') @@ -373,7 +401,8 @@ if __name__ == "__main__": dict_ui_elem["depth_strength"] = depth_strength dict_ui_elem["branch1_influence"] = branch1_influence - dict_ui_elem["branch1_mixing_depth"] = branch1_mixing_depth + dict_ui_elem["branch1_max_depth_influence"] = branch1_max_depth_influence + dict_ui_elem["branch1_influence_decay"] = branch1_influence_decay dict_ui_elem["num_inference_steps"] = num_inference_steps dict_ui_elem["guidance_scale"] = guidance_scale @@ -381,6 +410,10 @@ if __name__ == "__main__": dict_ui_elem["seed1"] = seed1 dict_ui_elem["seed2"] = seed2 + dict_ui_elem["parental_max_depth_influence"] = parental_max_depth_influence + dict_ui_elem["parental_influence"] = parental_influence + dict_ui_elem["parental_influence_decay"] = parental_influence_decay + # Convert to list, as gradio doesn't seem to accept dicts list_ui_elem = [] list_ui_keys = [] diff --git a/latent_blending.py b/latent_blending.py index 6364a11..ef100b1 100644 --- a/latent_blending.py +++ b/latent_blending.py @@ -107,8 +107,16 @@ class LatentBlending(): self.noise_level_upscaling = 20 self.list_injection_idx = None self.list_nmb_branches = None + + # Mixing parameters self.branch1_influence = 0.0 - self.branch1_mixing_depth = 0.65 + self.branch1_max_depth_influence = 0.65 + self.branch1_influence_decay = 0.8 + + self.parental_influence = 0.0 + self.parental_max_depth_influence = 1.0 + self.parental_influence_decay = 1.0 + self.branch1_insertion_completed = False self.set_guidance_scale(guidance_scale) self.init_mode() @@ -389,10 +397,6 @@ class LatentBlending(): fixed_seeds: Optional[List[int]] = None, ): - # # FIXME: deal with these tree args later - # self.num_inference_steps = 30 - # self.t_compute_max_allowed = 60 - # Sanity checks first assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before' assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before' @@ -411,17 +415,15 @@ class LatentBlending(): self.sdh.num_inference_steps = self.num_inference_steps # Compute / Recycle first image - if not recycle_img1: + if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps: list_latents1 = self.compute_latents1() else: - # FIXME: check if latents there... list_latents1 = self.tree_latents[0] # Compute / Recycle first image - if not recycle_img2: + if not recycle_img2 or len(self.tree_latents[-1]) != self.num_inference_steps: list_latents2 = self.compute_latents2() else: - # FIXME: check if latents there... list_latents2 = self.tree_latents[-1] # Reset the tree, injecting the edge latents1/2 we just generated/recycled @@ -438,7 +440,7 @@ class LatentBlending(): while t_compute < self.t_compute_max_allowed: list_compute_steps = self.num_inference_steps - list_idx_injection list_compute_steps *= list_nmb_stems - t_compute = np.sum(list_compute_steps) * self.dt_per_diff + t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15*np.sum(list_nmb_stems) increase_done = False for s_idx in range(len(list_nmb_stems)-1): if list_nmb_stems[s_idx+1] / list_nmb_stems[s_idx] >= 2: @@ -449,13 +451,14 @@ class LatentBlending(): list_nmb_stems[-1] += 1 # print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}") - # Run iteratively + # Run iteratively, always inserting new branches where they are needed most for s_idx in tqdm(range(len(list_idx_injection))): nmb_stems = list_nmb_stems[s_idx] idx_injection = list_idx_injection[s_idx] for i in range(nmb_stems): fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection) + self.set_guidance_mid_dampening(fract_mixing) # print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}") list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection) self.insert_into_tree(fract_mixing, idx_injection, list_latents) @@ -465,6 +468,14 @@ class LatentBlending(): def get_mixing_parameters(self, idx_injection): + r""" + Computes which parental latents should be mixed together to achieve a smooth blend. + As metric, we are using lpips image similarity. The insertion takes place + where the metric is maximal. + Args: + idx_injection: int + the index in terms of diffusion steps, where the next insertion will start. + """ # get_lpips_similarity similarities = [] for i in range(len(self.tree_final_imgs)-1): @@ -499,7 +510,16 @@ class LatentBlending(): def insert_into_tree(self, fract_mixing, idx_injection, list_latents): - # FIXME + r""" + Inserts all necessary parameters into the trajectory tree. + Args: + fract_mixing: float + the fraction along the transition axis [0, 1] + idx_injection: int + the index in terms of diffusion steps, where the next insertion will start. + list_latents: list + list of the latents to be inserted + """ b_parent1, b_parent2 = get_closest_idx(fract_mixing, self.tree_fracts) self.tree_latents.insert(b_parent1+1, list_latents) self.tree_final_imgs.insert(b_parent1+1, self.sdh.latent2image(list_latents[-1])) @@ -508,23 +528,67 @@ class LatentBlending(): def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection): - # FIXME + r""" + Runs a diffusion trajectory, using the latents from the respective parents + Args: + fract_mixing: float + the fraction along the transition axis [0, 1] + b_parent1: int + index of parent1 to be used + b_parent2: int + index of parent2 to be used + idx_injection: int + the index in terms of diffusion steps, where the next insertion will start. + """ list_conditionings = self.get_mixed_conditioning(fract_mixing) fract_mixing_parental = (fract_mixing - self.tree_fracts[b_parent1]) / (self.tree_fracts[b_parent2] - self.tree_fracts[b_parent1]) - idx_reversed = self.num_inference_steps - idx_injection - latents_for_injection = interpolate_spherical( - self.tree_latents[b_parent1][-idx_reversed-1], - self.tree_latents[b_parent2][-idx_reversed-1], - fract_mixing_parental) - list_latents = self.run_diffusion(list_conditionings, latents_for_injection=latents_for_injection, idx_start=idx_injection) + # idx_reversed = self.num_inference_steps - idx_injection + + list_latents_parental_mix = [] + for i in range(self.num_inference_steps): + latents_p1 = self.tree_latents[b_parent1][i] + latents_p2 = self.tree_latents[b_parent2][i] + if latents_p1 is None or latents_p2 is None: + latents_parental = None + else: + latents_parental = interpolate_spherical(latents_p1, latents_p2, fract_mixing_parental) + list_latents_parental_mix.append(latents_parental) + + idx_mixing_stop = int(round(self.num_inference_steps*self.parental_max_depth_influence)) + mixing_coeffs = idx_injection*[self.parental_influence] + nmb_mixing = idx_mixing_stop - idx_injection + if nmb_mixing > 0: + mixing_coeffs.extend(list(np.linspace(self.parental_influence, self.parental_influence*self.parental_influence_decay, nmb_mixing))) + mixing_coeffs.extend((self.num_inference_steps-len(mixing_coeffs))*[0]) + + latents_start = list_latents_parental_mix[idx_injection-1] + list_latents = self.run_diffusion( + list_conditionings, + latents_start = latents_start, + idx_start = idx_injection, + list_latents_mixing = list_latents_parental_mix, + mixing_coeffs = mixing_coeffs + ) + return list_latents def compute_latents1(self, return_image=False): + r""" + Runs a diffusion trajectory for the first image + Args: + return_image: bool + whether to return an image or the list of latents + """ print("starting compute_latents1") list_conditionings = [self.text_embedding1] t0 = time.time() - list_latents1 = self.run_diffusion(list_conditionings, seed_source=self.seed1) + latents_start = self.get_noise(self.seed1) + list_latents1 = self.run_diffusion( + list_conditionings, + latents_start = latents_start, + idx_start = 0 + ) t1 = time.time() self.dt_per_diff = (t1-t0) / self.num_inference_steps self.tree_latents[0] = list_latents1 @@ -533,25 +597,31 @@ class LatentBlending(): else: return list_latents1 - def compute_latents2(self, return_image=False): - print("starting compute_latents2") - list_conditionings = [self.text_embedding2 - ] + r""" + Runs a diffusion trajectory for the last image, which may be affected by the first image's trajectory. + Args: + return_image: bool + whether to return an image or the list of latents + """ + list_conditionings = [self.text_embedding2] + latents_start = self.get_noise(self.seed2) # Influence from branch1 if self.branch1_influence > 0.0: - self.branch1_influence = np.clip(self.branch1_influence, 0, 1) - self.branch1_mixing_depth = np.clip(self.branch1_mixing_depth, 0, 1) - idx_crossfeed = int(round(self.num_inference_steps*self.branch1_mixing_depth)) + # Set up the mixing_coeffs + idx_mixing_stop = int(round(self.num_inference_steps*self.branch1_max_depth_influence)) + mixing_coeffs = list(np.linspace(self.branch1_influence, self.branch1_influence*self.branch1_influence_decay, idx_mixing_stop)) + mixing_coeffs.extend((self.num_inference_steps-idx_mixing_stop)*[0]) + list_latents_mixing = self.tree_latents[0] list_latents2 = self.run_diffusion( list_conditionings, - idx_start=idx_crossfeed, - latents_for_injection=self.tree_latents[0], - seed_source=self.seed2, - seed_mixing_target=self.seed1, - mixing_coeff=self.branch1_influence) + latents_start = latents_start, + idx_start = 0, + list_latents_mixing = list_latents_mixing, + mixing_coeffs = mixing_coeffs + ) else: - list_latents2 = self.run_diffusion(list_conditionings) + list_latents2 = self.run_diffusion(list_conditionings, latents_start) self.tree_latents[-1] = list_latents2 if return_image: @@ -559,9 +629,14 @@ class LatentBlending(): else: return list_latents2 + def get_noise(self, seed): + generator = torch.Generator(device=self.sdh.device).manual_seed(int(seed)) + shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f] + C, H, W = shape_latents + return torch.randn((1, C, H, W), generator=generator, device=self.sdh.device) + - - def run_transition_OLD( + def run_transition_legacy( self, recycle_img1: Optional[bool] = False, recycle_img2: Optional[bool] = False, @@ -569,6 +644,7 @@ class LatentBlending(): premature_stop: Optional[int] = np.inf, ): r""" + Old legacy function for computing transitions. Returns a list of transition images using spherical latent blending. Args: recycle_img1: Optional[bool]: @@ -610,9 +686,9 @@ class LatentBlending(): if self.branch1_influence > 0.0 and not self.branch1_insertion_completed: assert self.list_nmb_branches[0]==2, 'branch1 influnce currently requires the self.list_nmb_branches[0] = 0' self.branch1_influence = np.clip(self.branch1_influence, 0, 1) - self.branch1_mixing_depth = np.clip(self.branch1_mixing_depth, 0, 1) + self.branch1_max_depth_influence = np.clip(self.branch1_max_depth_influence, 0, 1) self.list_nmb_branches.insert(1, 2) - idx_crossfeed = int(round(self.list_injection_idx[1]*self.branch1_mixing_depth)) + idx_crossfeed = int(round(self.list_injection_idx[1]*self.branch1_max_depth_influence)) self.list_injection_idx_ext.insert(1, idx_crossfeed) self.tree_fracts.insert(1, self.tree_fracts[0]) self.tree_status.insert(1, self.tree_status[0]) @@ -790,27 +866,27 @@ class LatentBlending(): def run_diffusion( self, list_conditionings, - latents_for_injection: torch.FloatTensor = None, - idx_start: int = -1, - idx_stop: int = -1, - seed_source: int = -1, - seed_mixing_target: int = -1, - mixing_coeff: float = 0.0, + latents_start: torch.FloatTensor = None, + idx_start: int = 0, + list_latents_mixing = None, + mixing_coeffs = 0.0, return_image: Optional[bool] = False ): + r""" - Wrapper function for run_diffusion_standard and run_diffusion_inpaint. + Wrapper function for diffusion runners. Depending on the mode, the correct one will be executed. Args: list_conditionings: List of all conditionings for the diffusion model. - latents_for_injection: torch.FloatTensor + latents_start: torch.FloatTensor Latents that are used for injection idx_start: int Index of the diffusion process start and where the latents_for_injection are injected - idx_stop: int - Index of the diffusion process end. - FIXME ARGS + list_latents_mixing: torch.FloatTensor + List of latents (latent trajectories) that are used for mixing + mixing_coeffs: float or list + Coefficients, how strong each element of list_latents_mixing will be mixed in. return_image: Optional[bool] Optionally return image directly """ @@ -822,26 +898,25 @@ class LatentBlending(): if self.mode == 'standard': text_embeddings = list_conditionings[0] return self.sdh.run_diffusion_standard( - text_embeddings, - latents_for_injection=latents_for_injection, - idx_start=idx_start, - idx_stop=idx_stop, - seed_source=seed_source, - seed_mixing_target=seed_mixing_target, - mixing_coeff=mixing_coeff, - return_image=return_image, + text_embeddings = text_embeddings, + latents_start = latents_start, + idx_start = idx_start, + list_latents_mixing = list_latents_mixing, + mixing_coeffs = mixing_coeffs, + return_image = return_image, ) - elif self.mode == 'inpaint': - text_embeddings = list_conditionings[0] - assert self.sdh.image_source is not None, "image_source is None. Please run init_inpainting first." - assert self.sdh.mask_image is not None, "image_source is None. Please run init_inpainting first." - return self.sdh.run_diffusion_inpaint(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image) - # FIXME LONG LINE - elif self.mode == 'upscale': - cond = list_conditionings[0] - uc_full = list_conditionings[1] - return self.sdh.run_diffusion_upscaling(cond, uc_full, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image) + # elif self.mode == 'inpaint': + # text_embeddings = list_conditionings[0] + # assert self.sdh.image_source is not None, "image_source is None. Please run init_inpainting first." + # assert self.sdh.mask_image is not None, "image_source is None. Please run init_inpainting first." + # return self.sdh.run_diffusion_inpaint(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image) + + # # FIXME LONG LINE and bad args + # elif self.mode == 'upscale': + # cond = list_conditionings[0] + # uc_full = list_conditionings[1] + # return self.sdh.run_diffusion_upscaling(cond, uc_full, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image) def run_upscaling_step1( self, @@ -1100,7 +1175,11 @@ class LatentBlending(): def get_lpips_similarity(self, imgA, imgB): - # FIXME + r""" + Computes the image similarity between two images imgA and imgB. + Used to determine the optimal point of insertion to create smooth transitions. + High values indicate low similarity. + """ tensorA = torch.from_numpy(imgA).float().cuda(self.device) tensorA = 2*tensorA/255.0 - 1 tensorA = tensorA.permute([2,0,1]).unsqueeze(0) @@ -1406,55 +1485,24 @@ if __name__ == "__main__": # Run latent blending self.branch1_influence = 0.3 - self.branch1_mixing_depth = 0.4 - self.run_transition(depth_strength=depth_strength, fixed_seeds=fixed_seeds) - #%% - self.branch1_influence = 0.3 - self.branch1_mixing_depth = 0.5 - img2 = self.compute_latents2(return_image=True) - Image.fromarray(img2) + self.branch1_max_depth_influence = 0.4 + # self.run_transition(depth_strength=depth_strength, fixed_seeds=fixed_seeds) + self.seed1=21312 + img1 =self.compute_latents1(True) + #% + self.seed2=1234121 + self.branch1_influence = 0.7 + self.branch1_max_depth_influence = 0.3 + self.branch1_influence_decay = 0.3 + img2 =self.compute_latents2(True) + # Image.fromarray(np.concatenate((img1, img2), axis=1)) #%% - idx_injection = 15 - fract_mixing = 0.5 - list_conditionings = self.get_mixed_conditioning(fract_mixing) - latents_for_injection = interpolate_spherical(self.tree_latents[0][idx_injection], self.tree_latents[-1][idx_injection], fract_mixing) - list_latents = self.run_diffusion(list_conditionings, latents_for_injection=latents_for_injection, idx_start=idx_injection) - img_mix = self.sdh.latent2image((list_latents[-1])) - - Image.fromarray(np.concatenate((img1,img_mix,img2), axis=1)).resize((800,800//3)) - - #%% scheme - # init scheme - list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps-1, 2) - list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32) - - #%% - list_compute_steps = self.num_inference_steps - list_idx_injection - list_compute_steps *= list_nmb_stems - t_compute = np.sum(list_compute_steps) * self.dt_per_diff - increase_done = False - for s_idx in range(len(list_nmb_stems)-1): - if list_nmb_stems[s_idx+1] / list_nmb_stems[s_idx] >= 3: - list_nmb_stems[s_idx] += 1 - increase_done = True - break - if not increase_done: - list_nmb_stems[-1] += 1 - print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}") - - - #%% - - imgs_transition = self.tree_final_imgs - # Let's get more cheap frames via linear interpolation (duration_transition*fps frames) - imgs_transition_ext = add_frames_linear_interp(imgs_transition, 15, fps) - - # Save as MP4 - fp_movie = "test.mp4" - if os.path.isfile(fp_movie): - os.remove(fp_movie) - ms = MovieSaver(fp_movie, fps=fps, shape_hw=[sdh.height, sdh.width]) - for img in tqdm(imgs_transition_ext): - ms.write_frame(img) - ms.finalize() + t0 = time.time() + self.t_compute_max_allowed = 30 + self.parental_max_depth_influence = 1.0 + self.parental_influence = 0.0 + self.parental_influence_decay = 1.0 + imgs_transition = self.run_transition(recycle_img1=True, recycle_img2=True) + t1 = time.time() + print(f"took: {t1-t0}s") \ No newline at end of file diff --git a/stable_diffusion_holder.py b/stable_diffusion_holder.py index 8cae407..235e319 100644 --- a/stable_diffusion_holder.py +++ b/stable_diffusion_holder.py @@ -278,12 +278,10 @@ class StableDiffusionHolder: def run_diffusion_standard( self, text_embeddings: torch.FloatTensor, - latents_for_injection = None, - idx_start: int = -1, - idx_stop: int = -1, - seed_source: int = -1, - seed_mixing_target: int = -1, - mixing_coeff: float = 0.0, + latents_start: torch.FloatTensor, + idx_start: int = 0, + list_latents_mixing = None, + mixing_coeffs = 0.0, return_image: Optional[bool] = False, ): r""" @@ -297,34 +295,26 @@ class StableDiffusionHolder: Latents that are used for injection idx_start: int Index of the diffusion process start and where the latents_for_injection are injected - idx_stop: int - Index of the diffusion process end. mixing_coeff: # FIXME - seed_source: - # FIXME - seed_mixing: - # FIXME return_image: Optional[bool] Optionally return image directly + """ - - if latents_for_injection is None: - do_inject_latents = False - do_mix_latents = False + # Asserts + if type(mixing_coeffs) == float: + list_mixing_coeffs = self.num_inference_steps*[mixing_coeffs] + elif type(mixing_coeffs) == list: + assert len(mixing_coeffs) == self.num_inference_steps + list_mixing_coeffs = mixing_coeffs else: - if mixing_coeff > 0.0: - do_inject_latents = False - do_mix_latents = True - assert seed_mixing_target != -1, "Set to correct seed for mixing" - else: - do_inject_latents = True - do_mix_latents = False + raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps") + if np.sum(list_mixing_coeffs) > 0: + assert len(list_latents_mixing) == self.num_inference_steps precision_scope = autocast if self.precision == "autocast" else nullcontext - generator = torch.Generator(device=self.device).manual_seed(int(seed_source)) with precision_scope("cuda"): with self.model.ema_scope(): @@ -332,14 +322,10 @@ class StableDiffusionHolder: uc = self.model.get_learned_conditioning(self.negative_prompt) else: uc = None - shape_latents = [self.C, self.height // self.f, self.width // self.f] self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False) - C, H, W = shape_latents - size = (1, C, H, W) - b = size[0] - latents = torch.randn(size, generator=generator, device=self.device) + latents = latents_start.clone() timesteps = self.sampler.ddim_timesteps @@ -349,29 +335,20 @@ class StableDiffusionHolder: # collect latents list_latents_out = [] for i, step in enumerate(time_range): - if do_inject_latents: - # Inject latent at right place - if i < idx_start: - continue - elif i == idx_start: - latents = latents_for_injection.clone() - if do_mix_latents: - if i == 0: - generator = torch.Generator(device=self.device).manual_seed(int(seed_mixing_target)) - latents_mixtarget = torch.randn(size, generator=generator, device=self.device) - if i < idx_start: - latents_mixtarget = latents_for_injection[i-1].clone() - latents = interpolate_spherical(latents, latents_mixtarget, mixing_coeff) - - if i == idx_start: - do_mix_latents = False + # Set the right starting latents + if i < idx_start: + list_latents_out.append(None) + continue + elif i == idx_start: + latents = latents_start.clone() + + # Mix the latents. + if i > 0 and list_mixing_coeffs[i]>0: + latents_mixtarget = list_latents_mixing[i-1].clone() + latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i]) - if i == idx_stop: - return list_latents_out - - # print(f"diffusion iter {i}") index = total_steps - i - 1 - ts = torch.full((b,), step, device=self.device, dtype=torch.long) + ts = torch.full((1,), step, device=self.device, dtype=torch.long) outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False, quantize_denoised=False, temperature=1.0, noise_dropout=0.0, score_corrector=None,