parental mixing mode

This commit is contained in:
Johannes Stelzer 2023-02-16 11:48:45 +01:00
parent 0371868603
commit 07a5c4ffd7
3 changed files with 259 additions and 201 deletions

View File

@ -31,15 +31,9 @@ from stable_diffusion_holder import StableDiffusionHolder
torch.set_grad_enabled(False)
import gradio as gr
import copy
from dotenv import find_dotenv, load_dotenv
"""
TODOS:
- clean parameter handling
- three buttons: diffuse A, diffuse B, make transition
- collapse for easy mode
- transition quality in terms of render time
"""
#%%
@ -65,13 +59,17 @@ class BlendingFrontend():
self.list_settings = []
self.state_current = {}
self.showing_current = True
self.branch1_influence = 0.1
self.branch1_mixing_depth = 0.3
self.branch1_influence = 0.3
self.branch1_max_depth_influence = 0.6
self.branch1_influence_decay = 0.3
self.parental_influence = 0.1
self.parental_max_depth_influence = 1.0
self.parental_influence_decay = 1.0
self.nmb_branches_final = 9
self.nmb_imgs_show = 5 # don't change
self.fps = 30
self.duration_video = 15
self.t_compute_max_allowed = 15
self.duration_video = 10
self.t_compute_max_allowed = 10
self.dict_multi_trans = {}
self.dict_multi_trans_include = {}
self.multi_trans_currently_shown = []
@ -87,9 +85,20 @@ class BlendingFrontend():
self.height = 768
self.width = 768
self.init_save_dir()
def init_save_dir(self):
load_dotenv(find_dotenv(), verbose=False)
try:
self.dp_out = os.getenv("dp_out")
except Exception as e:
self.dp_out = ""
# make dummy image
def save_empty_image(self):
self.fp_img_empty = 'empty.jpg'
self.fp_img_empty = os.path.join(self.dp_out, 'empty.jpg')
Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5)
@ -116,8 +125,6 @@ class BlendingFrontend():
self.lb.set_negative_prompt(list_ui_elem[list_ui_keys.index('negative_prompt')])
self.lb.guidance_scale = list_ui_elem[list_ui_keys.index('guidance_scale')]
self.lb.guidance_scale_mid_damper = list_ui_elem[list_ui_keys.index('guidance_scale_mid_damper')]
self.lb.branch1_influence = list_ui_elem[list_ui_keys.index('branch1_influence')]
self.lb.branch1_mixing_depth = list_ui_elem[list_ui_keys.index('branch1_mixing_depth')]
self.lb.t_compute_max_allowed = list_ui_elem[list_ui_keys.index('duration_compute')]
self.lb.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
self.lb.sdh.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
@ -125,11 +132,19 @@ class BlendingFrontend():
self.lb.seed1 = list_ui_elem[list_ui_keys.index('seed1')]
self.lb.seed2 = list_ui_elem[list_ui_keys.index('seed2')]
self.lb.branch1_influence = list_ui_elem[list_ui_keys.index('branch1_influence')]
self.lb.branch1_max_depth_influence = list_ui_elem[list_ui_keys.index('branch1_max_depth_influence')]
self.lb.branch1_influence_decay = list_ui_elem[list_ui_keys.index('branch1_influence_decay')]
self.lb.parental_influence = list_ui_elem[list_ui_keys.index('parental_influence')]
self.lb.parental_max_depth_influence = list_ui_elem[list_ui_keys.index('parental_max_depth_influence')]
self.lb.parental_influence_decay = list_ui_elem[list_ui_keys.index('parental_influence_decay')]
def compute_img1(self, *args):
list_ui_elem = args
self.setup_lb(list_ui_elem)
fp_img1 = f"img1_{get_time('second')}.jpg"
fp_img1 = os.path.join(self.dp_out, f"img1_{get_time('second')}.jpg")
img1 = Image.fromarray(self.lb.compute_latents1(return_image=True))
img1.save(fp_img1)
self.save_empty_image()
@ -138,7 +153,7 @@ class BlendingFrontend():
def compute_img2(self, *args):
list_ui_elem = args
self.setup_lb(list_ui_elem)
fp_img2 = f"img2_{get_time('second')}.jpg"
fp_img2 = os.path.join(self.dp_out, f"img2_{get_time('second')}.jpg")
img2 = Image.fromarray(self.lb.compute_latents2(return_image=True))
img2.save(fp_img2)
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2]
@ -198,9 +213,11 @@ class BlendingFrontend():
def get_fp_movie(self, timestamp, is_stacked=False):
if not is_stacked:
return f"movie_{timestamp}.mp4"
fn = f"movie_{timestamp}.mp4"
else:
return f"movie_stacked_{timestamp}.mp4"
fn = f"movie_stacked_{timestamp}.mp4"
fp = os.path.join(self.dp_out, fn)
return fp
def stack_forward(self, prompt2, seed2):
@ -319,33 +336,44 @@ if __name__ == "__main__":
with gr.Blocks() as demo:
with gr.Row():
prompt1 = gr.Textbox(label="prompt 1")
negative_prompt = gr.Textbox(label="negative prompt")
prompt2 = gr.Textbox(label="prompt 2")
with gr.Row():
duration_compute = gr.Slider(10, 40, self.duration_video, step=1, label='compute budget for transition (seconds)', interactive=True)
duration_compute = gr.Slider(5, 45, self.t_compute_max_allowed, step=1, label='compute budget for transition (seconds)', interactive=True)
duration_video = gr.Slider(0.1, 30, self.duration_video, step=0.1, label='result video duration (seconds)', interactive=True)
height = gr.Slider(256, 2048, self.height, step=128, label='height', interactive=True)
width = gr.Slider(256, 2048, self.width, step=128, label='width', interactive=True)
with gr.Accordion("Advanced Settings (click to expand)", open=False):
with gr.Row():
depth_strength = gr.Slider(0.01, 0.99, self.depth_strength, step=0.01, label='depth_strength', interactive=True)
branch1_influence = gr.Slider(0.0, 1.0, self.branch1_influence, step=0.01, label='branch1_influence', interactive=True)
branch1_mixing_depth = gr.Slider(0.0, 1.0, self.branch1_mixing_depth, step=0.01, label='branch1_mixing_depth', interactive=True)
with gr.Accordion("Diffusion settings", open=True):
with gr.Row():
num_inference_steps = gr.Slider(5, 100, self.num_inference_steps, step=1, label='num_inference_steps', interactive=True)
guidance_scale = gr.Slider(1, 25, self.guidance_scale, step=0.1, label='guidance_scale', interactive=True)
guidance_scale_mid_damper = gr.Slider(0.01, 2.0, self.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
negative_prompt = gr.Textbox(label="negative prompt")
with gr.Accordion("Seeds control", open=True):
with gr.Row():
seed1 = gr.Number(420, label="seed 1", interactive=True)
b_newseed1 = gr.Button("randomize seed 1", variant='secondary')
seed2 = gr.Number(420, label="seed 2", interactive=True)
b_newseed2 = gr.Button("randomize seed 2", variant='secondary')
with gr.Accordion("Crossfeeding for last image", open=True):
with gr.Row():
branch1_influence = gr.Slider(0.0, 1.0, self.branch1_influence, step=0.01, label='crossfeed power', interactive=True)
branch1_max_depth_influence = gr.Slider(0.0, 1.0, self.branch1_max_depth_influence, step=0.01, label='crossfeed range', interactive=True)
branch1_influence_decay = gr.Slider(0.0, 1.0, self.branch1_influence_decay, step=0.01, label='crossfeed decay', interactive=True)
with gr.Accordion("Transition settings", open=True):
with gr.Row():
depth_strength = gr.Slider(0.01, 0.99, self.depth_strength, step=0.01, label='depth_strength', interactive=True)
guidance_scale_mid_damper = gr.Slider(0.01, 2.0, self.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
parental_influence = gr.Slider(0.0, 1.0, self.parental_influence, step=0.01, label='parental power', interactive=True)
parental_max_depth_influence = gr.Slider(0.0, 1.0, self.parental_max_depth_influence, step=0.01, label='parental range', interactive=True)
parental_influence_decay = gr.Slider(0.0, 1.0, self.parental_influence_decay, step=0.01, label='parental decay', interactive=True)
with gr.Row():
b_compute1 = gr.Button('compute first image', variant='primary')
b_compute_transition = gr.Button('compute transition', variant='primary')
@ -373,7 +401,8 @@ if __name__ == "__main__":
dict_ui_elem["depth_strength"] = depth_strength
dict_ui_elem["branch1_influence"] = branch1_influence
dict_ui_elem["branch1_mixing_depth"] = branch1_mixing_depth
dict_ui_elem["branch1_max_depth_influence"] = branch1_max_depth_influence
dict_ui_elem["branch1_influence_decay"] = branch1_influence_decay
dict_ui_elem["num_inference_steps"] = num_inference_steps
dict_ui_elem["guidance_scale"] = guidance_scale
@ -381,6 +410,10 @@ if __name__ == "__main__":
dict_ui_elem["seed1"] = seed1
dict_ui_elem["seed2"] = seed2
dict_ui_elem["parental_max_depth_influence"] = parental_max_depth_influence
dict_ui_elem["parental_influence"] = parental_influence
dict_ui_elem["parental_influence_decay"] = parental_influence_decay
# Convert to list, as gradio doesn't seem to accept dicts
list_ui_elem = []
list_ui_keys = []

View File

@ -107,8 +107,16 @@ class LatentBlending():
self.noise_level_upscaling = 20
self.list_injection_idx = None
self.list_nmb_branches = None
# Mixing parameters
self.branch1_influence = 0.0
self.branch1_mixing_depth = 0.65
self.branch1_max_depth_influence = 0.65
self.branch1_influence_decay = 0.8
self.parental_influence = 0.0
self.parental_max_depth_influence = 1.0
self.parental_influence_decay = 1.0
self.branch1_insertion_completed = False
self.set_guidance_scale(guidance_scale)
self.init_mode()
@ -389,10 +397,6 @@ class LatentBlending():
fixed_seeds: Optional[List[int]] = None,
):
# # FIXME: deal with these tree args later
# self.num_inference_steps = 30
# self.t_compute_max_allowed = 60
# Sanity checks first
assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before'
assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before'
@ -411,17 +415,15 @@ class LatentBlending():
self.sdh.num_inference_steps = self.num_inference_steps
# Compute / Recycle first image
if not recycle_img1:
if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps:
list_latents1 = self.compute_latents1()
else:
# FIXME: check if latents there...
list_latents1 = self.tree_latents[0]
# Compute / Recycle first image
if not recycle_img2:
if not recycle_img2 or len(self.tree_latents[-1]) != self.num_inference_steps:
list_latents2 = self.compute_latents2()
else:
# FIXME: check if latents there...
list_latents2 = self.tree_latents[-1]
# Reset the tree, injecting the edge latents1/2 we just generated/recycled
@ -438,7 +440,7 @@ class LatentBlending():
while t_compute < self.t_compute_max_allowed:
list_compute_steps = self.num_inference_steps - list_idx_injection
list_compute_steps *= list_nmb_stems
t_compute = np.sum(list_compute_steps) * self.dt_per_diff
t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15*np.sum(list_nmb_stems)
increase_done = False
for s_idx in range(len(list_nmb_stems)-1):
if list_nmb_stems[s_idx+1] / list_nmb_stems[s_idx] >= 2:
@ -449,13 +451,14 @@ class LatentBlending():
list_nmb_stems[-1] += 1
# print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}")
# Run iteratively
# Run iteratively, always inserting new branches where they are needed most
for s_idx in tqdm(range(len(list_idx_injection))):
nmb_stems = list_nmb_stems[s_idx]
idx_injection = list_idx_injection[s_idx]
for i in range(nmb_stems):
fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection)
self.set_guidance_mid_dampening(fract_mixing)
# print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}")
list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection)
self.insert_into_tree(fract_mixing, idx_injection, list_latents)
@ -465,6 +468,14 @@ class LatentBlending():
def get_mixing_parameters(self, idx_injection):
r"""
Computes which parental latents should be mixed together to achieve a smooth blend.
As metric, we are using lpips image similarity. The insertion takes place
where the metric is maximal.
Args:
idx_injection: int
the index in terms of diffusion steps, where the next insertion will start.
"""
# get_lpips_similarity
similarities = []
for i in range(len(self.tree_final_imgs)-1):
@ -499,7 +510,16 @@ class LatentBlending():
def insert_into_tree(self, fract_mixing, idx_injection, list_latents):
# FIXME
r"""
Inserts all necessary parameters into the trajectory tree.
Args:
fract_mixing: float
the fraction along the transition axis [0, 1]
idx_injection: int
the index in terms of diffusion steps, where the next insertion will start.
list_latents: list
list of the latents to be inserted
"""
b_parent1, b_parent2 = get_closest_idx(fract_mixing, self.tree_fracts)
self.tree_latents.insert(b_parent1+1, list_latents)
self.tree_final_imgs.insert(b_parent1+1, self.sdh.latent2image(list_latents[-1]))
@ -508,23 +528,67 @@ class LatentBlending():
def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection):
# FIXME
r"""
Runs a diffusion trajectory, using the latents from the respective parents
Args:
fract_mixing: float
the fraction along the transition axis [0, 1]
b_parent1: int
index of parent1 to be used
b_parent2: int
index of parent2 to be used
idx_injection: int
the index in terms of diffusion steps, where the next insertion will start.
"""
list_conditionings = self.get_mixed_conditioning(fract_mixing)
fract_mixing_parental = (fract_mixing - self.tree_fracts[b_parent1]) / (self.tree_fracts[b_parent2] - self.tree_fracts[b_parent1])
idx_reversed = self.num_inference_steps - idx_injection
latents_for_injection = interpolate_spherical(
self.tree_latents[b_parent1][-idx_reversed-1],
self.tree_latents[b_parent2][-idx_reversed-1],
fract_mixing_parental)
list_latents = self.run_diffusion(list_conditionings, latents_for_injection=latents_for_injection, idx_start=idx_injection)
# idx_reversed = self.num_inference_steps - idx_injection
list_latents_parental_mix = []
for i in range(self.num_inference_steps):
latents_p1 = self.tree_latents[b_parent1][i]
latents_p2 = self.tree_latents[b_parent2][i]
if latents_p1 is None or latents_p2 is None:
latents_parental = None
else:
latents_parental = interpolate_spherical(latents_p1, latents_p2, fract_mixing_parental)
list_latents_parental_mix.append(latents_parental)
idx_mixing_stop = int(round(self.num_inference_steps*self.parental_max_depth_influence))
mixing_coeffs = idx_injection*[self.parental_influence]
nmb_mixing = idx_mixing_stop - idx_injection
if nmb_mixing > 0:
mixing_coeffs.extend(list(np.linspace(self.parental_influence, self.parental_influence*self.parental_influence_decay, nmb_mixing)))
mixing_coeffs.extend((self.num_inference_steps-len(mixing_coeffs))*[0])
latents_start = list_latents_parental_mix[idx_injection-1]
list_latents = self.run_diffusion(
list_conditionings,
latents_start = latents_start,
idx_start = idx_injection,
list_latents_mixing = list_latents_parental_mix,
mixing_coeffs = mixing_coeffs
)
return list_latents
def compute_latents1(self, return_image=False):
r"""
Runs a diffusion trajectory for the first image
Args:
return_image: bool
whether to return an image or the list of latents
"""
print("starting compute_latents1")
list_conditionings = [self.text_embedding1]
t0 = time.time()
list_latents1 = self.run_diffusion(list_conditionings, seed_source=self.seed1)
latents_start = self.get_noise(self.seed1)
list_latents1 = self.run_diffusion(
list_conditionings,
latents_start = latents_start,
idx_start = 0
)
t1 = time.time()
self.dt_per_diff = (t1-t0) / self.num_inference_steps
self.tree_latents[0] = list_latents1
@ -533,25 +597,31 @@ class LatentBlending():
else:
return list_latents1
def compute_latents2(self, return_image=False):
print("starting compute_latents2")
list_conditionings = [self.text_embedding2
]
r"""
Runs a diffusion trajectory for the last image, which may be affected by the first image's trajectory.
Args:
return_image: bool
whether to return an image or the list of latents
"""
list_conditionings = [self.text_embedding2]
latents_start = self.get_noise(self.seed2)
# Influence from branch1
if self.branch1_influence > 0.0:
self.branch1_influence = np.clip(self.branch1_influence, 0, 1)
self.branch1_mixing_depth = np.clip(self.branch1_mixing_depth, 0, 1)
idx_crossfeed = int(round(self.num_inference_steps*self.branch1_mixing_depth))
# Set up the mixing_coeffs
idx_mixing_stop = int(round(self.num_inference_steps*self.branch1_max_depth_influence))
mixing_coeffs = list(np.linspace(self.branch1_influence, self.branch1_influence*self.branch1_influence_decay, idx_mixing_stop))
mixing_coeffs.extend((self.num_inference_steps-idx_mixing_stop)*[0])
list_latents_mixing = self.tree_latents[0]
list_latents2 = self.run_diffusion(
list_conditionings,
idx_start=idx_crossfeed,
latents_for_injection=self.tree_latents[0],
seed_source=self.seed2,
seed_mixing_target=self.seed1,
mixing_coeff=self.branch1_influence)
latents_start = latents_start,
idx_start = 0,
list_latents_mixing = list_latents_mixing,
mixing_coeffs = mixing_coeffs
)
else:
list_latents2 = self.run_diffusion(list_conditionings)
list_latents2 = self.run_diffusion(list_conditionings, latents_start)
self.tree_latents[-1] = list_latents2
if return_image:
@ -559,9 +629,14 @@ class LatentBlending():
else:
return list_latents2
def get_noise(self, seed):
generator = torch.Generator(device=self.sdh.device).manual_seed(int(seed))
shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
C, H, W = shape_latents
return torch.randn((1, C, H, W), generator=generator, device=self.sdh.device)
def run_transition_OLD(
def run_transition_legacy(
self,
recycle_img1: Optional[bool] = False,
recycle_img2: Optional[bool] = False,
@ -569,6 +644,7 @@ class LatentBlending():
premature_stop: Optional[int] = np.inf,
):
r"""
Old legacy function for computing transitions.
Returns a list of transition images using spherical latent blending.
Args:
recycle_img1: Optional[bool]:
@ -610,9 +686,9 @@ class LatentBlending():
if self.branch1_influence > 0.0 and not self.branch1_insertion_completed:
assert self.list_nmb_branches[0]==2, 'branch1 influnce currently requires the self.list_nmb_branches[0] = 0'
self.branch1_influence = np.clip(self.branch1_influence, 0, 1)
self.branch1_mixing_depth = np.clip(self.branch1_mixing_depth, 0, 1)
self.branch1_max_depth_influence = np.clip(self.branch1_max_depth_influence, 0, 1)
self.list_nmb_branches.insert(1, 2)
idx_crossfeed = int(round(self.list_injection_idx[1]*self.branch1_mixing_depth))
idx_crossfeed = int(round(self.list_injection_idx[1]*self.branch1_max_depth_influence))
self.list_injection_idx_ext.insert(1, idx_crossfeed)
self.tree_fracts.insert(1, self.tree_fracts[0])
self.tree_status.insert(1, self.tree_status[0])
@ -790,27 +866,27 @@ class LatentBlending():
def run_diffusion(
self,
list_conditionings,
latents_for_injection: torch.FloatTensor = None,
idx_start: int = -1,
idx_stop: int = -1,
seed_source: int = -1,
seed_mixing_target: int = -1,
mixing_coeff: float = 0.0,
latents_start: torch.FloatTensor = None,
idx_start: int = 0,
list_latents_mixing = None,
mixing_coeffs = 0.0,
return_image: Optional[bool] = False
):
r"""
Wrapper function for run_diffusion_standard and run_diffusion_inpaint.
Wrapper function for diffusion runners.
Depending on the mode, the correct one will be executed.
Args:
list_conditionings: List of all conditionings for the diffusion model.
latents_for_injection: torch.FloatTensor
latents_start: torch.FloatTensor
Latents that are used for injection
idx_start: int
Index of the diffusion process start and where the latents_for_injection are injected
idx_stop: int
Index of the diffusion process end.
FIXME ARGS
list_latents_mixing: torch.FloatTensor
List of latents (latent trajectories) that are used for mixing
mixing_coeffs: float or list
Coefficients, how strong each element of list_latents_mixing will be mixed in.
return_image: Optional[bool]
Optionally return image directly
"""
@ -822,26 +898,25 @@ class LatentBlending():
if self.mode == 'standard':
text_embeddings = list_conditionings[0]
return self.sdh.run_diffusion_standard(
text_embeddings,
latents_for_injection=latents_for_injection,
idx_start=idx_start,
idx_stop=idx_stop,
seed_source=seed_source,
seed_mixing_target=seed_mixing_target,
mixing_coeff=mixing_coeff,
return_image=return_image,
text_embeddings = text_embeddings,
latents_start = latents_start,
idx_start = idx_start,
list_latents_mixing = list_latents_mixing,
mixing_coeffs = mixing_coeffs,
return_image = return_image,
)
elif self.mode == 'inpaint':
text_embeddings = list_conditionings[0]
assert self.sdh.image_source is not None, "image_source is None. Please run init_inpainting first."
assert self.sdh.mask_image is not None, "image_source is None. Please run init_inpainting first."
return self.sdh.run_diffusion_inpaint(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
# FIXME LONG LINE
elif self.mode == 'upscale':
cond = list_conditionings[0]
uc_full = list_conditionings[1]
return self.sdh.run_diffusion_upscaling(cond, uc_full, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
# elif self.mode == 'inpaint':
# text_embeddings = list_conditionings[0]
# assert self.sdh.image_source is not None, "image_source is None. Please run init_inpainting first."
# assert self.sdh.mask_image is not None, "image_source is None. Please run init_inpainting first."
# return self.sdh.run_diffusion_inpaint(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
# # FIXME LONG LINE and bad args
# elif self.mode == 'upscale':
# cond = list_conditionings[0]
# uc_full = list_conditionings[1]
# return self.sdh.run_diffusion_upscaling(cond, uc_full, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
def run_upscaling_step1(
self,
@ -1100,7 +1175,11 @@ class LatentBlending():
def get_lpips_similarity(self, imgA, imgB):
# FIXME
r"""
Computes the image similarity between two images imgA and imgB.
Used to determine the optimal point of insertion to create smooth transitions.
High values indicate low similarity.
"""
tensorA = torch.from_numpy(imgA).float().cuda(self.device)
tensorA = 2*tensorA/255.0 - 1
tensorA = tensorA.permute([2,0,1]).unsqueeze(0)
@ -1406,55 +1485,24 @@ if __name__ == "__main__":
# Run latent blending
self.branch1_influence = 0.3
self.branch1_mixing_depth = 0.4
self.run_transition(depth_strength=depth_strength, fixed_seeds=fixed_seeds)
#%%
self.branch1_influence = 0.3
self.branch1_mixing_depth = 0.5
img2 = self.compute_latents2(return_image=True)
Image.fromarray(img2)
self.branch1_max_depth_influence = 0.4
# self.run_transition(depth_strength=depth_strength, fixed_seeds=fixed_seeds)
self.seed1=21312
img1 =self.compute_latents1(True)
#%
self.seed2=1234121
self.branch1_influence = 0.7
self.branch1_max_depth_influence = 0.3
self.branch1_influence_decay = 0.3
img2 =self.compute_latents2(True)
# Image.fromarray(np.concatenate((img1, img2), axis=1))
#%%
idx_injection = 15
fract_mixing = 0.5
list_conditionings = self.get_mixed_conditioning(fract_mixing)
latents_for_injection = interpolate_spherical(self.tree_latents[0][idx_injection], self.tree_latents[-1][idx_injection], fract_mixing)
list_latents = self.run_diffusion(list_conditionings, latents_for_injection=latents_for_injection, idx_start=idx_injection)
img_mix = self.sdh.latent2image((list_latents[-1]))
Image.fromarray(np.concatenate((img1,img_mix,img2), axis=1)).resize((800,800//3))
#%% scheme
# init scheme
list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps-1, 2)
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
#%%
list_compute_steps = self.num_inference_steps - list_idx_injection
list_compute_steps *= list_nmb_stems
t_compute = np.sum(list_compute_steps) * self.dt_per_diff
increase_done = False
for s_idx in range(len(list_nmb_stems)-1):
if list_nmb_stems[s_idx+1] / list_nmb_stems[s_idx] >= 3:
list_nmb_stems[s_idx] += 1
increase_done = True
break
if not increase_done:
list_nmb_stems[-1] += 1
print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}")
#%%
imgs_transition = self.tree_final_imgs
# Let's get more cheap frames via linear interpolation (duration_transition*fps frames)
imgs_transition_ext = add_frames_linear_interp(imgs_transition, 15, fps)
# Save as MP4
fp_movie = "test.mp4"
if os.path.isfile(fp_movie):
os.remove(fp_movie)
ms = MovieSaver(fp_movie, fps=fps, shape_hw=[sdh.height, sdh.width])
for img in tqdm(imgs_transition_ext):
ms.write_frame(img)
ms.finalize()
t0 = time.time()
self.t_compute_max_allowed = 30
self.parental_max_depth_influence = 1.0
self.parental_influence = 0.0
self.parental_influence_decay = 1.0
imgs_transition = self.run_transition(recycle_img1=True, recycle_img2=True)
t1 = time.time()
print(f"took: {t1-t0}s")

View File

@ -278,12 +278,10 @@ class StableDiffusionHolder:
def run_diffusion_standard(
self,
text_embeddings: torch.FloatTensor,
latents_for_injection = None,
idx_start: int = -1,
idx_stop: int = -1,
seed_source: int = -1,
seed_mixing_target: int = -1,
mixing_coeff: float = 0.0,
latents_start: torch.FloatTensor,
idx_start: int = 0,
list_latents_mixing = None,
mixing_coeffs = 0.0,
return_image: Optional[bool] = False,
):
r"""
@ -297,34 +295,26 @@ class StableDiffusionHolder:
Latents that are used for injection
idx_start: int
Index of the diffusion process start and where the latents_for_injection are injected
idx_stop: int
Index of the diffusion process end.
mixing_coeff:
# FIXME
seed_source:
# FIXME
seed_mixing:
# FIXME
return_image: Optional[bool]
Optionally return image directly
"""
if latents_for_injection is None:
do_inject_latents = False
do_mix_latents = False
# Asserts
if type(mixing_coeffs) == float:
list_mixing_coeffs = self.num_inference_steps*[mixing_coeffs]
elif type(mixing_coeffs) == list:
assert len(mixing_coeffs) == self.num_inference_steps
list_mixing_coeffs = mixing_coeffs
else:
if mixing_coeff > 0.0:
do_inject_latents = False
do_mix_latents = True
assert seed_mixing_target != -1, "Set to correct seed for mixing"
else:
do_inject_latents = True
do_mix_latents = False
raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps")
if np.sum(list_mixing_coeffs) > 0:
assert len(list_latents_mixing) == self.num_inference_steps
precision_scope = autocast if self.precision == "autocast" else nullcontext
generator = torch.Generator(device=self.device).manual_seed(int(seed_source))
with precision_scope("cuda"):
with self.model.ema_scope():
@ -332,14 +322,10 @@ class StableDiffusionHolder:
uc = self.model.get_learned_conditioning(self.negative_prompt)
else:
uc = None
shape_latents = [self.C, self.height // self.f, self.width // self.f]
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
C, H, W = shape_latents
size = (1, C, H, W)
b = size[0]
latents = torch.randn(size, generator=generator, device=self.device)
latents = latents_start.clone()
timesteps = self.sampler.ddim_timesteps
@ -349,29 +335,20 @@ class StableDiffusionHolder:
# collect latents
list_latents_out = []
for i, step in enumerate(time_range):
if do_inject_latents:
# Inject latent at right place
# Set the right starting latents
if i < idx_start:
list_latents_out.append(None)
continue
elif i == idx_start:
latents = latents_for_injection.clone()
if do_mix_latents:
if i == 0:
generator = torch.Generator(device=self.device).manual_seed(int(seed_mixing_target))
latents_mixtarget = torch.randn(size, generator=generator, device=self.device)
if i < idx_start:
latents_mixtarget = latents_for_injection[i-1].clone()
latents = interpolate_spherical(latents, latents_mixtarget, mixing_coeff)
latents = latents_start.clone()
if i == idx_start:
do_mix_latents = False
# Mix the latents.
if i > 0 and list_mixing_coeffs[i]>0:
latents_mixtarget = list_latents_mixing[i-1].clone()
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
if i == idx_stop:
return list_latents_out
# print(f"diffusion iter {i}")
index = total_steps - i - 1
ts = torch.full((b,), step, device=self.device, dtype=torch.long)
ts = torch.full((1,), step, device=self.device, dtype=torch.long)
outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False,
quantize_denoised=False, temperature=1.0,
noise_dropout=0.0, score_corrector=None,