diff --git a/diffusers_holder.py b/diffusers_holder.py index 627414a..8a560bb 100644 --- a/diffusers_holder.py +++ b/diffusers_holder.py @@ -152,34 +152,59 @@ class DiffusersHolder(): output_type: "pil" or "np" """ assert output_type in ["pil", "np"] - if self.use_sd_xl: - # make sure the VAE is in float32 mode, as it overflows in float16 - self.pipe.vae.to(dtype=torch.float32) - - use_torch_2_0_or_xformers = isinstance( - self.pipe.vae.decoder.mid_block.attentions[0].processor, - ( - AttnProcessor2_0, - XFormersAttnProcessor, - LoRAXFormersAttnProcessor, - LoRAAttnProcessor2_0, - ), - ) - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - self.pipe.vae.post_quant_conv.to(latents.dtype) - self.pipe.vae.decoder.conv_in.to(latents.dtype) - self.pipe.vae.decoder.mid_block.to(latents.dtype) - else: - latents = latents.float() - + + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.pipe.vae.dtype == torch.float16 and self.pipe.vae.config.force_upcast + + if needs_upcasting: + self.pipe.upcast_vae() + latents = latents.to(next(iter(self.pipe.vae.post_quant_conv.parameters())).dtype) + image = self.pipe.vae.decode(latents / self.pipe.vae.config.scaling_factor, return_dict=False)[0] - image = self.pipe.image_processor.postprocess(image, output_type="pil", do_denormalize=[True] * image.shape[0])[0] - if output_type == "np": - return np.asarray(image) - else: - return image + + # cast back to fp16 if needed + if needs_upcasting: + self.pipe.vae.to(dtype=torch.float16) + + image = self.pipe.image_processor.postprocess(image, output_type=output_type)[0] + + return image + + # if output_type == "np": + # return np.asarray(image) + # else: + # return image + + + # # xxx + # if self.use_sd_xl: + # # make sure the VAE is in float32 mode, as it overflows in float16 + # self.pipe.vae.to(dtype=torch.float32) + + # use_torch_2_0_or_xformers = isinstance( + # self.pipe.vae.decoder.mid_block.attentions[0].processor, + # ( + # AttnProcessor2_0, + # XFormersAttnProcessor, + # LoRAXFormersAttnProcessor, + # LoRAAttnProcessor2_0, + # ), + # ) + # # if xformers or torch_2_0 is used attention block does not need + # # to be in float32 which can save lots of memory + # if use_torch_2_0_or_xformers: + # self.pipe.vae.post_quant_conv.to(latents.dtype) + # self.pipe.vae.decoder.conv_in.to(latents.dtype) + # self.pipe.vae.decoder.mid_block.to(latents.dtype) + # else: + # latents = latents.float() + + # image = self.pipe.vae.decode(latents / self.pipe.vae.config.scaling_factor, return_dict=False)[0] + # image = self.pipe.image_processor.postprocess(image, output_type="pil", do_denormalize=[True] * image.shape[0])[0] + # if output_type == "np": + # return np.asarray(image) + # else: + # return image def prepare_mixing(self, mixing_coeffs, list_latents_mixing): if type(mixing_coeffs) == float: @@ -718,13 +743,27 @@ class DiffusersHolder(): if __name__ == "__main__": from PIL import Image #%% + from diffusers import AutoencoderTiny # pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0" pretrained_model_name_or_path = "stabilityai/sdxl-turbo" pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16) pipe.to('cuda') # xxx + #% + pipe.vae = AutoencoderTiny.from_pretrained('madebyollin/taesdxl', torch_device='cuda', torch_dtype=torch.float16) + pipe.vae = pipe.vae.cuda() #%% + self = DiffusersHolder(pipe) + self.set_num_inference_steps(4) + prompt1 = "Photo of a colorful landscape with a blue sky with clouds" + text_embeddings1 = self.get_text_embedding(prompt1) + latents_start = self.get_noise(seed=420) + latents = self.run_diffusion_sd_xl(text_embeddings1, latents_start, idx_start=0, return_image=False)[-1] + image = self.latent2image(latents) + + + xxxx # # xxx # self.set_dimensions((512, 512)) # self.set_num_inference_steps(4) @@ -773,49 +812,6 @@ if __name__ == "__main__": self.run_diffusion_sd_xl(text_embeddings_mix, latents_start_mixed, idx_start=idx_start, return_image=True) - #%% - fract=0.8 - latentsmix = interpolate_spherical(latents1[-1], latents2[-1], fract) - self.latent2image(latentsmix) - - - - #%% - - """ - - - xxxxx - - # step1: first latents - latents1_step1 = pipe(latents=latents_start, guidance_scale=guidance_scale, prompt_embeds=prompt_embeds1, negative_prompt_embeds=negative_prompt_embeds1, pooled_prompt_embeds=pooled_prompt_embeds1, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds1, output_type='latent', timesteps=timesteps_step1) - - # step2: second latents - img_diffusion1 = pipe(latents=latents1_step1[0], guidance_scale=guidance_scale, prompt_embeds=prompt_embeds1, negative_prompt_embeds=negative_prompt_embeds1, pooled_prompt_embeds=pooled_prompt_embeds1, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds1, timesteps=timesteps_step2) - - - #%% img2 - latents_start = torch.randn((1,4,64//1,64)).half().cuda() - - # step1: first latents - latents2_step1 = pipe(latents=latents_start, guidance_scale=guidance_scale, prompt_embeds=prompt_embeds2, negative_prompt_embeds=negative_prompt_embeds2, pooled_prompt_embeds=pooled_prompt_embeds2, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds2, output_type='latent', timesteps=timesteps_step1) - - # step2: second latents - img_diffusion2 = pipe(latents=latents2_step1[0], guidance_scale=guidance_scale, prompt_embeds=prompt_embeds2, negative_prompt_embeds=negative_prompt_embeds2, pooled_prompt_embeds=pooled_prompt_embeds2, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds2, timesteps=timesteps_step2) - - xxx - - #%% find the middle - - prompt_embeds = prompt_embeds1 #interpolate_spherical(prompt_embeds1, prompt_embeds2, 0.5) - pooled_prompt_embeds = pooled_prompt_embeds1# interpolate_spherical(pooled_prompt_embeds1, pooled_prompt_embeds2, 0.5) - negative_prompt_embeds = negative_prompt_embeds1 - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds1 - - latents1_stepM = interpolate_spherical(latents1_step1[0], latents2_step1[0], 0.5) - - img_diffusionM = pipe(latents=latents1_stepM, guidance_scale=guidance_scale, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, timesteps=timesteps_step2) -"""