old python version compat
This commit is contained in:
parent
ab500dd288
commit
87ca894694
|
@ -224,59 +224,56 @@ class StableDiffusionHolder:
|
||||||
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
||||||
generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
|
generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
|
||||||
|
|
||||||
with (
|
with precision_scope("cuda"):
|
||||||
precision_scope("cuda"),
|
with self.model.ema_scope():
|
||||||
self.model.ema_scope(),
|
if self.guidance_scale != 1.0:
|
||||||
):
|
uc = self.model.get_learned_conditioning([""])
|
||||||
|
else:
|
||||||
if self.guidance_scale != 1.0:
|
uc = None
|
||||||
uc = self.model.get_learned_conditioning([""])
|
shape_latents = [self.C, self.height // self.f, self.width // self.f]
|
||||||
else:
|
|
||||||
uc = None
|
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
|
||||||
shape_latents = [self.C, self.height // self.f, self.width // self.f]
|
C, H, W = shape_latents
|
||||||
|
size = (1, C, H, W)
|
||||||
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
|
b = size[0]
|
||||||
C, H, W = shape_latents
|
|
||||||
size = (1, C, H, W)
|
|
||||||
b = size[0]
|
|
||||||
|
|
||||||
latents = torch.randn(size, generator=generator, device=self.device)
|
|
||||||
|
|
||||||
timesteps = self.sampler.ddim_timesteps
|
|
||||||
|
|
||||||
time_range = np.flip(timesteps)
|
|
||||||
total_steps = timesteps.shape[0]
|
|
||||||
|
|
||||||
# collect latents
|
|
||||||
list_latents_out = []
|
|
||||||
for i, step in enumerate(time_range):
|
|
||||||
if do_inject_latents:
|
|
||||||
# Inject latent at right place
|
|
||||||
if i < idx_start:
|
|
||||||
continue
|
|
||||||
elif i == idx_start:
|
|
||||||
latents = latents_for_injection.clone()
|
|
||||||
|
|
||||||
if i == idx_stop:
|
latents = torch.randn(size, generator=generator, device=self.device)
|
||||||
|
|
||||||
|
timesteps = self.sampler.ddim_timesteps
|
||||||
|
|
||||||
|
time_range = np.flip(timesteps)
|
||||||
|
total_steps = timesteps.shape[0]
|
||||||
|
|
||||||
|
# collect latents
|
||||||
|
list_latents_out = []
|
||||||
|
for i, step in enumerate(time_range):
|
||||||
|
if do_inject_latents:
|
||||||
|
# Inject latent at right place
|
||||||
|
if i < idx_start:
|
||||||
|
continue
|
||||||
|
elif i == idx_start:
|
||||||
|
latents = latents_for_injection.clone()
|
||||||
|
|
||||||
|
if i == idx_stop:
|
||||||
|
return list_latents_out
|
||||||
|
|
||||||
|
# print(f"diffusion iter {i}")
|
||||||
|
index = total_steps - i - 1
|
||||||
|
ts = torch.full((b,), step, device=self.device, dtype=torch.long)
|
||||||
|
outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False,
|
||||||
|
quantize_denoised=False, temperature=1.0,
|
||||||
|
noise_dropout=0.0, score_corrector=None,
|
||||||
|
corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=self.guidance_scale,
|
||||||
|
unconditional_conditioning=uc,
|
||||||
|
dynamic_threshold=None)
|
||||||
|
latents, pred_x0 = outs
|
||||||
|
list_latents_out.append(latents.clone())
|
||||||
|
|
||||||
|
if return_image:
|
||||||
|
return self.latent2image(latents)
|
||||||
|
else:
|
||||||
return list_latents_out
|
return list_latents_out
|
||||||
|
|
||||||
# print(f"diffusion iter {i}")
|
|
||||||
index = total_steps - i - 1
|
|
||||||
ts = torch.full((b,), step, device=self.device, dtype=torch.long)
|
|
||||||
outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False,
|
|
||||||
quantize_denoised=False, temperature=1.0,
|
|
||||||
noise_dropout=0.0, score_corrector=None,
|
|
||||||
corrector_kwargs=None,
|
|
||||||
unconditional_guidance_scale=self.guidance_scale,
|
|
||||||
unconditional_conditioning=uc,
|
|
||||||
dynamic_threshold=None)
|
|
||||||
latents, pred_x0 = outs
|
|
||||||
list_latents_out.append(latents.clone())
|
|
||||||
|
|
||||||
if return_image:
|
|
||||||
return self.latent2image(latents)
|
|
||||||
else:
|
|
||||||
return list_latents_out
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def run_diffusion_inpaint(
|
def run_diffusion_inpaint(
|
||||||
|
@ -318,77 +315,75 @@ class StableDiffusionHolder:
|
||||||
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
||||||
generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
|
generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
|
||||||
|
|
||||||
with (
|
with precision_scope("cuda"):
|
||||||
precision_scope("cuda"),
|
with self.model.ema_scope():
|
||||||
self.model.ema_scope(),
|
|
||||||
):
|
batch = make_batch_sd(self.image_source, self.mask_image, txt="willbereplaced", device=self.device, num_samples=1)
|
||||||
|
c = text_embeddings
|
||||||
|
c_cat = list()
|
||||||
|
for ck in self.model.concat_keys:
|
||||||
|
cc = batch[ck].float()
|
||||||
|
if ck != self.model.masked_image_key:
|
||||||
|
bchw = [1, 4, self.height // 8, self.width // 8]
|
||||||
|
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
||||||
|
else:
|
||||||
|
cc = self.model.get_first_stage_encoding(self.model.encode_first_stage(cc))
|
||||||
|
c_cat.append(cc)
|
||||||
|
c_cat = torch.cat(c_cat, dim=1)
|
||||||
|
|
||||||
|
# cond
|
||||||
|
cond = {"c_concat": [c_cat], "c_crossattn": [c]}
|
||||||
|
|
||||||
|
# uncond cond
|
||||||
|
uc_cross = self.model.get_unconditional_conditioning(1, "")
|
||||||
|
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
||||||
|
|
||||||
|
shape_latents = [self.model.channels, self.height // 8, self.width // 8]
|
||||||
|
|
||||||
batch = make_batch_sd(self.image_source, self.mask_image, txt="willbereplaced", device=self.device, num_samples=1)
|
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=0., verbose=False)
|
||||||
c = text_embeddings
|
# sampling
|
||||||
c_cat = list()
|
C, H, W = shape_latents
|
||||||
for ck in self.model.concat_keys:
|
size = (1, C, H, W)
|
||||||
cc = batch[ck].float()
|
|
||||||
if ck != self.model.masked_image_key:
|
device = self.model.betas.device
|
||||||
bchw = [1, 4, self.height // 8, self.width // 8]
|
b = size[0]
|
||||||
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
latents = torch.randn(size, generator=generator, device=device)
|
||||||
|
|
||||||
|
timesteps = self.sampler.ddim_timesteps
|
||||||
|
|
||||||
|
time_range = np.flip(timesteps)
|
||||||
|
total_steps = timesteps.shape[0]
|
||||||
|
|
||||||
|
# collect latents
|
||||||
|
list_latents_out = []
|
||||||
|
for i, step in enumerate(time_range):
|
||||||
|
if do_inject_latents:
|
||||||
|
# Inject latent at right place
|
||||||
|
if i < idx_start:
|
||||||
|
continue
|
||||||
|
elif i == idx_start:
|
||||||
|
latents = latents_for_injection.clone()
|
||||||
|
|
||||||
|
if i == idx_stop:
|
||||||
|
return list_latents_out
|
||||||
|
|
||||||
|
index = total_steps - i - 1
|
||||||
|
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
|
||||||
|
quantize_denoised=False, temperature=1.0,
|
||||||
|
noise_dropout=0.0, score_corrector=None,
|
||||||
|
corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=self.guidance_scale,
|
||||||
|
unconditional_conditioning=uc_full,
|
||||||
|
dynamic_threshold=None)
|
||||||
|
latents, pred_x0 = outs
|
||||||
|
list_latents_out.append(latents.clone())
|
||||||
|
|
||||||
|
if return_image:
|
||||||
|
return self.latent2image(latents)
|
||||||
else:
|
else:
|
||||||
cc = self.model.get_first_stage_encoding(self.model.encode_first_stage(cc))
|
|
||||||
c_cat.append(cc)
|
|
||||||
c_cat = torch.cat(c_cat, dim=1)
|
|
||||||
|
|
||||||
# cond
|
|
||||||
cond = {"c_concat": [c_cat], "c_crossattn": [c]}
|
|
||||||
|
|
||||||
# uncond cond
|
|
||||||
uc_cross = self.model.get_unconditional_conditioning(1, "")
|
|
||||||
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
|
||||||
|
|
||||||
shape_latents = [self.model.channels, self.height // 8, self.width // 8]
|
|
||||||
|
|
||||||
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=0., verbose=False)
|
|
||||||
# sampling
|
|
||||||
C, H, W = shape_latents
|
|
||||||
size = (1, C, H, W)
|
|
||||||
|
|
||||||
device = self.model.betas.device
|
|
||||||
b = size[0]
|
|
||||||
latents = torch.randn(size, generator=generator, device=device)
|
|
||||||
|
|
||||||
timesteps = self.sampler.ddim_timesteps
|
|
||||||
|
|
||||||
time_range = np.flip(timesteps)
|
|
||||||
total_steps = timesteps.shape[0]
|
|
||||||
|
|
||||||
# collect latents
|
|
||||||
list_latents_out = []
|
|
||||||
for i, step in enumerate(time_range):
|
|
||||||
if do_inject_latents:
|
|
||||||
# Inject latent at right place
|
|
||||||
if i < idx_start:
|
|
||||||
continue
|
|
||||||
elif i == idx_start:
|
|
||||||
latents = latents_for_injection.clone()
|
|
||||||
|
|
||||||
if i == idx_stop:
|
|
||||||
return list_latents_out
|
return list_latents_out
|
||||||
|
|
||||||
index = total_steps - i - 1
|
|
||||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
|
||||||
|
|
||||||
outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
|
|
||||||
quantize_denoised=False, temperature=1.0,
|
|
||||||
noise_dropout=0.0, score_corrector=None,
|
|
||||||
corrector_kwargs=None,
|
|
||||||
unconditional_guidance_scale=self.guidance_scale,
|
|
||||||
unconditional_conditioning=uc_full,
|
|
||||||
dynamic_threshold=None)
|
|
||||||
latents, pred_x0 = outs
|
|
||||||
list_latents_out.append(latents.clone())
|
|
||||||
|
|
||||||
if return_image:
|
|
||||||
return self.latent2image(latents)
|
|
||||||
else:
|
|
||||||
return list_latents_out
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
Loading…
Reference in New Issue