old python version compat

This commit is contained in:
lunar 2022-12-09 14:03:20 +00:00
parent ab500dd288
commit 87ca894694
1 changed files with 114 additions and 119 deletions

View File

@ -224,59 +224,56 @@ class StableDiffusionHolder:
precision_scope = autocast if self.precision == "autocast" else nullcontext precision_scope = autocast if self.precision == "autocast" else nullcontext
generator = torch.Generator(device=self.device).manual_seed(int(self.seed)) generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
with ( with precision_scope("cuda"):
precision_scope("cuda"), with self.model.ema_scope():
self.model.ema_scope(), if self.guidance_scale != 1.0:
): uc = self.model.get_learned_conditioning([""])
else:
if self.guidance_scale != 1.0: uc = None
uc = self.model.get_learned_conditioning([""]) shape_latents = [self.C, self.height // self.f, self.width // self.f]
else:
uc = None self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
shape_latents = [self.C, self.height // self.f, self.width // self.f] C, H, W = shape_latents
size = (1, C, H, W)
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False) b = size[0]
C, H, W = shape_latents
size = (1, C, H, W)
b = size[0]
latents = torch.randn(size, generator=generator, device=self.device)
timesteps = self.sampler.ddim_timesteps
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
# collect latents
list_latents_out = []
for i, step in enumerate(time_range):
if do_inject_latents:
# Inject latent at right place
if i < idx_start:
continue
elif i == idx_start:
latents = latents_for_injection.clone()
if i == idx_stop: latents = torch.randn(size, generator=generator, device=self.device)
timesteps = self.sampler.ddim_timesteps
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
# collect latents
list_latents_out = []
for i, step in enumerate(time_range):
if do_inject_latents:
# Inject latent at right place
if i < idx_start:
continue
elif i == idx_start:
latents = latents_for_injection.clone()
if i == idx_stop:
return list_latents_out
# print(f"diffusion iter {i}")
index = total_steps - i - 1
ts = torch.full((b,), step, device=self.device, dtype=torch.long)
outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False,
quantize_denoised=False, temperature=1.0,
noise_dropout=0.0, score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=self.guidance_scale,
unconditional_conditioning=uc,
dynamic_threshold=None)
latents, pred_x0 = outs
list_latents_out.append(latents.clone())
if return_image:
return self.latent2image(latents)
else:
return list_latents_out return list_latents_out
# print(f"diffusion iter {i}")
index = total_steps - i - 1
ts = torch.full((b,), step, device=self.device, dtype=torch.long)
outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False,
quantize_denoised=False, temperature=1.0,
noise_dropout=0.0, score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=self.guidance_scale,
unconditional_conditioning=uc,
dynamic_threshold=None)
latents, pred_x0 = outs
list_latents_out.append(latents.clone())
if return_image:
return self.latent2image(latents)
else:
return list_latents_out
@torch.no_grad() @torch.no_grad()
def run_diffusion_inpaint( def run_diffusion_inpaint(
@ -318,77 +315,75 @@ class StableDiffusionHolder:
precision_scope = autocast if self.precision == "autocast" else nullcontext precision_scope = autocast if self.precision == "autocast" else nullcontext
generator = torch.Generator(device=self.device).manual_seed(int(self.seed)) generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
with ( with precision_scope("cuda"):
precision_scope("cuda"), with self.model.ema_scope():
self.model.ema_scope(),
): batch = make_batch_sd(self.image_source, self.mask_image, txt="willbereplaced", device=self.device, num_samples=1)
c = text_embeddings
c_cat = list()
for ck in self.model.concat_keys:
cc = batch[ck].float()
if ck != self.model.masked_image_key:
bchw = [1, 4, self.height // 8, self.width // 8]
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else:
cc = self.model.get_first_stage_encoding(self.model.encode_first_stage(cc))
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
# cond
cond = {"c_concat": [c_cat], "c_crossattn": [c]}
# uncond cond
uc_cross = self.model.get_unconditional_conditioning(1, "")
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
shape_latents = [self.model.channels, self.height // 8, self.width // 8]
batch = make_batch_sd(self.image_source, self.mask_image, txt="willbereplaced", device=self.device, num_samples=1) self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=0., verbose=False)
c = text_embeddings # sampling
c_cat = list() C, H, W = shape_latents
for ck in self.model.concat_keys: size = (1, C, H, W)
cc = batch[ck].float()
if ck != self.model.masked_image_key: device = self.model.betas.device
bchw = [1, 4, self.height // 8, self.width // 8] b = size[0]
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) latents = torch.randn(size, generator=generator, device=device)
timesteps = self.sampler.ddim_timesteps
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
# collect latents
list_latents_out = []
for i, step in enumerate(time_range):
if do_inject_latents:
# Inject latent at right place
if i < idx_start:
continue
elif i == idx_start:
latents = latents_for_injection.clone()
if i == idx_stop:
return list_latents_out
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
quantize_denoised=False, temperature=1.0,
noise_dropout=0.0, score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=self.guidance_scale,
unconditional_conditioning=uc_full,
dynamic_threshold=None)
latents, pred_x0 = outs
list_latents_out.append(latents.clone())
if return_image:
return self.latent2image(latents)
else: else:
cc = self.model.get_first_stage_encoding(self.model.encode_first_stage(cc))
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
# cond
cond = {"c_concat": [c_cat], "c_crossattn": [c]}
# uncond cond
uc_cross = self.model.get_unconditional_conditioning(1, "")
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
shape_latents = [self.model.channels, self.height // 8, self.width // 8]
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=0., verbose=False)
# sampling
C, H, W = shape_latents
size = (1, C, H, W)
device = self.model.betas.device
b = size[0]
latents = torch.randn(size, generator=generator, device=device)
timesteps = self.sampler.ddim_timesteps
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
# collect latents
list_latents_out = []
for i, step in enumerate(time_range):
if do_inject_latents:
# Inject latent at right place
if i < idx_start:
continue
elif i == idx_start:
latents = latents_for_injection.clone()
if i == idx_stop:
return list_latents_out return list_latents_out
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
quantize_denoised=False, temperature=1.0,
noise_dropout=0.0, score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=self.guidance_scale,
unconditional_conditioning=uc_full,
dynamic_threshold=None)
latents, pred_x0 = outs
list_latents_out.append(latents.clone())
if return_image:
return self.latent2image(latents)
else:
return list_latents_out
@torch.no_grad() @torch.no_grad()