better branch handling

This commit is contained in:
Johannes Stelzer 2024-01-09 15:31:17 +01:00
parent a69b86e2cd
commit e889c2a0cc
2 changed files with 96 additions and 53 deletions

View File

@ -415,7 +415,7 @@ if __name__ == "__main__":
# img_refx = self.pipe(prompt=prompt1, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale)[0]
img_refx = self.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=True)
img_refx = self.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=False)
dt_ref = time.time() - t0
img_refx.save(f"x_{prefix}_{i}.jpg")

View File

@ -76,17 +76,11 @@ class LatentBlending():
self.tree_status = None
self.tree_final_imgs = []
self.list_nmb_branches_prev = []
self.list_injection_idx_prev = []
self.text_embedding1 = None
self.text_embedding2 = None
self.image1_lowres = None
self.image2_lowres = None
self.negative_prompt = None
self.num_inference_steps = self.dh.num_inference_steps
self.noise_level_upscaling = 20
self.list_injection_idx = None
self.list_nmb_branches = None
# Mixing parameters
self.branch1_crossfeed_power = 0.0
@ -100,11 +94,34 @@ class LatentBlending():
self.set_guidance_scale(guidance_scale)
self.multi_transition_img_first = None
self.multi_transition_img_last = None
self.dt_per_diff = 0
self.dt_unet_step = 0
self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
self.set_prompt1("")
self.set_prompt2("")
self.set_num_inference_steps()
self.benchmark_speed()
self.set_branching()
def benchmark_speed(self):
"""
Measures the time per diffusion step and for the vae decoding
"""
text_embeddings = self.dh.get_text_embedding("test")
latents_start = self.dh.get_noise(np.random.randint(111111))
# warmup
list_latents = self.dh.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=False, idx_start=self.num_inference_steps-1)
# bench unet
t0 = time.time()
list_latents = self.dh.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=False, idx_start=self.num_inference_steps-1)
self.dt_unet_step = time.time() - t0
# bench vae
t0 = time.time()
img = self.dh.latent2image(list_latents[-1])
self.dt_vae = time.time() - t0
def set_dimensions(self, size_output=None):
r"""
@ -208,28 +225,21 @@ class LatentBlending():
image: Image
"""
self.image2_lowres = image
def run_transition(
self,
recycle_img1: Optional[bool] = False,
recycle_img2: Optional[bool] = False,
num_inference_steps: Optional[int] = 30,
list_idx_injection: Optional[int] = None,
list_nmb_stems: Optional[int] = None,
depth_strength: Optional[float] = 0.3,
t_compute_max_allowed: Optional[float] = None,
nmb_max_branches: Optional[int] = None,
fixed_seeds: Optional[List[int]] = None):
r"""
Function for computing transitions.
Returns a list of transition images using spherical latent blending.
Args:
recycle_img1: Optional[bool]:
Don't recompute the latents for the first keyframe (purely prompt1). Saves compute.
recycle_img2: Optional[bool]:
Don't recompute the latents for the second keyframe (purely prompt2). Saves compute.
num_inference_steps:
Number of diffusion steps. Higher values will take more compute time.
def set_num_inference_steps(self, num_inference_steps=None):
if self.dh.is_sdxl_turbo:
if num_inference_steps is None:
num_inference_steps = 4
else:
if num_inference_steps is None:
num_inference_steps = 30
self.num_inference_steps = num_inference_steps
self.dh.set_num_inference_steps(num_inference_steps)
def set_branching(self, depth_strength=None, t_compute_max_allowed=None, nmb_max_branches=None):
"""
Sets the branching structure of the blending tree. Default arguments depend on pipe!
depth_strength:
Determines how deep the first injection will happen.
Deeper injections will cause (unwanted) formation of new structures,
@ -241,6 +251,45 @@ class LatentBlending():
Either provide t_compute_max_allowed or nmb_max_branches. The maximum number of branches to be computed. Higher values give better
results. Use this if you want to have controllable results independent
of your computer.
"""
if self.dh.is_sdxl_turbo:
assert t_compute_max_allowed is None, "time-based branching not supported for SDXL Turbo"
if depth_strength is not None:
idx_inject = int(round(self.num_inference_steps*depth_strength))
else:
idx_inject = 2
if nmb_max_branches is None:
nmb_max_branches = 10
self.list_idx_injection = [idx_inject]
self.list_nmb_stems = [nmb_max_branches]
else:
if depth_strength is None:
depth_strength = 0.5
if t_compute_max_allowed is None and nmb_max_branches is None:
t_compute_max_allowed = 20
elif t_compute_max_allowed is not None and nmb_max_branches is not None:
raise ValueErorr("Either specify t_compute_max_allowed or nmb_max_branches")
self.list_idx_injection, self.list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches)
def run_transition(
self,
recycle_img1: Optional[bool] = False,
recycle_img2: Optional[bool] = False,
fixed_seeds: Optional[List[int]] = None):
r"""
Function for computing transitions.
Returns a list of transition images using spherical latent blending.
Args:
recycle_img1: Optional[bool]:
Don't recompute the latents for the first keyframe (purely prompt1). Saves compute.
recycle_img2: Optional[bool]:
Don't recompute the latents for the second keyframe (purely prompt2). Saves compute.
num_inference_steps:
Number of diffusion steps. Higher values will take more compute time.
fixed_seeds: Optional[List[int)]:
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
Otherwise random seeds will be taken.
@ -261,12 +310,7 @@ class LatentBlending():
self.seed1 = fixed_seeds[0]
self.seed2 = fixed_seeds[1]
# Ensure correct num_inference_steps in holder
if self.dh.is_sdxl_turbo:
num_inference_steps = 4 #ideal results
self.num_inference_steps = num_inference_steps
self.dh.set_num_inference_steps(num_inference_steps)
# Compute / Recycle first image
if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps:
list_latents1 = self.compute_latents1()
@ -291,16 +335,13 @@ class LatentBlending():
self.parental_crossfeed_power = 1.0
self.parental_crossfeed_power_decay = 1.0
self.parental_crossfeed_range = 1.0
list_idx_injection = [2]
list_nmb_stems = [10]
else:
list_idx_injection, list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches)
# Run iteratively, starting with the longest trajectory.
# Always inserting new branches where they are needed most according to image similarity
for s_idx in tqdm(range(len(list_idx_injection))):
nmb_stems = list_nmb_stems[s_idx]
idx_injection = list_idx_injection[s_idx]
for s_idx in tqdm(range(len(self.list_idx_injection))):
nmb_stems = self.list_nmb_stems[s_idx]
idx_injection = self.list_idx_injection[s_idx]
for i in range(nmb_stems):
fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection)
@ -310,6 +351,9 @@ class LatentBlending():
# print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection} bp1 {b_parent1} bp2 {b_parent2}")
return self.tree_final_imgs
def compute_latents1(self, return_image=False):
r"""
@ -327,7 +371,7 @@ class LatentBlending():
latents_start=latents_start,
idx_start=0)
t1 = time.time()
self.dt_per_diff = (t1 - t0) / self.num_inference_steps
self.dt_unet_step = (t1 - t0) / self.num_inference_steps
self.tree_latents[0] = list_latents1
if return_image:
return self.dh.latent2image(list_latents1[-1])
@ -447,8 +491,8 @@ class LatentBlending():
while not stop_criterion_reached:
list_compute_steps = self.num_inference_steps - list_idx_injection
list_compute_steps *= list_nmb_stems
t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15 * np.sum(list_nmb_stems)
t_compute += 2 * self.num_inference_steps * self.dt_per_diff # outer branches
t_compute = np.sum(list_compute_steps) * self.dt_unet_step + self.dt_vae * np.sum(list_nmb_stems)
t_compute += 2 * (self.num_inference_steps * self.dt_unet_step + self.dt_vae) # outer branches
increase_done = False
for s_idx in range(len(list_nmb_stems) - 1):
if list_nmb_stems[s_idx + 1] / list_nmb_stems[s_idx] >= 1:
@ -765,8 +809,8 @@ if __name__ == "__main__":
from diffusers_holder import DiffusersHolder
from diffusers import DiffusionPipeline
from diffusers import AutoencoderTiny
# pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
pretrained_model_name_or_path = "stabilityai/sdxl-turbo"
pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
# pretrained_model_name_or_path = "stabilityai/sdxl-turbo"
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16, variant="fp16")
@ -776,8 +820,8 @@ if __name__ == "__main__":
dh = DiffusersHolder(pipe)
# %% Next let's set up all parameters
size_output = (512, 512)
# size_output = (1024, 1024)
# size_output = (512, 512)
size_output = (1024, 1024)
prompt1 = "photo of underwater landscape, fish, und the sea, incredible detail, high resolution"
prompt2 = "rendering of an alien planet, strange plants, strange creatures, surreal"
negative_prompt = "blurry, ugly, pale" # Optional
@ -793,10 +837,9 @@ if __name__ == "__main__":
lb.set_prompt2(prompt2)
lb.set_dimensions(size_output)
lb.set_negative_prompt(negative_prompt)
# Run latent blending
lb.run_transition(fixed_seeds=[420, 421], t_compute_max_allowed=15)
lb.run_transition(fixed_seeds=[420, 421])
# Save movie
fp_movie = f'test.mp4'
@ -804,4 +847,4 @@ if __name__ == "__main__":
#%%