latent injection for inpainting

This commit is contained in:
lugo 2022-11-29 18:03:08 +01:00
parent 58cadd23d5
commit 994c77e924
2 changed files with 86 additions and 59 deletions

View File

@ -52,21 +52,25 @@ prompt1 = "photo of a futuristic alien temple in a desert, mystic, glowing, orga
lb.set_prompt1(prompt1)
lb.init_inpainting(init_empty=True)
lb.set_seed(seed0)
image_source = lb.run_diffusion(lb.text_embedding1, return_image=True)
list_latents = lb.run_diffusion(lb.text_embedding1)
image_source = lb.sdh.latent2image(list_latents[-1])
mask_image = 255*np.ones([512,512], dtype=np.uint8)
mask_image[340:420, 170:280, ] = 0
mask_image = Image.fromarray(mask_image)
#%% Next let's set up all parameters
fixed_seeds = [seed0, 280335986]
lb.inject_latents(list_latents, inject_img1=True)
fixed_seeds = [seed0, 6579436]
prompt1 = "photo of a futuristic alien temple in a desert, mystic, glowing, organic, intricate, sci-fi movie, mesmerizing, scary"
prompt2 = "aerial photo of a futuristic alien temple in a coastal area, waves clashing"
prompt2 = "aerial photo of a futuristic alien temple in a blue coastal area, the sun is shining with a bright light"
lb.set_prompt1(prompt1)
lb.set_prompt2(prompt2)
lb.init_inpainting(image_source, mask_image)
imgs_transition = lb.run_transition(fixed_seeds=fixed_seeds)
imgs_transition = lb.run_transition(recycle_img1=True, fixed_seeds=fixed_seeds)
# let's get more cheap frames via linear interpolation
duration_transition = 12

View File

