fixed latent2image for tinyvae
This commit is contained in:
parent
e9ece062a4
commit
95870823fc
|
@ -152,34 +152,59 @@ class DiffusersHolder():
|
|||
output_type: "pil" or "np"
|
||||
"""
|
||||
assert output_type in ["pil", "np"]
|
||||
if self.use_sd_xl:
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
self.pipe.vae.to(dtype=torch.float32)
|
||||
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
self.pipe.vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
self.pipe.vae.post_quant_conv.to(latents.dtype)
|
||||
self.pipe.vae.decoder.conv_in.to(latents.dtype)
|
||||
self.pipe.vae.decoder.mid_block.to(latents.dtype)
|
||||
else:
|
||||
latents = latents.float()
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.pipe.vae.dtype == torch.float16 and self.pipe.vae.config.force_upcast
|
||||
|
||||
if needs_upcasting:
|
||||
self.pipe.upcast_vae()
|
||||
latents = latents.to(next(iter(self.pipe.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
image = self.pipe.vae.decode(latents / self.pipe.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image = self.pipe.image_processor.postprocess(image, output_type="pil", do_denormalize=[True] * image.shape[0])[0]
|
||||
if output_type == "np":
|
||||
return np.asarray(image)
|
||||
else:
|
||||
return image
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.pipe.vae.to(dtype=torch.float16)
|
||||
|
||||
image = self.pipe.image_processor.postprocess(image, output_type=output_type)[0]
|
||||
|
||||
return image
|
||||
|
||||
# if output_type == "np":
|
||||
# return np.asarray(image)
|
||||
# else:
|
||||
# return image
|
||||
|
||||
|
||||
# # xxx
|
||||
# if self.use_sd_xl:
|
||||
# # make sure the VAE is in float32 mode, as it overflows in float16
|
||||
# self.pipe.vae.to(dtype=torch.float32)
|
||||
|
||||
# use_torch_2_0_or_xformers = isinstance(
|
||||
# self.pipe.vae.decoder.mid_block.attentions[0].processor,
|
||||
# (
|
||||
# AttnProcessor2_0,
|
||||
# XFormersAttnProcessor,
|
||||
# LoRAXFormersAttnProcessor,
|
||||
# LoRAAttnProcessor2_0,
|
||||
# ),
|
||||
# )
|
||||
# # if xformers or torch_2_0 is used attention block does not need
|
||||
# # to be in float32 which can save lots of memory
|
||||
# if use_torch_2_0_or_xformers:
|
||||
# self.pipe.vae.post_quant_conv.to(latents.dtype)
|
||||
# self.pipe.vae.decoder.conv_in.to(latents.dtype)
|
||||
# self.pipe.vae.decoder.mid_block.to(latents.dtype)
|
||||
# else:
|
||||
# latents = latents.float()
|
||||
|
||||
# image = self.pipe.vae.decode(latents / self.pipe.vae.config.scaling_factor, return_dict=False)[0]
|
||||
# image = self.pipe.image_processor.postprocess(image, output_type="pil", do_denormalize=[True] * image.shape[0])[0]
|
||||
# if output_type == "np":
|
||||
# return np.asarray(image)
|
||||
# else:
|
||||
# return image
|
||||
|
||||
def prepare_mixing(self, mixing_coeffs, list_latents_mixing):
|
||||
if type(mixing_coeffs) == float:
|
||||
|
@ -718,13 +743,27 @@ class DiffusersHolder():
|
|||
if __name__ == "__main__":
|
||||
from PIL import Image
|
||||
#%%
|
||||
from diffusers import AutoencoderTiny
|
||||
# pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
pretrained_model_name_or_path = "stabilityai/sdxl-turbo"
|
||||
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16)
|
||||
pipe.to('cuda') # xxx
|
||||
|
||||
#%
|
||||
pipe.vae = AutoencoderTiny.from_pretrained('madebyollin/taesdxl', torch_device='cuda', torch_dtype=torch.float16)
|
||||
pipe.vae = pipe.vae.cuda()
|
||||
#%%
|
||||
self = DiffusersHolder(pipe)
|
||||
self.set_num_inference_steps(4)
|
||||
prompt1 = "Photo of a colorful landscape with a blue sky with clouds"
|
||||
text_embeddings1 = self.get_text_embedding(prompt1)
|
||||
latents_start = self.get_noise(seed=420)
|
||||
latents = self.run_diffusion_sd_xl(text_embeddings1, latents_start, idx_start=0, return_image=False)[-1]
|
||||
image = self.latent2image(latents)
|
||||
|
||||
|
||||
|
||||
xxxx
|
||||
# # xxx
|
||||
# self.set_dimensions((512, 512))
|
||||
# self.set_num_inference_steps(4)
|
||||
|
@ -773,49 +812,6 @@ if __name__ == "__main__":
|
|||
|
||||
self.run_diffusion_sd_xl(text_embeddings_mix, latents_start_mixed, idx_start=idx_start, return_image=True)
|
||||
|
||||
#%%
|
||||
fract=0.8
|
||||
latentsmix = interpolate_spherical(latents1[-1], latents2[-1], fract)
|
||||
self.latent2image(latentsmix)
|
||||
|
||||
|
||||
|
||||
#%%
|
||||
|
||||
"""
|
||||
|
||||
|
||||
xxxxx
|
||||
|
||||
# step1: first latents
|
||||
latents1_step1 = pipe(latents=latents_start, guidance_scale=guidance_scale, prompt_embeds=prompt_embeds1, negative_prompt_embeds=negative_prompt_embeds1, pooled_prompt_embeds=pooled_prompt_embeds1, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds1, output_type='latent', timesteps=timesteps_step1)
|
||||
|
||||
# step2: second latents
|
||||
img_diffusion1 = pipe(latents=latents1_step1[0], guidance_scale=guidance_scale, prompt_embeds=prompt_embeds1, negative_prompt_embeds=negative_prompt_embeds1, pooled_prompt_embeds=pooled_prompt_embeds1, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds1, timesteps=timesteps_step2)
|
||||
|
||||
|
||||
#%% img2
|
||||
latents_start = torch.randn((1,4,64//1,64)).half().cuda()
|
||||
|
||||
|
||||
|
||||
# step1: first latents
|
||||
latents2_step1 = pipe(latents=latents_start, guidance_scale=guidance_scale, prompt_embeds=prompt_embeds2, negative_prompt_embeds=negative_prompt_embeds2, pooled_prompt_embeds=pooled_prompt_embeds2, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds2, output_type='latent', timesteps=timesteps_step1)
|
||||
|
||||
# step2: second latents
|
||||
img_diffusion2 = pipe(latents=latents2_step1[0], guidance_scale=guidance_scale, prompt_embeds=prompt_embeds2, negative_prompt_embeds=negative_prompt_embeds2, pooled_prompt_embeds=pooled_prompt_embeds2, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds2, timesteps=timesteps_step2)
|
||||
|
||||
xxx
|
||||
|
||||
#%% find the middle
|
||||
|
||||
prompt_embeds = prompt_embeds1 #interpolate_spherical(prompt_embeds1, prompt_embeds2, 0.5)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds1# interpolate_spherical(pooled_prompt_embeds1, pooled_prompt_embeds2, 0.5)
|
||||
negative_prompt_embeds = negative_prompt_embeds1
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds1
|
||||
|
||||
latents1_stepM = interpolate_spherical(latents1_step1[0], latents2_step1[0], 0.5)
|
||||
|
||||
img_diffusionM = pipe(latents=latents1_stepM, guidance_scale=guidance_scale, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, timesteps=timesteps_step2)
|
||||
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue