latent injection for inpainting
This commit is contained in:
parent
58cadd23d5
commit
994c77e924
|
@ -52,21 +52,25 @@ prompt1 = "photo of a futuristic alien temple in a desert, mystic, glowing, orga
|
||||||
lb.set_prompt1(prompt1)
|
lb.set_prompt1(prompt1)
|
||||||
lb.init_inpainting(init_empty=True)
|
lb.init_inpainting(init_empty=True)
|
||||||
lb.set_seed(seed0)
|
lb.set_seed(seed0)
|
||||||
image_source = lb.run_diffusion(lb.text_embedding1, return_image=True)
|
list_latents = lb.run_diffusion(lb.text_embedding1)
|
||||||
|
image_source = lb.sdh.latent2image(list_latents[-1])
|
||||||
|
|
||||||
mask_image = 255*np.ones([512,512], dtype=np.uint8)
|
mask_image = 255*np.ones([512,512], dtype=np.uint8)
|
||||||
mask_image[340:420, 170:280, ] = 0
|
mask_image[340:420, 170:280, ] = 0
|
||||||
mask_image = Image.fromarray(mask_image)
|
mask_image = Image.fromarray(mask_image)
|
||||||
|
|
||||||
|
|
||||||
#%% Next let's set up all parameters
|
#%% Next let's set up all parameters
|
||||||
fixed_seeds = [seed0, 280335986]
|
lb.inject_latents(list_latents, inject_img1=True)
|
||||||
|
|
||||||
|
fixed_seeds = [seed0, 6579436]
|
||||||
|
|
||||||
prompt1 = "photo of a futuristic alien temple in a desert, mystic, glowing, organic, intricate, sci-fi movie, mesmerizing, scary"
|
prompt1 = "photo of a futuristic alien temple in a desert, mystic, glowing, organic, intricate, sci-fi movie, mesmerizing, scary"
|
||||||
prompt2 = "aerial photo of a futuristic alien temple in a coastal area, waves clashing"
|
prompt2 = "aerial photo of a futuristic alien temple in a blue coastal area, the sun is shining with a bright light"
|
||||||
lb.set_prompt1(prompt1)
|
lb.set_prompt1(prompt1)
|
||||||
lb.set_prompt2(prompt2)
|
lb.set_prompt2(prompt2)
|
||||||
lb.init_inpainting(image_source, mask_image)
|
lb.init_inpainting(image_source, mask_image)
|
||||||
imgs_transition = lb.run_transition(fixed_seeds=fixed_seeds)
|
imgs_transition = lb.run_transition(recycle_img1=True, fixed_seeds=fixed_seeds)
|
||||||
|
|
||||||
# let's get more cheap frames via linear interpolation
|
# let's get more cheap frames via linear interpolation
|
||||||
duration_transition = 12
|
duration_transition = 12
|
||||||
|
|
|
@ -69,11 +69,12 @@ class LatentBlending():
|
||||||
thus high values can give rough transitions. Values around 2 should be fine.
|
thus high values can give rough transitions. Values around 2 should be fine.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
assert guidance_scale_mid_damper>0 and guidance_scale_mid_damper<=1.0, f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}"
|
||||||
|
|
||||||
self.sdh = sdh
|
self.sdh = sdh
|
||||||
self.device = self.sdh.device
|
self.device = self.sdh.device
|
||||||
self.width = self.sdh.width
|
self.width = self.sdh.width
|
||||||
self.height = self.sdh.height
|
self.height = self.sdh.height
|
||||||
assert guidance_scale_mid_damper>0 and guidance_scale_mid_damper<=1.0, f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}"
|
|
||||||
self.guidance_scale_mid_damper = guidance_scale_mid_damper
|
self.guidance_scale_mid_damper = guidance_scale_mid_damper
|
||||||
self.mid_compression_scaler = mid_compression_scaler
|
self.mid_compression_scaler = mid_compression_scaler
|
||||||
self.seed = 420 # Run self.set_seed or fixed_seeds argument in run_transition
|
self.seed = 420 # Run self.set_seed or fixed_seeds argument in run_transition
|
||||||
|
@ -81,9 +82,9 @@ class LatentBlending():
|
||||||
# Initialize vars
|
# Initialize vars
|
||||||
self.prompt1 = ""
|
self.prompt1 = ""
|
||||||
self.prompt2 = ""
|
self.prompt2 = ""
|
||||||
self.tree_latents = []
|
self.tree_latents = None
|
||||||
self.tree_fracts = []
|
self.tree_fracts = None
|
||||||
self.tree_status = []
|
self.tree_status = None
|
||||||
self.tree_final_imgs = []
|
self.tree_final_imgs = []
|
||||||
self.list_nmb_branches_prev = []
|
self.list_nmb_branches_prev = []
|
||||||
self.list_injection_idx_prev = []
|
self.list_injection_idx_prev = []
|
||||||
|
@ -100,8 +101,7 @@ class LatentBlending():
|
||||||
|
|
||||||
def init_mode(self, mode='standard'):
|
def init_mode(self, mode='standard'):
|
||||||
r"""
|
r"""
|
||||||
Automatically sets the mode of this class, depending on the supplied pipeline.
|
Sets the mode of this class, either inpaint of standard.
|
||||||
FIXME XXX
|
|
||||||
"""
|
"""
|
||||||
if mode == 'inpaint':
|
if mode == 'inpaint':
|
||||||
self.sdh.image_source = None
|
self.sdh.image_source = None
|
||||||
|
@ -268,16 +268,46 @@ class LatentBlending():
|
||||||
assert max(list_injection_idx) < num_inference_steps, "Decrease the injection index or strength"
|
assert max(list_injection_idx) < num_inference_steps, "Decrease the injection index or strength"
|
||||||
assert len(list_injection_idx) == len(list_nmb_branches), "Need to have same length"
|
assert len(list_injection_idx) == len(list_nmb_branches), "Need to have same length"
|
||||||
assert max(list_injection_idx) < num_inference_steps,"Injection index cannot happen after last diffusion step! Decrease list_injection_idx or list_injection_strength[-1]"
|
assert max(list_injection_idx) < num_inference_steps,"Injection index cannot happen after last diffusion step! Decrease list_injection_idx or list_injection_strength[-1]"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Auto inits
|
||||||
|
list_injection_idx_ext = list_injection_idx[:]
|
||||||
|
list_injection_idx_ext.append(num_inference_steps)
|
||||||
|
|
||||||
|
# If injection at depth 0 not specified, we will start out with 2 branches
|
||||||
|
if list_injection_idx_ext[0] != 0:
|
||||||
|
list_injection_idx_ext.insert(0,0)
|
||||||
|
list_nmb_branches.insert(0,2)
|
||||||
|
assert list_nmb_branches[0] == 2, "Need to start with 2 branches. set list_nmb_branches[0]=2"
|
||||||
|
|
||||||
|
|
||||||
# Set attributes
|
# Set attributes
|
||||||
self.num_inference_steps = num_inference_steps
|
self.num_inference_steps = num_inference_steps
|
||||||
self.sdh.num_inference_steps = num_inference_steps
|
self.sdh.num_inference_steps = num_inference_steps
|
||||||
self.list_nmb_branches = list_nmb_branches
|
self.list_nmb_branches = list_nmb_branches
|
||||||
self.list_injection_idx = list_injection_idx
|
self.list_injection_idx = list_injection_idx
|
||||||
|
self.list_injection_idx_ext = list_injection_idx_ext
|
||||||
|
|
||||||
|
self.init_tree_struct()
|
||||||
|
|
||||||
|
def init_tree_struct(self):
|
||||||
|
r"""
|
||||||
|
Initializes tree variables for holding latents etc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.tree_latents = []
|
||||||
|
self.tree_fracts = []
|
||||||
|
self.tree_status = []
|
||||||
|
self.tree_final_imgs_timing = [0]*self.list_nmb_branches[-1]
|
||||||
|
|
||||||
|
nmb_blocks_time = len(self.list_injection_idx_ext)-1
|
||||||
|
for t_block in range(nmb_blocks_time):
|
||||||
|
nmb_branches = self.list_nmb_branches[t_block]
|
||||||
|
list_fract_mixing_current = get_spacing(nmb_branches, self.mid_compression_scaler)
|
||||||
|
self.tree_fracts.append(list_fract_mixing_current)
|
||||||
|
self.tree_latents.append([None]*nmb_branches)
|
||||||
|
self.tree_status.append(['untouched']*nmb_branches)
|
||||||
|
|
||||||
def run_transition(
|
def run_transition(
|
||||||
self,
|
self,
|
||||||
recycle_img1: Optional[bool] = False,
|
recycle_img1: Optional[bool] = False,
|
||||||
|
@ -313,57 +343,34 @@ class LatentBlending():
|
||||||
# Ensure correct num_inference_steps in holder
|
# Ensure correct num_inference_steps in holder
|
||||||
self.sdh.num_inference_steps = self.num_inference_steps
|
self.sdh.num_inference_steps = self.num_inference_steps
|
||||||
|
|
||||||
# Recycling? There are requirements
|
# # Recycling? There are requirements
|
||||||
if recycle_img1 or recycle_img2:
|
# if recycle_img1 or recycle_img2:
|
||||||
if self.list_nmb_branches_prev == []:
|
# # if self.list_nmb_branches_prev == []:
|
||||||
print("Warning. You want to recycle but there is nothing here. Disabling recycling.")
|
# # print("Warning. You want to recycle but there is nothing here. Disabling recycling.")
|
||||||
recycle_img1 = False
|
# # recycle_img1 = False
|
||||||
recycle_img2 = False
|
# # recycle_img2 = False
|
||||||
elif self.list_nmb_branches_prev != self.list_nmb_branches:
|
# if self.list_nmb_branches_prev != self.list_nmb_branches:
|
||||||
print("Warning. Cannot change list_nmb_branches if recycling latent. Disabling recycling.")
|
# print("Warning. Cannot change list_nmb_branches if recycling latent. Disabling recycling.")
|
||||||
recycle_img1 = False
|
# recycle_img1 = False
|
||||||
recycle_img2 = False
|
# recycle_img2 = False
|
||||||
elif self.list_injection_idx_prev != self.list_injection_idx:
|
# elif self.list_injection_idx_prev != self.list_injection_idx:
|
||||||
print("Warning. Cannot change list_nmb_branches if recycling latent. Disabling recycling.")
|
# print("Warning. Cannot change list_nmb_branches if recycling latent. Disabling recycling.")
|
||||||
recycle_img1 = False
|
# recycle_img1 = False
|
||||||
recycle_img2 = False
|
# recycle_img2 = False
|
||||||
|
|
||||||
# Make a backup for future reference
|
# Make a backup for future reference
|
||||||
self.list_nmb_branches_prev = self.list_nmb_branches[:]
|
self.list_nmb_branches_prev = self.list_nmb_branches[:]
|
||||||
self.list_injection_idx_prev = self.list_injection_idx[:]
|
self.list_injection_idx_prev = self.list_injection_idx[:]
|
||||||
|
|
||||||
# Auto inits
|
|
||||||
list_injection_idx_ext = self.list_injection_idx[:]
|
|
||||||
list_nmb_branches = self.list_nmb_branches[:]
|
|
||||||
list_injection_idx_ext.append(self.num_inference_steps)
|
|
||||||
|
|
||||||
# If injection at depth 0 not specified, we will start out with 2 branches
|
|
||||||
if list_injection_idx_ext[0] != 0:
|
|
||||||
list_injection_idx_ext.insert(0,0)
|
|
||||||
list_nmb_branches.insert(0,2)
|
|
||||||
assert list_nmb_branches[0] == 2, "Need to start with 2 branches. set list_nmb_branches[0]=2"
|
|
||||||
|
|
||||||
# Pre-define entire branching tree structures
|
# Pre-define entire branching tree structures
|
||||||
|
self.tree_final_imgs = [None]*self.list_nmb_branches[-1]
|
||||||
|
nmb_blocks_time = len(self.list_injection_idx_ext)-1
|
||||||
if not recycle_img1 and not recycle_img2:
|
if not recycle_img1 and not recycle_img2:
|
||||||
self.tree_latents = []
|
self.init_tree_struct()
|
||||||
self.tree_fracts = []
|
|
||||||
self.tree_status = []
|
|
||||||
self.tree_final_imgs = [None]*list_nmb_branches[-1]
|
|
||||||
self.tree_final_imgs_timing = [0]*list_nmb_branches[-1]
|
|
||||||
|
|
||||||
nmb_blocks_time = len(list_injection_idx_ext)-1
|
|
||||||
for t_block in range(nmb_blocks_time):
|
|
||||||
nmb_branches = list_nmb_branches[t_block]
|
|
||||||
# list_fract_mixing_current = np.linspace(0, 1, nmb_branches)
|
|
||||||
list_fract_mixing_current = get_spacing(nmb_branches, self.mid_compression_scaler)
|
|
||||||
self.tree_fracts.append(list_fract_mixing_current)
|
|
||||||
self.tree_latents.append([None]*nmb_branches)
|
|
||||||
self.tree_status.append(['untouched']*nmb_branches)
|
|
||||||
else:
|
else:
|
||||||
self.tree_final_imgs = [None]*list_nmb_branches[-1]
|
self.tree_final_imgs = [None]*self.list_nmb_branches[-1]
|
||||||
nmb_blocks_time = len(list_injection_idx_ext)-1
|
|
||||||
for t_block in range(nmb_blocks_time):
|
for t_block in range(nmb_blocks_time):
|
||||||
nmb_branches = list_nmb_branches[t_block]
|
nmb_branches = self.list_nmb_branches[t_block]
|
||||||
for idx_branch in range(nmb_branches):
|
for idx_branch in range(nmb_branches):
|
||||||
self.tree_status[t_block][idx_branch] = 'untouched'
|
self.tree_status[t_block][idx_branch] = 'untouched'
|
||||||
if recycle_img1:
|
if recycle_img1:
|
||||||
|
@ -386,7 +393,7 @@ class LatentBlending():
|
||||||
list_compute.extend(list_local_stem[::-1])
|
list_compute.extend(list_local_stem[::-1])
|
||||||
|
|
||||||
# setup compute order: start from last leafs (the final transition images) and work way down. what parents do they need?
|
# setup compute order: start from last leafs (the final transition images) and work way down. what parents do they need?
|
||||||
for idx_leaf in range(1, list_nmb_branches[-1]):
|
for idx_leaf in range(1, self.list_nmb_branches[-1]):
|
||||||
list_local_stem = []
|
list_local_stem = []
|
||||||
t_block = nmb_blocks_time - 1
|
t_block = nmb_blocks_time - 1
|
||||||
t_block_prev = t_block - 1
|
t_block_prev = t_block - 1
|
||||||
|
@ -414,7 +421,7 @@ class LatentBlending():
|
||||||
return self.tree_final_imgs
|
return self.tree_final_imgs
|
||||||
|
|
||||||
# print(f"computing t_block {t_block} idx_branch {idx_branch}")
|
# print(f"computing t_block {t_block} idx_branch {idx_branch}")
|
||||||
idx_stop = list_injection_idx_ext[t_block+1]
|
idx_stop = self.list_injection_idx_ext[t_block+1]
|
||||||
fract_mixing = self.tree_fracts[t_block][idx_branch]
|
fract_mixing = self.tree_fracts[t_block][idx_branch]
|
||||||
text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing)
|
text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing)
|
||||||
self.set_guidance_mid_dampening(fract_mixing)
|
self.set_guidance_mid_dampening(fract_mixing)
|
||||||
|
@ -423,7 +430,7 @@ class LatentBlending():
|
||||||
if fixed_seeds is not None:
|
if fixed_seeds is not None:
|
||||||
if idx_branch == 0:
|
if idx_branch == 0:
|
||||||
self.set_seed(fixed_seeds[0])
|
self.set_seed(fixed_seeds[0])
|
||||||
elif idx_branch == list_nmb_branches[0] -1:
|
elif idx_branch == self.list_nmb_branches[0] -1:
|
||||||
self.set_seed(fixed_seeds[1])
|
self.set_seed(fixed_seeds[1])
|
||||||
list_latents = self.run_diffusion(text_embeddings_mix, idx_stop=idx_stop)
|
list_latents = self.run_diffusion(text_embeddings_mix, idx_stop=idx_stop)
|
||||||
else:
|
else:
|
||||||
|
@ -434,7 +441,7 @@ class LatentBlending():
|
||||||
latents2 = latents1
|
latents2 = latents1
|
||||||
else:
|
else:
|
||||||
latents2 = self.tree_latents[t_block-1][b_parent2][-1]
|
latents2 = self.tree_latents[t_block-1][b_parent2][-1]
|
||||||
idx_start = list_injection_idx_ext[t_block]
|
idx_start = self.list_injection_idx_ext[t_block]
|
||||||
fract_mixing_parental = (fract_mixing - self.tree_fracts[t_block-1][b_parent1]) / (self.tree_fracts[t_block-1][b_parent2] - self.tree_fracts[t_block-1][b_parent1])
|
fract_mixing_parental = (fract_mixing - self.tree_fracts[t_block-1][b_parent1]) / (self.tree_fracts[t_block-1][b_parent2] - self.tree_fracts[t_block-1][b_parent1])
|
||||||
latents_for_injection = interpolate_spherical(latents1, latents2, fract_mixing_parental)
|
latents_for_injection = interpolate_spherical(latents1, latents2, fract_mixing_parental)
|
||||||
list_latents = self.run_diffusion(text_embeddings_mix, latents_for_injection, idx_start=idx_start, idx_stop=idx_stop)
|
list_latents = self.run_diffusion(text_embeddings_mix, latents_for_injection, idx_start=idx_start, idx_stop=idx_stop)
|
||||||
|
@ -594,6 +601,22 @@ class LatentBlending():
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.sdh.seed = seed
|
self.sdh.seed = seed
|
||||||
|
|
||||||
|
def inject_latents(self, list_latents, inject_img1=True, inject_img2=False):
|
||||||
|
r"""
|
||||||
|
Injects list of latents into tree structure.
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert inject_img1 != inject_img2, "Either inject into img1 or img2"
|
||||||
|
assert self.tree_latents is not None, "You need to setup the branching beforehand, run autosetup_branching() or setup_branching() before"
|
||||||
|
|
||||||
|
for t_block in range(len(self.list_injection_idx)):
|
||||||
|
if inject_img1:
|
||||||
|
self.tree_latents[t_block][0] = list_latents[self.list_injection_idx_ext[t_block]:self.list_injection_idx_ext[t_block+1]]
|
||||||
|
if inject_img2:
|
||||||
|
self.tree_latents[t_block][-1] = list_latents[self.list_injection_idx_ext[t_block]:self.list_injection_idx_ext[t_block+1]]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def swap_forward(self):
|
def swap_forward(self):
|
||||||
r"""
|
r"""
|
||||||
|
|
Loading…
Reference in New Issue