new transition engine and crossfeeding
This commit is contained in:
parent
5e979818b2
commit
0371868603
224
gradio_ui.py
224
gradio_ui.py
|
@ -33,6 +33,13 @@ import gradio as gr
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
TODOS:
|
||||||
|
- clean parameter handling
|
||||||
|
- three buttons: diffuse A, diffuse B, make transition
|
||||||
|
- collapse for easy mode
|
||||||
|
- transition quality in terms of render time
|
||||||
|
"""
|
||||||
|
|
||||||
#%%
|
#%%
|
||||||
|
|
||||||
|
@ -45,7 +52,7 @@ class BlendingFrontend():
|
||||||
self.lb = LatentBlending(sdh)
|
self.lb = LatentBlending(sdh)
|
||||||
|
|
||||||
self.share = True
|
self.share = True
|
||||||
self.num_inference_steps = 20
|
self.num_inference_steps = 30
|
||||||
self.depth_strength = 0.25
|
self.depth_strength = 0.25
|
||||||
self.seed1 = 42
|
self.seed1 = 42
|
||||||
self.seed2 = 420
|
self.seed2 = 420
|
||||||
|
@ -58,11 +65,13 @@ class BlendingFrontend():
|
||||||
self.list_settings = []
|
self.list_settings = []
|
||||||
self.state_current = {}
|
self.state_current = {}
|
||||||
self.showing_current = True
|
self.showing_current = True
|
||||||
self.branch1_influence = 0.02
|
self.branch1_influence = 0.1
|
||||||
|
self.branch1_mixing_depth = 0.3
|
||||||
self.nmb_branches_final = 9
|
self.nmb_branches_final = 9
|
||||||
self.nmb_imgs_show = 5 # don't change
|
self.nmb_imgs_show = 5 # don't change
|
||||||
self.fps = 30
|
self.fps = 30
|
||||||
self.duration = 10
|
self.duration_video = 15
|
||||||
|
self.t_compute_max_allowed = 15
|
||||||
self.dict_multi_trans = {}
|
self.dict_multi_trans = {}
|
||||||
self.dict_multi_trans_include = {}
|
self.dict_multi_trans_include = {}
|
||||||
self.multi_trans_currently_shown = []
|
self.multi_trans_currently_shown = []
|
||||||
|
@ -79,114 +88,87 @@ class BlendingFrontend():
|
||||||
self.width = 768
|
self.width = 768
|
||||||
|
|
||||||
# make dummy image
|
# make dummy image
|
||||||
|
def save_empty_image(self):
|
||||||
self.fp_img_empty = 'empty.jpg'
|
self.fp_img_empty = 'empty.jpg'
|
||||||
Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5)
|
Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5)
|
||||||
|
|
||||||
def change_depth_strength(self, value):
|
|
||||||
self.depth_strength = value
|
|
||||||
print(f"changed depth_strength to {value}")
|
|
||||||
|
|
||||||
def change_num_inference_steps(self, value):
|
|
||||||
self.num_inference_steps = value
|
|
||||||
print(f"changed num_inference_steps to {value}")
|
|
||||||
|
|
||||||
def change_guidance_scale(self, value):
|
|
||||||
self.guidance_scale = value
|
|
||||||
self.lb.set_guidance_scale(value)
|
|
||||||
print(f"changed guidance_scale to {value}")
|
|
||||||
|
|
||||||
def change_guidance_scale_mid_damper(self, value):
|
|
||||||
self.guidance_scale_mid_damper = value
|
|
||||||
print(f"changed guidance_scale_mid_damper to {value}")
|
|
||||||
|
|
||||||
def change_mid_compression_scaler(self, value):
|
|
||||||
self.mid_compression_scaler = value
|
|
||||||
print(f"changed mid_compression_scaler to {value}")
|
|
||||||
|
|
||||||
def change_branch1_influence(self, value):
|
|
||||||
self.branch1_influence = value
|
|
||||||
print(f"changed branch1_influence to {value}")
|
|
||||||
|
|
||||||
def change_height(self, value):
|
|
||||||
self.height = value
|
|
||||||
print(f"changed height to {value}")
|
|
||||||
|
|
||||||
def change_width(self, value):
|
|
||||||
self.width = value
|
|
||||||
print(f"changed width to {value}")
|
|
||||||
|
|
||||||
def change_nmb_branches_final(self, value):
|
|
||||||
self.nmb_branches_final = value
|
|
||||||
print(f"changed nmb_branches_final to {value}")
|
|
||||||
|
|
||||||
def change_duration(self, value):
|
|
||||||
self.duration = value
|
|
||||||
print(f"changed duration to {value}")
|
|
||||||
|
|
||||||
def change_fps(self, value):
|
|
||||||
self.fps = value
|
|
||||||
print(f"changed fps to {value}")
|
|
||||||
|
|
||||||
def change_negative_prompt(self, value):
|
|
||||||
self.negative_prompt = value
|
|
||||||
|
|
||||||
def change_seed1(self, value):
|
|
||||||
self.seed1 = int(value)
|
|
||||||
|
|
||||||
def change_seed2(self, value):
|
|
||||||
self.seed2 = int(value)
|
|
||||||
|
|
||||||
def randomize_seed1(self):
|
def randomize_seed1(self):
|
||||||
seed = np.random.randint(0, 10000000)
|
seed = np.random.randint(0, 10000000)
|
||||||
self.change_seed1(seed)
|
self.seed1 = int(seed)
|
||||||
print(f"randomize_seed1: new seed = {self.seed1}")
|
print(f"randomize_seed1: new seed = {self.seed1}")
|
||||||
return seed
|
return seed
|
||||||
|
|
||||||
def randomize_seed2(self):
|
def randomize_seed2(self):
|
||||||
seed = np.random.randint(0, 10000000)
|
seed = np.random.randint(0, 10000000)
|
||||||
self.change_seed2(seed)
|
self.seed2 = int(seed)
|
||||||
print(f"randomize_seed2: new seed = {self.seed2}")
|
print(f"randomize_seed2: new seed = {self.seed2}")
|
||||||
return seed
|
return seed
|
||||||
|
|
||||||
|
|
||||||
def compute_transition(self, prompt1, prompt2):
|
def setup_lb(self, list_ui_elem):
|
||||||
self.prompt1 = prompt1
|
# Collect latent blending variables
|
||||||
self.prompt2 = prompt2
|
|
||||||
print("STARTING DIFFUSION!")
|
|
||||||
self.state_current = self.get_state_dict()
|
self.state_current = self.get_state_dict()
|
||||||
|
self.lb.set_width(list_ui_elem[list_ui_keys.index('width')])
|
||||||
|
self.lb.set_height(list_ui_elem[list_ui_keys.index('height')])
|
||||||
|
self.lb.set_prompt1(list_ui_elem[list_ui_keys.index('prompt1')])
|
||||||
|
self.lb.set_prompt2(list_ui_elem[list_ui_keys.index('prompt2')])
|
||||||
|
self.lb.set_negative_prompt(list_ui_elem[list_ui_keys.index('negative_prompt')])
|
||||||
|
self.lb.guidance_scale = list_ui_elem[list_ui_keys.index('guidance_scale')]
|
||||||
|
self.lb.guidance_scale_mid_damper = list_ui_elem[list_ui_keys.index('guidance_scale_mid_damper')]
|
||||||
|
self.lb.branch1_influence = list_ui_elem[list_ui_keys.index('branch1_influence')]
|
||||||
|
self.lb.branch1_mixing_depth = list_ui_elem[list_ui_keys.index('branch1_mixing_depth')]
|
||||||
|
self.lb.t_compute_max_allowed = list_ui_elem[list_ui_keys.index('duration_compute')]
|
||||||
|
self.lb.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
|
||||||
|
self.lb.sdh.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
|
||||||
|
self.duration_video = list_ui_elem[list_ui_keys.index('duration_video')]
|
||||||
|
self.lb.seed1 = list_ui_elem[list_ui_keys.index('seed1')]
|
||||||
|
self.lb.seed2 = list_ui_elem[list_ui_keys.index('seed2')]
|
||||||
|
|
||||||
|
|
||||||
|
def compute_img1(self, *args):
|
||||||
|
list_ui_elem = args
|
||||||
|
self.setup_lb(list_ui_elem)
|
||||||
|
fp_img1 = f"img1_{get_time('second')}.jpg"
|
||||||
|
img1 = Image.fromarray(self.lb.compute_latents1(return_image=True))
|
||||||
|
img1.save(fp_img1)
|
||||||
|
self.save_empty_image()
|
||||||
|
return [fp_img1, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty]
|
||||||
|
|
||||||
|
def compute_img2(self, *args):
|
||||||
|
list_ui_elem = args
|
||||||
|
self.setup_lb(list_ui_elem)
|
||||||
|
fp_img2 = f"img2_{get_time('second')}.jpg"
|
||||||
|
img2 = Image.fromarray(self.lb.compute_latents2(return_image=True))
|
||||||
|
img2.save(fp_img2)
|
||||||
|
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2]
|
||||||
|
|
||||||
|
def compute_transition(self, *args):
|
||||||
|
list_ui_elem = args
|
||||||
|
self.setup_lb(list_ui_elem)
|
||||||
|
print("STARTING DIFFUSION!")
|
||||||
if self.use_debug:
|
if self.use_debug:
|
||||||
list_imgs = [(255*np.random.rand(self.height,self.width,3)).astype(np.uint8) for l in range(5)]
|
list_imgs = [(255*np.random.rand(self.height,self.width,3)).astype(np.uint8) for l in range(5)]
|
||||||
list_imgs = [Image.fromarray(l) for l in list_imgs]
|
list_imgs = [Image.fromarray(l) for l in list_imgs]
|
||||||
print("DONE! SENDING BACK RESULTS")
|
print("DONE! SENDING BACK RESULTS")
|
||||||
return list_imgs
|
return list_imgs
|
||||||
|
|
||||||
# Collect latent blending variables
|
|
||||||
self.lb.set_width(self.width)
|
|
||||||
self.lb.set_height(self.height)
|
|
||||||
self.lb.autosetup_branching(
|
|
||||||
depth_strength = self.depth_strength,
|
|
||||||
num_inference_steps = self.num_inference_steps,
|
|
||||||
nmb_branches_final = self.nmb_branches_final,
|
|
||||||
nmb_mindist = 3)
|
|
||||||
self.lb.set_prompt1(self.prompt1)
|
|
||||||
self.lb.set_prompt2(self.prompt2)
|
|
||||||
self.lb.set_negative_prompt(self.negative_prompt)
|
|
||||||
|
|
||||||
self.lb.guidance_scale = self.guidance_scale
|
|
||||||
self.lb.guidance_scale_mid_damper = self.guidance_scale_mid_damper
|
|
||||||
self.lb.mid_compression_scaler = self.mid_compression_scaler
|
|
||||||
self.lb.branch1_influence = self.branch1_influence
|
|
||||||
fixed_seeds = [self.seed1, self.seed2]
|
fixed_seeds = [self.seed1, self.seed2]
|
||||||
|
|
||||||
# Run Latent Blending
|
# Run Latent Blending
|
||||||
imgs_transition = self.lb.run_transition(fixed_seeds=fixed_seeds)
|
imgs_transition = self.lb.run_transition(
|
||||||
|
recycle_img1=True,
|
||||||
|
recycle_img2=True,
|
||||||
|
num_inference_steps=self.num_inference_steps,
|
||||||
|
depth_strength=self.depth_strength,
|
||||||
|
fixed_seeds=fixed_seeds
|
||||||
|
)
|
||||||
print(f"Latent Blending pass finished. Resulted in {len(imgs_transition)} images")
|
print(f"Latent Blending pass finished. Resulted in {len(imgs_transition)} images")
|
||||||
|
|
||||||
# Subselect the preview images (hard fixed to self.nmb_imgs_show=5)
|
# Subselect three preview images
|
||||||
assert np.mod((self.nmb_branches_final-self.nmb_imgs_show)/4, 1)==0, 'self.nmb_branches_final illegal value!'
|
idx_img_prev = np.round(np.linspace(0, len(imgs_transition)-1, 5)[1:-1]).astype(np.int32)
|
||||||
idx_list = np.linspace(0, self.nmb_branches_final-1, self.nmb_imgs_show).astype(np.int32)
|
|
||||||
list_imgs_preview = []
|
list_imgs_preview = []
|
||||||
for j in idx_list:
|
for j in idx_img_prev:
|
||||||
list_imgs_preview.append(Image.fromarray(imgs_transition[j]))
|
list_imgs_preview.append(Image.fromarray(imgs_transition[j]))
|
||||||
|
|
||||||
# Save the preview imgs as jpgs on disk so we are not sending umcompressed data around
|
# Save the preview imgs as jpgs on disk so we are not sending umcompressed data around
|
||||||
|
@ -198,7 +180,7 @@ class BlendingFrontend():
|
||||||
self.list_fp_imgs_current.append(fp_img)
|
self.list_fp_imgs_current.append(fp_img)
|
||||||
|
|
||||||
# Insert cheap frames for the movie
|
# Insert cheap frames for the movie
|
||||||
imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration, self.fps)
|
imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration_video, self.fps)
|
||||||
|
|
||||||
# Save as movie
|
# Save as movie
|
||||||
fp_movie = self.get_fp_movie(self.current_timestamp)
|
fp_movie = self.get_fp_movie(self.current_timestamp)
|
||||||
|
@ -330,35 +312,44 @@ if __name__ == "__main__":
|
||||||
sdh = StableDiffusionHolder(fp_ckpt)
|
sdh = StableDiffusionHolder(fp_ckpt)
|
||||||
|
|
||||||
self = BlendingFrontend(sdh) # Yes this is possible in python and yes it is an awesome trick
|
self = BlendingFrontend(sdh) # Yes this is possible in python and yes it is an awesome trick
|
||||||
|
# self = BlendingFrontend(None) # Yes this is possible in python and yes it is an awesome trick
|
||||||
|
|
||||||
|
dict_ui_elem = {}
|
||||||
|
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
prompt1 = gr.Textbox(label="prompt 1")
|
prompt1 = gr.Textbox(label="prompt 1")
|
||||||
prompt2 = gr.Textbox(label="prompt 2")
|
|
||||||
negative_prompt = gr.Textbox(label="negative prompt")
|
negative_prompt = gr.Textbox(label="negative prompt")
|
||||||
|
prompt2 = gr.Textbox(label="prompt 2")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
nmb_branches_final = gr.Slider(5, 125, self.nmb_branches_final, step=4, label='nmb trans images', interactive=True)
|
duration_compute = gr.Slider(10, 40, self.duration_video, step=1, label='compute budget for transition (seconds)', interactive=True)
|
||||||
|
duration_video = gr.Slider(0.1, 30, self.duration_video, step=0.1, label='result video duration (seconds)', interactive=True)
|
||||||
height = gr.Slider(256, 2048, self.height, step=128, label='height', interactive=True)
|
height = gr.Slider(256, 2048, self.height, step=128, label='height', interactive=True)
|
||||||
width = gr.Slider(256, 2048, self.width, step=128, label='width', interactive=True)
|
width = gr.Slider(256, 2048, self.width, step=128, label='width', interactive=True)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Accordion("Advanced Settings (click to expand)", open=False):
|
||||||
num_inference_steps = gr.Slider(5, 100, self.num_inference_steps, step=1, label='num_inference_steps', interactive=True)
|
|
||||||
branch1_influence = gr.Slider(0.0, 1.0, self.branch1_influence, step=0.01, label='branch1_influence', interactive=True)
|
|
||||||
guidance_scale = gr.Slider(1, 25, self.guidance_scale, step=0.1, label='guidance_scale', interactive=True)
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
depth_strength = gr.Slider(0.01, 0.99, self.depth_strength, step=0.01, label='depth_strength', interactive=True)
|
depth_strength = gr.Slider(0.01, 0.99, self.depth_strength, step=0.01, label='depth_strength', interactive=True)
|
||||||
duration = gr.Slider(0.1, 30, self.duration, step=0.1, label='video duration', interactive=True)
|
branch1_influence = gr.Slider(0.0, 1.0, self.branch1_influence, step=0.01, label='branch1_influence', interactive=True)
|
||||||
|
branch1_mixing_depth = gr.Slider(0.0, 1.0, self.branch1_mixing_depth, step=0.01, label='branch1_mixing_depth', interactive=True)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
num_inference_steps = gr.Slider(5, 100, self.num_inference_steps, step=1, label='num_inference_steps', interactive=True)
|
||||||
|
guidance_scale = gr.Slider(1, 25, self.guidance_scale, step=0.1, label='guidance_scale', interactive=True)
|
||||||
guidance_scale_mid_damper = gr.Slider(0.01, 2.0, self.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
|
guidance_scale_mid_damper = gr.Slider(0.01, 2.0, self.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
seed1 = gr.Number(42, label="seed 1", interactive=True)
|
seed1 = gr.Number(420, label="seed 1", interactive=True)
|
||||||
b_newseed1 = gr.Button("randomize seed 1", variant='secondary')
|
b_newseed1 = gr.Button("randomize seed 1", variant='secondary')
|
||||||
seed2 = gr.Number(420, label="seed 2", interactive=True)
|
seed2 = gr.Number(420, label="seed 2", interactive=True)
|
||||||
b_newseed2 = gr.Button("randomize seed 2", variant='secondary')
|
b_newseed2 = gr.Button("randomize seed 2", variant='secondary')
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
b_compute1 = gr.Button('compute first image', variant='primary')
|
||||||
b_compute_transition = gr.Button('compute transition', variant='primary')
|
b_compute_transition = gr.Button('compute transition', variant='primary')
|
||||||
|
b_compute2 = gr.Button('compute last image', variant='primary')
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
img1 = gr.Image(label="1/5")
|
img1 = gr.Image(label="1/5")
|
||||||
|
@ -370,31 +361,40 @@ if __name__ == "__main__":
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
vid_transition = gr.Video()
|
vid_transition = gr.Video()
|
||||||
|
|
||||||
# Bind the on-change methods
|
# Collect all UI elemts in list to easily pass as inputs
|
||||||
depth_strength.change(fn=self.change_depth_strength, inputs=depth_strength)
|
dict_ui_elem["prompt1"] = prompt1
|
||||||
num_inference_steps.change(fn=self.change_num_inference_steps, inputs=num_inference_steps)
|
dict_ui_elem["negative_prompt"] = negative_prompt
|
||||||
nmb_branches_final.change(fn=self.change_nmb_branches_final, inputs=nmb_branches_final)
|
dict_ui_elem["prompt2"] = prompt2
|
||||||
|
|
||||||
guidance_scale.change(fn=self.change_guidance_scale, inputs=guidance_scale)
|
dict_ui_elem["duration_compute"] = duration_compute
|
||||||
guidance_scale_mid_damper.change(fn=self.change_guidance_scale_mid_damper, inputs=guidance_scale_mid_damper)
|
dict_ui_elem["duration_video"] = duration_video
|
||||||
|
dict_ui_elem["height"] = height
|
||||||
|
dict_ui_elem["width"] = width
|
||||||
|
|
||||||
height.change(fn=self.change_height, inputs=height)
|
dict_ui_elem["depth_strength"] = depth_strength
|
||||||
width.change(fn=self.change_width, inputs=width)
|
dict_ui_elem["branch1_influence"] = branch1_influence
|
||||||
negative_prompt.change(fn=self.change_negative_prompt, inputs=negative_prompt)
|
dict_ui_elem["branch1_mixing_depth"] = branch1_mixing_depth
|
||||||
seed1.change(fn=self.change_seed1, inputs=seed1)
|
|
||||||
seed2.change(fn=self.change_seed2, inputs=seed2)
|
dict_ui_elem["num_inference_steps"] = num_inference_steps
|
||||||
duration.change(fn=self.change_duration, inputs=duration)
|
dict_ui_elem["guidance_scale"] = guidance_scale
|
||||||
branch1_influence.change(fn=self.change_branch1_influence, inputs=branch1_influence)
|
dict_ui_elem["guidance_scale_mid_damper"] = guidance_scale_mid_damper
|
||||||
|
dict_ui_elem["seed1"] = seed1
|
||||||
|
dict_ui_elem["seed2"] = seed2
|
||||||
|
|
||||||
|
# Convert to list, as gradio doesn't seem to accept dicts
|
||||||
|
list_ui_elem = []
|
||||||
|
list_ui_keys = []
|
||||||
|
for k in dict_ui_elem.keys():
|
||||||
|
list_ui_elem.append(dict_ui_elem[k])
|
||||||
|
list_ui_keys.append(k)
|
||||||
|
self.list_ui_keys = list_ui_keys
|
||||||
|
|
||||||
b_newseed1.click(self.randomize_seed1, outputs=seed1)
|
b_newseed1.click(self.randomize_seed1, outputs=seed1)
|
||||||
b_newseed2.click(self.randomize_seed2, outputs=seed2)
|
b_newseed2.click(self.randomize_seed2, outputs=seed2)
|
||||||
# b_stackforward.click(self.stack_forward,
|
b_compute1.click(self.compute_img1, inputs=list_ui_elem, outputs=[img1, img2, img3, img4, img5])
|
||||||
# inputs=[prompt2, seed2],
|
b_compute2.click(self.compute_img2, inputs=list_ui_elem, outputs=[img2, img3, img4, img5])
|
||||||
# outputs=[img1, img2, img3, img4, img5, prompt1, seed1, prompt2])
|
|
||||||
b_compute_transition.click(self.compute_transition,
|
b_compute_transition.click(self.compute_transition,
|
||||||
inputs=[prompt1, prompt2],
|
inputs=list_ui_elem,
|
||||||
outputs=[img1, img2, img3, img4, img5, vid_transition])
|
outputs=[img2, img3, img4, vid_transition])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
demo.launch(share=self.share, inbrowser=True, inline=False)
|
demo.launch(share=self.share, inbrowser=True, inline=False)
|
||||||
|
|
|
@ -45,7 +45,7 @@ from ldm.util import instantiate_from_config
|
||||||
from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentInpaintDiffusion
|
from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentInpaintDiffusion
|
||||||
from stable_diffusion_holder import StableDiffusionHolder
|
from stable_diffusion_holder import StableDiffusionHolder
|
||||||
import yaml
|
import yaml
|
||||||
|
import lpips
|
||||||
#%%
|
#%%
|
||||||
class LatentBlending():
|
class LatentBlending():
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -88,10 +88,13 @@ class LatentBlending():
|
||||||
self.prompt1 = ""
|
self.prompt1 = ""
|
||||||
self.prompt2 = ""
|
self.prompt2 = ""
|
||||||
self.negative_prompt = ""
|
self.negative_prompt = ""
|
||||||
self.tree_latents = None
|
|
||||||
|
self.tree_latents = [None, None]
|
||||||
self.tree_fracts = None
|
self.tree_fracts = None
|
||||||
|
self.idx_injection = []
|
||||||
self.tree_status = None
|
self.tree_status = None
|
||||||
self.tree_final_imgs = []
|
self.tree_final_imgs = []
|
||||||
|
|
||||||
self.list_nmb_branches_prev = []
|
self.list_nmb_branches_prev = []
|
||||||
self.list_injection_idx_prev = []
|
self.list_injection_idx_prev = []
|
||||||
self.text_embedding1 = None
|
self.text_embedding1 = None
|
||||||
|
@ -105,12 +108,15 @@ class LatentBlending():
|
||||||
self.list_injection_idx = None
|
self.list_injection_idx = None
|
||||||
self.list_nmb_branches = None
|
self.list_nmb_branches = None
|
||||||
self.branch1_influence = 0.0
|
self.branch1_influence = 0.0
|
||||||
self.branch1_fract_crossfeed = 0.65
|
self.branch1_mixing_depth = 0.65
|
||||||
self.branch1_insertion_completed = False
|
self.branch1_insertion_completed = False
|
||||||
self.set_guidance_scale(guidance_scale)
|
self.set_guidance_scale(guidance_scale)
|
||||||
self.init_mode()
|
self.init_mode()
|
||||||
self.multi_transition_img_first = None
|
self.multi_transition_img_first = None
|
||||||
self.multi_transition_img_last = None
|
self.multi_transition_img_last = None
|
||||||
|
self.dt_per_diff = 0
|
||||||
|
|
||||||
|
self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
|
||||||
|
|
||||||
|
|
||||||
def init_mode(self):
|
def init_mode(self):
|
||||||
|
@ -375,6 +381,187 @@ class LatentBlending():
|
||||||
self.tree_status.append(['untouched']*nmb_branches)
|
self.tree_status.append(['untouched']*nmb_branches)
|
||||||
|
|
||||||
def run_transition(
|
def run_transition(
|
||||||
|
self,
|
||||||
|
recycle_img1: Optional[bool] = False,
|
||||||
|
recycle_img2: Optional[bool] = False,
|
||||||
|
num_inference_steps: Optional[int] = 30,
|
||||||
|
depth_strength: Optional[float] = 0.3,
|
||||||
|
fixed_seeds: Optional[List[int]] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
# # FIXME: deal with these tree args later
|
||||||
|
# self.num_inference_steps = 30
|
||||||
|
# self.t_compute_max_allowed = 60
|
||||||
|
|
||||||
|
# Sanity checks first
|
||||||
|
assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before'
|
||||||
|
assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before'
|
||||||
|
|
||||||
|
# Random seeds
|
||||||
|
if fixed_seeds is not None:
|
||||||
|
if fixed_seeds == 'randomize':
|
||||||
|
fixed_seeds = list(np.random.randint(0, 1000000, 2).astype(np.int32))
|
||||||
|
else:
|
||||||
|
assert len(fixed_seeds)==2, "Supply a list with len = 2"
|
||||||
|
|
||||||
|
self.seed1 = fixed_seeds[0]
|
||||||
|
self.seed2 = fixed_seeds[1]
|
||||||
|
|
||||||
|
# Ensure correct num_inference_steps in holder
|
||||||
|
self.sdh.num_inference_steps = self.num_inference_steps
|
||||||
|
|
||||||
|
# Compute / Recycle first image
|
||||||
|
if not recycle_img1:
|
||||||
|
list_latents1 = self.compute_latents1()
|
||||||
|
else:
|
||||||
|
# FIXME: check if latents there...
|
||||||
|
list_latents1 = self.tree_latents[0]
|
||||||
|
|
||||||
|
# Compute / Recycle first image
|
||||||
|
if not recycle_img2:
|
||||||
|
list_latents2 = self.compute_latents2()
|
||||||
|
else:
|
||||||
|
# FIXME: check if latents there...
|
||||||
|
list_latents2 = self.tree_latents[-1]
|
||||||
|
|
||||||
|
# Reset the tree, injecting the edge latents1/2 we just generated/recycled
|
||||||
|
self.tree_latents = [list_latents1, list_latents2]
|
||||||
|
self.tree_fracts = [0.0, 1.0]
|
||||||
|
self.tree_final_imgs = [self.sdh.latent2image((self.tree_latents[0][-1])), self.sdh.latent2image((self.tree_latents[-1][-1]))]
|
||||||
|
self.tree_idx_injection = [0, 0]
|
||||||
|
|
||||||
|
# Set up branching scheme (dependent on provided compute time)
|
||||||
|
idx_injection_base = int(round(self.num_inference_steps*depth_strength))
|
||||||
|
list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps-1, 3)
|
||||||
|
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
|
||||||
|
t_compute = 0
|
||||||
|
while t_compute < self.t_compute_max_allowed:
|
||||||
|
list_compute_steps = self.num_inference_steps - list_idx_injection
|
||||||
|
list_compute_steps *= list_nmb_stems
|
||||||
|
t_compute = np.sum(list_compute_steps) * self.dt_per_diff
|
||||||
|
increase_done = False
|
||||||
|
for s_idx in range(len(list_nmb_stems)-1):
|
||||||
|
if list_nmb_stems[s_idx+1] / list_nmb_stems[s_idx] >= 2:
|
||||||
|
list_nmb_stems[s_idx] += 1
|
||||||
|
increase_done = True
|
||||||
|
break
|
||||||
|
if not increase_done:
|
||||||
|
list_nmb_stems[-1] += 1
|
||||||
|
# print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}")
|
||||||
|
|
||||||
|
# Run iteratively
|
||||||
|
for s_idx in tqdm(range(len(list_idx_injection))):
|
||||||
|
nmb_stems = list_nmb_stems[s_idx]
|
||||||
|
idx_injection = list_idx_injection[s_idx]
|
||||||
|
|
||||||
|
for i in range(nmb_stems):
|
||||||
|
fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection)
|
||||||
|
# print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}")
|
||||||
|
list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection)
|
||||||
|
self.insert_into_tree(fract_mixing, idx_injection, list_latents)
|
||||||
|
|
||||||
|
return self.tree_final_imgs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_mixing_parameters(self, idx_injection):
|
||||||
|
# get_lpips_similarity
|
||||||
|
similarities = []
|
||||||
|
for i in range(len(self.tree_final_imgs)-1):
|
||||||
|
similarities.append(self.get_lpips_similarity(self.tree_final_imgs[i], self.tree_final_imgs[i+1]))
|
||||||
|
b_closest1 = np.argmax(similarities)
|
||||||
|
b_closest2 = b_closest1+1
|
||||||
|
fract_closest1 = self.tree_fracts[b_closest1]
|
||||||
|
fract_closest2 = self.tree_fracts[b_closest2]
|
||||||
|
|
||||||
|
# Ensure that the parents are indeed older!
|
||||||
|
b_parent1 = b_closest1
|
||||||
|
while True:
|
||||||
|
if self.tree_idx_injection[b_parent1] < idx_injection:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
b_parent1 -= 1
|
||||||
|
|
||||||
|
b_parent2 = b_closest2
|
||||||
|
while True:
|
||||||
|
if self.tree_idx_injection[b_parent2] < idx_injection:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
b_parent2 += 1
|
||||||
|
|
||||||
|
# print(f"\n\nb_closest: {b_closest1} {b_closest2} fract_closest1 {fract_closest1} fract_closest2 {fract_closest2}")
|
||||||
|
# print(f"b_parent: {b_parent1} {b_parent2}")
|
||||||
|
# print(f"similarities {similarities}")
|
||||||
|
# print(f"idx_injection {idx_injection} tree_idx_injection {self.tree_idx_injection}")
|
||||||
|
|
||||||
|
fract_mixing = (fract_closest1 + fract_closest2) /2
|
||||||
|
return fract_mixing, b_parent1, b_parent2
|
||||||
|
|
||||||
|
|
||||||
|
def insert_into_tree(self, fract_mixing, idx_injection, list_latents):
|
||||||
|
# FIXME
|
||||||
|
b_parent1, b_parent2 = get_closest_idx(fract_mixing, self.tree_fracts)
|
||||||
|
self.tree_latents.insert(b_parent1+1, list_latents)
|
||||||
|
self.tree_final_imgs.insert(b_parent1+1, self.sdh.latent2image(list_latents[-1]))
|
||||||
|
self.tree_fracts.insert(b_parent1+1, fract_mixing)
|
||||||
|
self.tree_idx_injection.insert(b_parent1+1, idx_injection)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection):
|
||||||
|
# FIXME
|
||||||
|
list_conditionings = self.get_mixed_conditioning(fract_mixing)
|
||||||
|
fract_mixing_parental = (fract_mixing - self.tree_fracts[b_parent1]) / (self.tree_fracts[b_parent2] - self.tree_fracts[b_parent1])
|
||||||
|
idx_reversed = self.num_inference_steps - idx_injection
|
||||||
|
latents_for_injection = interpolate_spherical(
|
||||||
|
self.tree_latents[b_parent1][-idx_reversed-1],
|
||||||
|
self.tree_latents[b_parent2][-idx_reversed-1],
|
||||||
|
fract_mixing_parental)
|
||||||
|
list_latents = self.run_diffusion(list_conditionings, latents_for_injection=latents_for_injection, idx_start=idx_injection)
|
||||||
|
return list_latents
|
||||||
|
|
||||||
|
|
||||||
|
def compute_latents1(self, return_image=False):
|
||||||
|
print("starting compute_latents1")
|
||||||
|
list_conditionings = [self.text_embedding1]
|
||||||
|
t0 = time.time()
|
||||||
|
list_latents1 = self.run_diffusion(list_conditionings, seed_source=self.seed1)
|
||||||
|
t1 = time.time()
|
||||||
|
self.dt_per_diff = (t1-t0) / self.num_inference_steps
|
||||||
|
self.tree_latents[0] = list_latents1
|
||||||
|
if return_image:
|
||||||
|
return self.sdh.latent2image(list_latents1[-1])
|
||||||
|
else:
|
||||||
|
return list_latents1
|
||||||
|
|
||||||
|
|
||||||
|
def compute_latents2(self, return_image=False):
|
||||||
|
print("starting compute_latents2")
|
||||||
|
list_conditionings = [self.text_embedding2
|
||||||
|
]
|
||||||
|
# Influence from branch1
|
||||||
|
if self.branch1_influence > 0.0:
|
||||||
|
self.branch1_influence = np.clip(self.branch1_influence, 0, 1)
|
||||||
|
self.branch1_mixing_depth = np.clip(self.branch1_mixing_depth, 0, 1)
|
||||||
|
idx_crossfeed = int(round(self.num_inference_steps*self.branch1_mixing_depth))
|
||||||
|
list_latents2 = self.run_diffusion(
|
||||||
|
list_conditionings,
|
||||||
|
idx_start=idx_crossfeed,
|
||||||
|
latents_for_injection=self.tree_latents[0],
|
||||||
|
seed_source=self.seed2,
|
||||||
|
seed_mixing_target=self.seed1,
|
||||||
|
mixing_coeff=self.branch1_influence)
|
||||||
|
else:
|
||||||
|
list_latents2 = self.run_diffusion(list_conditionings)
|
||||||
|
self.tree_latents[-1] = list_latents2
|
||||||
|
|
||||||
|
if return_image:
|
||||||
|
return self.sdh.latent2image(list_latents2[-1])
|
||||||
|
else:
|
||||||
|
return list_latents2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def run_transition_OLD(
|
||||||
self,
|
self,
|
||||||
recycle_img1: Optional[bool] = False,
|
recycle_img1: Optional[bool] = False,
|
||||||
recycle_img2: Optional[bool] = False,
|
recycle_img2: Optional[bool] = False,
|
||||||
|
@ -423,9 +610,9 @@ class LatentBlending():
|
||||||
if self.branch1_influence > 0.0 and not self.branch1_insertion_completed:
|
if self.branch1_influence > 0.0 and not self.branch1_insertion_completed:
|
||||||
assert self.list_nmb_branches[0]==2, 'branch1 influnce currently requires the self.list_nmb_branches[0] = 0'
|
assert self.list_nmb_branches[0]==2, 'branch1 influnce currently requires the self.list_nmb_branches[0] = 0'
|
||||||
self.branch1_influence = np.clip(self.branch1_influence, 0, 1)
|
self.branch1_influence = np.clip(self.branch1_influence, 0, 1)
|
||||||
self.branch1_fract_crossfeed = np.clip(self.branch1_fract_crossfeed, 0, 1)
|
self.branch1_mixing_depth = np.clip(self.branch1_mixing_depth, 0, 1)
|
||||||
self.list_nmb_branches.insert(1, 2)
|
self.list_nmb_branches.insert(1, 2)
|
||||||
idx_crossfeed = int(round(self.list_injection_idx[1]*self.branch1_fract_crossfeed))
|
idx_crossfeed = int(round(self.list_injection_idx[1]*self.branch1_mixing_depth))
|
||||||
self.list_injection_idx_ext.insert(1, idx_crossfeed)
|
self.list_injection_idx_ext.insert(1, idx_crossfeed)
|
||||||
self.tree_fracts.insert(1, self.tree_fracts[0])
|
self.tree_fracts.insert(1, self.tree_fracts[0])
|
||||||
self.tree_status.insert(1, self.tree_status[0])
|
self.tree_status.insert(1, self.tree_status[0])
|
||||||
|
@ -606,6 +793,9 @@ class LatentBlending():
|
||||||
latents_for_injection: torch.FloatTensor = None,
|
latents_for_injection: torch.FloatTensor = None,
|
||||||
idx_start: int = -1,
|
idx_start: int = -1,
|
||||||
idx_stop: int = -1,
|
idx_stop: int = -1,
|
||||||
|
seed_source: int = -1,
|
||||||
|
seed_mixing_target: int = -1,
|
||||||
|
mixing_coeff: float = 0.0,
|
||||||
return_image: Optional[bool] = False
|
return_image: Optional[bool] = False
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
|
@ -620,6 +810,7 @@ class LatentBlending():
|
||||||
Index of the diffusion process start and where the latents_for_injection are injected
|
Index of the diffusion process start and where the latents_for_injection are injected
|
||||||
idx_stop: int
|
idx_stop: int
|
||||||
Index of the diffusion process end.
|
Index of the diffusion process end.
|
||||||
|
FIXME ARGS
|
||||||
return_image: Optional[bool]
|
return_image: Optional[bool]
|
||||||
Optionally return image directly
|
Optionally return image directly
|
||||||
"""
|
"""
|
||||||
|
@ -630,14 +821,23 @@ class LatentBlending():
|
||||||
|
|
||||||
if self.mode == 'standard':
|
if self.mode == 'standard':
|
||||||
text_embeddings = list_conditionings[0]
|
text_embeddings = list_conditionings[0]
|
||||||
return self.sdh.run_diffusion_standard(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
|
return self.sdh.run_diffusion_standard(
|
||||||
|
text_embeddings,
|
||||||
|
latents_for_injection=latents_for_injection,
|
||||||
|
idx_start=idx_start,
|
||||||
|
idx_stop=idx_stop,
|
||||||
|
seed_source=seed_source,
|
||||||
|
seed_mixing_target=seed_mixing_target,
|
||||||
|
mixing_coeff=mixing_coeff,
|
||||||
|
return_image=return_image,
|
||||||
|
)
|
||||||
|
|
||||||
elif self.mode == 'inpaint':
|
elif self.mode == 'inpaint':
|
||||||
text_embeddings = list_conditionings[0]
|
text_embeddings = list_conditionings[0]
|
||||||
assert self.sdh.image_source is not None, "image_source is None. Please run init_inpainting first."
|
assert self.sdh.image_source is not None, "image_source is None. Please run init_inpainting first."
|
||||||
assert self.sdh.mask_image is not None, "image_source is None. Please run init_inpainting first."
|
assert self.sdh.mask_image is not None, "image_source is None. Please run init_inpainting first."
|
||||||
return self.sdh.run_diffusion_inpaint(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
|
return self.sdh.run_diffusion_inpaint(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
|
||||||
|
# FIXME LONG LINE
|
||||||
elif self.mode == 'upscale':
|
elif self.mode == 'upscale':
|
||||||
cond = list_conditionings[0]
|
cond = list_conditionings[0]
|
||||||
uc_full = list_conditionings[1]
|
uc_full = list_conditionings[1]
|
||||||
|
@ -882,8 +1082,6 @@ class LatentBlending():
|
||||||
self.tree_latents[t_block][-1] = list_latents[self.list_injection_idx_ext[t_block]:self.list_injection_idx_ext[t_block+1]]
|
self.tree_latents[t_block][-1] = list_latents[self.list_injection_idx_ext[t_block]:self.list_injection_idx_ext[t_block+1]]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def swap_forward(self):
|
def swap_forward(self):
|
||||||
r"""
|
r"""
|
||||||
Moves over keyframe two -> keyframe one. Useful for making a sequence of transitions
|
Moves over keyframe two -> keyframe one. Useful for making a sequence of transitions
|
||||||
|
@ -901,6 +1099,21 @@ class LatentBlending():
|
||||||
self.tree_final_imgs = []
|
self.tree_final_imgs = []
|
||||||
|
|
||||||
|
|
||||||
|
def get_lpips_similarity(self, imgA, imgB):
|
||||||
|
# FIXME
|
||||||
|
tensorA = torch.from_numpy(imgA).float().cuda(self.device)
|
||||||
|
tensorA = 2*tensorA/255.0 - 1
|
||||||
|
tensorA = tensorA.permute([2,0,1]).unsqueeze(0)
|
||||||
|
|
||||||
|
tensorB = torch.from_numpy(imgB).float().cuda(self.device)
|
||||||
|
tensorB = 2*tensorB/255.0 - 1
|
||||||
|
tensorB = tensorB.permute([2,0,1]).unsqueeze(0)
|
||||||
|
lploss = self.lpips(tensorA, tensorB)
|
||||||
|
lploss = float(lploss[0][0][0][0])
|
||||||
|
|
||||||
|
return lploss
|
||||||
|
|
||||||
|
|
||||||
# Auxiliary functions
|
# Auxiliary functions
|
||||||
def get_closest_idx(
|
def get_closest_idx(
|
||||||
fract_mixing: float,
|
fract_mixing: float,
|
||||||
|
@ -1169,31 +1382,79 @@ if __name__ == "__main__":
|
||||||
#%% First let us spawn a stable diffusion holder
|
#%% First let us spawn a stable diffusion holder
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
|
fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
|
||||||
fp_config = 'configs/v2-inference.yaml'
|
|
||||||
|
|
||||||
sdh = StableDiffusionHolder(fp_ckpt, fp_config, device)
|
sdh = StableDiffusionHolder(fp_ckpt)
|
||||||
|
|
||||||
xxx
|
xxx
|
||||||
|
|
||||||
|
|
||||||
#%% Next let's set up all parameters
|
#%% Next let's set up all parameters
|
||||||
quality = 'medium'
|
depth_strength = 0.3 # Specifies how deep (in terms of diffusion iterations the first branching happens)
|
||||||
depth_strength = 0.65 # Specifies how deep (in terms of diffusion iterations the first branching happens)
|
fixed_seeds = [697164, 430214]
|
||||||
fixed_seeds = [69731932, 504430820]
|
|
||||||
|
|
||||||
prompt1 = "photo of a beautiful cherry forest covered in white flowers, ambient light, very detailed, magic"
|
prompt1 = "photo of a desert and a sky"
|
||||||
prompt2 = "photo of an golden statue with a funny hat, surrounded by ferns and vines, grainy analog photograph, mystical ambience, incredible detail"
|
prompt2 = "photo of a tree with a lake"
|
||||||
|
|
||||||
duration_transition = 12 # In seconds
|
duration_transition = 12 # In seconds
|
||||||
fps = 30
|
fps = 30
|
||||||
|
|
||||||
# Spawn latent blending
|
# Spawn latent blending
|
||||||
self = LatentBlending(sdh)
|
self = LatentBlending(sdh)
|
||||||
self.branch1_influence = 0.8
|
|
||||||
self.load_branching_profile(quality=quality, depth_strength=0.3)
|
|
||||||
self.set_prompt1(prompt1)
|
self.set_prompt1(prompt1)
|
||||||
self.set_prompt2(prompt2)
|
self.set_prompt2(prompt2)
|
||||||
|
|
||||||
# Run latent blending
|
# Run latent blending
|
||||||
imgs_transition = self.run_transition(fixed_seeds=fixed_seeds)
|
self.branch1_influence = 0.3
|
||||||
|
self.branch1_mixing_depth = 0.4
|
||||||
|
self.run_transition(depth_strength=depth_strength, fixed_seeds=fixed_seeds)
|
||||||
|
#%%
|
||||||
|
self.branch1_influence = 0.3
|
||||||
|
self.branch1_mixing_depth = 0.5
|
||||||
|
img2 = self.compute_latents2(return_image=True)
|
||||||
|
Image.fromarray(img2)
|
||||||
|
|
||||||
|
#%%
|
||||||
|
idx_injection = 15
|
||||||
|
fract_mixing = 0.5
|
||||||
|
list_conditionings = self.get_mixed_conditioning(fract_mixing)
|
||||||
|
latents_for_injection = interpolate_spherical(self.tree_latents[0][idx_injection], self.tree_latents[-1][idx_injection], fract_mixing)
|
||||||
|
list_latents = self.run_diffusion(list_conditionings, latents_for_injection=latents_for_injection, idx_start=idx_injection)
|
||||||
|
img_mix = self.sdh.latent2image((list_latents[-1]))
|
||||||
|
|
||||||
|
Image.fromarray(np.concatenate((img1,img_mix,img2), axis=1)).resize((800,800//3))
|
||||||
|
|
||||||
|
#%% scheme
|
||||||
|
# init scheme
|
||||||
|
list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps-1, 2)
|
||||||
|
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
|
||||||
|
|
||||||
|
#%%
|
||||||
|
list_compute_steps = self.num_inference_steps - list_idx_injection
|
||||||
|
list_compute_steps *= list_nmb_stems
|
||||||
|
t_compute = np.sum(list_compute_steps) * self.dt_per_diff
|
||||||
|
increase_done = False
|
||||||
|
for s_idx in range(len(list_nmb_stems)-1):
|
||||||
|
if list_nmb_stems[s_idx+1] / list_nmb_stems[s_idx] >= 3:
|
||||||
|
list_nmb_stems[s_idx] += 1
|
||||||
|
increase_done = True
|
||||||
|
break
|
||||||
|
if not increase_done:
|
||||||
|
list_nmb_stems[-1] += 1
|
||||||
|
print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}")
|
||||||
|
|
||||||
|
|
||||||
|
#%%
|
||||||
|
|
||||||
|
imgs_transition = self.tree_final_imgs
|
||||||
|
# Let's get more cheap frames via linear interpolation (duration_transition*fps frames)
|
||||||
|
imgs_transition_ext = add_frames_linear_interp(imgs_transition, 15, fps)
|
||||||
|
|
||||||
|
# Save as MP4
|
||||||
|
fp_movie = "test.mp4"
|
||||||
|
if os.path.isfile(fp_movie):
|
||||||
|
os.remove(fp_movie)
|
||||||
|
ms = MovieSaver(fp_movie, fps=fps, shape_hw=[sdh.height, sdh.width])
|
||||||
|
for img in tqdm(imgs_transition_ext):
|
||||||
|
ms.write_frame(img)
|
||||||
|
ms.finalize()
|
||||||
|
|
|
@ -42,7 +42,6 @@ from contextlib import nullcontext
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from einops import repeat, rearrange
|
from einops import repeat, rearrange
|
||||||
|
|
||||||
#%%
|
#%%
|
||||||
|
|
||||||
|
|
||||||
|
@ -279,10 +278,13 @@ class StableDiffusionHolder:
|
||||||
def run_diffusion_standard(
|
def run_diffusion_standard(
|
||||||
self,
|
self,
|
||||||
text_embeddings: torch.FloatTensor,
|
text_embeddings: torch.FloatTensor,
|
||||||
latents_for_injection: torch.FloatTensor = None,
|
latents_for_injection = None,
|
||||||
idx_start: int = -1,
|
idx_start: int = -1,
|
||||||
idx_stop: int = -1,
|
idx_stop: int = -1,
|
||||||
return_image: Optional[bool] = False
|
seed_source: int = -1,
|
||||||
|
seed_mixing_target: int = -1,
|
||||||
|
mixing_coeff: float = 0.0,
|
||||||
|
return_image: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Wrapper function for run_diffusion_standard and run_diffusion_inpaint.
|
Wrapper function for run_diffusion_standard and run_diffusion_inpaint.
|
||||||
|
@ -291,12 +293,18 @@ class StableDiffusionHolder:
|
||||||
Args:
|
Args:
|
||||||
text_embeddings: torch.FloatTensor
|
text_embeddings: torch.FloatTensor
|
||||||
Text embeddings used for diffusion
|
Text embeddings used for diffusion
|
||||||
latents_for_injection: torch.FloatTensor
|
latents_for_injection: torch.FloatTensor or list
|
||||||
Latents that are used for injection
|
Latents that are used for injection
|
||||||
idx_start: int
|
idx_start: int
|
||||||
Index of the diffusion process start and where the latents_for_injection are injected
|
Index of the diffusion process start and where the latents_for_injection are injected
|
||||||
idx_stop: int
|
idx_stop: int
|
||||||
Index of the diffusion process end.
|
Index of the diffusion process end.
|
||||||
|
mixing_coeff:
|
||||||
|
# FIXME
|
||||||
|
seed_source:
|
||||||
|
# FIXME
|
||||||
|
seed_mixing:
|
||||||
|
# FIXME
|
||||||
return_image: Optional[bool]
|
return_image: Optional[bool]
|
||||||
Optionally return image directly
|
Optionally return image directly
|
||||||
"""
|
"""
|
||||||
|
@ -304,12 +312,19 @@ class StableDiffusionHolder:
|
||||||
|
|
||||||
if latents_for_injection is None:
|
if latents_for_injection is None:
|
||||||
do_inject_latents = False
|
do_inject_latents = False
|
||||||
|
do_mix_latents = False
|
||||||
|
else:
|
||||||
|
if mixing_coeff > 0.0:
|
||||||
|
do_inject_latents = False
|
||||||
|
do_mix_latents = True
|
||||||
|
assert seed_mixing_target != -1, "Set to correct seed for mixing"
|
||||||
else:
|
else:
|
||||||
do_inject_latents = True
|
do_inject_latents = True
|
||||||
|
do_mix_latents = False
|
||||||
|
|
||||||
|
|
||||||
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
||||||
generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
|
generator = torch.Generator(device=self.device).manual_seed(int(seed_source))
|
||||||
|
|
||||||
with precision_scope("cuda"):
|
with precision_scope("cuda"):
|
||||||
with self.model.ema_scope():
|
with self.model.ema_scope():
|
||||||
|
@ -340,6 +355,16 @@ class StableDiffusionHolder:
|
||||||
continue
|
continue
|
||||||
elif i == idx_start:
|
elif i == idx_start:
|
||||||
latents = latents_for_injection.clone()
|
latents = latents_for_injection.clone()
|
||||||
|
if do_mix_latents:
|
||||||
|
if i == 0:
|
||||||
|
generator = torch.Generator(device=self.device).manual_seed(int(seed_mixing_target))
|
||||||
|
latents_mixtarget = torch.randn(size, generator=generator, device=self.device)
|
||||||
|
if i < idx_start:
|
||||||
|
latents_mixtarget = latents_for_injection[i-1].clone()
|
||||||
|
latents = interpolate_spherical(latents, latents_mixtarget, mixing_coeff)
|
||||||
|
|
||||||
|
if i == idx_start:
|
||||||
|
do_mix_latents = False
|
||||||
|
|
||||||
if i == idx_stop:
|
if i == idx_stop:
|
||||||
return list_latents_out
|
return list_latents_out
|
||||||
|
@ -576,6 +601,50 @@ class StableDiffusionHolder:
|
||||||
image = x_sample.astype(np.uint8)
|
image = x_sample.astype(np.uint8)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def interpolate_spherical(p0, p1, fract_mixing: float):
|
||||||
|
r"""
|
||||||
|
Helper function to correctly mix two random variables using spherical interpolation.
|
||||||
|
See https://en.wikipedia.org/wiki/Slerp
|
||||||
|
The function will always cast up to float64 for sake of extra 4.
|
||||||
|
Args:
|
||||||
|
p0:
|
||||||
|
First tensor for interpolation
|
||||||
|
p1:
|
||||||
|
Second tensor for interpolation
|
||||||
|
fract_mixing: float
|
||||||
|
Mixing coefficient of interval [0, 1].
|
||||||
|
0 will return in p0
|
||||||
|
1 will return in p1
|
||||||
|
0.x will return a mix between both preserving angular velocity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if p0.dtype == torch.float16:
|
||||||
|
recast_to = 'fp16'
|
||||||
|
else:
|
||||||
|
recast_to = 'fp32'
|
||||||
|
|
||||||
|
p0 = p0.double()
|
||||||
|
p1 = p1.double()
|
||||||
|
norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
|
||||||
|
epsilon = 1e-7
|
||||||
|
dot = torch.sum(p0 * p1) / norm
|
||||||
|
dot = dot.clamp(-1+epsilon, 1-epsilon)
|
||||||
|
|
||||||
|
theta_0 = torch.arccos(dot)
|
||||||
|
sin_theta_0 = torch.sin(theta_0)
|
||||||
|
theta_t = theta_0 * fract_mixing
|
||||||
|
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
|
||||||
|
s1 = torch.sin(theta_t) / sin_theta_0
|
||||||
|
interp = p0*s0 + p1*s1
|
||||||
|
|
||||||
|
if recast_to == 'fp16':
|
||||||
|
interp = interp.half()
|
||||||
|
elif recast_to == 'fp32':
|
||||||
|
interp = interp.float()
|
||||||
|
|
||||||
|
return interp
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue