better branch handling

This commit is contained in:
Johannes Stelzer 2024-01-09 15:31:17 +01:00
parent a69b86e2cd
commit e889c2a0cc
2 changed files with 96 additions and 53 deletions

View File

@ -415,7 +415,7 @@ if __name__ == "__main__":
# img_refx = self.pipe(prompt=prompt1, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale)[0] # img_refx = self.pipe(prompt=prompt1, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale)[0]
img_refx = self.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=True) img_refx = self.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=False)
dt_ref = time.time() - t0 dt_ref = time.time() - t0
img_refx.save(f"x_{prefix}_{i}.jpg") img_refx.save(f"x_{prefix}_{i}.jpg")

View File

@ -76,17 +76,11 @@ class LatentBlending():
self.tree_status = None self.tree_status = None
self.tree_final_imgs = [] self.tree_final_imgs = []
self.list_nmb_branches_prev = []
self.list_injection_idx_prev = []
self.text_embedding1 = None self.text_embedding1 = None
self.text_embedding2 = None self.text_embedding2 = None
self.image1_lowres = None self.image1_lowres = None
self.image2_lowres = None self.image2_lowres = None
self.negative_prompt = None self.negative_prompt = None
self.num_inference_steps = self.dh.num_inference_steps
self.noise_level_upscaling = 20
self.list_injection_idx = None
self.list_nmb_branches = None
# Mixing parameters # Mixing parameters
self.branch1_crossfeed_power = 0.0 self.branch1_crossfeed_power = 0.0
@ -100,11 +94,34 @@ class LatentBlending():
self.set_guidance_scale(guidance_scale) self.set_guidance_scale(guidance_scale)
self.multi_transition_img_first = None self.multi_transition_img_first = None
self.multi_transition_img_last = None self.multi_transition_img_last = None
self.dt_per_diff = 0 self.dt_unet_step = 0
self.lpips = lpips.LPIPS(net='alex').cuda(self.device) self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
self.set_prompt1("") self.set_prompt1("")
self.set_prompt2("") self.set_prompt2("")
self.set_num_inference_steps()
self.benchmark_speed()
self.set_branching()
def benchmark_speed(self):
"""
Measures the time per diffusion step and for the vae decoding
"""
text_embeddings = self.dh.get_text_embedding("test")
latents_start = self.dh.get_noise(np.random.randint(111111))
# warmup
list_latents = self.dh.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=False, idx_start=self.num_inference_steps-1)
# bench unet
t0 = time.time()
list_latents = self.dh.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=False, idx_start=self.num_inference_steps-1)
self.dt_unet_step = time.time() - t0
# bench vae
t0 = time.time()
img = self.dh.latent2image(list_latents[-1])
self.dt_vae = time.time() - t0
def set_dimensions(self, size_output=None): def set_dimensions(self, size_output=None):
r""" r"""
@ -208,28 +225,21 @@ class LatentBlending():
image: Image image: Image
""" """
self.image2_lowres = image self.image2_lowres = image
def run_transition( def set_num_inference_steps(self, num_inference_steps=None):
self, if self.dh.is_sdxl_turbo:
recycle_img1: Optional[bool] = False, if num_inference_steps is None:
recycle_img2: Optional[bool] = False, num_inference_steps = 4
num_inference_steps: Optional[int] = 30, else:
list_idx_injection: Optional[int] = None, if num_inference_steps is None:
list_nmb_stems: Optional[int] = None, num_inference_steps = 30
depth_strength: Optional[float] = 0.3,
t_compute_max_allowed: Optional[float] = None, self.num_inference_steps = num_inference_steps
nmb_max_branches: Optional[int] = None, self.dh.set_num_inference_steps(num_inference_steps)
fixed_seeds: Optional[List[int]] = None):
r""" def set_branching(self, depth_strength=None, t_compute_max_allowed=None, nmb_max_branches=None):
Function for computing transitions. """
Returns a list of transition images using spherical latent blending. Sets the branching structure of the blending tree. Default arguments depend on pipe!
Args:
recycle_img1: Optional[bool]:
Don't recompute the latents for the first keyframe (purely prompt1). Saves compute.
recycle_img2: Optional[bool]:
Don't recompute the latents for the second keyframe (purely prompt2). Saves compute.
num_inference_steps:
Number of diffusion steps. Higher values will take more compute time.
depth_strength: depth_strength:
Determines how deep the first injection will happen. Determines how deep the first injection will happen.
Deeper injections will cause (unwanted) formation of new structures, Deeper injections will cause (unwanted) formation of new structures,
@ -241,6 +251,45 @@ class LatentBlending():
Either provide t_compute_max_allowed or nmb_max_branches. The maximum number of branches to be computed. Higher values give better Either provide t_compute_max_allowed or nmb_max_branches. The maximum number of branches to be computed. Higher values give better
results. Use this if you want to have controllable results independent results. Use this if you want to have controllable results independent
of your computer. of your computer.
"""
if self.dh.is_sdxl_turbo:
assert t_compute_max_allowed is None, "time-based branching not supported for SDXL Turbo"
if depth_strength is not None:
idx_inject = int(round(self.num_inference_steps*depth_strength))
else:
idx_inject = 2
if nmb_max_branches is None:
nmb_max_branches = 10
self.list_idx_injection = [idx_inject]
self.list_nmb_stems = [nmb_max_branches]
else:
if depth_strength is None:
depth_strength = 0.5
if t_compute_max_allowed is None and nmb_max_branches is None:
t_compute_max_allowed = 20
elif t_compute_max_allowed is not None and nmb_max_branches is not None:
raise ValueErorr("Either specify t_compute_max_allowed or nmb_max_branches")
self.list_idx_injection, self.list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches)
def run_transition(
self,
recycle_img1: Optional[bool] = False,
recycle_img2: Optional[bool] = False,
fixed_seeds: Optional[List[int]] = None):
r"""
Function for computing transitions.
Returns a list of transition images using spherical latent blending.
Args:
recycle_img1: Optional[bool]:
Don't recompute the latents for the first keyframe (purely prompt1). Saves compute.
recycle_img2: Optional[bool]:
Don't recompute the latents for the second keyframe (purely prompt2). Saves compute.
num_inference_steps:
Number of diffusion steps. Higher values will take more compute time.
fixed_seeds: Optional[List[int)]: fixed_seeds: Optional[List[int)]:
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2). You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
Otherwise random seeds will be taken. Otherwise random seeds will be taken.
@ -261,12 +310,7 @@ class LatentBlending():
self.seed1 = fixed_seeds[0] self.seed1 = fixed_seeds[0]
self.seed2 = fixed_seeds[1] self.seed2 = fixed_seeds[1]
# Ensure correct num_inference_steps in holder
if self.dh.is_sdxl_turbo:
num_inference_steps = 4 #ideal results
self.num_inference_steps = num_inference_steps
self.dh.set_num_inference_steps(num_inference_steps)
# Compute / Recycle first image # Compute / Recycle first image
if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps: if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps:
list_latents1 = self.compute_latents1() list_latents1 = self.compute_latents1()
@ -291,16 +335,13 @@ class LatentBlending():
self.parental_crossfeed_power = 1.0 self.parental_crossfeed_power = 1.0
self.parental_crossfeed_power_decay = 1.0 self.parental_crossfeed_power_decay = 1.0
self.parental_crossfeed_range = 1.0 self.parental_crossfeed_range = 1.0
list_idx_injection = [2]
list_nmb_stems = [10]
else:
list_idx_injection, list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches)
# Run iteratively, starting with the longest trajectory. # Run iteratively, starting with the longest trajectory.
# Always inserting new branches where they are needed most according to image similarity # Always inserting new branches where they are needed most according to image similarity
for s_idx in tqdm(range(len(list_idx_injection))): for s_idx in tqdm(range(len(self.list_idx_injection))):
nmb_stems = list_nmb_stems[s_idx] nmb_stems = self.list_nmb_stems[s_idx]
idx_injection = list_idx_injection[s_idx] idx_injection = self.list_idx_injection[s_idx]
for i in range(nmb_stems): for i in range(nmb_stems):
fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection) fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection)
@ -310,6 +351,9 @@ class LatentBlending():
# print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection} bp1 {b_parent1} bp2 {b_parent2}") # print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection} bp1 {b_parent1} bp2 {b_parent2}")
return self.tree_final_imgs return self.tree_final_imgs
def compute_latents1(self, return_image=False): def compute_latents1(self, return_image=False):
r""" r"""
@ -327,7 +371,7 @@ class LatentBlending():
latents_start=latents_start, latents_start=latents_start,
idx_start=0) idx_start=0)
t1 = time.time() t1 = time.time()
self.dt_per_diff = (t1 - t0) / self.num_inference_steps self.dt_unet_step = (t1 - t0) / self.num_inference_steps
self.tree_latents[0] = list_latents1 self.tree_latents[0] = list_latents1
if return_image: if return_image:
return self.dh.latent2image(list_latents1[-1]) return self.dh.latent2image(list_latents1[-1])
@ -447,8 +491,8 @@ class LatentBlending():
while not stop_criterion_reached: while not stop_criterion_reached:
list_compute_steps = self.num_inference_steps - list_idx_injection list_compute_steps = self.num_inference_steps - list_idx_injection
list_compute_steps *= list_nmb_stems list_compute_steps *= list_nmb_stems
t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15 * np.sum(list_nmb_stems) t_compute = np.sum(list_compute_steps) * self.dt_unet_step + self.dt_vae * np.sum(list_nmb_stems)
t_compute += 2 * self.num_inference_steps * self.dt_per_diff # outer branches t_compute += 2 * (self.num_inference_steps * self.dt_unet_step + self.dt_vae) # outer branches
increase_done = False increase_done = False
for s_idx in range(len(list_nmb_stems) - 1): for s_idx in range(len(list_nmb_stems) - 1):
if list_nmb_stems[s_idx + 1] / list_nmb_stems[s_idx] >= 1: if list_nmb_stems[s_idx + 1] / list_nmb_stems[s_idx] >= 1:
@ -765,8 +809,8 @@ if __name__ == "__main__":
from diffusers_holder import DiffusersHolder from diffusers_holder import DiffusersHolder
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from diffusers import AutoencoderTiny from diffusers import AutoencoderTiny
# pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0" pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
pretrained_model_name_or_path = "stabilityai/sdxl-turbo" # pretrained_model_name_or_path = "stabilityai/sdxl-turbo"
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16, variant="fp16") pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16, variant="fp16")
@ -776,8 +820,8 @@ if __name__ == "__main__":
dh = DiffusersHolder(pipe) dh = DiffusersHolder(pipe)
# %% Next let's set up all parameters # %% Next let's set up all parameters
size_output = (512, 512) # size_output = (512, 512)
# size_output = (1024, 1024) size_output = (1024, 1024)
prompt1 = "photo of underwater landscape, fish, und the sea, incredible detail, high resolution" prompt1 = "photo of underwater landscape, fish, und the sea, incredible detail, high resolution"
prompt2 = "rendering of an alien planet, strange plants, strange creatures, surreal" prompt2 = "rendering of an alien planet, strange plants, strange creatures, surreal"
negative_prompt = "blurry, ugly, pale" # Optional negative_prompt = "blurry, ugly, pale" # Optional
@ -793,10 +837,9 @@ if __name__ == "__main__":
lb.set_prompt2(prompt2) lb.set_prompt2(prompt2)
lb.set_dimensions(size_output) lb.set_dimensions(size_output)
lb.set_negative_prompt(negative_prompt) lb.set_negative_prompt(negative_prompt)
# Run latent blending # Run latent blending
lb.run_transition(fixed_seeds=[420, 421], t_compute_max_allowed=15) lb.run_transition(fixed_seeds=[420, 421])
# Save movie # Save movie
fp_movie = f'test.mp4' fp_movie = f'test.mp4'
@ -804,4 +847,4 @@ if __name__ == "__main__":
#%%