From 87ca894694e2bdd89b09d563dac2a4ed82730ccb Mon Sep 17 00:00:00 2001 From: lunar Date: Fri, 9 Dec 2022 14:03:20 +0000 Subject: [PATCH] old python version compat --- stable_diffusion_holder.py | 233 ++++++++++++++++++------------------- 1 file changed, 114 insertions(+), 119 deletions(-) diff --git a/stable_diffusion_holder.py b/stable_diffusion_holder.py index b5f6c8c..d28de6f 100644 --- a/stable_diffusion_holder.py +++ b/stable_diffusion_holder.py @@ -224,59 +224,56 @@ class StableDiffusionHolder: precision_scope = autocast if self.precision == "autocast" else nullcontext generator = torch.Generator(device=self.device).manual_seed(int(self.seed)) - with ( - precision_scope("cuda"), - self.model.ema_scope(), - ): - - if self.guidance_scale != 1.0: - uc = self.model.get_learned_conditioning([""]) - else: - uc = None - shape_latents = [self.C, self.height // self.f, self.width // self.f] - - self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False) - C, H, W = shape_latents - size = (1, C, H, W) - b = size[0] - - latents = torch.randn(size, generator=generator, device=self.device) - - timesteps = self.sampler.ddim_timesteps - - time_range = np.flip(timesteps) - total_steps = timesteps.shape[0] - - # collect latents - list_latents_out = [] - for i, step in enumerate(time_range): - if do_inject_latents: - # Inject latent at right place - if i < idx_start: - continue - elif i == idx_start: - latents = latents_for_injection.clone() + with precision_scope("cuda"): + with self.model.ema_scope(): + if self.guidance_scale != 1.0: + uc = self.model.get_learned_conditioning([""]) + else: + uc = None + shape_latents = [self.C, self.height // self.f, self.width // self.f] + + self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False) + C, H, W = shape_latents + size = (1, C, H, W) + b = size[0] - if i == idx_stop: + latents = torch.randn(size, generator=generator, device=self.device) + + timesteps = self.sampler.ddim_timesteps + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + + # collect latents + list_latents_out = [] + for i, step in enumerate(time_range): + if do_inject_latents: + # Inject latent at right place + if i < idx_start: + continue + elif i == idx_start: + latents = latents_for_injection.clone() + + if i == idx_stop: + return list_latents_out + + # print(f"diffusion iter {i}") + index = total_steps - i - 1 + ts = torch.full((b,), step, device=self.device, dtype=torch.long) + outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False, + quantize_denoised=False, temperature=1.0, + noise_dropout=0.0, score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=self.guidance_scale, + unconditional_conditioning=uc, + dynamic_threshold=None) + latents, pred_x0 = outs + list_latents_out.append(latents.clone()) + + if return_image: + return self.latent2image(latents) + else: return list_latents_out - - # print(f"diffusion iter {i}") - index = total_steps - i - 1 - ts = torch.full((b,), step, device=self.device, dtype=torch.long) - outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False, - quantize_denoised=False, temperature=1.0, - noise_dropout=0.0, score_corrector=None, - corrector_kwargs=None, - unconditional_guidance_scale=self.guidance_scale, - unconditional_conditioning=uc, - dynamic_threshold=None) - latents, pred_x0 = outs - list_latents_out.append(latents.clone()) - - if return_image: - return self.latent2image(latents) - else: - return list_latents_out @torch.no_grad() def run_diffusion_inpaint( @@ -318,77 +315,75 @@ class StableDiffusionHolder: precision_scope = autocast if self.precision == "autocast" else nullcontext generator = torch.Generator(device=self.device).manual_seed(int(self.seed)) - with ( - precision_scope("cuda"), - self.model.ema_scope(), - ): + with precision_scope("cuda"): + with self.model.ema_scope(): + + batch = make_batch_sd(self.image_source, self.mask_image, txt="willbereplaced", device=self.device, num_samples=1) + c = text_embeddings + c_cat = list() + for ck in self.model.concat_keys: + cc = batch[ck].float() + if ck != self.model.masked_image_key: + bchw = [1, 4, self.height // 8, self.width // 8] + cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) + else: + cc = self.model.get_first_stage_encoding(self.model.encode_first_stage(cc)) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + + # cond + cond = {"c_concat": [c_cat], "c_crossattn": [c]} + + # uncond cond + uc_cross = self.model.get_unconditional_conditioning(1, "") + uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]} + + shape_latents = [self.model.channels, self.height // 8, self.width // 8] - batch = make_batch_sd(self.image_source, self.mask_image, txt="willbereplaced", device=self.device, num_samples=1) - c = text_embeddings - c_cat = list() - for ck in self.model.concat_keys: - cc = batch[ck].float() - if ck != self.model.masked_image_key: - bchw = [1, 4, self.height // 8, self.width // 8] - cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) + self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=0., verbose=False) + # sampling + C, H, W = shape_latents + size = (1, C, H, W) + + device = self.model.betas.device + b = size[0] + latents = torch.randn(size, generator=generator, device=device) + + timesteps = self.sampler.ddim_timesteps + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + + # collect latents + list_latents_out = [] + for i, step in enumerate(time_range): + if do_inject_latents: + # Inject latent at right place + if i < idx_start: + continue + elif i == idx_start: + latents = latents_for_injection.clone() + + if i == idx_stop: + return list_latents_out + + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False, + quantize_denoised=False, temperature=1.0, + noise_dropout=0.0, score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=self.guidance_scale, + unconditional_conditioning=uc_full, + dynamic_threshold=None) + latents, pred_x0 = outs + list_latents_out.append(latents.clone()) + + if return_image: + return self.latent2image(latents) else: - cc = self.model.get_first_stage_encoding(self.model.encode_first_stage(cc)) - c_cat.append(cc) - c_cat = torch.cat(c_cat, dim=1) - - # cond - cond = {"c_concat": [c_cat], "c_crossattn": [c]} - - # uncond cond - uc_cross = self.model.get_unconditional_conditioning(1, "") - uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]} - - shape_latents = [self.model.channels, self.height // 8, self.width // 8] - - self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=0., verbose=False) - # sampling - C, H, W = shape_latents - size = (1, C, H, W) - - device = self.model.betas.device - b = size[0] - latents = torch.randn(size, generator=generator, device=device) - - timesteps = self.sampler.ddim_timesteps - - time_range = np.flip(timesteps) - total_steps = timesteps.shape[0] - - # collect latents - list_latents_out = [] - for i, step in enumerate(time_range): - if do_inject_latents: - # Inject latent at right place - if i < idx_start: - continue - elif i == idx_start: - latents = latents_for_injection.clone() - - if i == idx_stop: return list_latents_out - - index = total_steps - i - 1 - ts = torch.full((b,), step, device=device, dtype=torch.long) - - outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False, - quantize_denoised=False, temperature=1.0, - noise_dropout=0.0, score_corrector=None, - corrector_kwargs=None, - unconditional_guidance_scale=self.guidance_scale, - unconditional_conditioning=uc_full, - dynamic_threshold=None) - latents, pred_x0 = outs - list_latents_out.append(latents.clone()) - - if return_image: - return self.latent2image(latents) - else: - return list_latents_out @torch.no_grad()