better branch handling
This commit is contained in:
parent
a69b86e2cd
commit
e889c2a0cc
|
@ -415,7 +415,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# img_refx = self.pipe(prompt=prompt1, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale)[0]
|
# img_refx = self.pipe(prompt=prompt1, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale)[0]
|
||||||
|
|
||||||
img_refx = self.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=True)
|
img_refx = self.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=False)
|
||||||
|
|
||||||
dt_ref = time.time() - t0
|
dt_ref = time.time() - t0
|
||||||
img_refx.save(f"x_{prefix}_{i}.jpg")
|
img_refx.save(f"x_{prefix}_{i}.jpg")
|
||||||
|
|
|
@ -76,17 +76,11 @@ class LatentBlending():
|
||||||
self.tree_status = None
|
self.tree_status = None
|
||||||
self.tree_final_imgs = []
|
self.tree_final_imgs = []
|
||||||
|
|
||||||
self.list_nmb_branches_prev = []
|
|
||||||
self.list_injection_idx_prev = []
|
|
||||||
self.text_embedding1 = None
|
self.text_embedding1 = None
|
||||||
self.text_embedding2 = None
|
self.text_embedding2 = None
|
||||||
self.image1_lowres = None
|
self.image1_lowres = None
|
||||||
self.image2_lowres = None
|
self.image2_lowres = None
|
||||||
self.negative_prompt = None
|
self.negative_prompt = None
|
||||||
self.num_inference_steps = self.dh.num_inference_steps
|
|
||||||
self.noise_level_upscaling = 20
|
|
||||||
self.list_injection_idx = None
|
|
||||||
self.list_nmb_branches = None
|
|
||||||
|
|
||||||
# Mixing parameters
|
# Mixing parameters
|
||||||
self.branch1_crossfeed_power = 0.0
|
self.branch1_crossfeed_power = 0.0
|
||||||
|
@ -100,11 +94,34 @@ class LatentBlending():
|
||||||
self.set_guidance_scale(guidance_scale)
|
self.set_guidance_scale(guidance_scale)
|
||||||
self.multi_transition_img_first = None
|
self.multi_transition_img_first = None
|
||||||
self.multi_transition_img_last = None
|
self.multi_transition_img_last = None
|
||||||
self.dt_per_diff = 0
|
self.dt_unet_step = 0
|
||||||
self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
|
self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
|
||||||
|
|
||||||
self.set_prompt1("")
|
self.set_prompt1("")
|
||||||
self.set_prompt2("")
|
self.set_prompt2("")
|
||||||
|
|
||||||
|
self.set_num_inference_steps()
|
||||||
|
self.benchmark_speed()
|
||||||
|
self.set_branching()
|
||||||
|
|
||||||
|
def benchmark_speed(self):
|
||||||
|
"""
|
||||||
|
Measures the time per diffusion step and for the vae decoding
|
||||||
|
"""
|
||||||
|
|
||||||
|
text_embeddings = self.dh.get_text_embedding("test")
|
||||||
|
latents_start = self.dh.get_noise(np.random.randint(111111))
|
||||||
|
# warmup
|
||||||
|
list_latents = self.dh.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=False, idx_start=self.num_inference_steps-1)
|
||||||
|
# bench unet
|
||||||
|
t0 = time.time()
|
||||||
|
list_latents = self.dh.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=False, idx_start=self.num_inference_steps-1)
|
||||||
|
self.dt_unet_step = time.time() - t0
|
||||||
|
|
||||||
|
# bench vae
|
||||||
|
t0 = time.time()
|
||||||
|
img = self.dh.latent2image(list_latents[-1])
|
||||||
|
self.dt_vae = time.time() - t0
|
||||||
|
|
||||||
def set_dimensions(self, size_output=None):
|
def set_dimensions(self, size_output=None):
|
||||||
r"""
|
r"""
|
||||||
|
@ -208,28 +225,21 @@ class LatentBlending():
|
||||||
image: Image
|
image: Image
|
||||||
"""
|
"""
|
||||||
self.image2_lowres = image
|
self.image2_lowres = image
|
||||||
|
|
||||||
def run_transition(
|
def set_num_inference_steps(self, num_inference_steps=None):
|
||||||
self,
|
if self.dh.is_sdxl_turbo:
|
||||||
recycle_img1: Optional[bool] = False,
|
if num_inference_steps is None:
|
||||||
recycle_img2: Optional[bool] = False,
|
num_inference_steps = 4
|
||||||
num_inference_steps: Optional[int] = 30,
|
else:
|
||||||
list_idx_injection: Optional[int] = None,
|
if num_inference_steps is None:
|
||||||
list_nmb_stems: Optional[int] = None,
|
num_inference_steps = 30
|
||||||
depth_strength: Optional[float] = 0.3,
|
|
||||||
t_compute_max_allowed: Optional[float] = None,
|
self.num_inference_steps = num_inference_steps
|
||||||
nmb_max_branches: Optional[int] = None,
|
self.dh.set_num_inference_steps(num_inference_steps)
|
||||||
fixed_seeds: Optional[List[int]] = None):
|
|
||||||
r"""
|
def set_branching(self, depth_strength=None, t_compute_max_allowed=None, nmb_max_branches=None):
|
||||||
Function for computing transitions.
|
"""
|
||||||
Returns a list of transition images using spherical latent blending.
|
Sets the branching structure of the blending tree. Default arguments depend on pipe!
|
||||||
Args:
|
|
||||||
recycle_img1: Optional[bool]:
|
|
||||||
Don't recompute the latents for the first keyframe (purely prompt1). Saves compute.
|
|
||||||
recycle_img2: Optional[bool]:
|
|
||||||
Don't recompute the latents for the second keyframe (purely prompt2). Saves compute.
|
|
||||||
num_inference_steps:
|
|
||||||
Number of diffusion steps. Higher values will take more compute time.
|
|
||||||
depth_strength:
|
depth_strength:
|
||||||
Determines how deep the first injection will happen.
|
Determines how deep the first injection will happen.
|
||||||
Deeper injections will cause (unwanted) formation of new structures,
|
Deeper injections will cause (unwanted) formation of new structures,
|
||||||
|
@ -241,6 +251,45 @@ class LatentBlending():
|
||||||
Either provide t_compute_max_allowed or nmb_max_branches. The maximum number of branches to be computed. Higher values give better
|
Either provide t_compute_max_allowed or nmb_max_branches. The maximum number of branches to be computed. Higher values give better
|
||||||
results. Use this if you want to have controllable results independent
|
results. Use this if you want to have controllable results independent
|
||||||
of your computer.
|
of your computer.
|
||||||
|
"""
|
||||||
|
if self.dh.is_sdxl_turbo:
|
||||||
|
assert t_compute_max_allowed is None, "time-based branching not supported for SDXL Turbo"
|
||||||
|
if depth_strength is not None:
|
||||||
|
idx_inject = int(round(self.num_inference_steps*depth_strength))
|
||||||
|
else:
|
||||||
|
idx_inject = 2
|
||||||
|
if nmb_max_branches is None:
|
||||||
|
nmb_max_branches = 10
|
||||||
|
|
||||||
|
self.list_idx_injection = [idx_inject]
|
||||||
|
self.list_nmb_stems = [nmb_max_branches]
|
||||||
|
|
||||||
|
else:
|
||||||
|
if depth_strength is None:
|
||||||
|
depth_strength = 0.5
|
||||||
|
if t_compute_max_allowed is None and nmb_max_branches is None:
|
||||||
|
t_compute_max_allowed = 20
|
||||||
|
elif t_compute_max_allowed is not None and nmb_max_branches is not None:
|
||||||
|
raise ValueErorr("Either specify t_compute_max_allowed or nmb_max_branches")
|
||||||
|
|
||||||
|
self.list_idx_injection, self.list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches)
|
||||||
|
|
||||||
|
def run_transition(
|
||||||
|
self,
|
||||||
|
recycle_img1: Optional[bool] = False,
|
||||||
|
recycle_img2: Optional[bool] = False,
|
||||||
|
fixed_seeds: Optional[List[int]] = None):
|
||||||
|
r"""
|
||||||
|
Function for computing transitions.
|
||||||
|
Returns a list of transition images using spherical latent blending.
|
||||||
|
Args:
|
||||||
|
recycle_img1: Optional[bool]:
|
||||||
|
Don't recompute the latents for the first keyframe (purely prompt1). Saves compute.
|
||||||
|
recycle_img2: Optional[bool]:
|
||||||
|
Don't recompute the latents for the second keyframe (purely prompt2). Saves compute.
|
||||||
|
num_inference_steps:
|
||||||
|
Number of diffusion steps. Higher values will take more compute time.
|
||||||
|
|
||||||
fixed_seeds: Optional[List[int)]:
|
fixed_seeds: Optional[List[int)]:
|
||||||
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
|
You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
|
||||||
Otherwise random seeds will be taken.
|
Otherwise random seeds will be taken.
|
||||||
|
@ -261,12 +310,7 @@ class LatentBlending():
|
||||||
self.seed1 = fixed_seeds[0]
|
self.seed1 = fixed_seeds[0]
|
||||||
self.seed2 = fixed_seeds[1]
|
self.seed2 = fixed_seeds[1]
|
||||||
|
|
||||||
# Ensure correct num_inference_steps in holder
|
|
||||||
if self.dh.is_sdxl_turbo:
|
|
||||||
num_inference_steps = 4 #ideal results
|
|
||||||
self.num_inference_steps = num_inference_steps
|
|
||||||
self.dh.set_num_inference_steps(num_inference_steps)
|
|
||||||
|
|
||||||
# Compute / Recycle first image
|
# Compute / Recycle first image
|
||||||
if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps:
|
if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps:
|
||||||
list_latents1 = self.compute_latents1()
|
list_latents1 = self.compute_latents1()
|
||||||
|
@ -291,16 +335,13 @@ class LatentBlending():
|
||||||
self.parental_crossfeed_power = 1.0
|
self.parental_crossfeed_power = 1.0
|
||||||
self.parental_crossfeed_power_decay = 1.0
|
self.parental_crossfeed_power_decay = 1.0
|
||||||
self.parental_crossfeed_range = 1.0
|
self.parental_crossfeed_range = 1.0
|
||||||
list_idx_injection = [2]
|
|
||||||
list_nmb_stems = [10]
|
|
||||||
else:
|
|
||||||
list_idx_injection, list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches)
|
|
||||||
|
|
||||||
# Run iteratively, starting with the longest trajectory.
|
# Run iteratively, starting with the longest trajectory.
|
||||||
# Always inserting new branches where they are needed most according to image similarity
|
# Always inserting new branches where they are needed most according to image similarity
|
||||||
for s_idx in tqdm(range(len(list_idx_injection))):
|
for s_idx in tqdm(range(len(self.list_idx_injection))):
|
||||||
nmb_stems = list_nmb_stems[s_idx]
|
nmb_stems = self.list_nmb_stems[s_idx]
|
||||||
idx_injection = list_idx_injection[s_idx]
|
idx_injection = self.list_idx_injection[s_idx]
|
||||||
|
|
||||||
for i in range(nmb_stems):
|
for i in range(nmb_stems):
|
||||||
fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection)
|
fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection)
|
||||||
|
@ -310,6 +351,9 @@ class LatentBlending():
|
||||||
# print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection} bp1 {b_parent1} bp2 {b_parent2}")
|
# print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection} bp1 {b_parent1} bp2 {b_parent2}")
|
||||||
|
|
||||||
return self.tree_final_imgs
|
return self.tree_final_imgs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def compute_latents1(self, return_image=False):
|
def compute_latents1(self, return_image=False):
|
||||||
r"""
|
r"""
|
||||||
|
@ -327,7 +371,7 @@ class LatentBlending():
|
||||||
latents_start=latents_start,
|
latents_start=latents_start,
|
||||||
idx_start=0)
|
idx_start=0)
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
self.dt_per_diff = (t1 - t0) / self.num_inference_steps
|
self.dt_unet_step = (t1 - t0) / self.num_inference_steps
|
||||||
self.tree_latents[0] = list_latents1
|
self.tree_latents[0] = list_latents1
|
||||||
if return_image:
|
if return_image:
|
||||||
return self.dh.latent2image(list_latents1[-1])
|
return self.dh.latent2image(list_latents1[-1])
|
||||||
|
@ -447,8 +491,8 @@ class LatentBlending():
|
||||||
while not stop_criterion_reached:
|
while not stop_criterion_reached:
|
||||||
list_compute_steps = self.num_inference_steps - list_idx_injection
|
list_compute_steps = self.num_inference_steps - list_idx_injection
|
||||||
list_compute_steps *= list_nmb_stems
|
list_compute_steps *= list_nmb_stems
|
||||||
t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15 * np.sum(list_nmb_stems)
|
t_compute = np.sum(list_compute_steps) * self.dt_unet_step + self.dt_vae * np.sum(list_nmb_stems)
|
||||||
t_compute += 2 * self.num_inference_steps * self.dt_per_diff # outer branches
|
t_compute += 2 * (self.num_inference_steps * self.dt_unet_step + self.dt_vae) # outer branches
|
||||||
increase_done = False
|
increase_done = False
|
||||||
for s_idx in range(len(list_nmb_stems) - 1):
|
for s_idx in range(len(list_nmb_stems) - 1):
|
||||||
if list_nmb_stems[s_idx + 1] / list_nmb_stems[s_idx] >= 1:
|
if list_nmb_stems[s_idx + 1] / list_nmb_stems[s_idx] >= 1:
|
||||||
|
@ -765,8 +809,8 @@ if __name__ == "__main__":
|
||||||
from diffusers_holder import DiffusersHolder
|
from diffusers_holder import DiffusersHolder
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from diffusers import AutoencoderTiny
|
from diffusers import AutoencoderTiny
|
||||||
# pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||||
pretrained_model_name_or_path = "stabilityai/sdxl-turbo"
|
# pretrained_model_name_or_path = "stabilityai/sdxl-turbo"
|
||||||
|
|
||||||
|
|
||||||
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16, variant="fp16")
|
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16, variant="fp16")
|
||||||
|
@ -776,8 +820,8 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
dh = DiffusersHolder(pipe)
|
dh = DiffusersHolder(pipe)
|
||||||
# %% Next let's set up all parameters
|
# %% Next let's set up all parameters
|
||||||
size_output = (512, 512)
|
# size_output = (512, 512)
|
||||||
# size_output = (1024, 1024)
|
size_output = (1024, 1024)
|
||||||
prompt1 = "photo of underwater landscape, fish, und the sea, incredible detail, high resolution"
|
prompt1 = "photo of underwater landscape, fish, und the sea, incredible detail, high resolution"
|
||||||
prompt2 = "rendering of an alien planet, strange plants, strange creatures, surreal"
|
prompt2 = "rendering of an alien planet, strange plants, strange creatures, surreal"
|
||||||
negative_prompt = "blurry, ugly, pale" # Optional
|
negative_prompt = "blurry, ugly, pale" # Optional
|
||||||
|
@ -793,10 +837,9 @@ if __name__ == "__main__":
|
||||||
lb.set_prompt2(prompt2)
|
lb.set_prompt2(prompt2)
|
||||||
lb.set_dimensions(size_output)
|
lb.set_dimensions(size_output)
|
||||||
lb.set_negative_prompt(negative_prompt)
|
lb.set_negative_prompt(negative_prompt)
|
||||||
|
|
||||||
|
|
||||||
# Run latent blending
|
# Run latent blending
|
||||||
lb.run_transition(fixed_seeds=[420, 421], t_compute_max_allowed=15)
|
lb.run_transition(fixed_seeds=[420, 421])
|
||||||
|
|
||||||
# Save movie
|
# Save movie
|
||||||
fp_movie = f'test.mp4'
|
fp_movie = f'test.mp4'
|
||||||
|
@ -804,4 +847,4 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#%%
|
|
||||||
|
|
Loading…
Reference in New Issue