masked
This commit is contained in:
parent
c5d88046a0
commit
bc34a83008
|
@ -123,6 +123,7 @@ class LatentBlending():
|
|||
self.multi_transition_img_first = None
|
||||
self.multi_transition_img_last = None
|
||||
self.dt_per_diff = 0
|
||||
self.spatial_mask = None
|
||||
|
||||
self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
|
||||
|
||||
|
@ -277,6 +278,9 @@ class LatentBlending():
|
|||
self.tree_final_imgs = [self.sdh.latent2image((self.tree_latents[0][-1])), self.sdh.latent2image((self.tree_latents[-1][-1]))]
|
||||
self.tree_idx_injection = [0, 0]
|
||||
|
||||
# Hard-fix. Apply spatial mask only for list_latents2 but not for transition. WIP...
|
||||
self.spatial_mask = None
|
||||
|
||||
# Set up branching scheme (dependent on provided compute time)
|
||||
list_idx_injection, list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches)
|
||||
|
||||
|
@ -296,6 +300,109 @@ class LatentBlending():
|
|||
return self.tree_final_imgs
|
||||
|
||||
|
||||
def compute_latents1(self, return_image=False):
|
||||
r"""
|
||||
Runs a diffusion trajectory for the first image
|
||||
Args:
|
||||
return_image: bool
|
||||
whether to return an image or the list of latents
|
||||
"""
|
||||
print("starting compute_latents1")
|
||||
list_conditionings = self.get_mixed_conditioning(0)
|
||||
t0 = time.time()
|
||||
latents_start = self.get_noise(self.seed1)
|
||||
list_latents1 = self.run_diffusion(
|
||||
list_conditionings,
|
||||
latents_start = latents_start,
|
||||
idx_start = 0
|
||||
)
|
||||
t1 = time.time()
|
||||
self.dt_per_diff = (t1-t0) / self.num_inference_steps
|
||||
self.tree_latents[0] = list_latents1
|
||||
if return_image:
|
||||
return self.sdh.latent2image(list_latents1[-1])
|
||||
else:
|
||||
return list_latents1
|
||||
|
||||
def compute_latents2(self, return_image=False):
|
||||
r"""
|
||||
Runs a diffusion trajectory for the last image, which may be affected by the first image's trajectory.
|
||||
Args:
|
||||
return_image: bool
|
||||
whether to return an image or the list of latents
|
||||
"""
|
||||
print("starting compute_latents2")
|
||||
list_conditionings = self.get_mixed_conditioning(1)
|
||||
latents_start = self.get_noise(self.seed2)
|
||||
# Influence from branch1
|
||||
if self.branch1_influence > 0.0:
|
||||
# Set up the mixing_coeffs
|
||||
idx_mixing_stop = int(round(self.num_inference_steps*self.branch1_max_depth_influence))
|
||||
mixing_coeffs = list(np.linspace(self.branch1_influence, self.branch1_influence*self.branch1_influence_decay, idx_mixing_stop))
|
||||
mixing_coeffs.extend((self.num_inference_steps-idx_mixing_stop)*[0])
|
||||
list_latents_mixing = self.tree_latents[0]
|
||||
list_latents2 = self.run_diffusion(
|
||||
list_conditionings,
|
||||
latents_start = latents_start,
|
||||
idx_start = 0,
|
||||
list_latents_mixing = list_latents_mixing,
|
||||
mixing_coeffs = mixing_coeffs
|
||||
)
|
||||
else:
|
||||
list_latents2 = self.run_diffusion(list_conditionings, latents_start)
|
||||
self.tree_latents[-1] = list_latents2
|
||||
|
||||
if return_image:
|
||||
return self.sdh.latent2image(list_latents2[-1])
|
||||
else:
|
||||
return list_latents2
|
||||
|
||||
|
||||
def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection):
|
||||
r"""
|
||||
Runs a diffusion trajectory, using the latents from the respective parents
|
||||
Args:
|
||||
fract_mixing: float
|
||||
the fraction along the transition axis [0, 1]
|
||||
b_parent1: int
|
||||
index of parent1 to be used
|
||||
b_parent2: int
|
||||
index of parent2 to be used
|
||||
idx_injection: int
|
||||
the index in terms of diffusion steps, where the next insertion will start.
|
||||
"""
|
||||
list_conditionings = self.get_mixed_conditioning(fract_mixing)
|
||||
fract_mixing_parental = (fract_mixing - self.tree_fracts[b_parent1]) / (self.tree_fracts[b_parent2] - self.tree_fracts[b_parent1])
|
||||
# idx_reversed = self.num_inference_steps - idx_injection
|
||||
|
||||
list_latents_parental_mix = []
|
||||
for i in range(self.num_inference_steps):
|
||||
latents_p1 = self.tree_latents[b_parent1][i]
|
||||
latents_p2 = self.tree_latents[b_parent2][i]
|
||||
if latents_p1 is None or latents_p2 is None:
|
||||
latents_parental = None
|
||||
else:
|
||||
latents_parental = interpolate_spherical(latents_p1, latents_p2, fract_mixing_parental)
|
||||
list_latents_parental_mix.append(latents_parental)
|
||||
|
||||
idx_mixing_stop = int(round(self.num_inference_steps*self.parental_max_depth_influence))
|
||||
mixing_coeffs = idx_injection*[self.parental_influence]
|
||||
nmb_mixing = idx_mixing_stop - idx_injection
|
||||
if nmb_mixing > 0:
|
||||
mixing_coeffs.extend(list(np.linspace(self.parental_influence, self.parental_influence*self.parental_influence_decay, nmb_mixing)))
|
||||
mixing_coeffs.extend((self.num_inference_steps-len(mixing_coeffs))*[0])
|
||||
|
||||
latents_start = list_latents_parental_mix[idx_injection-1]
|
||||
list_latents = self.run_diffusion(
|
||||
list_conditionings,
|
||||
latents_start = latents_start,
|
||||
idx_start = idx_injection,
|
||||
list_latents_mixing = list_latents_parental_mix,
|
||||
mixing_coeffs = mixing_coeffs
|
||||
)
|
||||
|
||||
return list_latents
|
||||
|
||||
def get_time_based_branching(self, depth_strength, t_compute_max_allowed=None, nmb_max_branches=None):
|
||||
r"""
|
||||
Sets up the branching scheme dependent on the time that is granted for compute.
|
||||
|
@ -420,108 +527,32 @@ class LatentBlending():
|
|||
self.tree_idx_injection.insert(b_parent1+1, idx_injection)
|
||||
|
||||
|
||||
def compute_latents1(self, return_image=False):
|
||||
def get_spatial_mask_template(self):
|
||||
shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
|
||||
C, H, W = shape_latents
|
||||
return np.ones((H, W))
|
||||
|
||||
def set_spatial_mask(self, img_mask):
|
||||
r"""
|
||||
Runs a diffusion trajectory for the first image
|
||||
Helper function to #FIXME
|
||||
Args:
|
||||
return_image: bool
|
||||
whether to return an image or the list of latents
|
||||
seed: int
|
||||
|
||||
"""
|
||||
print("starting compute_latents1")
|
||||
list_conditionings = self.get_mixed_conditioning(0)
|
||||
t0 = time.time()
|
||||
latents_start = self.get_noise(self.seed1)
|
||||
list_latents1 = self.run_diffusion(
|
||||
list_conditionings,
|
||||
latents_start = latents_start,
|
||||
idx_start = 0
|
||||
)
|
||||
t1 = time.time()
|
||||
self.dt_per_diff = (t1-t0) / self.num_inference_steps
|
||||
self.tree_latents[0] = list_latents1
|
||||
if return_image:
|
||||
return self.sdh.latent2image(list_latents1[-1])
|
||||
else:
|
||||
return list_latents1
|
||||
|
||||
def compute_latents2(self, return_image=False):
|
||||
r"""
|
||||
Runs a diffusion trajectory for the last image, which may be affected by the first image's trajectory.
|
||||
Args:
|
||||
return_image: bool
|
||||
whether to return an image or the list of latents
|
||||
"""
|
||||
print("starting compute_latents2")
|
||||
list_conditionings = self.get_mixed_conditioning(1)
|
||||
latents_start = self.get_noise(self.seed2)
|
||||
# Influence from branch1
|
||||
if self.branch1_influence > 0.0:
|
||||
# Set up the mixing_coeffs
|
||||
idx_mixing_stop = int(round(self.num_inference_steps*self.branch1_max_depth_influence))
|
||||
mixing_coeffs = list(np.linspace(self.branch1_influence, self.branch1_influence*self.branch1_influence_decay, idx_mixing_stop))
|
||||
mixing_coeffs.extend((self.num_inference_steps-idx_mixing_stop)*[0])
|
||||
list_latents_mixing = self.tree_latents[0]
|
||||
list_latents2 = self.run_diffusion(
|
||||
list_conditionings,
|
||||
latents_start = latents_start,
|
||||
idx_start = 0,
|
||||
list_latents_mixing = list_latents_mixing,
|
||||
mixing_coeffs = mixing_coeffs
|
||||
)
|
||||
else:
|
||||
list_latents2 = self.run_diffusion(list_conditionings, latents_start)
|
||||
self.tree_latents[-1] = list_latents2
|
||||
shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
|
||||
C, H, W = shape_latents
|
||||
img_mask = np.asarray(img_mask)
|
||||
assert len(img_mask.shape) == 2, "Currently, only 2D images are supported as mask"
|
||||
img_mask = np.clip(img_mask, 0, 1)
|
||||
assert img_mask.shape[0] == H, f"Your mask needs to be of dimension {H} x {W}"
|
||||
assert img_mask.shape[1] == W, f"Your mask needs to be of dimension {H} x {W}"
|
||||
spatial_mask = torch.from_numpy(img_mask).to(device=self.device)
|
||||
spatial_mask = torch.unsqueeze(spatial_mask, 0)
|
||||
spatial_mask = spatial_mask.repeat((C,1,1))
|
||||
spatial_mask = torch.unsqueeze(spatial_mask, 0)
|
||||
|
||||
if return_image:
|
||||
return self.sdh.latent2image(list_latents2[-1])
|
||||
else:
|
||||
return list_latents2
|
||||
|
||||
|
||||
def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection):
|
||||
r"""
|
||||
Runs a diffusion trajectory, using the latents from the respective parents
|
||||
Args:
|
||||
fract_mixing: float
|
||||
the fraction along the transition axis [0, 1]
|
||||
b_parent1: int
|
||||
index of parent1 to be used
|
||||
b_parent2: int
|
||||
index of parent2 to be used
|
||||
idx_injection: int
|
||||
the index in terms of diffusion steps, where the next insertion will start.
|
||||
"""
|
||||
list_conditionings = self.get_mixed_conditioning(fract_mixing)
|
||||
fract_mixing_parental = (fract_mixing - self.tree_fracts[b_parent1]) / (self.tree_fracts[b_parent2] - self.tree_fracts[b_parent1])
|
||||
# idx_reversed = self.num_inference_steps - idx_injection
|
||||
|
||||
list_latents_parental_mix = []
|
||||
for i in range(self.num_inference_steps):
|
||||
latents_p1 = self.tree_latents[b_parent1][i]
|
||||
latents_p2 = self.tree_latents[b_parent2][i]
|
||||
if latents_p1 is None or latents_p2 is None:
|
||||
latents_parental = None
|
||||
else:
|
||||
latents_parental = interpolate_spherical(latents_p1, latents_p2, fract_mixing_parental)
|
||||
list_latents_parental_mix.append(latents_parental)
|
||||
|
||||
idx_mixing_stop = int(round(self.num_inference_steps*self.parental_max_depth_influence))
|
||||
mixing_coeffs = idx_injection*[self.parental_influence]
|
||||
nmb_mixing = idx_mixing_stop - idx_injection
|
||||
if nmb_mixing > 0:
|
||||
mixing_coeffs.extend(list(np.linspace(self.parental_influence, self.parental_influence*self.parental_influence_decay, nmb_mixing)))
|
||||
mixing_coeffs.extend((self.num_inference_steps-len(mixing_coeffs))*[0])
|
||||
|
||||
latents_start = list_latents_parental_mix[idx_injection-1]
|
||||
list_latents = self.run_diffusion(
|
||||
list_conditionings,
|
||||
latents_start = latents_start,
|
||||
idx_start = idx_injection,
|
||||
list_latents_mixing = list_latents_parental_mix,
|
||||
mixing_coeffs = mixing_coeffs
|
||||
)
|
||||
|
||||
return list_latents
|
||||
self.spatial_mask = spatial_mask
|
||||
|
||||
|
||||
def get_noise(self, seed):
|
||||
|
@ -585,6 +616,7 @@ class LatentBlending():
|
|||
idx_start = idx_start,
|
||||
list_latents_mixing = list_latents_mixing,
|
||||
mixing_coeffs = mixing_coeffs,
|
||||
spatial_mask = self.spatial_mask,
|
||||
return_image = return_image,
|
||||
)
|
||||
|
||||
|
|
|
@ -218,37 +218,6 @@ class StableDiffusionHolder:
|
|||
if len(self.negative_prompt) > 1:
|
||||
self.negative_prompt = [self.negative_prompt[0]]
|
||||
|
||||
def init_inpainting(
|
||||
self,
|
||||
image_source: Union[Image.Image, np.ndarray] = None,
|
||||
mask_image: Union[Image.Image, np.ndarray] = None,
|
||||
init_empty: Optional[bool] = False,
|
||||
):
|
||||
r"""
|
||||
Initializes inpainting with a source and maks image.
|
||||
Args:
|
||||
image_source: Union[Image.Image, np.ndarray]
|
||||
Source image onto which the mask will be applied.
|
||||
mask_image: Union[Image.Image, np.ndarray]
|
||||
Mask image, value = 0 will stay untouched, value = 255 subjet to diffusion
|
||||
init_empty: Optional[bool]:
|
||||
Initialize inpainting with an empty image and mask, effectively disabling inpainting,
|
||||
useful for generating a first image for transitions using diffusion.
|
||||
"""
|
||||
if not init_empty:
|
||||
assert image_source is not None, "init_inpainting: you need to provide image_source"
|
||||
assert mask_image is not None, "init_inpainting: you need to provide mask_image"
|
||||
if type(image_source) == np.ndarray:
|
||||
image_source = Image.fromarray(image_source)
|
||||
self.image_source = image_source
|
||||
|
||||
if type(mask_image) == np.ndarray:
|
||||
mask_image = Image.fromarray(mask_image)
|
||||
self.mask_image = mask_image
|
||||
else:
|
||||
self.mask_image = self.mask_empty
|
||||
self.image_source = self.image_empty
|
||||
|
||||
|
||||
def get_text_embedding(self, prompt):
|
||||
c = self.model.get_learned_conditioning(prompt)
|
||||
|
@ -282,6 +251,7 @@ class StableDiffusionHolder:
|
|||
idx_start: int = 0,
|
||||
list_latents_mixing = None,
|
||||
mixing_coeffs = 0.0,
|
||||
spatial_mask = None,
|
||||
return_image: Optional[bool] = False,
|
||||
):
|
||||
r"""
|
||||
|
@ -295,7 +265,7 @@ class StableDiffusionHolder:
|
|||
idx_start: int
|
||||
Index of the diffusion process start and where the latents_for_injection are injected
|
||||
mixing_coeff:
|
||||
# FIXME
|
||||
# FIXME spatial_mask
|
||||
return_image: Optional[bool]
|
||||
Optionally return image directly
|
||||
|
||||
|
@ -313,6 +283,7 @@ class StableDiffusionHolder:
|
|||
if np.sum(list_mixing_coeffs) > 0:
|
||||
assert len(list_latents_mixing) == self.num_inference_steps
|
||||
|
||||
|
||||
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
||||
|
||||
with precision_scope("cuda"):
|
||||
|
@ -346,6 +317,10 @@ class StableDiffusionHolder:
|
|||
latents_mixtarget = list_latents_mixing[i-1].clone()
|
||||
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
|
||||
|
||||
if spatial_mask is not None and list_latents_mixing is not None:
|
||||
latents = interpolate_spherical(latents, list_latents_mixing[i-1], 1-spatial_mask)
|
||||
# latents[:,:,-15:,:] = latents_mixtarget[:,:,-15:,:]
|
||||
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((1,), step, device=self.device, dtype=torch.long)
|
||||
outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False,
|
||||
|
|
Loading…
Reference in New Issue