fixed latent2image for tinyvae

This commit is contained in:
Johannes Stelzer 2024-01-07 16:17:42 +01:00
parent e9ece062a4
commit 95870823fc
1 changed files with 66 additions and 70 deletions

View File

@ -152,35 +152,60 @@ class DiffusersHolder():
output_type: "pil" or "np"
"""
assert output_type in ["pil", "np"]
if self.use_sd_xl:
# make sure the VAE is in float32 mode, as it overflows in float16
self.pipe.vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance(
self.pipe.vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
self.pipe.vae.post_quant_conv.to(latents.dtype)
self.pipe.vae.decoder.conv_in.to(latents.dtype)
self.pipe.vae.decoder.mid_block.to(latents.dtype)
else:
latents = latents.float()
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.pipe.vae.dtype == torch.float16 and self.pipe.vae.config.force_upcast
if needs_upcasting:
self.pipe.upcast_vae()
latents = latents.to(next(iter(self.pipe.vae.post_quant_conv.parameters())).dtype)
image = self.pipe.vae.decode(latents / self.pipe.vae.config.scaling_factor, return_dict=False)[0]
image = self.pipe.image_processor.postprocess(image, output_type="pil", do_denormalize=[True] * image.shape[0])[0]
if output_type == "np":
return np.asarray(image)
else:
# cast back to fp16 if needed
if needs_upcasting:
self.pipe.vae.to(dtype=torch.float16)
image = self.pipe.image_processor.postprocess(image, output_type=output_type)[0]
return image
# if output_type == "np":
# return np.asarray(image)
# else:
# return image
# # xxx
# if self.use_sd_xl:
# # make sure the VAE is in float32 mode, as it overflows in float16
# self.pipe.vae.to(dtype=torch.float32)
# use_torch_2_0_or_xformers = isinstance(
# self.pipe.vae.decoder.mid_block.attentions[0].processor,
# (
# AttnProcessor2_0,
# XFormersAttnProcessor,
# LoRAXFormersAttnProcessor,
# LoRAAttnProcessor2_0,
# ),
# )
# # if xformers or torch_2_0 is used attention block does not need
# # to be in float32 which can save lots of memory
# if use_torch_2_0_or_xformers:
# self.pipe.vae.post_quant_conv.to(latents.dtype)
# self.pipe.vae.decoder.conv_in.to(latents.dtype)
# self.pipe.vae.decoder.mid_block.to(latents.dtype)
# else:
# latents = latents.float()
# image = self.pipe.vae.decode(latents / self.pipe.vae.config.scaling_factor, return_dict=False)[0]
# image = self.pipe.image_processor.postprocess(image, output_type="pil", do_denormalize=[True] * image.shape[0])[0]
# if output_type == "np":
# return np.asarray(image)
# else:
# return image
def prepare_mixing(self, mixing_coeffs, list_latents_mixing):
if type(mixing_coeffs) == float:
list_mixing_coeffs = (1 + self.num_inference_steps) * [mixing_coeffs]
@ -718,13 +743,27 @@ class DiffusersHolder():
if __name__ == "__main__":
from PIL import Image
#%%
from diffusers import AutoencoderTiny
# pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
pretrained_model_name_or_path = "stabilityai/sdxl-turbo"
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16)
pipe.to('cuda') # xxx
#%
pipe.vae = AutoencoderTiny.from_pretrained('madebyollin/taesdxl', torch_device='cuda', torch_dtype=torch.float16)
pipe.vae = pipe.vae.cuda()
#%%
self = DiffusersHolder(pipe)
self.set_num_inference_steps(4)
prompt1 = "Photo of a colorful landscape with a blue sky with clouds"
text_embeddings1 = self.get_text_embedding(prompt1)
latents_start = self.get_noise(seed=420)
latents = self.run_diffusion_sd_xl(text_embeddings1, latents_start, idx_start=0, return_image=False)[-1]
image = self.latent2image(latents)
xxxx
# # xxx
# self.set_dimensions((512, 512))
# self.set_num_inference_steps(4)
@ -773,49 +812,6 @@ if __name__ == "__main__":
self.run_diffusion_sd_xl(text_embeddings_mix, latents_start_mixed, idx_start=idx_start, return_image=True)
#%%
fract=0.8
latentsmix = interpolate_spherical(latents1[-1], latents2[-1], fract)
self.latent2image(latentsmix)
#%%
"""
xxxxx
# step1: first latents
latents1_step1 = pipe(latents=latents_start, guidance_scale=guidance_scale, prompt_embeds=prompt_embeds1, negative_prompt_embeds=negative_prompt_embeds1, pooled_prompt_embeds=pooled_prompt_embeds1, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds1, output_type='latent', timesteps=timesteps_step1)
# step2: second latents
img_diffusion1 = pipe(latents=latents1_step1[0], guidance_scale=guidance_scale, prompt_embeds=prompt_embeds1, negative_prompt_embeds=negative_prompt_embeds1, pooled_prompt_embeds=pooled_prompt_embeds1, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds1, timesteps=timesteps_step2)
#%% img2
latents_start = torch.randn((1,4,64//1,64)).half().cuda()
# step1: first latents
latents2_step1 = pipe(latents=latents_start, guidance_scale=guidance_scale, prompt_embeds=prompt_embeds2, negative_prompt_embeds=negative_prompt_embeds2, pooled_prompt_embeds=pooled_prompt_embeds2, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds2, output_type='latent', timesteps=timesteps_step1)
# step2: second latents
img_diffusion2 = pipe(latents=latents2_step1[0], guidance_scale=guidance_scale, prompt_embeds=prompt_embeds2, negative_prompt_embeds=negative_prompt_embeds2, pooled_prompt_embeds=pooled_prompt_embeds2, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds2, timesteps=timesteps_step2)
xxx
#%% find the middle
prompt_embeds = prompt_embeds1 #interpolate_spherical(prompt_embeds1, prompt_embeds2, 0.5)
pooled_prompt_embeds = pooled_prompt_embeds1# interpolate_spherical(pooled_prompt_embeds1, pooled_prompt_embeds2, 0.5)
negative_prompt_embeds = negative_prompt_embeds1
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds1
latents1_stepM = interpolate_spherical(latents1_step1[0], latents2_step1[0], 0.5)
img_diffusionM = pipe(latents=latents1_stepM, guidance_scale=guidance_scale, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, timesteps=timesteps_step2)
"""