new transition engine and crossfeeding

This commit is contained in:
Johannes Stelzer 2023-02-15 18:21:00 +01:00
parent 5e979818b2
commit 0371868603
3 changed files with 480 additions and 150 deletions

View File

@ -33,6 +33,13 @@ import gradio as gr
import copy import copy
"""
TODOS:
- clean parameter handling
- three buttons: diffuse A, diffuse B, make transition
- collapse for easy mode
- transition quality in terms of render time
"""
#%% #%%
@ -45,7 +52,7 @@ class BlendingFrontend():
self.lb = LatentBlending(sdh) self.lb = LatentBlending(sdh)
self.share = True self.share = True
self.num_inference_steps = 20 self.num_inference_steps = 30
self.depth_strength = 0.25 self.depth_strength = 0.25
self.seed1 = 42 self.seed1 = 42
self.seed2 = 420 self.seed2 = 420
@ -58,11 +65,13 @@ class BlendingFrontend():
self.list_settings = [] self.list_settings = []
self.state_current = {} self.state_current = {}
self.showing_current = True self.showing_current = True
self.branch1_influence = 0.02 self.branch1_influence = 0.1
self.branch1_mixing_depth = 0.3
self.nmb_branches_final = 9 self.nmb_branches_final = 9
self.nmb_imgs_show = 5 # don't change self.nmb_imgs_show = 5 # don't change
self.fps = 30 self.fps = 30
self.duration = 10 self.duration_video = 15
self.t_compute_max_allowed = 15
self.dict_multi_trans = {} self.dict_multi_trans = {}
self.dict_multi_trans_include = {} self.dict_multi_trans_include = {}
self.multi_trans_currently_shown = [] self.multi_trans_currently_shown = []
@ -79,114 +88,87 @@ class BlendingFrontend():
self.width = 768 self.width = 768
# make dummy image # make dummy image
def save_empty_image(self):
self.fp_img_empty = 'empty.jpg' self.fp_img_empty = 'empty.jpg'
Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5) Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5)
def change_depth_strength(self, value):
self.depth_strength = value
print(f"changed depth_strength to {value}")
def change_num_inference_steps(self, value):
self.num_inference_steps = value
print(f"changed num_inference_steps to {value}")
def change_guidance_scale(self, value):
self.guidance_scale = value
self.lb.set_guidance_scale(value)
print(f"changed guidance_scale to {value}")
def change_guidance_scale_mid_damper(self, value):
self.guidance_scale_mid_damper = value
print(f"changed guidance_scale_mid_damper to {value}")
def change_mid_compression_scaler(self, value):
self.mid_compression_scaler = value
print(f"changed mid_compression_scaler to {value}")
def change_branch1_influence(self, value):
self.branch1_influence = value
print(f"changed branch1_influence to {value}")
def change_height(self, value):
self.height = value
print(f"changed height to {value}")
def change_width(self, value):
self.width = value
print(f"changed width to {value}")
def change_nmb_branches_final(self, value):
self.nmb_branches_final = value
print(f"changed nmb_branches_final to {value}")
def change_duration(self, value):
self.duration = value
print(f"changed duration to {value}")
def change_fps(self, value):
self.fps = value
print(f"changed fps to {value}")
def change_negative_prompt(self, value):
self.negative_prompt = value
def change_seed1(self, value):
self.seed1 = int(value)
def change_seed2(self, value):
self.seed2 = int(value)
def randomize_seed1(self): def randomize_seed1(self):
seed = np.random.randint(0, 10000000) seed = np.random.randint(0, 10000000)
self.change_seed1(seed) self.seed1 = int(seed)
print(f"randomize_seed1: new seed = {self.seed1}") print(f"randomize_seed1: new seed = {self.seed1}")
return seed return seed
def randomize_seed2(self): def randomize_seed2(self):
seed = np.random.randint(0, 10000000) seed = np.random.randint(0, 10000000)
self.change_seed2(seed) self.seed2 = int(seed)
print(f"randomize_seed2: new seed = {self.seed2}") print(f"randomize_seed2: new seed = {self.seed2}")
return seed return seed
def compute_transition(self, prompt1, prompt2): def setup_lb(self, list_ui_elem):
self.prompt1 = prompt1 # Collect latent blending variables
self.prompt2 = prompt2
print("STARTING DIFFUSION!")
self.state_current = self.get_state_dict() self.state_current = self.get_state_dict()
self.lb.set_width(list_ui_elem[list_ui_keys.index('width')])
self.lb.set_height(list_ui_elem[list_ui_keys.index('height')])
self.lb.set_prompt1(list_ui_elem[list_ui_keys.index('prompt1')])
self.lb.set_prompt2(list_ui_elem[list_ui_keys.index('prompt2')])
self.lb.set_negative_prompt(list_ui_elem[list_ui_keys.index('negative_prompt')])
self.lb.guidance_scale = list_ui_elem[list_ui_keys.index('guidance_scale')]
self.lb.guidance_scale_mid_damper = list_ui_elem[list_ui_keys.index('guidance_scale_mid_damper')]
self.lb.branch1_influence = list_ui_elem[list_ui_keys.index('branch1_influence')]
self.lb.branch1_mixing_depth = list_ui_elem[list_ui_keys.index('branch1_mixing_depth')]
self.lb.t_compute_max_allowed = list_ui_elem[list_ui_keys.index('duration_compute')]
self.lb.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
self.lb.sdh.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
self.duration_video = list_ui_elem[list_ui_keys.index('duration_video')]
self.lb.seed1 = list_ui_elem[list_ui_keys.index('seed1')]
self.lb.seed2 = list_ui_elem[list_ui_keys.index('seed2')]
def compute_img1(self, *args):
list_ui_elem = args
self.setup_lb(list_ui_elem)
fp_img1 = f"img1_{get_time('second')}.jpg"
img1 = Image.fromarray(self.lb.compute_latents1(return_image=True))
img1.save(fp_img1)
self.save_empty_image()
return [fp_img1, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty]
def compute_img2(self, *args):
list_ui_elem = args
self.setup_lb(list_ui_elem)
fp_img2 = f"img2_{get_time('second')}.jpg"
img2 = Image.fromarray(self.lb.compute_latents2(return_image=True))
img2.save(fp_img2)
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2]
def compute_transition(self, *args):
list_ui_elem = args
self.setup_lb(list_ui_elem)
print("STARTING DIFFUSION!")
if self.use_debug: if self.use_debug:
list_imgs = [(255*np.random.rand(self.height,self.width,3)).astype(np.uint8) for l in range(5)] list_imgs = [(255*np.random.rand(self.height,self.width,3)).astype(np.uint8) for l in range(5)]
list_imgs = [Image.fromarray(l) for l in list_imgs] list_imgs = [Image.fromarray(l) for l in list_imgs]
print("DONE! SENDING BACK RESULTS") print("DONE! SENDING BACK RESULTS")
return list_imgs return list_imgs
# Collect latent blending variables
self.lb.set_width(self.width)
self.lb.set_height(self.height)
self.lb.autosetup_branching(
depth_strength = self.depth_strength,
num_inference_steps = self.num_inference_steps,
nmb_branches_final = self.nmb_branches_final,
nmb_mindist = 3)
self.lb.set_prompt1(self.prompt1)
self.lb.set_prompt2(self.prompt2)
self.lb.set_negative_prompt(self.negative_prompt)
self.lb.guidance_scale = self.guidance_scale
self.lb.guidance_scale_mid_damper = self.guidance_scale_mid_damper
self.lb.mid_compression_scaler = self.mid_compression_scaler
self.lb.branch1_influence = self.branch1_influence
fixed_seeds = [self.seed1, self.seed2] fixed_seeds = [self.seed1, self.seed2]
# Run Latent Blending # Run Latent Blending
imgs_transition = self.lb.run_transition(fixed_seeds=fixed_seeds) imgs_transition = self.lb.run_transition(
recycle_img1=True,
recycle_img2=True,
num_inference_steps=self.num_inference_steps,
depth_strength=self.depth_strength,
fixed_seeds=fixed_seeds
)
print(f"Latent Blending pass finished. Resulted in {len(imgs_transition)} images") print(f"Latent Blending pass finished. Resulted in {len(imgs_transition)} images")
# Subselect the preview images (hard fixed to self.nmb_imgs_show=5) # Subselect three preview images
assert np.mod((self.nmb_branches_final-self.nmb_imgs_show)/4, 1)==0, 'self.nmb_branches_final illegal value!' idx_img_prev = np.round(np.linspace(0, len(imgs_transition)-1, 5)[1:-1]).astype(np.int32)
idx_list = np.linspace(0, self.nmb_branches_final-1, self.nmb_imgs_show).astype(np.int32)
list_imgs_preview = [] list_imgs_preview = []
for j in idx_list: for j in idx_img_prev:
list_imgs_preview.append(Image.fromarray(imgs_transition[j])) list_imgs_preview.append(Image.fromarray(imgs_transition[j]))
# Save the preview imgs as jpgs on disk so we are not sending umcompressed data around # Save the preview imgs as jpgs on disk so we are not sending umcompressed data around
@ -198,7 +180,7 @@ class BlendingFrontend():
self.list_fp_imgs_current.append(fp_img) self.list_fp_imgs_current.append(fp_img)
# Insert cheap frames for the movie # Insert cheap frames for the movie
imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration, self.fps) imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration_video, self.fps)
# Save as movie # Save as movie
fp_movie = self.get_fp_movie(self.current_timestamp) fp_movie = self.get_fp_movie(self.current_timestamp)
@ -330,35 +312,44 @@ if __name__ == "__main__":
sdh = StableDiffusionHolder(fp_ckpt) sdh = StableDiffusionHolder(fp_ckpt)
self = BlendingFrontend(sdh) # Yes this is possible in python and yes it is an awesome trick self = BlendingFrontend(sdh) # Yes this is possible in python and yes it is an awesome trick
# self = BlendingFrontend(None) # Yes this is possible in python and yes it is an awesome trick
dict_ui_elem = {}
with gr.Blocks() as demo: with gr.Blocks() as demo:
with gr.Row(): with gr.Row():
prompt1 = gr.Textbox(label="prompt 1") prompt1 = gr.Textbox(label="prompt 1")
prompt2 = gr.Textbox(label="prompt 2")
negative_prompt = gr.Textbox(label="negative prompt") negative_prompt = gr.Textbox(label="negative prompt")
prompt2 = gr.Textbox(label="prompt 2")
with gr.Row(): with gr.Row():
nmb_branches_final = gr.Slider(5, 125, self.nmb_branches_final, step=4, label='nmb trans images', interactive=True) duration_compute = gr.Slider(10, 40, self.duration_video, step=1, label='compute budget for transition (seconds)', interactive=True)
duration_video = gr.Slider(0.1, 30, self.duration_video, step=0.1, label='result video duration (seconds)', interactive=True)
height = gr.Slider(256, 2048, self.height, step=128, label='height', interactive=True) height = gr.Slider(256, 2048, self.height, step=128, label='height', interactive=True)
width = gr.Slider(256, 2048, self.width, step=128, label='width', interactive=True) width = gr.Slider(256, 2048, self.width, step=128, label='width', interactive=True)
with gr.Row(): with gr.Accordion("Advanced Settings (click to expand)", open=False):
num_inference_steps = gr.Slider(5, 100, self.num_inference_steps, step=1, label='num_inference_steps', interactive=True)
branch1_influence = gr.Slider(0.0, 1.0, self.branch1_influence, step=0.01, label='branch1_influence', interactive=True)
guidance_scale = gr.Slider(1, 25, self.guidance_scale, step=0.1, label='guidance_scale', interactive=True)
with gr.Row(): with gr.Row():
depth_strength = gr.Slider(0.01, 0.99, self.depth_strength, step=0.01, label='depth_strength', interactive=True) depth_strength = gr.Slider(0.01, 0.99, self.depth_strength, step=0.01, label='depth_strength', interactive=True)
duration = gr.Slider(0.1, 30, self.duration, step=0.1, label='video duration', interactive=True) branch1_influence = gr.Slider(0.0, 1.0, self.branch1_influence, step=0.01, label='branch1_influence', interactive=True)
branch1_mixing_depth = gr.Slider(0.0, 1.0, self.branch1_mixing_depth, step=0.01, label='branch1_mixing_depth', interactive=True)
with gr.Row():
num_inference_steps = gr.Slider(5, 100, self.num_inference_steps, step=1, label='num_inference_steps', interactive=True)
guidance_scale = gr.Slider(1, 25, self.guidance_scale, step=0.1, label='guidance_scale', interactive=True)
guidance_scale_mid_damper = gr.Slider(0.01, 2.0, self.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True) guidance_scale_mid_damper = gr.Slider(0.01, 2.0, self.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
with gr.Row(): with gr.Row():
seed1 = gr.Number(42, label="seed 1", interactive=True) seed1 = gr.Number(420, label="seed 1", interactive=True)
b_newseed1 = gr.Button("randomize seed 1", variant='secondary') b_newseed1 = gr.Button("randomize seed 1", variant='secondary')
seed2 = gr.Number(420, label="seed 2", interactive=True) seed2 = gr.Number(420, label="seed 2", interactive=True)
b_newseed2 = gr.Button("randomize seed 2", variant='secondary') b_newseed2 = gr.Button("randomize seed 2", variant='secondary')
with gr.Row(): with gr.Row():
b_compute1 = gr.Button('compute first image', variant='primary')
b_compute_transition = gr.Button('compute transition', variant='primary') b_compute_transition = gr.Button('compute transition', variant='primary')
b_compute2 = gr.Button('compute last image', variant='primary')
with gr.Row(): with gr.Row():
img1 = gr.Image(label="1/5") img1 = gr.Image(label="1/5")
@ -370,31 +361,40 @@ if __name__ == "__main__":
with gr.Row(): with gr.Row():
vid_transition = gr.Video() vid_transition = gr.Video()
# Bind the on-change methods # Collect all UI elemts in list to easily pass as inputs
depth_strength.change(fn=self.change_depth_strength, inputs=depth_strength) dict_ui_elem["prompt1"] = prompt1
num_inference_steps.change(fn=self.change_num_inference_steps, inputs=num_inference_steps) dict_ui_elem["negative_prompt"] = negative_prompt
nmb_branches_final.change(fn=self.change_nmb_branches_final, inputs=nmb_branches_final) dict_ui_elem["prompt2"] = prompt2
guidance_scale.change(fn=self.change_guidance_scale, inputs=guidance_scale) dict_ui_elem["duration_compute"] = duration_compute
guidance_scale_mid_damper.change(fn=self.change_guidance_scale_mid_damper, inputs=guidance_scale_mid_damper) dict_ui_elem["duration_video"] = duration_video
dict_ui_elem["height"] = height
dict_ui_elem["width"] = width
height.change(fn=self.change_height, inputs=height) dict_ui_elem["depth_strength"] = depth_strength
width.change(fn=self.change_width, inputs=width) dict_ui_elem["branch1_influence"] = branch1_influence
negative_prompt.change(fn=self.change_negative_prompt, inputs=negative_prompt) dict_ui_elem["branch1_mixing_depth"] = branch1_mixing_depth
seed1.change(fn=self.change_seed1, inputs=seed1)
seed2.change(fn=self.change_seed2, inputs=seed2) dict_ui_elem["num_inference_steps"] = num_inference_steps
duration.change(fn=self.change_duration, inputs=duration) dict_ui_elem["guidance_scale"] = guidance_scale
branch1_influence.change(fn=self.change_branch1_influence, inputs=branch1_influence) dict_ui_elem["guidance_scale_mid_damper"] = guidance_scale_mid_damper
dict_ui_elem["seed1"] = seed1
dict_ui_elem["seed2"] = seed2
# Convert to list, as gradio doesn't seem to accept dicts
list_ui_elem = []
list_ui_keys = []
for k in dict_ui_elem.keys():
list_ui_elem.append(dict_ui_elem[k])
list_ui_keys.append(k)
self.list_ui_keys = list_ui_keys
b_newseed1.click(self.randomize_seed1, outputs=seed1) b_newseed1.click(self.randomize_seed1, outputs=seed1)
b_newseed2.click(self.randomize_seed2, outputs=seed2) b_newseed2.click(self.randomize_seed2, outputs=seed2)
# b_stackforward.click(self.stack_forward, b_compute1.click(self.compute_img1, inputs=list_ui_elem, outputs=[img1, img2, img3, img4, img5])
# inputs=[prompt2, seed2], b_compute2.click(self.compute_img2, inputs=list_ui_elem, outputs=[img2, img3, img4, img5])
# outputs=[img1, img2, img3, img4, img5, prompt1, seed1, prompt2])
b_compute_transition.click(self.compute_transition, b_compute_transition.click(self.compute_transition,
inputs=[prompt1, prompt2], inputs=list_ui_elem,
outputs=[img1, img2, img3, img4, img5, vid_transition]) outputs=[img2, img3, img4, vid_transition])
demo.launch(share=self.share, inbrowser=True, inline=False) demo.launch(share=self.share, inbrowser=True, inline=False)

View File

@ -45,7 +45,7 @@ from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentInpaintDiffusion from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentInpaintDiffusion
from stable_diffusion_holder import StableDiffusionHolder from stable_diffusion_holder import StableDiffusionHolder
import yaml import yaml
import lpips
#%% #%%
class LatentBlending(): class LatentBlending():
def __init__( def __init__(
@ -88,10 +88,13 @@ class LatentBlending():
self.prompt1 = "" self.prompt1 = ""
self.prompt2 = "" self.prompt2 = ""
self.negative_prompt = "" self.negative_prompt = ""
self.tree_latents = None
self.tree_latents = [None, None]
self.tree_fracts = None self.tree_fracts = None
self.idx_injection = []
self.tree_status = None self.tree_status = None
self.tree_final_imgs = [] self.tree_final_imgs = []
self.list_nmb_branches_prev = [] self.list_nmb_branches_prev = []
self.list_injection_idx_prev = [] self.list_injection_idx_prev = []
self.text_embedding1 = None self.text_embedding1 = None
@ -105,12 +108,15 @@ class LatentBlending():
self.list_injection_idx = None self.list_injection_idx = None
self.list_nmb_branches = None self.list_nmb_branches = None
self.branch1_influence = 0.0 self.branch1_influence = 0.0
self.branch1_fract_crossfeed = 0.65 self.branch1_mixing_depth = 0.65
self.branch1_insertion_completed = False self.branch1_insertion_completed = False
self.set_guidance_scale(guidance_scale) self.set_guidance_scale(guidance_scale)
self.init_mode() self.init_mode()
self.multi_transition_img_first = None self.multi_transition_img_first = None
self.multi_transition_img_last = None self.multi_transition_img_last = None
self.dt_per_diff = 0
self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
def init_mode(self): def init_mode(self):
@ -375,6 +381,187 @@ class LatentBlending():
self.tree_status.append(['untouched']*nmb_branches) self.tree_status.append(['untouched']*nmb_branches)
def run_transition( def run_transition(
self,
recycle_img1: Optional[bool] = False,
recycle_img2: Optional[bool] = False,
num_inference_steps: Optional[int] = 30,
depth_strength: Optional[float] = 0.3,
fixed_seeds: Optional[List[int]] = None,
):
# # FIXME: deal with these tree args later
# self.num_inference_steps = 30
# self.t_compute_max_allowed = 60
# Sanity checks first
assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before'
assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before'
# Random seeds
if fixed_seeds is not None:
if fixed_seeds == 'randomize':
fixed_seeds = list(np.random.randint(0, 1000000, 2).astype(np.int32))
else:
assert len(fixed_seeds)==2, "Supply a list with len = 2"
self.seed1 = fixed_seeds[0]
self.seed2 = fixed_seeds[1]
# Ensure correct num_inference_steps in holder
self.sdh.num_inference_steps = self.num_inference_steps
# Compute / Recycle first image
if not recycle_img1:
list_latents1 = self.compute_latents1()
else:
# FIXME: check if latents there...
list_latents1 = self.tree_latents[0]
# Compute / Recycle first image
if not recycle_img2:
list_latents2 = self.compute_latents2()
else:
# FIXME: check if latents there...
list_latents2 = self.tree_latents[-1]
# Reset the tree, injecting the edge latents1/2 we just generated/recycled
self.tree_latents = [list_latents1, list_latents2]
self.tree_fracts = [0.0, 1.0]
self.tree_final_imgs = [self.sdh.latent2image((self.tree_latents[0][-1])), self.sdh.latent2image((self.tree_latents[-1][-1]))]
self.tree_idx_injection = [0, 0]
# Set up branching scheme (dependent on provided compute time)
idx_injection_base = int(round(self.num_inference_steps*depth_strength))
list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps-1, 3)
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
t_compute = 0
while t_compute < self.t_compute_max_allowed:
list_compute_steps = self.num_inference_steps - list_idx_injection
list_compute_steps *= list_nmb_stems
t_compute = np.sum(list_compute_steps) * self.dt_per_diff
increase_done = False
for s_idx in range(len(list_nmb_stems)-1):
if list_nmb_stems[s_idx+1] / list_nmb_stems[s_idx] >= 2:
list_nmb_stems[s_idx] += 1
increase_done = True
break
if not increase_done:
list_nmb_stems[-1] += 1
# print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}")
# Run iteratively
for s_idx in tqdm(range(len(list_idx_injection))):
nmb_stems = list_nmb_stems[s_idx]
idx_injection = list_idx_injection[s_idx]
for i in range(nmb_stems):
fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection)
# print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}")
list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection)
self.insert_into_tree(fract_mixing, idx_injection, list_latents)
return self.tree_final_imgs
def get_mixing_parameters(self, idx_injection):
# get_lpips_similarity
similarities = []
for i in range(len(self.tree_final_imgs)-1):
similarities.append(self.get_lpips_similarity(self.tree_final_imgs[i], self.tree_final_imgs[i+1]))
b_closest1 = np.argmax(similarities)
b_closest2 = b_closest1+1
fract_closest1 = self.tree_fracts[b_closest1]
fract_closest2 = self.tree_fracts[b_closest2]
# Ensure that the parents are indeed older!
b_parent1 = b_closest1
while True:
if self.tree_idx_injection[b_parent1] < idx_injection:
break
else:
b_parent1 -= 1
b_parent2 = b_closest2
while True:
if self.tree_idx_injection[b_parent2] < idx_injection:
break
else:
b_parent2 += 1
# print(f"\n\nb_closest: {b_closest1} {b_closest2} fract_closest1 {fract_closest1} fract_closest2 {fract_closest2}")
# print(f"b_parent: {b_parent1} {b_parent2}")
# print(f"similarities {similarities}")
# print(f"idx_injection {idx_injection} tree_idx_injection {self.tree_idx_injection}")
fract_mixing = (fract_closest1 + fract_closest2) /2
return fract_mixing, b_parent1, b_parent2
def insert_into_tree(self, fract_mixing, idx_injection, list_latents):
# FIXME
b_parent1, b_parent2 = get_closest_idx(fract_mixing, self.tree_fracts)
self.tree_latents.insert(b_parent1+1, list_latents)
self.tree_final_imgs.insert(b_parent1+1, self.sdh.latent2image(list_latents[-1]))
self.tree_fracts.insert(b_parent1+1, fract_mixing)
self.tree_idx_injection.insert(b_parent1+1, idx_injection)
def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection):
# FIXME
list_conditionings = self.get_mixed_conditioning(fract_mixing)
fract_mixing_parental = (fract_mixing - self.tree_fracts[b_parent1]) / (self.tree_fracts[b_parent2] - self.tree_fracts[b_parent1])
idx_reversed = self.num_inference_steps - idx_injection
latents_for_injection = interpolate_spherical(
self.tree_latents[b_parent1][-idx_reversed-1],
self.tree_latents[b_parent2][-idx_reversed-1],
fract_mixing_parental)
list_latents = self.run_diffusion(list_conditionings, latents_for_injection=latents_for_injection, idx_start=idx_injection)
return list_latents
def compute_latents1(self, return_image=False):
print("starting compute_latents1")
list_conditionings = [self.text_embedding1]
t0 = time.time()
list_latents1 = self.run_diffusion(list_conditionings, seed_source=self.seed1)
t1 = time.time()
self.dt_per_diff = (t1-t0) / self.num_inference_steps
self.tree_latents[0] = list_latents1
if return_image:
return self.sdh.latent2image(list_latents1[-1])
else:
return list_latents1
def compute_latents2(self, return_image=False):
print("starting compute_latents2")
list_conditionings = [self.text_embedding2
]
# Influence from branch1
if self.branch1_influence > 0.0:
self.branch1_influence = np.clip(self.branch1_influence, 0, 1)
self.branch1_mixing_depth = np.clip(self.branch1_mixing_depth, 0, 1)
idx_crossfeed = int(round(self.num_inference_steps*self.branch1_mixing_depth))
list_latents2 = self.run_diffusion(
list_conditionings,
idx_start=idx_crossfeed,
latents_for_injection=self.tree_latents[0],
seed_source=self.seed2,
seed_mixing_target=self.seed1,
mixing_coeff=self.branch1_influence)
else:
list_latents2 = self.run_diffusion(list_conditionings)
self.tree_latents[-1] = list_latents2
if return_image:
return self.sdh.latent2image(list_latents2[-1])
else:
return list_latents2
def run_transition_OLD(
self, self,
recycle_img1: Optional[bool] = False, recycle_img1: Optional[bool] = False,
recycle_img2: Optional[bool] = False, recycle_img2: Optional[bool] = False,
@ -423,9 +610,9 @@ class LatentBlending():
if self.branch1_influence > 0.0 and not self.branch1_insertion_completed: if self.branch1_influence > 0.0 and not self.branch1_insertion_completed:
assert self.list_nmb_branches[0]==2, 'branch1 influnce currently requires the self.list_nmb_branches[0] = 0' assert self.list_nmb_branches[0]==2, 'branch1 influnce currently requires the self.list_nmb_branches[0] = 0'
self.branch1_influence = np.clip(self.branch1_influence, 0, 1) self.branch1_influence = np.clip(self.branch1_influence, 0, 1)
self.branch1_fract_crossfeed = np.clip(self.branch1_fract_crossfeed, 0, 1) self.branch1_mixing_depth = np.clip(self.branch1_mixing_depth, 0, 1)
self.list_nmb_branches.insert(1, 2) self.list_nmb_branches.insert(1, 2)
idx_crossfeed = int(round(self.list_injection_idx[1]*self.branch1_fract_crossfeed)) idx_crossfeed = int(round(self.list_injection_idx[1]*self.branch1_mixing_depth))
self.list_injection_idx_ext.insert(1, idx_crossfeed) self.list_injection_idx_ext.insert(1, idx_crossfeed)
self.tree_fracts.insert(1, self.tree_fracts[0]) self.tree_fracts.insert(1, self.tree_fracts[0])
self.tree_status.insert(1, self.tree_status[0]) self.tree_status.insert(1, self.tree_status[0])
@ -606,6 +793,9 @@ class LatentBlending():
latents_for_injection: torch.FloatTensor = None, latents_for_injection: torch.FloatTensor = None,
idx_start: int = -1, idx_start: int = -1,
idx_stop: int = -1, idx_stop: int = -1,
seed_source: int = -1,
seed_mixing_target: int = -1,
mixing_coeff: float = 0.0,
return_image: Optional[bool] = False return_image: Optional[bool] = False
): ):
r""" r"""
@ -620,6 +810,7 @@ class LatentBlending():
Index of the diffusion process start and where the latents_for_injection are injected Index of the diffusion process start and where the latents_for_injection are injected
idx_stop: int idx_stop: int
Index of the diffusion process end. Index of the diffusion process end.
FIXME ARGS
return_image: Optional[bool] return_image: Optional[bool]
Optionally return image directly Optionally return image directly
""" """
@ -630,14 +821,23 @@ class LatentBlending():
if self.mode == 'standard': if self.mode == 'standard':
text_embeddings = list_conditionings[0] text_embeddings = list_conditionings[0]
return self.sdh.run_diffusion_standard(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image) return self.sdh.run_diffusion_standard(
text_embeddings,
latents_for_injection=latents_for_injection,
idx_start=idx_start,
idx_stop=idx_stop,
seed_source=seed_source,
seed_mixing_target=seed_mixing_target,
mixing_coeff=mixing_coeff,
return_image=return_image,
)
elif self.mode == 'inpaint': elif self.mode == 'inpaint':
text_embeddings = list_conditionings[0] text_embeddings = list_conditionings[0]
assert self.sdh.image_source is not None, "image_source is None. Please run init_inpainting first." assert self.sdh.image_source is not None, "image_source is None. Please run init_inpainting first."
assert self.sdh.mask_image is not None, "image_source is None. Please run init_inpainting first." assert self.sdh.mask_image is not None, "image_source is None. Please run init_inpainting first."
return self.sdh.run_diffusion_inpaint(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image) return self.sdh.run_diffusion_inpaint(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
# FIXME LONG LINE
elif self.mode == 'upscale': elif self.mode == 'upscale':
cond = list_conditionings[0] cond = list_conditionings[0]
uc_full = list_conditionings[1] uc_full = list_conditionings[1]
@ -882,8 +1082,6 @@ class LatentBlending():
self.tree_latents[t_block][-1] = list_latents[self.list_injection_idx_ext[t_block]:self.list_injection_idx_ext[t_block+1]] self.tree_latents[t_block][-1] = list_latents[self.list_injection_idx_ext[t_block]:self.list_injection_idx_ext[t_block+1]]
def swap_forward(self): def swap_forward(self):
r""" r"""
Moves over keyframe two -> keyframe one. Useful for making a sequence of transitions Moves over keyframe two -> keyframe one. Useful for making a sequence of transitions
@ -901,6 +1099,21 @@ class LatentBlending():
self.tree_final_imgs = [] self.tree_final_imgs = []
def get_lpips_similarity(self, imgA, imgB):
# FIXME
tensorA = torch.from_numpy(imgA).float().cuda(self.device)
tensorA = 2*tensorA/255.0 - 1
tensorA = tensorA.permute([2,0,1]).unsqueeze(0)
tensorB = torch.from_numpy(imgB).float().cuda(self.device)
tensorB = 2*tensorB/255.0 - 1
tensorB = tensorB.permute([2,0,1]).unsqueeze(0)
lploss = self.lpips(tensorA, tensorB)
lploss = float(lploss[0][0][0][0])
return lploss
# Auxiliary functions # Auxiliary functions
def get_closest_idx( def get_closest_idx(
fract_mixing: float, fract_mixing: float,
@ -1169,31 +1382,79 @@ if __name__ == "__main__":
#%% First let us spawn a stable diffusion holder #%% First let us spawn a stable diffusion holder
device = "cuda" device = "cuda"
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt" fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
fp_config = 'configs/v2-inference.yaml'
sdh = StableDiffusionHolder(fp_ckpt, fp_config, device) sdh = StableDiffusionHolder(fp_ckpt)
xxx xxx
#%% Next let's set up all parameters #%% Next let's set up all parameters
quality = 'medium' depth_strength = 0.3 # Specifies how deep (in terms of diffusion iterations the first branching happens)
depth_strength = 0.65 # Specifies how deep (in terms of diffusion iterations the first branching happens) fixed_seeds = [697164, 430214]
fixed_seeds = [69731932, 504430820]
prompt1 = "photo of a beautiful cherry forest covered in white flowers, ambient light, very detailed, magic" prompt1 = "photo of a desert and a sky"
prompt2 = "photo of an golden statue with a funny hat, surrounded by ferns and vines, grainy analog photograph, mystical ambience, incredible detail" prompt2 = "photo of a tree with a lake"
duration_transition = 12 # In seconds duration_transition = 12 # In seconds
fps = 30 fps = 30
# Spawn latent blending # Spawn latent blending
self = LatentBlending(sdh) self = LatentBlending(sdh)
self.branch1_influence = 0.8
self.load_branching_profile(quality=quality, depth_strength=0.3)
self.set_prompt1(prompt1) self.set_prompt1(prompt1)
self.set_prompt2(prompt2) self.set_prompt2(prompt2)
# Run latent blending # Run latent blending
imgs_transition = self.run_transition(fixed_seeds=fixed_seeds) self.branch1_influence = 0.3
self.branch1_mixing_depth = 0.4
self.run_transition(depth_strength=depth_strength, fixed_seeds=fixed_seeds)
#%%
self.branch1_influence = 0.3
self.branch1_mixing_depth = 0.5
img2 = self.compute_latents2(return_image=True)
Image.fromarray(img2)
#%%
idx_injection = 15
fract_mixing = 0.5
list_conditionings = self.get_mixed_conditioning(fract_mixing)
latents_for_injection = interpolate_spherical(self.tree_latents[0][idx_injection], self.tree_latents[-1][idx_injection], fract_mixing)
list_latents = self.run_diffusion(list_conditionings, latents_for_injection=latents_for_injection, idx_start=idx_injection)
img_mix = self.sdh.latent2image((list_latents[-1]))
Image.fromarray(np.concatenate((img1,img_mix,img2), axis=1)).resize((800,800//3))
#%% scheme
# init scheme
list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps-1, 2)
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
#%%
list_compute_steps = self.num_inference_steps - list_idx_injection
list_compute_steps *= list_nmb_stems
t_compute = np.sum(list_compute_steps) * self.dt_per_diff
increase_done = False
for s_idx in range(len(list_nmb_stems)-1):
if list_nmb_stems[s_idx+1] / list_nmb_stems[s_idx] >= 3:
list_nmb_stems[s_idx] += 1
increase_done = True
break
if not increase_done:
list_nmb_stems[-1] += 1
print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}")
#%%
imgs_transition = self.tree_final_imgs
# Let's get more cheap frames via linear interpolation (duration_transition*fps frames)
imgs_transition_ext = add_frames_linear_interp(imgs_transition, 15, fps)
# Save as MP4
fp_movie = "test.mp4"
if os.path.isfile(fp_movie):
os.remove(fp_movie)
ms = MovieSaver(fp_movie, fps=fps, shape_hw=[sdh.height, sdh.width])
for img in tqdm(imgs_transition_ext):
ms.write_frame(img)
ms.finalize()

View File

@ -42,7 +42,6 @@ from contextlib import nullcontext
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from einops import repeat, rearrange from einops import repeat, rearrange
#%% #%%
@ -279,10 +278,13 @@ class StableDiffusionHolder:
def run_diffusion_standard( def run_diffusion_standard(
self, self,
text_embeddings: torch.FloatTensor, text_embeddings: torch.FloatTensor,
latents_for_injection: torch.FloatTensor = None, latents_for_injection = None,
idx_start: int = -1, idx_start: int = -1,
idx_stop: int = -1, idx_stop: int = -1,
return_image: Optional[bool] = False seed_source: int = -1,
seed_mixing_target: int = -1,
mixing_coeff: float = 0.0,
return_image: Optional[bool] = False,
): ):
r""" r"""
Wrapper function for run_diffusion_standard and run_diffusion_inpaint. Wrapper function for run_diffusion_standard and run_diffusion_inpaint.
@ -291,12 +293,18 @@ class StableDiffusionHolder:
Args: Args:
text_embeddings: torch.FloatTensor text_embeddings: torch.FloatTensor
Text embeddings used for diffusion Text embeddings used for diffusion
latents_for_injection: torch.FloatTensor latents_for_injection: torch.FloatTensor or list
Latents that are used for injection Latents that are used for injection
idx_start: int idx_start: int
Index of the diffusion process start and where the latents_for_injection are injected Index of the diffusion process start and where the latents_for_injection are injected
idx_stop: int idx_stop: int
Index of the diffusion process end. Index of the diffusion process end.
mixing_coeff:
# FIXME
seed_source:
# FIXME
seed_mixing:
# FIXME
return_image: Optional[bool] return_image: Optional[bool]
Optionally return image directly Optionally return image directly
""" """
@ -304,12 +312,19 @@ class StableDiffusionHolder:
if latents_for_injection is None: if latents_for_injection is None:
do_inject_latents = False do_inject_latents = False
do_mix_latents = False
else:
if mixing_coeff > 0.0:
do_inject_latents = False
do_mix_latents = True
assert seed_mixing_target != -1, "Set to correct seed for mixing"
else: else:
do_inject_latents = True do_inject_latents = True
do_mix_latents = False
precision_scope = autocast if self.precision == "autocast" else nullcontext precision_scope = autocast if self.precision == "autocast" else nullcontext
generator = torch.Generator(device=self.device).manual_seed(int(self.seed)) generator = torch.Generator(device=self.device).manual_seed(int(seed_source))
with precision_scope("cuda"): with precision_scope("cuda"):
with self.model.ema_scope(): with self.model.ema_scope():
@ -340,6 +355,16 @@ class StableDiffusionHolder:
continue continue
elif i == idx_start: elif i == idx_start:
latents = latents_for_injection.clone() latents = latents_for_injection.clone()
if do_mix_latents:
if i == 0:
generator = torch.Generator(device=self.device).manual_seed(int(seed_mixing_target))
latents_mixtarget = torch.randn(size, generator=generator, device=self.device)
if i < idx_start:
latents_mixtarget = latents_for_injection[i-1].clone()
latents = interpolate_spherical(latents, latents_mixtarget, mixing_coeff)
if i == idx_start:
do_mix_latents = False
if i == idx_stop: if i == idx_stop:
return list_latents_out return list_latents_out
@ -576,6 +601,50 @@ class StableDiffusionHolder:
image = x_sample.astype(np.uint8) image = x_sample.astype(np.uint8)
return image return image
@torch.no_grad()
def interpolate_spherical(p0, p1, fract_mixing: float):
r"""
Helper function to correctly mix two random variables using spherical interpolation.
See https://en.wikipedia.org/wiki/Slerp
The function will always cast up to float64 for sake of extra 4.
Args:
p0:
First tensor for interpolation
p1:
Second tensor for interpolation
fract_mixing: float
Mixing coefficient of interval [0, 1].
0 will return in p0
1 will return in p1
0.x will return a mix between both preserving angular velocity.
"""
if p0.dtype == torch.float16:
recast_to = 'fp16'
else:
recast_to = 'fp32'
p0 = p0.double()
p1 = p1.double()
norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
epsilon = 1e-7
dot = torch.sum(p0 * p1) / norm
dot = dot.clamp(-1+epsilon, 1-epsilon)
theta_0 = torch.arccos(dot)
sin_theta_0 = torch.sin(theta_0)
theta_t = theta_0 * fract_mixing
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
s1 = torch.sin(theta_t) / sin_theta_0
interp = p0*s0 + p1*s1
if recast_to == 'fp16':
interp = interp.half()
elif recast_to == 'fp32':
interp = interp.float()
return interp
if __name__ == "__main__": if __name__ == "__main__":