2023-07-20 11:49:19 +00:00
# Copyright 2022 Lunar Ring. All rights reserved.
# Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
torch . backends . cudnn . benchmark = False
torch . set_grad_enabled ( False )
import numpy as np
import warnings
warnings . filterwarnings ( ' ignore ' )
import warnings
import torch
from PIL import Image
import torch
from typing import Optional
from torch import autocast
from contextlib import nullcontext
from utils import interpolate_spherical
2023-07-20 13:45:06 +00:00
from diffusers import DiffusionPipeline , StableDiffusionControlNetPipeline , ControlNetModel
2023-07-20 11:49:19 +00:00
from diffusers . models . attention_processor import (
AttnProcessor2_0 ,
LoRAAttnProcessor2_0 ,
LoRAXFormersAttnProcessor ,
XFormersAttnProcessor ,
)
class DiffusersHolder ( ) :
def __init__ ( self , pipe ) :
# Base settings
self . negative_prompt = " "
self . guidance_scale = 5.0
self . num_inference_steps = 30
# Check if valid pipe
self . pipe = pipe
self . device = str ( pipe . _execution_device )
2023-07-20 13:45:06 +00:00
self . init_types ( )
2023-07-20 11:49:19 +00:00
self . width_latent = self . pipe . unet . config . sample_size
self . height_latent = self . pipe . unet . config . sample_size
2023-07-20 13:45:06 +00:00
def init_types ( self ) :
assert hasattr ( self . pipe , " __class__ " ) , " No valid diffusers pipeline found. "
assert hasattr ( self . pipe . __class__ , " __name__ " ) , " No valid diffusers pipeline found. "
if self . pipe . __class__ . __name__ == ' StableDiffusionXLPipeline ' :
2023-07-20 11:49:19 +00:00
self . pipe . scheduler . set_timesteps ( self . num_inference_steps , device = self . device )
self . use_sd_xl = True
2023-07-20 13:45:06 +00:00
prompt_embeds , _ , _ , _ = self . pipe . encode_prompt ( " test " )
2023-07-20 11:49:19 +00:00
else :
self . use_sd_xl = False
2023-07-20 13:45:06 +00:00
prompt_embeds = self . pipe . _encode_prompt ( " test " , self . device , 1 , True )
self . dtype = prompt_embeds . dtype
2023-07-20 11:49:19 +00:00
def set_num_inference_steps ( self , num_inference_steps ) :
self . num_inference_steps = num_inference_steps
if self . use_sd_xl :
self . pipe . scheduler . set_timesteps ( self . num_inference_steps , device = self . device )
def set_dimensions ( self , width , height ) :
s = self . pipe . vae_scale_factor
if width is None :
self . width_latent = self . pipe . unet . config . sample_size
self . width_img = self . width_latent * self . pipe . vae_scale_factor
else :
self . width_img = int ( round ( width / s ) * s )
self . width_latent = int ( self . width_img / s )
if height is None :
self . height_latent = self . pipe . unet . config . sample_size
self . height_img = self . width_latent * self . pipe . vae_scale_factor
else :
self . height_img = int ( round ( height / s ) * s )
self . height_latent = int ( self . height_img / s )
def set_negative_prompt ( self , negative_prompt ) :
r """ Set the negative prompt. Currenty only one negative prompt is supported
"""
if isinstance ( negative_prompt , str ) :
self . negative_prompt = [ negative_prompt ]
else :
self . negative_prompt = negative_prompt
if len ( self . negative_prompt ) > 1 :
self . negative_prompt = [ self . negative_prompt [ 0 ] ]
2023-07-20 13:45:06 +00:00
2023-07-20 11:49:19 +00:00
def get_text_embedding ( self , prompt , do_classifier_free_guidance = True ) :
if self . use_sd_xl :
pr_encoder = self . pipe . encode_prompt
else :
pr_encoder = self . pipe . _encode_prompt
prompt_embeds = pr_encoder (
2023-10-11 10:17:15 +00:00
prompt = prompt ,
device = self . device ,
num_images_per_prompt = 1 ,
do_classifier_free_guidance = do_classifier_free_guidance ,
2023-07-20 11:49:19 +00:00
negative_prompt = self . negative_prompt ,
prompt_embeds = None ,
negative_prompt_embeds = None ,
lora_scale = None ,
)
return prompt_embeds
def get_noise ( self , seed = 420 , mode = None ) :
H = self . height_latent
W = self . width_latent
C = self . pipe . unet . config . in_channels
generator = torch . Generator ( device = self . device ) . manual_seed ( int ( seed ) )
latents = torch . randn ( ( 1 , C , H , W ) , generator = generator , dtype = self . dtype , device = self . device )
if self . use_sd_xl :
latents = latents * self . pipe . scheduler . init_noise_sigma
return latents
@torch.no_grad ( )
def latent2image (
self ,
2023-10-11 10:17:15 +00:00
latents : torch . FloatTensor ,
convert_numpy = True ) :
2023-07-20 11:49:19 +00:00
r """
Returns an image provided a latent representation from diffusion .
Args :
latents : torch . FloatTensor
Result of the diffusion process .
2023-10-11 10:17:15 +00:00
convert_numpy : if converting to numpy
2023-07-20 11:49:19 +00:00
"""
if self . use_sd_xl :
# make sure the VAE is in float32 mode, as it overflows in float16
self . pipe . vae . to ( dtype = torch . float32 )
use_torch_2_0_or_xformers = isinstance (
self . pipe . vae . decoder . mid_block . attentions [ 0 ] . processor ,
(
AttnProcessor2_0 ,
XFormersAttnProcessor ,
LoRAXFormersAttnProcessor ,
LoRAAttnProcessor2_0 ,
) ,
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers :
self . pipe . vae . post_quant_conv . to ( latents . dtype )
self . pipe . vae . decoder . conv_in . to ( latents . dtype )
self . pipe . vae . decoder . mid_block . to ( latents . dtype )
else :
latents = latents . float ( )
image = self . pipe . vae . decode ( latents / self . pipe . vae . config . scaling_factor , return_dict = False ) [ 0 ]
2023-10-11 10:17:15 +00:00
image = self . pipe . image_processor . postprocess ( image , output_type = " pil " , do_denormalize = [ True ] * image . shape [ 0 ] ) [ 0 ]
if convert_numpy :
return np . asarray ( image )
else :
return image
2023-07-20 11:49:19 +00:00
2023-07-21 12:03:02 +00:00
def prepare_mixing ( self , mixing_coeffs , list_latents_mixing ) :
if type ( mixing_coeffs ) == float :
list_mixing_coeffs = ( 1 + self . num_inference_steps ) * [ mixing_coeffs ]
elif type ( mixing_coeffs ) == list :
assert len ( mixing_coeffs ) == self . num_inference_steps , f " len(mixing_coeffs) { len ( mixing_coeffs ) } != self.num_inference_steps { self . num_inference_steps } "
list_mixing_coeffs = mixing_coeffs
else :
raise ValueError ( " mixing_coeffs should be float or list with len=num_inference_steps " )
if np . sum ( list_mixing_coeffs ) > 0 :
assert len ( list_latents_mixing ) == self . num_inference_steps , f " len(list_latents_mixing) { len ( list_latents_mixing ) } != self.num_inference_steps { self . num_inference_steps } "
return list_mixing_coeffs
2023-07-20 11:49:19 +00:00
@torch.no_grad ( )
2023-07-20 13:45:06 +00:00
def run_diffusion (
self ,
text_embeddings : torch . FloatTensor ,
latents_start : torch . FloatTensor ,
idx_start : int = 0 ,
list_latents_mixing = None ,
mixing_coeffs = 0.0 ,
return_image : Optional [ bool ] = False ) :
2023-07-21 12:03:02 +00:00
2023-07-20 13:45:06 +00:00
if self . pipe . __class__ . __name__ == ' StableDiffusionXLPipeline ' :
return self . run_diffusion_sd_xl ( text_embeddings , latents_start , idx_start , list_latents_mixing , mixing_coeffs , return_image )
elif self . pipe . __class__ . __name__ == ' StableDiffusionPipeline ' :
return self . run_diffusion_sd12x ( text_embeddings , latents_start , idx_start , list_latents_mixing , mixing_coeffs , return_image )
elif self . pipe . __class__ . __name__ == ' StableDiffusionControlNetPipeline ' :
pass
@torch.no_grad ( )
def run_diffusion_sd12x (
2023-07-20 11:49:19 +00:00
self ,
text_embeddings : torch . FloatTensor ,
latents_start : torch . FloatTensor ,
idx_start : int = 0 ,
list_latents_mixing = None ,
mixing_coeffs = 0.0 ,
return_image : Optional [ bool ] = False ) :
2023-07-21 12:03:02 +00:00
list_mixing_coeffs = self . prepare_mixing ( )
2023-07-20 11:49:19 +00:00
do_classifier_free_guidance = self . guidance_scale > 1.0
2023-07-21 12:03:02 +00:00
# accomodate different sd model types
self . pipe . scheduler . set_timesteps ( self . num_inference_steps - 1 , device = self . device )
2023-07-20 11:49:19 +00:00
timesteps = self . pipe . scheduler . timesteps
2023-07-21 12:03:02 +00:00
2023-07-20 11:49:19 +00:00
if len ( timesteps ) != self . num_inference_steps :
self . pipe . scheduler . set_timesteps ( self . num_inference_steps , device = self . device )
timesteps = self . pipe . scheduler . timesteps
2023-07-21 12:03:02 +00:00
2023-07-20 11:49:19 +00:00
latents = latents_start . clone ( )
list_latents_out = [ ]
for i , t in enumerate ( timesteps ) :
# Set the right starting latents
if i < idx_start :
list_latents_out . append ( None )
continue
elif i == idx_start :
latents = latents_start . clone ( )
# Mix latents
if i > 0 and list_mixing_coeffs [ i ] > 0 :
latents_mixtarget = list_latents_mixing [ i - 1 ] . clone ( )
latents = interpolate_spherical ( latents , latents_mixtarget , list_mixing_coeffs [ i ] )
2023-07-21 12:03:02 +00:00
2023-07-20 11:49:19 +00:00
# expand the latents if we are doing classifier free guidance
latent_model_input = torch . cat ( [ latents ] * 2 ) if do_classifier_free_guidance else latents
latent_model_input = self . pipe . scheduler . scale_model_input ( latent_model_input , t )
2023-07-21 12:03:02 +00:00
2023-07-20 11:49:19 +00:00
# predict the noise residual
noise_pred = self . pipe . unet (
latent_model_input ,
t ,
encoder_hidden_states = text_embeddings ,
return_dict = False ,
) [ 0 ]
if do_classifier_free_guidance :
noise_pred_uncond , noise_pred_text = noise_pred . chunk ( 2 )
noise_pred = noise_pred_uncond + self . guidance_scale * ( noise_pred_text - noise_pred_uncond )
# compute the previous noisy sample x_t -> x_t-1
latents = self . pipe . scheduler . step ( noise_pred , t , latents , return_dict = False ) [ 0 ]
list_latents_out . append ( latents . clone ( ) )
2023-07-21 12:03:02 +00:00
2023-07-20 11:49:19 +00:00
if return_image :
return self . latent2image ( latents )
else :
return list_latents_out
@torch.no_grad ( )
def run_diffusion_sd_xl (
self ,
text_embeddings : list ,
latents_start : torch . FloatTensor ,
idx_start : int = 0 ,
list_latents_mixing = None ,
mixing_coeffs = 0.0 ,
return_image : Optional [ bool ] = False ) :
# 0. Default height and width to unet
2023-10-11 10:17:15 +00:00
original_size = ( self . width_img , self . height_img ) # FIXME
2023-07-20 11:49:19 +00:00
crops_coords_top_left = ( 0 , 0 ) # FIXME
target_size = original_size
batch_size = 1
eta = 0.0
num_images_per_prompt = 1
cross_attention_kwargs = None
generator = torch . Generator ( device = self . device ) # dummy generator
do_classifier_free_guidance = self . guidance_scale > 1.0
# 1. Check inputs. Raise error if not correct & 2. Define call parameters
2023-10-11 10:17:15 +00:00
list_mixing_coeffs = self . prepare_mixing ( mixing_coeffs , list_latents_mixing )
2023-07-21 12:03:02 +00:00
2023-07-20 11:49:19 +00:00
# 3. Encode input prompt (already encoded outside bc of mixing, just split here)
prompt_embeds , negative_prompt_embeds , pooled_prompt_embeds , negative_pooled_prompt_embeds = text_embeddings
# 4. Prepare timesteps
self . pipe . scheduler . set_timesteps ( self . num_inference_steps , device = self . device )
timesteps = self . pipe . scheduler . timesteps
# 5. Prepare latent variables
latents = latents_start . clone ( )
list_latents_out = [ ]
# 6. Prepare extra step kwargs. usedummy generator
extra_step_kwargs = self . pipe . prepare_extra_step_kwargs ( generator , eta ) # dummy
# 7. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds
add_time_ids = self . pipe . _get_add_time_ids (
original_size , crops_coords_top_left , target_size , dtype = prompt_embeds . dtype
)
if do_classifier_free_guidance :
prompt_embeds = torch . cat ( [ negative_prompt_embeds , prompt_embeds ] , dim = 0 )
add_text_embeds = torch . cat ( [ negative_pooled_prompt_embeds , add_text_embeds ] , dim = 0 )
add_time_ids = torch . cat ( [ add_time_ids , add_time_ids ] , dim = 0 )
prompt_embeds = prompt_embeds . to ( self . device )
add_text_embeds = add_text_embeds . to ( self . device )
add_time_ids = add_time_ids . to ( self . device ) . repeat ( batch_size * num_images_per_prompt , 1 )
# 8. Denoising loop
for i , t in enumerate ( timesteps ) :
# Set the right starting latents
if i < idx_start :
list_latents_out . append ( None )
continue
elif i == idx_start :
latents = latents_start . clone ( )
# Mix latents for crossfeeding
if i > 0 and list_mixing_coeffs [ i ] > 0 :
latents_mixtarget = list_latents_mixing [ i - 1 ] . clone ( )
latents = interpolate_spherical ( latents , latents_mixtarget , list_mixing_coeffs [ i ] )
# expand the latents if we are doing classifier free guidance
latent_model_input = torch . cat ( [ latents ] * 2 ) if do_classifier_free_guidance else latents
# Always scale latents
latent_model_input = self . pipe . scheduler . scale_model_input ( latent_model_input , t )
# predict the noise residual
added_cond_kwargs = { " text_embeds " : add_text_embeds , " time_ids " : add_time_ids }
noise_pred = self . pipe . unet (
latent_model_input ,
t ,
encoder_hidden_states = prompt_embeds ,
cross_attention_kwargs = cross_attention_kwargs ,
added_cond_kwargs = added_cond_kwargs ,
return_dict = False ,
) [ 0 ]
# perform guidance
if do_classifier_free_guidance :
noise_pred_uncond , noise_pred_text = noise_pred . chunk ( 2 )
noise_pred = noise_pred_uncond + self . guidance_scale * ( noise_pred_text - noise_pred_uncond )
# FIXME guidance_rescale disabled
# compute the previous noisy sample x_t -> x_t-1
latents = self . pipe . scheduler . step ( noise_pred , t , latents , * * extra_step_kwargs , return_dict = False ) [ 0 ]
# Append latents
list_latents_out . append ( latents . clone ( ) )
if return_image :
return self . latent2image ( latents )
else :
return list_latents_out
2023-07-20 13:45:06 +00:00
@torch.no_grad ( )
def run_diffusion_controlnet (
self ,
conditioning : list ,
latents_start : torch . FloatTensor ,
idx_start : int = 0 ,
list_latents_mixing = None ,
mixing_coeffs = 0.0 ,
return_image : Optional [ bool ] = False ) :
2023-07-21 12:03:02 +00:00
2023-07-20 13:45:06 +00:00
prompt_embeds = conditioning [ 0 ]
2023-07-21 12:03:02 +00:00
image = conditioning [ 1 ]
list_mixing_coeffs = self . prepare_mixing ( )
2023-07-20 13:45:06 +00:00
controlnet = self . pipe . controlnet
control_guidance_start = [ 0.0 ]
control_guidance_end = [ 1.0 ]
guess_mode = False
num_images_per_prompt = 1
batch_size = 1
eta = 0.0
controlnet_conditioning_scale = 1.0
# align format for control guidance
if not isinstance ( control_guidance_start , list ) and isinstance ( control_guidance_end , list ) :
control_guidance_start = len ( control_guidance_end ) * [ control_guidance_start ]
elif not isinstance ( control_guidance_end , list ) and isinstance ( control_guidance_start , list ) :
control_guidance_end = len ( control_guidance_start ) * [ control_guidance_end ]
# 2. Define call parameters
device = self . pipe . _execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = self . guidance_scale > 1.0
# 4. Prepare image
image = self . pipe . prepare_image (
image = image ,
width = None ,
height = None ,
batch_size = batch_size * num_images_per_prompt ,
num_images_per_prompt = num_images_per_prompt ,
device = self . device ,
dtype = controlnet . dtype ,
do_classifier_free_guidance = do_classifier_free_guidance ,
guess_mode = guess_mode ,
)
height , width = image . shape [ - 2 : ]
# 5. Prepare timesteps
self . pipe . scheduler . set_timesteps ( self . num_inference_steps , device = self . device )
timesteps = self . pipe . scheduler . timesteps
# 6. Prepare latent variables
generator = torch . Generator ( device = self . device ) . manual_seed ( int ( 420 ) )
latents = latents_start . clone ( )
2023-07-21 12:03:02 +00:00
list_latents_out = [ ]
2023-07-20 13:45:06 +00:00
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self . pipe . prepare_extra_step_kwargs ( generator , eta )
# 7.1 Create tensor stating which controlnets to keep
controlnet_keep = [ ]
for i in range ( len ( timesteps ) ) :
keeps = [
1.0 - float ( i / len ( timesteps ) < s or ( i + 1 ) / len ( timesteps ) > e )
for s , e in zip ( control_guidance_start , control_guidance_end )
]
controlnet_keep . append ( keeps [ 0 ] if len ( keeps ) == 1 else keeps )
# 8. Denoising loop
for i , t in enumerate ( timesteps ) :
2023-07-21 12:03:02 +00:00
if i < idx_start :
list_latents_out . append ( None )
continue
elif i == idx_start :
latents = latents_start . clone ( )
# Mix latents for crossfeeding
if i > 0 and list_mixing_coeffs [ i ] > 0 :
latents_mixtarget = list_latents_mixing [ i - 1 ] . clone ( )
latents = interpolate_spherical ( latents , latents_mixtarget , list_mixing_coeffs [ i ] )
2023-07-20 13:45:06 +00:00
# expand the latents if we are doing classifier free guidance
latent_model_input = torch . cat ( [ latents ] * 2 ) if do_classifier_free_guidance else latents
latent_model_input = self . pipe . scheduler . scale_model_input ( latent_model_input , t )
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds
if isinstance ( controlnet_keep [ i ] , list ) :
cond_scale = [ c * s for c , s in zip ( controlnet_conditioning_scale , controlnet_keep [ i ] ) ]
else :
cond_scale = controlnet_conditioning_scale * controlnet_keep [ i ]
down_block_res_samples , mid_block_res_sample = self . pipe . controlnet (
control_model_input ,
t ,
encoder_hidden_states = controlnet_prompt_embeds ,
controlnet_cond = image ,
conditioning_scale = cond_scale ,
guess_mode = guess_mode ,
return_dict = False ,
)
if guess_mode and do_classifier_free_guidance :
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [ torch . cat ( [ torch . zeros_like ( d ) , d ] ) for d in down_block_res_samples ]
mid_block_res_sample = torch . cat ( [ torch . zeros_like ( mid_block_res_sample ) , mid_block_res_sample ] )
# predict the noise residual
noise_pred = self . pipe . unet (
latent_model_input ,
t ,
encoder_hidden_states = prompt_embeds ,
cross_attention_kwargs = None ,
down_block_additional_residuals = down_block_res_samples ,
mid_block_additional_residual = mid_block_res_sample ,
return_dict = False ,
) [ 0 ]
# perform guidance
if do_classifier_free_guidance :
noise_pred_uncond , noise_pred_text = noise_pred . chunk ( 2 )
noise_pred = noise_pred_uncond + self . guidance_scale * ( noise_pred_text - noise_pred_uncond )
# compute the previous noisy sample x_t -> x_t-1
latents = self . pipe . scheduler . step ( noise_pred , t , latents , * * extra_step_kwargs , return_dict = False ) [ 0 ]
2023-07-21 12:03:02 +00:00
# Append latents
list_latents_out . append ( latents . clone ( ) )
if return_image :
return self . latent2image ( latents )
else :
return list_latents_out
2023-07-20 13:45:06 +00:00
#%%
"""
steps :
x get controlnet vanilla running .
- externalize conditions
- have conditions as input ( use one list )
- include latent blending
- test latent blending
- have lora and latent blending
"""
2023-07-20 11:49:19 +00:00
#%%
if __name__ == " __main__ " :
2023-07-20 13:45:06 +00:00
2023-10-11 10:17:15 +00:00
#%%
pretrained_model_name_or_path = " stabilityai/stable-diffusion-xl-base-1.0 "
pipe = DiffusionPipeline . from_pretrained ( pretrained_model_name_or_path , torch_dtype = torch . float16 )
pipe . to ( ' cuda:1 ' ) # xxx
2023-07-20 13:45:06 +00:00
2023-10-11 10:17:15 +00:00
#%%
2023-07-20 11:49:19 +00:00
self = DiffusersHolder ( pipe )
2023-10-11 10:17:15 +00:00
# xxx
self . set_dimensions ( 1024 , 704 )
self . set_num_inference_steps ( 40 )
# self.set_dimensions(1536, 1024)
prompt = " Surreal painting of eerie, nebulous glow of an indigo moon, a spine-chilling spectacle unfolds; a baroque, marbled hand reaches out from a viscous, purple lake clutching a melting clock, its face distorted in a never-ending scream of hysteria, while a cluster of laughing orchids, their petals morphed into grotesque human lips, festoon a crimson tree weeping blood instead of sap, a psychedelic cat with an unnaturally playful grin and mismatched eyes lounges atop a floating vintage television showing static, an albino peacock with iridescent, crystalline feathers dances around a towering, inverted pyramid on top of which a humanoid figure with an octopus head lounges seductively, all against the backdrop of a sprawling cityscape where buildings are inverted and writhing as if alive, and the sky is punctuated by floating aquatic creatures glowing neon, adding a touch of haunting beauty to this otherwise deeply unsettling tableau "
text_embeddings = self . get_text_embedding ( prompt )
generator = torch . Generator ( device = self . device ) . manual_seed ( int ( 420 ) )
latents_start = self . get_noise ( )
list_latents_1 = self . run_diffusion ( text_embeddings , latents_start )
img_orig = self . latent2image ( list_latents_1 [ - 1 ] )
2023-07-20 13:45:06 +00:00
2023-07-20 11:49:19 +00:00
# %%
"""
OPEN
2023-07-20 13:45:06 +00:00
- rename text encodings to conditionings
2023-07-20 11:49:19 +00:00
- other examples
- kill upscaling ? or keep ?
- cleanup
- ldh
- sdh class
- diffusion holder
- check linting
- check docstrings
- fix readme
"""