@ -69,11 +69,12 @@ class LatentBlending():
thus high values can give rough transitions. Values around 2 should be fine.
"""
assert guidance_scale_mid_damper>0 and guidance_scale_mid_damper<=1.0, f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}"
self.sdh = sdh
self.device = self.sdh.device
self.width = self.sdh.width
self.height = self.sdh.height
assert guidance_scale_mid_damper>0 and guidance_scale_mid_damper<=1.0, f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}"
self.guidance_scale_mid_damper = guidance_scale_mid_damper
self.mid_compression_scaler = mid_compression_scaler
self.seed = 420 # Run self.set_seed or fixed_seeds argument in run_transition
@ -81,9 +82,9 @@ class LatentBlending():
# Initialize vars
self.prompt1 = ""
self.prompt2 = ""
self.tree_latents = []
self.tree_fracts = []
self.tree_status = []
self.tree_latents = None
self.tree_fracts = None
self.tree_status = None
self.tree_final_imgs = []
self.list_nmb_branches_prev = []
self.list_injection_idx_prev = []
@ -100,8 +101,7 @@ class LatentBlending():
def init_mode(self, mode='standard'):
r"""
Automatically sets the mode of this class, depending on the supplied pipeline.
FIXME XXX
Sets the mode of this class, either inpaint of standard.
"""
if mode == 'inpaint':
self.sdh.image_source = None
@ -268,16 +268,46 @@ class LatentBlending():
assert max(list_injection_idx) < num_inference_steps, "Decrease the injection index or strength"
assert len(list_injection_idx) == len(list_nmb_branches), "Need to have same length"
assert max(list_injection_idx) < num_inference_steps,"Injection index cannot happen after last diffusion step! Decrease list_injection_idx or list_injection_strength[-1]"
# Auto inits
list_injection_idx_ext = list_injection_idx[:]
list_injection_idx_ext.append(num_inference_steps)
# If injection at depth 0 not specified, we will start out with 2 branches
if list_injection_idx_ext[0] != 0:
list_injection_idx_ext.insert(0,0)
list_nmb_branches.insert(0,2)
assert list_nmb_branches[0] == 2, "Need to start with 2 branches. set list_nmb_branches[0]=2"
# Set attributes
self.num_inference_steps = num_inference_steps
self.sdh.num_inference_steps = num_inference_steps
self.list_nmb_branches = list_nmb_branches
self.list_injection_idx = list_injection_idx
self.list_injection_idx_ext = list_injection_idx_ext
self.init_tree_struct()
def init_tree_struct(self):
r"""
Initializes tree variables for holding latents etc.
"""
self.tree_latents = []
self.tree_fracts = []
self.tree_status = []
self.tree_final_imgs_timing = [0]*self.list_nmb_branches[-1]
nmb_blocks_time = len(self.list_injection_idx_ext)-1
for t_block in range(nmb_blocks_time):
nmb_branches = self.list_nmb_branches[t_block]
list_fract_mixing_current = get_spacing(nmb_branches, self.mid_compression_scaler)
self.tree_fracts.append(list_fract_mixing_current)
self.tree_latents.append([None]*nmb_branches)
self.tree_status.append(['untouched']*nmb_branches)
def run_transition(
self,
recycle_img1: Optional[bool] = False,
@ -313,57 +343,34 @@ class LatentBlending():
# Ensure correct num_inference_steps in holder
self.sdh.num_inference_steps = self.num_inference_steps
# Recycling? There are requirements
if recycle_img1 or recycle_img2:
if self.list_nmb_branches_prev == []:
print("Warning. You want to recycle but there is nothing here. Disabling recycling.")
recycle_img1 = False
recycle_img2 = False
elif self.list_nmb_branches_prev != self.list_nmb_branches:
print("Warning. Cannot change list_nmb_branches if recycling latent. Disabling recycling.")
recycle_img1 = False
recycle_img2 = False
elif self.list_injection_idx_prev != self.list_injection_idx:
print("Warning. Cannot change list_nmb_branches if recycling latent. Disabling recycling.")
recycle_img1 = False
recycle_img2 = False
# # Recycling? There are requirements
# if recycle_img1 or recycle_img2:
# # if self.list_nmb_branches_prev == []:
# # print("Warning. You want to recycle but there is nothing here. Disabling recycling.")
# # recycle_img1 = False
# # recycle_img2 = False
# if self.list_nmb_branches_prev != self.list_nmb_branches:
# print("Warning. Cannot change list_nmb_branches if recycling latent. Disabling recycling.")
# recycle_img1 = False
# recycle_img2 = False
# elif self.list_injection_idx_prev != self.list_injection_idx:
# print("Warning. Cannot change list_nmb_branches if recycling latent. Disabling recycling.")
# recycle_img1 = False
# recycle_img2 = False
# Make a backup for future reference
self.list_nmb_branches_prev = self.list_nmb_branches[:]
self.list_injection_idx_prev = self.list_injection_idx[:]
# Auto inits
list_injection_idx_ext = self.list_injection_idx[:]
list_nmb_branches = self.list_nmb_branches[:]
list_injection_idx_ext.append(self.num_inference_steps)
# If injection at depth 0 not specified, we will start out with 2 branches
if list_injection_idx_ext[0] != 0:
list_injection_idx_ext.insert(0,0)
list_nmb_branches.insert(0,2)
assert list_nmb_branches[0] == 2, "Need to start with 2 branches. set list_nmb_branches[0]=2"
# Pre-define entire branching tree structures
self.tree_final_imgs = [None]*self.list_nmb_branches[-1]
nmb_blocks_time = len(self.list_injection_idx_ext)-1
if not recycle_img1 and not recycle_img2:
self.tree_latents = []
self.tree_fracts = []
self.tree_status = []
self.tree_final_imgs = [None]*list_nmb_branches[-1]
self.tree_final_imgs_timing = [0]*list_nmb_branches[-1]
nmb_blocks_time = len(list_injection_idx_ext)-1
for t_block in range(nmb_blocks_time):
nmb_branches = list_nmb_branches[t_block]
# list_fract_mixing_current = np.linspace(0, 1, nmb_branches)
list_fract_mixing_current = get_spacing(nmb_branches, self.mid_compression_scaler)
self.tree_fracts.append(list_fract_mixing_current)
self.tree_latents.append([None]*nmb_branches)
self.tree_status.append(['untouched']*nmb_branches)
self.init_tree_struct()
else:
self.tree_final_imgs = [None]*list_nmb_branches[-1]
nmb_blocks_time = len(list_injection_idx_ext)-1
self.tree_final_imgs = [None]*self.list_nmb_branches[-1]
for t_block in range(nmb_blocks_time):
nmb_branches = list_nmb_branches[t_block]
nmb_branches = self.list_nmb_branches[t_block]
for idx_branch in range(nmb_branches):
self.tree_status[t_block][idx_branch] = 'untouched'
if recycle_img1:
@ -386,7 +393,7 @@ class LatentBlending():
list_compute.extend(list_local_stem[::-1])
# setup compute order: start from last leafs (the final transition images) and work way down. what parents do they need?
for idx_leaf in range(1, list_nmb_branches[-1]):
for idx_leaf in range(1, self.list_nmb_branches[-1]):
list_local_stem = []
t_block = nmb_blocks_time - 1
t_block_prev = t_block - 1
@ -414,7 +421,7 @@ class LatentBlending():
return self.tree_final_imgs
# print(f"computing t_block {t_block} idx_branch {idx_branch}")
idx_stop = list_injection_idx_ext[t_block+1]
idx_stop = self.list_injection_idx_ext[t_block+1]
fract_mixing = self.tree_fracts[t_block][idx_branch]
text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing)
self.set_guidance_mid_dampening(fract_mixing)
@ -423,7 +430,7 @@ class LatentBlending():
if fixed_seeds is not None:
if idx_branch == 0:
self.set_seed(fixed_seeds[0])
elif idx_branch == list_nmb_branches[0] -1:
elif idx_branch == self.list_nmb_branches[0] -1:
self.set_seed(fixed_seeds[1])
list_latents = self.run_diffusion(text_embeddings_mix, idx_stop=idx_stop)
else:
@ -434,7 +441,7 @@ class LatentBlending():
latents2 = latents1
else:
latents2 = self.tree_latents[t_block-1][b_parent2][-1]
idx_start = list_injection_idx_ext[t_block]
idx_start = self.list_injection_idx_ext[t_block]
fract_mixing_parental = (fract_mixing - self.tree_fracts[t_block-1][b_parent1]) / (self.tree_fracts[t_block-1][b_parent2] - self.tree_fracts[t_block-1][b_parent1])
latents_for_injection = interpolate_spherical(latents1, latents2, fract_mixing_parental)
list_latents = self.run_diffusion(text_embeddings_mix, latents_for_injection, idx_start=idx_start, idx_stop=idx_stop)
@ -594,6 +601,22 @@ class LatentBlending():
self.seed = seed
self.sdh.seed = seed
def inject_latents(self, list_latents, inject_img1=True, inject_img2=False):
r"""
Injects list of latents into tree structure.
"""
assert inject_img1 != inject_img2, "Either inject into img1 or img2"
assert self.tree_latents is not None, "You need to setup the branching beforehand, run autosetup_branching() or setup_branching() before"
for t_block in range(len(self.list_injection_idx)):
if inject_img1:
self.tree_latents[t_block][0] = list_latents[self.list_injection_idx_ext[t_block]:self.list_injection_idx_ext[t_block+1]]
if inject_img2:
self.tree_latents[t_block][-1] = list_latents[self.list_injection_idx_ext[t_block]:self.list_injection_idx_ext[t_block+1]]
def swap_forward(self):
r"""