fixed latent2image for tinyvae
This commit is contained in:
parent
e9ece062a4
commit
95870823fc
|
@ -152,35 +152,60 @@ class DiffusersHolder():
|
||||||
output_type: "pil" or "np"
|
output_type: "pil" or "np"
|
||||||
"""
|
"""
|
||||||
assert output_type in ["pil", "np"]
|
assert output_type in ["pil", "np"]
|
||||||
if self.use_sd_xl:
|
|
||||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
|
||||||
self.pipe.vae.to(dtype=torch.float32)
|
|
||||||
|
|
||||||
use_torch_2_0_or_xformers = isinstance(
|
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||||
self.pipe.vae.decoder.mid_block.attentions[0].processor,
|
needs_upcasting = self.pipe.vae.dtype == torch.float16 and self.pipe.vae.config.force_upcast
|
||||||
(
|
|
||||||
AttnProcessor2_0,
|
if needs_upcasting:
|
||||||
XFormersAttnProcessor,
|
self.pipe.upcast_vae()
|
||||||
LoRAXFormersAttnProcessor,
|
latents = latents.to(next(iter(self.pipe.vae.post_quant_conv.parameters())).dtype)
|
||||||
LoRAAttnProcessor2_0,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
# if xformers or torch_2_0 is used attention block does not need
|
|
||||||
# to be in float32 which can save lots of memory
|
|
||||||
if use_torch_2_0_or_xformers:
|
|
||||||
self.pipe.vae.post_quant_conv.to(latents.dtype)
|
|
||||||
self.pipe.vae.decoder.conv_in.to(latents.dtype)
|
|
||||||
self.pipe.vae.decoder.mid_block.to(latents.dtype)
|
|
||||||
else:
|
|
||||||
latents = latents.float()
|
|
||||||
|
|
||||||
image = self.pipe.vae.decode(latents / self.pipe.vae.config.scaling_factor, return_dict=False)[0]
|
image = self.pipe.vae.decode(latents / self.pipe.vae.config.scaling_factor, return_dict=False)[0]
|
||||||
image = self.pipe.image_processor.postprocess(image, output_type="pil", do_denormalize=[True] * image.shape[0])[0]
|
|
||||||
if output_type == "np":
|
# cast back to fp16 if needed
|
||||||
return np.asarray(image)
|
if needs_upcasting:
|
||||||
else:
|
self.pipe.vae.to(dtype=torch.float16)
|
||||||
|
|
||||||
|
image = self.pipe.image_processor.postprocess(image, output_type=output_type)[0]
|
||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
# if output_type == "np":
|
||||||
|
# return np.asarray(image)
|
||||||
|
# else:
|
||||||
|
# return image
|
||||||
|
|
||||||
|
|
||||||
|
# # xxx
|
||||||
|
# if self.use_sd_xl:
|
||||||
|
# # make sure the VAE is in float32 mode, as it overflows in float16
|
||||||
|
# self.pipe.vae.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
# use_torch_2_0_or_xformers = isinstance(
|
||||||
|
# self.pipe.vae.decoder.mid_block.attentions[0].processor,
|
||||||
|
# (
|
||||||
|
# AttnProcessor2_0,
|
||||||
|
# XFormersAttnProcessor,
|
||||||
|
# LoRAXFormersAttnProcessor,
|
||||||
|
# LoRAAttnProcessor2_0,
|
||||||
|
# ),
|
||||||
|
# )
|
||||||
|
# # if xformers or torch_2_0 is used attention block does not need
|
||||||
|
# # to be in float32 which can save lots of memory
|
||||||
|
# if use_torch_2_0_or_xformers:
|
||||||
|
# self.pipe.vae.post_quant_conv.to(latents.dtype)
|
||||||
|
# self.pipe.vae.decoder.conv_in.to(latents.dtype)
|
||||||
|
# self.pipe.vae.decoder.mid_block.to(latents.dtype)
|
||||||
|
# else:
|
||||||
|
# latents = latents.float()
|
||||||
|
|
||||||
|
# image = self.pipe.vae.decode(latents / self.pipe.vae.config.scaling_factor, return_dict=False)[0]
|
||||||
|
# image = self.pipe.image_processor.postprocess(image, output_type="pil", do_denormalize=[True] * image.shape[0])[0]
|
||||||
|
# if output_type == "np":
|
||||||
|
# return np.asarray(image)
|
||||||
|
# else:
|
||||||
|
# return image
|
||||||
|
|
||||||
def prepare_mixing(self, mixing_coeffs, list_latents_mixing):
|
def prepare_mixing(self, mixing_coeffs, list_latents_mixing):
|
||||||
if type(mixing_coeffs) == float:
|
if type(mixing_coeffs) == float:
|
||||||
list_mixing_coeffs = (1 + self.num_inference_steps) * [mixing_coeffs]
|
list_mixing_coeffs = (1 + self.num_inference_steps) * [mixing_coeffs]
|
||||||
|
@ -718,13 +743,27 @@ class DiffusersHolder():
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
#%%
|
#%%
|
||||||
|
from diffusers import AutoencoderTiny
|
||||||
# pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
# pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||||
pretrained_model_name_or_path = "stabilityai/sdxl-turbo"
|
pretrained_model_name_or_path = "stabilityai/sdxl-turbo"
|
||||||
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16)
|
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16)
|
||||||
pipe.to('cuda') # xxx
|
pipe.to('cuda') # xxx
|
||||||
|
|
||||||
|
#%
|
||||||
|
pipe.vae = AutoencoderTiny.from_pretrained('madebyollin/taesdxl', torch_device='cuda', torch_dtype=torch.float16)
|
||||||
|
pipe.vae = pipe.vae.cuda()
|
||||||
#%%
|
#%%
|
||||||
|
self = DiffusersHolder(pipe)
|
||||||
|
self.set_num_inference_steps(4)
|
||||||
|
prompt1 = "Photo of a colorful landscape with a blue sky with clouds"
|
||||||
|
text_embeddings1 = self.get_text_embedding(prompt1)
|
||||||
|
latents_start = self.get_noise(seed=420)
|
||||||
|
latents = self.run_diffusion_sd_xl(text_embeddings1, latents_start, idx_start=0, return_image=False)[-1]
|
||||||
|
image = self.latent2image(latents)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
xxxx
|
||||||
# # xxx
|
# # xxx
|
||||||
# self.set_dimensions((512, 512))
|
# self.set_dimensions((512, 512))
|
||||||
# self.set_num_inference_steps(4)
|
# self.set_num_inference_steps(4)
|
||||||
|
@ -773,49 +812,6 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
self.run_diffusion_sd_xl(text_embeddings_mix, latents_start_mixed, idx_start=idx_start, return_image=True)
|
self.run_diffusion_sd_xl(text_embeddings_mix, latents_start_mixed, idx_start=idx_start, return_image=True)
|
||||||
|
|
||||||
#%%
|
|
||||||
fract=0.8
|
|
||||||
latentsmix = interpolate_spherical(latents1[-1], latents2[-1], fract)
|
|
||||||
self.latent2image(latentsmix)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#%%
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
xxxxx
|
|
||||||
|
|
||||||
# step1: first latents
|
|
||||||
latents1_step1 = pipe(latents=latents_start, guidance_scale=guidance_scale, prompt_embeds=prompt_embeds1, negative_prompt_embeds=negative_prompt_embeds1, pooled_prompt_embeds=pooled_prompt_embeds1, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds1, output_type='latent', timesteps=timesteps_step1)
|
|
||||||
|
|
||||||
# step2: second latents
|
|
||||||
img_diffusion1 = pipe(latents=latents1_step1[0], guidance_scale=guidance_scale, prompt_embeds=prompt_embeds1, negative_prompt_embeds=negative_prompt_embeds1, pooled_prompt_embeds=pooled_prompt_embeds1, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds1, timesteps=timesteps_step2)
|
|
||||||
|
|
||||||
|
|
||||||
#%% img2
|
|
||||||
latents_start = torch.randn((1,4,64//1,64)).half().cuda()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# step1: first latents
|
|
||||||
latents2_step1 = pipe(latents=latents_start, guidance_scale=guidance_scale, prompt_embeds=prompt_embeds2, negative_prompt_embeds=negative_prompt_embeds2, pooled_prompt_embeds=pooled_prompt_embeds2, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds2, output_type='latent', timesteps=timesteps_step1)
|
|
||||||
|
|
||||||
# step2: second latents
|
|
||||||
img_diffusion2 = pipe(latents=latents2_step1[0], guidance_scale=guidance_scale, prompt_embeds=prompt_embeds2, negative_prompt_embeds=negative_prompt_embeds2, pooled_prompt_embeds=pooled_prompt_embeds2, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds2, timesteps=timesteps_step2)
|
|
||||||
|
|
||||||
xxx
|
|
||||||
|
|
||||||
#%% find the middle
|
|
||||||
|
|
||||||
prompt_embeds = prompt_embeds1 #interpolate_spherical(prompt_embeds1, prompt_embeds2, 0.5)
|
|
||||||
pooled_prompt_embeds = pooled_prompt_embeds1# interpolate_spherical(pooled_prompt_embeds1, pooled_prompt_embeds2, 0.5)
|
|
||||||
negative_prompt_embeds = negative_prompt_embeds1
|
|
||||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds1
|
|
||||||
|
|
||||||
latents1_stepM = interpolate_spherical(latents1_step1[0], latents2_step1[0], 0.5)
|
|
||||||
|
|
||||||
img_diffusionM = pipe(latents=latents1_stepM, guidance_scale=guidance_scale, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, timesteps=timesteps_step2)
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
Loading…
Reference in New Issue