old python version compat
This commit is contained in:
parent
ab500dd288
commit
87ca894694
|
@ -224,59 +224,56 @@ class StableDiffusionHolder:
|
|||
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
||||
generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
|
||||
|
||||
with (
|
||||
precision_scope("cuda"),
|
||||
self.model.ema_scope(),
|
||||
):
|
||||
|
||||
if self.guidance_scale != 1.0:
|
||||
uc = self.model.get_learned_conditioning([""])
|
||||
else:
|
||||
uc = None
|
||||
shape_latents = [self.C, self.height // self.f, self.width // self.f]
|
||||
|
||||
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
|
||||
C, H, W = shape_latents
|
||||
size = (1, C, H, W)
|
||||
b = size[0]
|
||||
|
||||
latents = torch.randn(size, generator=generator, device=self.device)
|
||||
|
||||
timesteps = self.sampler.ddim_timesteps
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
|
||||
# collect latents
|
||||
list_latents_out = []
|
||||
for i, step in enumerate(time_range):
|
||||
if do_inject_latents:
|
||||
# Inject latent at right place
|
||||
if i < idx_start:
|
||||
continue
|
||||
elif i == idx_start:
|
||||
latents = latents_for_injection.clone()
|
||||
with precision_scope("cuda"):
|
||||
with self.model.ema_scope():
|
||||
if self.guidance_scale != 1.0:
|
||||
uc = self.model.get_learned_conditioning([""])
|
||||
else:
|
||||
uc = None
|
||||
shape_latents = [self.C, self.height // self.f, self.width // self.f]
|
||||
|
||||
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
|
||||
C, H, W = shape_latents
|
||||
size = (1, C, H, W)
|
||||
b = size[0]
|
||||
|
||||
if i == idx_stop:
|
||||
latents = torch.randn(size, generator=generator, device=self.device)
|
||||
|
||||
timesteps = self.sampler.ddim_timesteps
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
|
||||
# collect latents
|
||||
list_latents_out = []
|
||||
for i, step in enumerate(time_range):
|
||||
if do_inject_latents:
|
||||
# Inject latent at right place
|
||||
if i < idx_start:
|
||||
continue
|
||||
elif i == idx_start:
|
||||
latents = latents_for_injection.clone()
|
||||
|
||||
if i == idx_stop:
|
||||
return list_latents_out
|
||||
|
||||
# print(f"diffusion iter {i}")
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=self.device, dtype=torch.long)
|
||||
outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False,
|
||||
quantize_denoised=False, temperature=1.0,
|
||||
noise_dropout=0.0, score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=self.guidance_scale,
|
||||
unconditional_conditioning=uc,
|
||||
dynamic_threshold=None)
|
||||
latents, pred_x0 = outs
|
||||
list_latents_out.append(latents.clone())
|
||||
|
||||
if return_image:
|
||||
return self.latent2image(latents)
|
||||
else:
|
||||
return list_latents_out
|
||||
|
||||
# print(f"diffusion iter {i}")
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=self.device, dtype=torch.long)
|
||||
outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False,
|
||||
quantize_denoised=False, temperature=1.0,
|
||||
noise_dropout=0.0, score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=self.guidance_scale,
|
||||
unconditional_conditioning=uc,
|
||||
dynamic_threshold=None)
|
||||
latents, pred_x0 = outs
|
||||
list_latents_out.append(latents.clone())
|
||||
|
||||
if return_image:
|
||||
return self.latent2image(latents)
|
||||
else:
|
||||
return list_latents_out
|
||||
|
||||
@torch.no_grad()
|
||||
def run_diffusion_inpaint(
|
||||
|
@ -318,77 +315,75 @@ class StableDiffusionHolder:
|
|||
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
||||
generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
|
||||
|
||||
with (
|
||||
precision_scope("cuda"),
|
||||
self.model.ema_scope(),
|
||||
):
|
||||
with precision_scope("cuda"):
|
||||
with self.model.ema_scope():
|
||||
|
||||
batch = make_batch_sd(self.image_source, self.mask_image, txt="willbereplaced", device=self.device, num_samples=1)
|
||||
c = text_embeddings
|
||||
c_cat = list()
|
||||
for ck in self.model.concat_keys:
|
||||
cc = batch[ck].float()
|
||||
if ck != self.model.masked_image_key:
|
||||
bchw = [1, 4, self.height // 8, self.width // 8]
|
||||
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
||||
else:
|
||||
cc = self.model.get_first_stage_encoding(self.model.encode_first_stage(cc))
|
||||
c_cat.append(cc)
|
||||
c_cat = torch.cat(c_cat, dim=1)
|
||||
|
||||
# cond
|
||||
cond = {"c_concat": [c_cat], "c_crossattn": [c]}
|
||||
|
||||
# uncond cond
|
||||
uc_cross = self.model.get_unconditional_conditioning(1, "")
|
||||
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
||||
|
||||
shape_latents = [self.model.channels, self.height // 8, self.width // 8]
|
||||
|
||||
batch = make_batch_sd(self.image_source, self.mask_image, txt="willbereplaced", device=self.device, num_samples=1)
|
||||
c = text_embeddings
|
||||
c_cat = list()
|
||||
for ck in self.model.concat_keys:
|
||||
cc = batch[ck].float()
|
||||
if ck != self.model.masked_image_key:
|
||||
bchw = [1, 4, self.height // 8, self.width // 8]
|
||||
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
||||
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=0., verbose=False)
|
||||
# sampling
|
||||
C, H, W = shape_latents
|
||||
size = (1, C, H, W)
|
||||
|
||||
device = self.model.betas.device
|
||||
b = size[0]
|
||||
latents = torch.randn(size, generator=generator, device=device)
|
||||
|
||||
timesteps = self.sampler.ddim_timesteps
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
|
||||
# collect latents
|
||||
list_latents_out = []
|
||||
for i, step in enumerate(time_range):
|
||||
if do_inject_latents:
|
||||
# Inject latent at right place
|
||||
if i < idx_start:
|
||||
continue
|
||||
elif i == idx_start:
|
||||
latents = latents_for_injection.clone()
|
||||
|
||||
if i == idx_stop:
|
||||
return list_latents_out
|
||||
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
|
||||
outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
|
||||
quantize_denoised=False, temperature=1.0,
|
||||
noise_dropout=0.0, score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=self.guidance_scale,
|
||||
unconditional_conditioning=uc_full,
|
||||
dynamic_threshold=None)
|
||||
latents, pred_x0 = outs
|
||||
list_latents_out.append(latents.clone())
|
||||
|
||||
if return_image:
|
||||
return self.latent2image(latents)
|
||||
else:
|
||||
cc = self.model.get_first_stage_encoding(self.model.encode_first_stage(cc))
|
||||
c_cat.append(cc)
|
||||
c_cat = torch.cat(c_cat, dim=1)
|
||||
|
||||
# cond
|
||||
cond = {"c_concat": [c_cat], "c_crossattn": [c]}
|
||||
|
||||
# uncond cond
|
||||
uc_cross = self.model.get_unconditional_conditioning(1, "")
|
||||
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
||||
|
||||
shape_latents = [self.model.channels, self.height // 8, self.width // 8]
|
||||
|
||||
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=0., verbose=False)
|
||||
# sampling
|
||||
C, H, W = shape_latents
|
||||
size = (1, C, H, W)
|
||||
|
||||
device = self.model.betas.device
|
||||
b = size[0]
|
||||
latents = torch.randn(size, generator=generator, device=device)
|
||||
|
||||
timesteps = self.sampler.ddim_timesteps
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
|
||||
# collect latents
|
||||
list_latents_out = []
|
||||
for i, step in enumerate(time_range):
|
||||
if do_inject_latents:
|
||||
# Inject latent at right place
|
||||
if i < idx_start:
|
||||
continue
|
||||
elif i == idx_start:
|
||||
latents = latents_for_injection.clone()
|
||||
|
||||
if i == idx_stop:
|
||||
return list_latents_out
|
||||
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
|
||||
outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
|
||||
quantize_denoised=False, temperature=1.0,
|
||||
noise_dropout=0.0, score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=self.guidance_scale,
|
||||
unconditional_conditioning=uc_full,
|
||||
dynamic_threshold=None)
|
||||
latents, pred_x0 = outs
|
||||
list_latents_out.append(latents.clone())
|
||||
|
||||
if return_image:
|
||||
return self.latent2image(latents)
|
||||
else:
|
||||
return list_latents_out
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
Loading…
Reference in New Issue