masked
This commit is contained in:
parent
c5d88046a0
commit
bc34a83008
|
@ -123,6 +123,7 @@ class LatentBlending():
|
||||||
self.multi_transition_img_first = None
|
self.multi_transition_img_first = None
|
||||||
self.multi_transition_img_last = None
|
self.multi_transition_img_last = None
|
||||||
self.dt_per_diff = 0
|
self.dt_per_diff = 0
|
||||||
|
self.spatial_mask = None
|
||||||
|
|
||||||
self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
|
self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
|
||||||
|
|
||||||
|
@ -277,6 +278,9 @@ class LatentBlending():
|
||||||
self.tree_final_imgs = [self.sdh.latent2image((self.tree_latents[0][-1])), self.sdh.latent2image((self.tree_latents[-1][-1]))]
|
self.tree_final_imgs = [self.sdh.latent2image((self.tree_latents[0][-1])), self.sdh.latent2image((self.tree_latents[-1][-1]))]
|
||||||
self.tree_idx_injection = [0, 0]
|
self.tree_idx_injection = [0, 0]
|
||||||
|
|
||||||
|
# Hard-fix. Apply spatial mask only for list_latents2 but not for transition. WIP...
|
||||||
|
self.spatial_mask = None
|
||||||
|
|
||||||
# Set up branching scheme (dependent on provided compute time)
|
# Set up branching scheme (dependent on provided compute time)
|
||||||
list_idx_injection, list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches)
|
list_idx_injection, list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches)
|
||||||
|
|
||||||
|
@ -296,6 +300,109 @@ class LatentBlending():
|
||||||
return self.tree_final_imgs
|
return self.tree_final_imgs
|
||||||
|
|
||||||
|
|
||||||
|
def compute_latents1(self, return_image=False):
|
||||||
|
r"""
|
||||||
|
Runs a diffusion trajectory for the first image
|
||||||
|
Args:
|
||||||
|
return_image: bool
|
||||||
|
whether to return an image or the list of latents
|
||||||
|
"""
|
||||||
|
print("starting compute_latents1")
|
||||||
|
list_conditionings = self.get_mixed_conditioning(0)
|
||||||
|
t0 = time.time()
|
||||||
|
latents_start = self.get_noise(self.seed1)
|
||||||
|
list_latents1 = self.run_diffusion(
|
||||||
|
list_conditionings,
|
||||||
|
latents_start = latents_start,
|
||||||
|
idx_start = 0
|
||||||
|
)
|
||||||
|
t1 = time.time()
|
||||||
|
self.dt_per_diff = (t1-t0) / self.num_inference_steps
|
||||||
|
self.tree_latents[0] = list_latents1
|
||||||
|
if return_image:
|
||||||
|
return self.sdh.latent2image(list_latents1[-1])
|
||||||
|
else:
|
||||||
|
return list_latents1
|
||||||
|
|
||||||
|
def compute_latents2(self, return_image=False):
|
||||||
|
r"""
|
||||||
|
Runs a diffusion trajectory for the last image, which may be affected by the first image's trajectory.
|
||||||
|
Args:
|
||||||
|
return_image: bool
|
||||||
|
whether to return an image or the list of latents
|
||||||
|
"""
|
||||||
|
print("starting compute_latents2")
|
||||||
|
list_conditionings = self.get_mixed_conditioning(1)
|
||||||
|
latents_start = self.get_noise(self.seed2)
|
||||||
|
# Influence from branch1
|
||||||
|
if self.branch1_influence > 0.0:
|
||||||
|
# Set up the mixing_coeffs
|
||||||
|
idx_mixing_stop = int(round(self.num_inference_steps*self.branch1_max_depth_influence))
|
||||||
|
mixing_coeffs = list(np.linspace(self.branch1_influence, self.branch1_influence*self.branch1_influence_decay, idx_mixing_stop))
|
||||||
|
mixing_coeffs.extend((self.num_inference_steps-idx_mixing_stop)*[0])
|
||||||
|
list_latents_mixing = self.tree_latents[0]
|
||||||
|
list_latents2 = self.run_diffusion(
|
||||||
|
list_conditionings,
|
||||||
|
latents_start = latents_start,
|
||||||
|
idx_start = 0,
|
||||||
|
list_latents_mixing = list_latents_mixing,
|
||||||
|
mixing_coeffs = mixing_coeffs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
list_latents2 = self.run_diffusion(list_conditionings, latents_start)
|
||||||
|
self.tree_latents[-1] = list_latents2
|
||||||
|
|
||||||
|
if return_image:
|
||||||
|
return self.sdh.latent2image(list_latents2[-1])
|
||||||
|
else:
|
||||||
|
return list_latents2
|
||||||
|
|
||||||
|
|
||||||
|
def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection):
|
||||||
|
r"""
|
||||||
|
Runs a diffusion trajectory, using the latents from the respective parents
|
||||||
|
Args:
|
||||||
|
fract_mixing: float
|
||||||
|
the fraction along the transition axis [0, 1]
|
||||||
|
b_parent1: int
|
||||||
|
index of parent1 to be used
|
||||||
|
b_parent2: int
|
||||||
|
index of parent2 to be used
|
||||||
|
idx_injection: int
|
||||||
|
the index in terms of diffusion steps, where the next insertion will start.
|
||||||
|
"""
|
||||||
|
list_conditionings = self.get_mixed_conditioning(fract_mixing)
|
||||||
|
fract_mixing_parental = (fract_mixing - self.tree_fracts[b_parent1]) / (self.tree_fracts[b_parent2] - self.tree_fracts[b_parent1])
|
||||||
|
# idx_reversed = self.num_inference_steps - idx_injection
|
||||||
|
|
||||||
|
list_latents_parental_mix = []
|
||||||
|
for i in range(self.num_inference_steps):
|
||||||
|
latents_p1 = self.tree_latents[b_parent1][i]
|
||||||
|
latents_p2 = self.tree_latents[b_parent2][i]
|
||||||
|
if latents_p1 is None or latents_p2 is None:
|
||||||
|
latents_parental = None
|
||||||
|
else:
|
||||||
|
latents_parental = interpolate_spherical(latents_p1, latents_p2, fract_mixing_parental)
|
||||||
|
list_latents_parental_mix.append(latents_parental)
|
||||||
|
|
||||||
|
idx_mixing_stop = int(round(self.num_inference_steps*self.parental_max_depth_influence))
|
||||||
|
mixing_coeffs = idx_injection*[self.parental_influence]
|
||||||
|
nmb_mixing = idx_mixing_stop - idx_injection
|
||||||
|
if nmb_mixing > 0:
|
||||||
|
mixing_coeffs.extend(list(np.linspace(self.parental_influence, self.parental_influence*self.parental_influence_decay, nmb_mixing)))
|
||||||
|
mixing_coeffs.extend((self.num_inference_steps-len(mixing_coeffs))*[0])
|
||||||
|
|
||||||
|
latents_start = list_latents_parental_mix[idx_injection-1]
|
||||||
|
list_latents = self.run_diffusion(
|
||||||
|
list_conditionings,
|
||||||
|
latents_start = latents_start,
|
||||||
|
idx_start = idx_injection,
|
||||||
|
list_latents_mixing = list_latents_parental_mix,
|
||||||
|
mixing_coeffs = mixing_coeffs
|
||||||
|
)
|
||||||
|
|
||||||
|
return list_latents
|
||||||
|
|
||||||
def get_time_based_branching(self, depth_strength, t_compute_max_allowed=None, nmb_max_branches=None):
|
def get_time_based_branching(self, depth_strength, t_compute_max_allowed=None, nmb_max_branches=None):
|
||||||
r"""
|
r"""
|
||||||
Sets up the branching scheme dependent on the time that is granted for compute.
|
Sets up the branching scheme dependent on the time that is granted for compute.
|
||||||
|
@ -419,110 +526,34 @@ class LatentBlending():
|
||||||
self.tree_fracts.insert(b_parent1+1, fract_mixing)
|
self.tree_fracts.insert(b_parent1+1, fract_mixing)
|
||||||
self.tree_idx_injection.insert(b_parent1+1, idx_injection)
|
self.tree_idx_injection.insert(b_parent1+1, idx_injection)
|
||||||
|
|
||||||
|
|
||||||
def compute_latents1(self, return_image=False):
|
|
||||||
r"""
|
|
||||||
Runs a diffusion trajectory for the first image
|
|
||||||
Args:
|
|
||||||
return_image: bool
|
|
||||||
whether to return an image or the list of latents
|
|
||||||
"""
|
|
||||||
print("starting compute_latents1")
|
|
||||||
list_conditionings = self.get_mixed_conditioning(0)
|
|
||||||
t0 = time.time()
|
|
||||||
latents_start = self.get_noise(self.seed1)
|
|
||||||
list_latents1 = self.run_diffusion(
|
|
||||||
list_conditionings,
|
|
||||||
latents_start = latents_start,
|
|
||||||
idx_start = 0
|
|
||||||
)
|
|
||||||
t1 = time.time()
|
|
||||||
self.dt_per_diff = (t1-t0) / self.num_inference_steps
|
|
||||||
self.tree_latents[0] = list_latents1
|
|
||||||
if return_image:
|
|
||||||
return self.sdh.latent2image(list_latents1[-1])
|
|
||||||
else:
|
|
||||||
return list_latents1
|
|
||||||
|
|
||||||
def compute_latents2(self, return_image=False):
|
def get_spatial_mask_template(self):
|
||||||
r"""
|
shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
|
||||||
Runs a diffusion trajectory for the last image, which may be affected by the first image's trajectory.
|
C, H, W = shape_latents
|
||||||
Args:
|
return np.ones((H, W))
|
||||||
return_image: bool
|
|
||||||
whether to return an image or the list of latents
|
|
||||||
"""
|
|
||||||
print("starting compute_latents2")
|
|
||||||
list_conditionings = self.get_mixed_conditioning(1)
|
|
||||||
latents_start = self.get_noise(self.seed2)
|
|
||||||
# Influence from branch1
|
|
||||||
if self.branch1_influence > 0.0:
|
|
||||||
# Set up the mixing_coeffs
|
|
||||||
idx_mixing_stop = int(round(self.num_inference_steps*self.branch1_max_depth_influence))
|
|
||||||
mixing_coeffs = list(np.linspace(self.branch1_influence, self.branch1_influence*self.branch1_influence_decay, idx_mixing_stop))
|
|
||||||
mixing_coeffs.extend((self.num_inference_steps-idx_mixing_stop)*[0])
|
|
||||||
list_latents_mixing = self.tree_latents[0]
|
|
||||||
list_latents2 = self.run_diffusion(
|
|
||||||
list_conditionings,
|
|
||||||
latents_start = latents_start,
|
|
||||||
idx_start = 0,
|
|
||||||
list_latents_mixing = list_latents_mixing,
|
|
||||||
mixing_coeffs = mixing_coeffs
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
list_latents2 = self.run_diffusion(list_conditionings, latents_start)
|
|
||||||
self.tree_latents[-1] = list_latents2
|
|
||||||
|
|
||||||
if return_image:
|
|
||||||
return self.sdh.latent2image(list_latents2[-1])
|
|
||||||
else:
|
|
||||||
return list_latents2
|
|
||||||
|
|
||||||
|
|
||||||
def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection):
|
|
||||||
r"""
|
|
||||||
Runs a diffusion trajectory, using the latents from the respective parents
|
|
||||||
Args:
|
|
||||||
fract_mixing: float
|
|
||||||
the fraction along the transition axis [0, 1]
|
|
||||||
b_parent1: int
|
|
||||||
index of parent1 to be used
|
|
||||||
b_parent2: int
|
|
||||||
index of parent2 to be used
|
|
||||||
idx_injection: int
|
|
||||||
the index in terms of diffusion steps, where the next insertion will start.
|
|
||||||
"""
|
|
||||||
list_conditionings = self.get_mixed_conditioning(fract_mixing)
|
|
||||||
fract_mixing_parental = (fract_mixing - self.tree_fracts[b_parent1]) / (self.tree_fracts[b_parent2] - self.tree_fracts[b_parent1])
|
|
||||||
# idx_reversed = self.num_inference_steps - idx_injection
|
|
||||||
|
|
||||||
list_latents_parental_mix = []
|
|
||||||
for i in range(self.num_inference_steps):
|
|
||||||
latents_p1 = self.tree_latents[b_parent1][i]
|
|
||||||
latents_p2 = self.tree_latents[b_parent2][i]
|
|
||||||
if latents_p1 is None or latents_p2 is None:
|
|
||||||
latents_parental = None
|
|
||||||
else:
|
|
||||||
latents_parental = interpolate_spherical(latents_p1, latents_p2, fract_mixing_parental)
|
|
||||||
list_latents_parental_mix.append(latents_parental)
|
|
||||||
|
|
||||||
idx_mixing_stop = int(round(self.num_inference_steps*self.parental_max_depth_influence))
|
|
||||||
mixing_coeffs = idx_injection*[self.parental_influence]
|
|
||||||
nmb_mixing = idx_mixing_stop - idx_injection
|
|
||||||
if nmb_mixing > 0:
|
|
||||||
mixing_coeffs.extend(list(np.linspace(self.parental_influence, self.parental_influence*self.parental_influence_decay, nmb_mixing)))
|
|
||||||
mixing_coeffs.extend((self.num_inference_steps-len(mixing_coeffs))*[0])
|
|
||||||
|
|
||||||
latents_start = list_latents_parental_mix[idx_injection-1]
|
|
||||||
list_latents = self.run_diffusion(
|
|
||||||
list_conditionings,
|
|
||||||
latents_start = latents_start,
|
|
||||||
idx_start = idx_injection,
|
|
||||||
list_latents_mixing = list_latents_parental_mix,
|
|
||||||
mixing_coeffs = mixing_coeffs
|
|
||||||
)
|
|
||||||
|
|
||||||
return list_latents
|
|
||||||
|
|
||||||
|
def set_spatial_mask(self, img_mask):
|
||||||
|
r"""
|
||||||
|
Helper function to #FIXME
|
||||||
|
Args:
|
||||||
|
seed: int
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
|
||||||
|
C, H, W = shape_latents
|
||||||
|
img_mask = np.asarray(img_mask)
|
||||||
|
assert len(img_mask.shape) == 2, "Currently, only 2D images are supported as mask"
|
||||||
|
img_mask = np.clip(img_mask, 0, 1)
|
||||||
|
assert img_mask.shape[0] == H, f"Your mask needs to be of dimension {H} x {W}"
|
||||||
|
assert img_mask.shape[1] == W, f"Your mask needs to be of dimension {H} x {W}"
|
||||||
|
spatial_mask = torch.from_numpy(img_mask).to(device=self.device)
|
||||||
|
spatial_mask = torch.unsqueeze(spatial_mask, 0)
|
||||||
|
spatial_mask = spatial_mask.repeat((C,1,1))
|
||||||
|
spatial_mask = torch.unsqueeze(spatial_mask, 0)
|
||||||
|
|
||||||
|
self.spatial_mask = spatial_mask
|
||||||
|
|
||||||
|
|
||||||
def get_noise(self, seed):
|
def get_noise(self, seed):
|
||||||
r"""
|
r"""
|
||||||
|
@ -585,6 +616,7 @@ class LatentBlending():
|
||||||
idx_start = idx_start,
|
idx_start = idx_start,
|
||||||
list_latents_mixing = list_latents_mixing,
|
list_latents_mixing = list_latents_mixing,
|
||||||
mixing_coeffs = mixing_coeffs,
|
mixing_coeffs = mixing_coeffs,
|
||||||
|
spatial_mask = self.spatial_mask,
|
||||||
return_image = return_image,
|
return_image = return_image,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -218,37 +218,6 @@ class StableDiffusionHolder:
|
||||||
if len(self.negative_prompt) > 1:
|
if len(self.negative_prompt) > 1:
|
||||||
self.negative_prompt = [self.negative_prompt[0]]
|
self.negative_prompt = [self.negative_prompt[0]]
|
||||||
|
|
||||||
def init_inpainting(
|
|
||||||
self,
|
|
||||||
image_source: Union[Image.Image, np.ndarray] = None,
|
|
||||||
mask_image: Union[Image.Image, np.ndarray] = None,
|
|
||||||
init_empty: Optional[bool] = False,
|
|
||||||
):
|
|
||||||
r"""
|
|
||||||
Initializes inpainting with a source and maks image.
|
|
||||||
Args:
|
|
||||||
image_source: Union[Image.Image, np.ndarray]
|
|
||||||
Source image onto which the mask will be applied.
|
|
||||||
mask_image: Union[Image.Image, np.ndarray]
|
|
||||||
Mask image, value = 0 will stay untouched, value = 255 subjet to diffusion
|
|
||||||
init_empty: Optional[bool]:
|
|
||||||
Initialize inpainting with an empty image and mask, effectively disabling inpainting,
|
|
||||||
useful for generating a first image for transitions using diffusion.
|
|
||||||
"""
|
|
||||||
if not init_empty:
|
|
||||||
assert image_source is not None, "init_inpainting: you need to provide image_source"
|
|
||||||
assert mask_image is not None, "init_inpainting: you need to provide mask_image"
|
|
||||||
if type(image_source) == np.ndarray:
|
|
||||||
image_source = Image.fromarray(image_source)
|
|
||||||
self.image_source = image_source
|
|
||||||
|
|
||||||
if type(mask_image) == np.ndarray:
|
|
||||||
mask_image = Image.fromarray(mask_image)
|
|
||||||
self.mask_image = mask_image
|
|
||||||
else:
|
|
||||||
self.mask_image = self.mask_empty
|
|
||||||
self.image_source = self.image_empty
|
|
||||||
|
|
||||||
|
|
||||||
def get_text_embedding(self, prompt):
|
def get_text_embedding(self, prompt):
|
||||||
c = self.model.get_learned_conditioning(prompt)
|
c = self.model.get_learned_conditioning(prompt)
|
||||||
|
@ -282,6 +251,7 @@ class StableDiffusionHolder:
|
||||||
idx_start: int = 0,
|
idx_start: int = 0,
|
||||||
list_latents_mixing = None,
|
list_latents_mixing = None,
|
||||||
mixing_coeffs = 0.0,
|
mixing_coeffs = 0.0,
|
||||||
|
spatial_mask = None,
|
||||||
return_image: Optional[bool] = False,
|
return_image: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
|
@ -295,7 +265,7 @@ class StableDiffusionHolder:
|
||||||
idx_start: int
|
idx_start: int
|
||||||
Index of the diffusion process start and where the latents_for_injection are injected
|
Index of the diffusion process start and where the latents_for_injection are injected
|
||||||
mixing_coeff:
|
mixing_coeff:
|
||||||
# FIXME
|
# FIXME spatial_mask
|
||||||
return_image: Optional[bool]
|
return_image: Optional[bool]
|
||||||
Optionally return image directly
|
Optionally return image directly
|
||||||
|
|
||||||
|
@ -313,6 +283,7 @@ class StableDiffusionHolder:
|
||||||
if np.sum(list_mixing_coeffs) > 0:
|
if np.sum(list_mixing_coeffs) > 0:
|
||||||
assert len(list_latents_mixing) == self.num_inference_steps
|
assert len(list_latents_mixing) == self.num_inference_steps
|
||||||
|
|
||||||
|
|
||||||
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
||||||
|
|
||||||
with precision_scope("cuda"):
|
with precision_scope("cuda"):
|
||||||
|
@ -345,6 +316,10 @@ class StableDiffusionHolder:
|
||||||
if i > 0 and list_mixing_coeffs[i]>0:
|
if i > 0 and list_mixing_coeffs[i]>0:
|
||||||
latents_mixtarget = list_latents_mixing[i-1].clone()
|
latents_mixtarget = list_latents_mixing[i-1].clone()
|
||||||
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
|
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
|
||||||
|
|
||||||
|
if spatial_mask is not None and list_latents_mixing is not None:
|
||||||
|
latents = interpolate_spherical(latents, list_latents_mixing[i-1], 1-spatial_mask)
|
||||||
|
# latents[:,:,-15:,:] = latents_mixtarget[:,:,-15:,:]
|
||||||
|
|
||||||
index = total_steps - i - 1
|
index = total_steps - i - 1
|
||||||
ts = torch.full((1,), step, device=self.device, dtype=torch.long)
|
ts = torch.full((1,), step, device=self.device, dtype=torch.long)
|
||||||
|
|
Loading…
Reference in New Issue