parental mixing mode
This commit is contained in:
parent
0371868603
commit
07a5c4ffd7
101
gradio_ui.py
101
gradio_ui.py
|
@ -31,15 +31,9 @@ from stable_diffusion_holder import StableDiffusionHolder
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import copy
|
import copy
|
||||||
|
from dotenv import find_dotenv, load_dotenv
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
TODOS:
|
|
||||||
- clean parameter handling
|
|
||||||
- three buttons: diffuse A, diffuse B, make transition
|
|
||||||
- collapse for easy mode
|
|
||||||
- transition quality in terms of render time
|
|
||||||
"""
|
|
||||||
|
|
||||||
#%%
|
#%%
|
||||||
|
|
||||||
|
@ -65,13 +59,17 @@ class BlendingFrontend():
|
||||||
self.list_settings = []
|
self.list_settings = []
|
||||||
self.state_current = {}
|
self.state_current = {}
|
||||||
self.showing_current = True
|
self.showing_current = True
|
||||||
self.branch1_influence = 0.1
|
self.branch1_influence = 0.3
|
||||||
self.branch1_mixing_depth = 0.3
|
self.branch1_max_depth_influence = 0.6
|
||||||
|
self.branch1_influence_decay = 0.3
|
||||||
|
self.parental_influence = 0.1
|
||||||
|
self.parental_max_depth_influence = 1.0
|
||||||
|
self.parental_influence_decay = 1.0
|
||||||
self.nmb_branches_final = 9
|
self.nmb_branches_final = 9
|
||||||
self.nmb_imgs_show = 5 # don't change
|
self.nmb_imgs_show = 5 # don't change
|
||||||
self.fps = 30
|
self.fps = 30
|
||||||
self.duration_video = 15
|
self.duration_video = 10
|
||||||
self.t_compute_max_allowed = 15
|
self.t_compute_max_allowed = 10
|
||||||
self.dict_multi_trans = {}
|
self.dict_multi_trans = {}
|
||||||
self.dict_multi_trans_include = {}
|
self.dict_multi_trans_include = {}
|
||||||
self.multi_trans_currently_shown = []
|
self.multi_trans_currently_shown = []
|
||||||
|
@ -87,9 +85,20 @@ class BlendingFrontend():
|
||||||
self.height = 768
|
self.height = 768
|
||||||
self.width = 768
|
self.width = 768
|
||||||
|
|
||||||
|
self.init_save_dir()
|
||||||
|
|
||||||
|
|
||||||
|
def init_save_dir(self):
|
||||||
|
load_dotenv(find_dotenv(), verbose=False)
|
||||||
|
try:
|
||||||
|
self.dp_out = os.getenv("dp_out")
|
||||||
|
except Exception as e:
|
||||||
|
self.dp_out = ""
|
||||||
|
|
||||||
|
|
||||||
# make dummy image
|
# make dummy image
|
||||||
def save_empty_image(self):
|
def save_empty_image(self):
|
||||||
self.fp_img_empty = 'empty.jpg'
|
self.fp_img_empty = os.path.join(self.dp_out, 'empty.jpg')
|
||||||
Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5)
|
Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5)
|
||||||
|
|
||||||
|
|
||||||
|
@ -116,8 +125,6 @@ class BlendingFrontend():
|
||||||
self.lb.set_negative_prompt(list_ui_elem[list_ui_keys.index('negative_prompt')])
|
self.lb.set_negative_prompt(list_ui_elem[list_ui_keys.index('negative_prompt')])
|
||||||
self.lb.guidance_scale = list_ui_elem[list_ui_keys.index('guidance_scale')]
|
self.lb.guidance_scale = list_ui_elem[list_ui_keys.index('guidance_scale')]
|
||||||
self.lb.guidance_scale_mid_damper = list_ui_elem[list_ui_keys.index('guidance_scale_mid_damper')]
|
self.lb.guidance_scale_mid_damper = list_ui_elem[list_ui_keys.index('guidance_scale_mid_damper')]
|
||||||
self.lb.branch1_influence = list_ui_elem[list_ui_keys.index('branch1_influence')]
|
|
||||||
self.lb.branch1_mixing_depth = list_ui_elem[list_ui_keys.index('branch1_mixing_depth')]
|
|
||||||
self.lb.t_compute_max_allowed = list_ui_elem[list_ui_keys.index('duration_compute')]
|
self.lb.t_compute_max_allowed = list_ui_elem[list_ui_keys.index('duration_compute')]
|
||||||
self.lb.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
|
self.lb.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
|
||||||
self.lb.sdh.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
|
self.lb.sdh.num_inference_steps = list_ui_elem[list_ui_keys.index('num_inference_steps')]
|
||||||
|
@ -125,11 +132,19 @@ class BlendingFrontend():
|
||||||
self.lb.seed1 = list_ui_elem[list_ui_keys.index('seed1')]
|
self.lb.seed1 = list_ui_elem[list_ui_keys.index('seed1')]
|
||||||
self.lb.seed2 = list_ui_elem[list_ui_keys.index('seed2')]
|
self.lb.seed2 = list_ui_elem[list_ui_keys.index('seed2')]
|
||||||
|
|
||||||
|
self.lb.branch1_influence = list_ui_elem[list_ui_keys.index('branch1_influence')]
|
||||||
|
self.lb.branch1_max_depth_influence = list_ui_elem[list_ui_keys.index('branch1_max_depth_influence')]
|
||||||
|
self.lb.branch1_influence_decay = list_ui_elem[list_ui_keys.index('branch1_influence_decay')]
|
||||||
|
self.lb.parental_influence = list_ui_elem[list_ui_keys.index('parental_influence')]
|
||||||
|
self.lb.parental_max_depth_influence = list_ui_elem[list_ui_keys.index('parental_max_depth_influence')]
|
||||||
|
self.lb.parental_influence_decay = list_ui_elem[list_ui_keys.index('parental_influence_decay')]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def compute_img1(self, *args):
|
def compute_img1(self, *args):
|
||||||
list_ui_elem = args
|
list_ui_elem = args
|
||||||
self.setup_lb(list_ui_elem)
|
self.setup_lb(list_ui_elem)
|
||||||
fp_img1 = f"img1_{get_time('second')}.jpg"
|
fp_img1 = os.path.join(self.dp_out, f"img1_{get_time('second')}.jpg")
|
||||||
img1 = Image.fromarray(self.lb.compute_latents1(return_image=True))
|
img1 = Image.fromarray(self.lb.compute_latents1(return_image=True))
|
||||||
img1.save(fp_img1)
|
img1.save(fp_img1)
|
||||||
self.save_empty_image()
|
self.save_empty_image()
|
||||||
|
@ -138,7 +153,7 @@ class BlendingFrontend():
|
||||||
def compute_img2(self, *args):
|
def compute_img2(self, *args):
|
||||||
list_ui_elem = args
|
list_ui_elem = args
|
||||||
self.setup_lb(list_ui_elem)
|
self.setup_lb(list_ui_elem)
|
||||||
fp_img2 = f"img2_{get_time('second')}.jpg"
|
fp_img2 = os.path.join(self.dp_out, f"img2_{get_time('second')}.jpg")
|
||||||
img2 = Image.fromarray(self.lb.compute_latents2(return_image=True))
|
img2 = Image.fromarray(self.lb.compute_latents2(return_image=True))
|
||||||
img2.save(fp_img2)
|
img2.save(fp_img2)
|
||||||
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2]
|
return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2]
|
||||||
|
@ -198,9 +213,11 @@ class BlendingFrontend():
|
||||||
|
|
||||||
def get_fp_movie(self, timestamp, is_stacked=False):
|
def get_fp_movie(self, timestamp, is_stacked=False):
|
||||||
if not is_stacked:
|
if not is_stacked:
|
||||||
return f"movie_{timestamp}.mp4"
|
fn = f"movie_{timestamp}.mp4"
|
||||||
else:
|
else:
|
||||||
return f"movie_stacked_{timestamp}.mp4"
|
fn = f"movie_stacked_{timestamp}.mp4"
|
||||||
|
fp = os.path.join(self.dp_out, fn)
|
||||||
|
return fp
|
||||||
|
|
||||||
|
|
||||||
def stack_forward(self, prompt2, seed2):
|
def stack_forward(self, prompt2, seed2):
|
||||||
|
@ -319,32 +336,43 @@ if __name__ == "__main__":
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
prompt1 = gr.Textbox(label="prompt 1")
|
prompt1 = gr.Textbox(label="prompt 1")
|
||||||
negative_prompt = gr.Textbox(label="negative prompt")
|
|
||||||
prompt2 = gr.Textbox(label="prompt 2")
|
prompt2 = gr.Textbox(label="prompt 2")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
duration_compute = gr.Slider(10, 40, self.duration_video, step=1, label='compute budget for transition (seconds)', interactive=True)
|
duration_compute = gr.Slider(5, 45, self.t_compute_max_allowed, step=1, label='compute budget for transition (seconds)', interactive=True)
|
||||||
duration_video = gr.Slider(0.1, 30, self.duration_video, step=0.1, label='result video duration (seconds)', interactive=True)
|
duration_video = gr.Slider(0.1, 30, self.duration_video, step=0.1, label='result video duration (seconds)', interactive=True)
|
||||||
height = gr.Slider(256, 2048, self.height, step=128, label='height', interactive=True)
|
height = gr.Slider(256, 2048, self.height, step=128, label='height', interactive=True)
|
||||||
width = gr.Slider(256, 2048, self.width, step=128, label='width', interactive=True)
|
width = gr.Slider(256, 2048, self.width, step=128, label='width', interactive=True)
|
||||||
|
|
||||||
with gr.Accordion("Advanced Settings (click to expand)", open=False):
|
with gr.Accordion("Advanced Settings (click to expand)", open=False):
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Accordion("Diffusion settings", open=True):
|
||||||
depth_strength = gr.Slider(0.01, 0.99, self.depth_strength, step=0.01, label='depth_strength', interactive=True)
|
with gr.Row():
|
||||||
branch1_influence = gr.Slider(0.0, 1.0, self.branch1_influence, step=0.01, label='branch1_influence', interactive=True)
|
num_inference_steps = gr.Slider(5, 100, self.num_inference_steps, step=1, label='num_inference_steps', interactive=True)
|
||||||
branch1_mixing_depth = gr.Slider(0.0, 1.0, self.branch1_mixing_depth, step=0.01, label='branch1_mixing_depth', interactive=True)
|
guidance_scale = gr.Slider(1, 25, self.guidance_scale, step=0.1, label='guidance_scale', interactive=True)
|
||||||
|
negative_prompt = gr.Textbox(label="negative prompt")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Accordion("Seeds control", open=True):
|
||||||
num_inference_steps = gr.Slider(5, 100, self.num_inference_steps, step=1, label='num_inference_steps', interactive=True)
|
with gr.Row():
|
||||||
guidance_scale = gr.Slider(1, 25, self.guidance_scale, step=0.1, label='guidance_scale', interactive=True)
|
seed1 = gr.Number(420, label="seed 1", interactive=True)
|
||||||
guidance_scale_mid_damper = gr.Slider(0.01, 2.0, self.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
|
b_newseed1 = gr.Button("randomize seed 1", variant='secondary')
|
||||||
|
seed2 = gr.Number(420, label="seed 2", interactive=True)
|
||||||
|
b_newseed2 = gr.Button("randomize seed 2", variant='secondary')
|
||||||
|
|
||||||
|
with gr.Accordion("Crossfeeding for last image", open=True):
|
||||||
|
with gr.Row():
|
||||||
|
branch1_influence = gr.Slider(0.0, 1.0, self.branch1_influence, step=0.01, label='crossfeed power', interactive=True)
|
||||||
|
branch1_max_depth_influence = gr.Slider(0.0, 1.0, self.branch1_max_depth_influence, step=0.01, label='crossfeed range', interactive=True)
|
||||||
|
branch1_influence_decay = gr.Slider(0.0, 1.0, self.branch1_influence_decay, step=0.01, label='crossfeed decay', interactive=True)
|
||||||
|
|
||||||
|
with gr.Accordion("Transition settings", open=True):
|
||||||
|
with gr.Row():
|
||||||
|
depth_strength = gr.Slider(0.01, 0.99, self.depth_strength, step=0.01, label='depth_strength', interactive=True)
|
||||||
|
guidance_scale_mid_damper = gr.Slider(0.01, 2.0, self.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
|
||||||
|
parental_influence = gr.Slider(0.0, 1.0, self.parental_influence, step=0.01, label='parental power', interactive=True)
|
||||||
|
parental_max_depth_influence = gr.Slider(0.0, 1.0, self.parental_max_depth_influence, step=0.01, label='parental range', interactive=True)
|
||||||
|
parental_influence_decay = gr.Slider(0.0, 1.0, self.parental_influence_decay, step=0.01, label='parental decay', interactive=True)
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
seed1 = gr.Number(420, label="seed 1", interactive=True)
|
|
||||||
b_newseed1 = gr.Button("randomize seed 1", variant='secondary')
|
|
||||||
seed2 = gr.Number(420, label="seed 2", interactive=True)
|
|
||||||
b_newseed2 = gr.Button("randomize seed 2", variant='secondary')
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
b_compute1 = gr.Button('compute first image', variant='primary')
|
b_compute1 = gr.Button('compute first image', variant='primary')
|
||||||
|
@ -373,7 +401,8 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
dict_ui_elem["depth_strength"] = depth_strength
|
dict_ui_elem["depth_strength"] = depth_strength
|
||||||
dict_ui_elem["branch1_influence"] = branch1_influence
|
dict_ui_elem["branch1_influence"] = branch1_influence
|
||||||
dict_ui_elem["branch1_mixing_depth"] = branch1_mixing_depth
|
dict_ui_elem["branch1_max_depth_influence"] = branch1_max_depth_influence
|
||||||
|
dict_ui_elem["branch1_influence_decay"] = branch1_influence_decay
|
||||||
|
|
||||||
dict_ui_elem["num_inference_steps"] = num_inference_steps
|
dict_ui_elem["num_inference_steps"] = num_inference_steps
|
||||||
dict_ui_elem["guidance_scale"] = guidance_scale
|
dict_ui_elem["guidance_scale"] = guidance_scale
|
||||||
|
@ -381,6 +410,10 @@ if __name__ == "__main__":
|
||||||
dict_ui_elem["seed1"] = seed1
|
dict_ui_elem["seed1"] = seed1
|
||||||
dict_ui_elem["seed2"] = seed2
|
dict_ui_elem["seed2"] = seed2
|
||||||
|
|
||||||
|
dict_ui_elem["parental_max_depth_influence"] = parental_max_depth_influence
|
||||||
|
dict_ui_elem["parental_influence"] = parental_influence
|
||||||
|
dict_ui_elem["parental_influence_decay"] = parental_influence_decay
|
||||||
|
|
||||||
# Convert to list, as gradio doesn't seem to accept dicts
|
# Convert to list, as gradio doesn't seem to accept dicts
|
||||||
list_ui_elem = []
|
list_ui_elem = []
|
||||||
list_ui_keys = []
|
list_ui_keys = []
|
||||||
|
|
|
@ -107,8 +107,16 @@ class LatentBlending():
|
||||||
self.noise_level_upscaling = 20
|
self.noise_level_upscaling = 20
|
||||||
self.list_injection_idx = None
|
self.list_injection_idx = None
|
||||||
self.list_nmb_branches = None
|
self.list_nmb_branches = None
|
||||||
|
|
||||||
|
# Mixing parameters
|
||||||
self.branch1_influence = 0.0
|
self.branch1_influence = 0.0
|
||||||
self.branch1_mixing_depth = 0.65
|
self.branch1_max_depth_influence = 0.65
|
||||||
|
self.branch1_influence_decay = 0.8
|
||||||
|
|
||||||
|
self.parental_influence = 0.0
|
||||||
|
self.parental_max_depth_influence = 1.0
|
||||||
|
self.parental_influence_decay = 1.0
|
||||||
|
|
||||||
self.branch1_insertion_completed = False
|
self.branch1_insertion_completed = False
|
||||||
self.set_guidance_scale(guidance_scale)
|
self.set_guidance_scale(guidance_scale)
|
||||||
self.init_mode()
|
self.init_mode()
|
||||||
|
@ -389,10 +397,6 @@ class LatentBlending():
|
||||||
fixed_seeds: Optional[List[int]] = None,
|
fixed_seeds: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
# # FIXME: deal with these tree args later
|
|
||||||
# self.num_inference_steps = 30
|
|
||||||
# self.t_compute_max_allowed = 60
|
|
||||||
|
|
||||||
# Sanity checks first
|
# Sanity checks first
|
||||||
assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before'
|
assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before'
|
||||||
assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before'
|
assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before'
|
||||||
|
@ -411,17 +415,15 @@ class LatentBlending():
|
||||||
self.sdh.num_inference_steps = self.num_inference_steps
|
self.sdh.num_inference_steps = self.num_inference_steps
|
||||||
|
|
||||||
# Compute / Recycle first image
|
# Compute / Recycle first image
|
||||||
if not recycle_img1:
|
if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps:
|
||||||
list_latents1 = self.compute_latents1()
|
list_latents1 = self.compute_latents1()
|
||||||
else:
|
else:
|
||||||
# FIXME: check if latents there...
|
|
||||||
list_latents1 = self.tree_latents[0]
|
list_latents1 = self.tree_latents[0]
|
||||||
|
|
||||||
# Compute / Recycle first image
|
# Compute / Recycle first image
|
||||||
if not recycle_img2:
|
if not recycle_img2 or len(self.tree_latents[-1]) != self.num_inference_steps:
|
||||||
list_latents2 = self.compute_latents2()
|
list_latents2 = self.compute_latents2()
|
||||||
else:
|
else:
|
||||||
# FIXME: check if latents there...
|
|
||||||
list_latents2 = self.tree_latents[-1]
|
list_latents2 = self.tree_latents[-1]
|
||||||
|
|
||||||
# Reset the tree, injecting the edge latents1/2 we just generated/recycled
|
# Reset the tree, injecting the edge latents1/2 we just generated/recycled
|
||||||
|
@ -438,7 +440,7 @@ class LatentBlending():
|
||||||
while t_compute < self.t_compute_max_allowed:
|
while t_compute < self.t_compute_max_allowed:
|
||||||
list_compute_steps = self.num_inference_steps - list_idx_injection
|
list_compute_steps = self.num_inference_steps - list_idx_injection
|
||||||
list_compute_steps *= list_nmb_stems
|
list_compute_steps *= list_nmb_stems
|
||||||
t_compute = np.sum(list_compute_steps) * self.dt_per_diff
|
t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15*np.sum(list_nmb_stems)
|
||||||
increase_done = False
|
increase_done = False
|
||||||
for s_idx in range(len(list_nmb_stems)-1):
|
for s_idx in range(len(list_nmb_stems)-1):
|
||||||
if list_nmb_stems[s_idx+1] / list_nmb_stems[s_idx] >= 2:
|
if list_nmb_stems[s_idx+1] / list_nmb_stems[s_idx] >= 2:
|
||||||
|
@ -449,13 +451,14 @@ class LatentBlending():
|
||||||
list_nmb_stems[-1] += 1
|
list_nmb_stems[-1] += 1
|
||||||
# print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}")
|
# print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}")
|
||||||
|
|
||||||
# Run iteratively
|
# Run iteratively, always inserting new branches where they are needed most
|
||||||
for s_idx in tqdm(range(len(list_idx_injection))):
|
for s_idx in tqdm(range(len(list_idx_injection))):
|
||||||
nmb_stems = list_nmb_stems[s_idx]
|
nmb_stems = list_nmb_stems[s_idx]
|
||||||
idx_injection = list_idx_injection[s_idx]
|
idx_injection = list_idx_injection[s_idx]
|
||||||
|
|
||||||
for i in range(nmb_stems):
|
for i in range(nmb_stems):
|
||||||
fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection)
|
fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection)
|
||||||
|
self.set_guidance_mid_dampening(fract_mixing)
|
||||||
# print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}")
|
# print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}")
|
||||||
list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection)
|
list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection)
|
||||||
self.insert_into_tree(fract_mixing, idx_injection, list_latents)
|
self.insert_into_tree(fract_mixing, idx_injection, list_latents)
|
||||||
|
@ -465,6 +468,14 @@ class LatentBlending():
|
||||||
|
|
||||||
|
|
||||||
def get_mixing_parameters(self, idx_injection):
|
def get_mixing_parameters(self, idx_injection):
|
||||||
|
r"""
|
||||||
|
Computes which parental latents should be mixed together to achieve a smooth blend.
|
||||||
|
As metric, we are using lpips image similarity. The insertion takes place
|
||||||
|
where the metric is maximal.
|
||||||
|
Args:
|
||||||
|
idx_injection: int
|
||||||
|
the index in terms of diffusion steps, where the next insertion will start.
|
||||||
|
"""
|
||||||
# get_lpips_similarity
|
# get_lpips_similarity
|
||||||
similarities = []
|
similarities = []
|
||||||
for i in range(len(self.tree_final_imgs)-1):
|
for i in range(len(self.tree_final_imgs)-1):
|
||||||
|
@ -499,7 +510,16 @@ class LatentBlending():
|
||||||
|
|
||||||
|
|
||||||
def insert_into_tree(self, fract_mixing, idx_injection, list_latents):
|
def insert_into_tree(self, fract_mixing, idx_injection, list_latents):
|
||||||
# FIXME
|
r"""
|
||||||
|
Inserts all necessary parameters into the trajectory tree.
|
||||||
|
Args:
|
||||||
|
fract_mixing: float
|
||||||
|
the fraction along the transition axis [0, 1]
|
||||||
|
idx_injection: int
|
||||||
|
the index in terms of diffusion steps, where the next insertion will start.
|
||||||
|
list_latents: list
|
||||||
|
list of the latents to be inserted
|
||||||
|
"""
|
||||||
b_parent1, b_parent2 = get_closest_idx(fract_mixing, self.tree_fracts)
|
b_parent1, b_parent2 = get_closest_idx(fract_mixing, self.tree_fracts)
|
||||||
self.tree_latents.insert(b_parent1+1, list_latents)
|
self.tree_latents.insert(b_parent1+1, list_latents)
|
||||||
self.tree_final_imgs.insert(b_parent1+1, self.sdh.latent2image(list_latents[-1]))
|
self.tree_final_imgs.insert(b_parent1+1, self.sdh.latent2image(list_latents[-1]))
|
||||||
|
@ -508,23 +528,67 @@ class LatentBlending():
|
||||||
|
|
||||||
|
|
||||||
def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection):
|
def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection):
|
||||||
# FIXME
|
r"""
|
||||||
|
Runs a diffusion trajectory, using the latents from the respective parents
|
||||||
|
Args:
|
||||||
|
fract_mixing: float
|
||||||
|
the fraction along the transition axis [0, 1]
|
||||||
|
b_parent1: int
|
||||||
|
index of parent1 to be used
|
||||||
|
b_parent2: int
|
||||||
|
index of parent2 to be used
|
||||||
|
idx_injection: int
|
||||||
|
the index in terms of diffusion steps, where the next insertion will start.
|
||||||
|
"""
|
||||||
list_conditionings = self.get_mixed_conditioning(fract_mixing)
|
list_conditionings = self.get_mixed_conditioning(fract_mixing)
|
||||||
fract_mixing_parental = (fract_mixing - self.tree_fracts[b_parent1]) / (self.tree_fracts[b_parent2] - self.tree_fracts[b_parent1])
|
fract_mixing_parental = (fract_mixing - self.tree_fracts[b_parent1]) / (self.tree_fracts[b_parent2] - self.tree_fracts[b_parent1])
|
||||||
idx_reversed = self.num_inference_steps - idx_injection
|
# idx_reversed = self.num_inference_steps - idx_injection
|
||||||
latents_for_injection = interpolate_spherical(
|
|
||||||
self.tree_latents[b_parent1][-idx_reversed-1],
|
list_latents_parental_mix = []
|
||||||
self.tree_latents[b_parent2][-idx_reversed-1],
|
for i in range(self.num_inference_steps):
|
||||||
fract_mixing_parental)
|
latents_p1 = self.tree_latents[b_parent1][i]
|
||||||
list_latents = self.run_diffusion(list_conditionings, latents_for_injection=latents_for_injection, idx_start=idx_injection)
|
latents_p2 = self.tree_latents[b_parent2][i]
|
||||||
|
if latents_p1 is None or latents_p2 is None:
|
||||||
|
latents_parental = None
|
||||||
|
else:
|
||||||
|
latents_parental = interpolate_spherical(latents_p1, latents_p2, fract_mixing_parental)
|
||||||
|
list_latents_parental_mix.append(latents_parental)
|
||||||
|
|
||||||
|
idx_mixing_stop = int(round(self.num_inference_steps*self.parental_max_depth_influence))
|
||||||
|
mixing_coeffs = idx_injection*[self.parental_influence]
|
||||||
|
nmb_mixing = idx_mixing_stop - idx_injection
|
||||||
|
if nmb_mixing > 0:
|
||||||
|
mixing_coeffs.extend(list(np.linspace(self.parental_influence, self.parental_influence*self.parental_influence_decay, nmb_mixing)))
|
||||||
|
mixing_coeffs.extend((self.num_inference_steps-len(mixing_coeffs))*[0])
|
||||||
|
|
||||||
|
latents_start = list_latents_parental_mix[idx_injection-1]
|
||||||
|
list_latents = self.run_diffusion(
|
||||||
|
list_conditionings,
|
||||||
|
latents_start = latents_start,
|
||||||
|
idx_start = idx_injection,
|
||||||
|
list_latents_mixing = list_latents_parental_mix,
|
||||||
|
mixing_coeffs = mixing_coeffs
|
||||||
|
)
|
||||||
|
|
||||||
return list_latents
|
return list_latents
|
||||||
|
|
||||||
|
|
||||||
def compute_latents1(self, return_image=False):
|
def compute_latents1(self, return_image=False):
|
||||||
|
r"""
|
||||||
|
Runs a diffusion trajectory for the first image
|
||||||
|
Args:
|
||||||
|
return_image: bool
|
||||||
|
whether to return an image or the list of latents
|
||||||
|
"""
|
||||||
print("starting compute_latents1")
|
print("starting compute_latents1")
|
||||||
list_conditionings = [self.text_embedding1]
|
list_conditionings = [self.text_embedding1]
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
list_latents1 = self.run_diffusion(list_conditionings, seed_source=self.seed1)
|
latents_start = self.get_noise(self.seed1)
|
||||||
|
list_latents1 = self.run_diffusion(
|
||||||
|
list_conditionings,
|
||||||
|
latents_start = latents_start,
|
||||||
|
idx_start = 0
|
||||||
|
)
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
self.dt_per_diff = (t1-t0) / self.num_inference_steps
|
self.dt_per_diff = (t1-t0) / self.num_inference_steps
|
||||||
self.tree_latents[0] = list_latents1
|
self.tree_latents[0] = list_latents1
|
||||||
|
@ -533,25 +597,31 @@ class LatentBlending():
|
||||||
else:
|
else:
|
||||||
return list_latents1
|
return list_latents1
|
||||||
|
|
||||||
|
|
||||||
def compute_latents2(self, return_image=False):
|
def compute_latents2(self, return_image=False):
|
||||||
print("starting compute_latents2")
|
r"""
|
||||||
list_conditionings = [self.text_embedding2
|
Runs a diffusion trajectory for the last image, which may be affected by the first image's trajectory.
|
||||||
]
|
Args:
|
||||||
|
return_image: bool
|
||||||
|
whether to return an image or the list of latents
|
||||||
|
"""
|
||||||
|
list_conditionings = [self.text_embedding2]
|
||||||
|
latents_start = self.get_noise(self.seed2)
|
||||||
# Influence from branch1
|
# Influence from branch1
|
||||||
if self.branch1_influence > 0.0:
|
if self.branch1_influence > 0.0:
|
||||||
self.branch1_influence = np.clip(self.branch1_influence, 0, 1)
|
# Set up the mixing_coeffs
|
||||||
self.branch1_mixing_depth = np.clip(self.branch1_mixing_depth, 0, 1)
|
idx_mixing_stop = int(round(self.num_inference_steps*self.branch1_max_depth_influence))
|
||||||
idx_crossfeed = int(round(self.num_inference_steps*self.branch1_mixing_depth))
|
mixing_coeffs = list(np.linspace(self.branch1_influence, self.branch1_influence*self.branch1_influence_decay, idx_mixing_stop))
|
||||||
|
mixing_coeffs.extend((self.num_inference_steps-idx_mixing_stop)*[0])
|
||||||
|
list_latents_mixing = self.tree_latents[0]
|
||||||
list_latents2 = self.run_diffusion(
|
list_latents2 = self.run_diffusion(
|
||||||
list_conditionings,
|
list_conditionings,
|
||||||
idx_start=idx_crossfeed,
|
latents_start = latents_start,
|
||||||
latents_for_injection=self.tree_latents[0],
|
idx_start = 0,
|
||||||
seed_source=self.seed2,
|
list_latents_mixing = list_latents_mixing,
|
||||||
seed_mixing_target=self.seed1,
|
mixing_coeffs = mixing_coeffs
|
||||||
mixing_coeff=self.branch1_influence)
|
)
|
||||||
else:
|
else:
|
||||||
list_latents2 = self.run_diffusion(list_conditionings)
|
list_latents2 = self.run_diffusion(list_conditionings, latents_start)
|
||||||
self.tree_latents[-1] = list_latents2
|
self.tree_latents[-1] = list_latents2
|
||||||
|
|
||||||
if return_image:
|
if return_image:
|
||||||
|
@ -559,9 +629,14 @@ class LatentBlending():
|
||||||
else:
|
else:
|
||||||
return list_latents2
|
return list_latents2
|
||||||
|
|
||||||
|
def get_noise(self, seed):
|
||||||
|
generator = torch.Generator(device=self.sdh.device).manual_seed(int(seed))
|
||||||
|
shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
|
||||||
|
C, H, W = shape_latents
|
||||||
|
return torch.randn((1, C, H, W), generator=generator, device=self.sdh.device)
|
||||||
|
|
||||||
|
|
||||||
def run_transition_OLD(
|
def run_transition_legacy(
|
||||||
self,
|
self,
|
||||||
recycle_img1: Optional[bool] = False,
|
recycle_img1: Optional[bool] = False,
|
||||||
recycle_img2: Optional[bool] = False,
|
recycle_img2: Optional[bool] = False,
|
||||||
|
@ -569,6 +644,7 @@ class LatentBlending():
|
||||||
premature_stop: Optional[int] = np.inf,
|
premature_stop: Optional[int] = np.inf,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
|
Old legacy function for computing transitions.
|
||||||
Returns a list of transition images using spherical latent blending.
|
Returns a list of transition images using spherical latent blending.
|
||||||
Args:
|
Args:
|
||||||
recycle_img1: Optional[bool]:
|
recycle_img1: Optional[bool]:
|
||||||
|
@ -610,9 +686,9 @@ class LatentBlending():
|
||||||
if self.branch1_influence > 0.0 and not self.branch1_insertion_completed:
|
if self.branch1_influence > 0.0 and not self.branch1_insertion_completed:
|
||||||
assert self.list_nmb_branches[0]==2, 'branch1 influnce currently requires the self.list_nmb_branches[0] = 0'
|
assert self.list_nmb_branches[0]==2, 'branch1 influnce currently requires the self.list_nmb_branches[0] = 0'
|
||||||
self.branch1_influence = np.clip(self.branch1_influence, 0, 1)
|
self.branch1_influence = np.clip(self.branch1_influence, 0, 1)
|
||||||
self.branch1_mixing_depth = np.clip(self.branch1_mixing_depth, 0, 1)
|
self.branch1_max_depth_influence = np.clip(self.branch1_max_depth_influence, 0, 1)
|
||||||
self.list_nmb_branches.insert(1, 2)
|
self.list_nmb_branches.insert(1, 2)
|
||||||
idx_crossfeed = int(round(self.list_injection_idx[1]*self.branch1_mixing_depth))
|
idx_crossfeed = int(round(self.list_injection_idx[1]*self.branch1_max_depth_influence))
|
||||||
self.list_injection_idx_ext.insert(1, idx_crossfeed)
|
self.list_injection_idx_ext.insert(1, idx_crossfeed)
|
||||||
self.tree_fracts.insert(1, self.tree_fracts[0])
|
self.tree_fracts.insert(1, self.tree_fracts[0])
|
||||||
self.tree_status.insert(1, self.tree_status[0])
|
self.tree_status.insert(1, self.tree_status[0])
|
||||||
|
@ -790,27 +866,27 @@ class LatentBlending():
|
||||||
def run_diffusion(
|
def run_diffusion(
|
||||||
self,
|
self,
|
||||||
list_conditionings,
|
list_conditionings,
|
||||||
latents_for_injection: torch.FloatTensor = None,
|
latents_start: torch.FloatTensor = None,
|
||||||
idx_start: int = -1,
|
idx_start: int = 0,
|
||||||
idx_stop: int = -1,
|
list_latents_mixing = None,
|
||||||
seed_source: int = -1,
|
mixing_coeffs = 0.0,
|
||||||
seed_mixing_target: int = -1,
|
|
||||||
mixing_coeff: float = 0.0,
|
|
||||||
return_image: Optional[bool] = False
|
return_image: Optional[bool] = False
|
||||||
):
|
):
|
||||||
|
|
||||||
r"""
|
r"""
|
||||||
Wrapper function for run_diffusion_standard and run_diffusion_inpaint.
|
Wrapper function for diffusion runners.
|
||||||
Depending on the mode, the correct one will be executed.
|
Depending on the mode, the correct one will be executed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
list_conditionings: List of all conditionings for the diffusion model.
|
list_conditionings: List of all conditionings for the diffusion model.
|
||||||
latents_for_injection: torch.FloatTensor
|
latents_start: torch.FloatTensor
|
||||||
Latents that are used for injection
|
Latents that are used for injection
|
||||||
idx_start: int
|
idx_start: int
|
||||||
Index of the diffusion process start and where the latents_for_injection are injected
|
Index of the diffusion process start and where the latents_for_injection are injected
|
||||||
idx_stop: int
|
list_latents_mixing: torch.FloatTensor
|
||||||
Index of the diffusion process end.
|
List of latents (latent trajectories) that are used for mixing
|
||||||
FIXME ARGS
|
mixing_coeffs: float or list
|
||||||
|
Coefficients, how strong each element of list_latents_mixing will be mixed in.
|
||||||
return_image: Optional[bool]
|
return_image: Optional[bool]
|
||||||
Optionally return image directly
|
Optionally return image directly
|
||||||
"""
|
"""
|
||||||
|
@ -822,26 +898,25 @@ class LatentBlending():
|
||||||
if self.mode == 'standard':
|
if self.mode == 'standard':
|
||||||
text_embeddings = list_conditionings[0]
|
text_embeddings = list_conditionings[0]
|
||||||
return self.sdh.run_diffusion_standard(
|
return self.sdh.run_diffusion_standard(
|
||||||
text_embeddings,
|
text_embeddings = text_embeddings,
|
||||||
latents_for_injection=latents_for_injection,
|
latents_start = latents_start,
|
||||||
idx_start=idx_start,
|
idx_start = idx_start,
|
||||||
idx_stop=idx_stop,
|
list_latents_mixing = list_latents_mixing,
|
||||||
seed_source=seed_source,
|
mixing_coeffs = mixing_coeffs,
|
||||||
seed_mixing_target=seed_mixing_target,
|
return_image = return_image,
|
||||||
mixing_coeff=mixing_coeff,
|
|
||||||
return_image=return_image,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
elif self.mode == 'inpaint':
|
# elif self.mode == 'inpaint':
|
||||||
text_embeddings = list_conditionings[0]
|
# text_embeddings = list_conditionings[0]
|
||||||
assert self.sdh.image_source is not None, "image_source is None. Please run init_inpainting first."
|
# assert self.sdh.image_source is not None, "image_source is None. Please run init_inpainting first."
|
||||||
assert self.sdh.mask_image is not None, "image_source is None. Please run init_inpainting first."
|
# assert self.sdh.mask_image is not None, "image_source is None. Please run init_inpainting first."
|
||||||
return self.sdh.run_diffusion_inpaint(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
|
# return self.sdh.run_diffusion_inpaint(text_embeddings, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
|
||||||
# FIXME LONG LINE
|
|
||||||
elif self.mode == 'upscale':
|
# # FIXME LONG LINE and bad args
|
||||||
cond = list_conditionings[0]
|
# elif self.mode == 'upscale':
|
||||||
uc_full = list_conditionings[1]
|
# cond = list_conditionings[0]
|
||||||
return self.sdh.run_diffusion_upscaling(cond, uc_full, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
|
# uc_full = list_conditionings[1]
|
||||||
|
# return self.sdh.run_diffusion_upscaling(cond, uc_full, latents_for_injection=latents_for_injection, idx_start=idx_start, idx_stop=idx_stop, return_image=return_image)
|
||||||
|
|
||||||
def run_upscaling_step1(
|
def run_upscaling_step1(
|
||||||
self,
|
self,
|
||||||
|
@ -1100,7 +1175,11 @@ class LatentBlending():
|
||||||
|
|
||||||
|
|
||||||
def get_lpips_similarity(self, imgA, imgB):
|
def get_lpips_similarity(self, imgA, imgB):
|
||||||
# FIXME
|
r"""
|
||||||
|
Computes the image similarity between two images imgA and imgB.
|
||||||
|
Used to determine the optimal point of insertion to create smooth transitions.
|
||||||
|
High values indicate low similarity.
|
||||||
|
"""
|
||||||
tensorA = torch.from_numpy(imgA).float().cuda(self.device)
|
tensorA = torch.from_numpy(imgA).float().cuda(self.device)
|
||||||
tensorA = 2*tensorA/255.0 - 1
|
tensorA = 2*tensorA/255.0 - 1
|
||||||
tensorA = tensorA.permute([2,0,1]).unsqueeze(0)
|
tensorA = tensorA.permute([2,0,1]).unsqueeze(0)
|
||||||
|
@ -1406,55 +1485,24 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# Run latent blending
|
# Run latent blending
|
||||||
self.branch1_influence = 0.3
|
self.branch1_influence = 0.3
|
||||||
self.branch1_mixing_depth = 0.4
|
self.branch1_max_depth_influence = 0.4
|
||||||
self.run_transition(depth_strength=depth_strength, fixed_seeds=fixed_seeds)
|
# self.run_transition(depth_strength=depth_strength, fixed_seeds=fixed_seeds)
|
||||||
#%%
|
self.seed1=21312
|
||||||
self.branch1_influence = 0.3
|
img1 =self.compute_latents1(True)
|
||||||
self.branch1_mixing_depth = 0.5
|
#%
|
||||||
img2 = self.compute_latents2(return_image=True)
|
self.seed2=1234121
|
||||||
Image.fromarray(img2)
|
self.branch1_influence = 0.7
|
||||||
|
self.branch1_max_depth_influence = 0.3
|
||||||
|
self.branch1_influence_decay = 0.3
|
||||||
|
img2 =self.compute_latents2(True)
|
||||||
|
# Image.fromarray(np.concatenate((img1, img2), axis=1))
|
||||||
|
|
||||||
#%%
|
#%%
|
||||||
idx_injection = 15
|
t0 = time.time()
|
||||||
fract_mixing = 0.5
|
self.t_compute_max_allowed = 30
|
||||||
list_conditionings = self.get_mixed_conditioning(fract_mixing)
|
self.parental_max_depth_influence = 1.0
|
||||||
latents_for_injection = interpolate_spherical(self.tree_latents[0][idx_injection], self.tree_latents[-1][idx_injection], fract_mixing)
|
self.parental_influence = 0.0
|
||||||
list_latents = self.run_diffusion(list_conditionings, latents_for_injection=latents_for_injection, idx_start=idx_injection)
|
self.parental_influence_decay = 1.0
|
||||||
img_mix = self.sdh.latent2image((list_latents[-1]))
|
imgs_transition = self.run_transition(recycle_img1=True, recycle_img2=True)
|
||||||
|
t1 = time.time()
|
||||||
Image.fromarray(np.concatenate((img1,img_mix,img2), axis=1)).resize((800,800//3))
|
print(f"took: {t1-t0}s")
|
||||||
|
|
||||||
#%% scheme
|
|
||||||
# init scheme
|
|
||||||
list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps-1, 2)
|
|
||||||
list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
|
|
||||||
|
|
||||||
#%%
|
|
||||||
list_compute_steps = self.num_inference_steps - list_idx_injection
|
|
||||||
list_compute_steps *= list_nmb_stems
|
|
||||||
t_compute = np.sum(list_compute_steps) * self.dt_per_diff
|
|
||||||
increase_done = False
|
|
||||||
for s_idx in range(len(list_nmb_stems)-1):
|
|
||||||
if list_nmb_stems[s_idx+1] / list_nmb_stems[s_idx] >= 3:
|
|
||||||
list_nmb_stems[s_idx] += 1
|
|
||||||
increase_done = True
|
|
||||||
break
|
|
||||||
if not increase_done:
|
|
||||||
list_nmb_stems[-1] += 1
|
|
||||||
print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}")
|
|
||||||
|
|
||||||
|
|
||||||
#%%
|
|
||||||
|
|
||||||
imgs_transition = self.tree_final_imgs
|
|
||||||
# Let's get more cheap frames via linear interpolation (duration_transition*fps frames)
|
|
||||||
imgs_transition_ext = add_frames_linear_interp(imgs_transition, 15, fps)
|
|
||||||
|
|
||||||
# Save as MP4
|
|
||||||
fp_movie = "test.mp4"
|
|
||||||
if os.path.isfile(fp_movie):
|
|
||||||
os.remove(fp_movie)
|
|
||||||
ms = MovieSaver(fp_movie, fps=fps, shape_hw=[sdh.height, sdh.width])
|
|
||||||
for img in tqdm(imgs_transition_ext):
|
|
||||||
ms.write_frame(img)
|
|
||||||
ms.finalize()
|
|
|
@ -278,12 +278,10 @@ class StableDiffusionHolder:
|
||||||
def run_diffusion_standard(
|
def run_diffusion_standard(
|
||||||
self,
|
self,
|
||||||
text_embeddings: torch.FloatTensor,
|
text_embeddings: torch.FloatTensor,
|
||||||
latents_for_injection = None,
|
latents_start: torch.FloatTensor,
|
||||||
idx_start: int = -1,
|
idx_start: int = 0,
|
||||||
idx_stop: int = -1,
|
list_latents_mixing = None,
|
||||||
seed_source: int = -1,
|
mixing_coeffs = 0.0,
|
||||||
seed_mixing_target: int = -1,
|
|
||||||
mixing_coeff: float = 0.0,
|
|
||||||
return_image: Optional[bool] = False,
|
return_image: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
|
@ -297,34 +295,26 @@ class StableDiffusionHolder:
|
||||||
Latents that are used for injection
|
Latents that are used for injection
|
||||||
idx_start: int
|
idx_start: int
|
||||||
Index of the diffusion process start and where the latents_for_injection are injected
|
Index of the diffusion process start and where the latents_for_injection are injected
|
||||||
idx_stop: int
|
|
||||||
Index of the diffusion process end.
|
|
||||||
mixing_coeff:
|
mixing_coeff:
|
||||||
# FIXME
|
# FIXME
|
||||||
seed_source:
|
|
||||||
# FIXME
|
|
||||||
seed_mixing:
|
|
||||||
# FIXME
|
|
||||||
return_image: Optional[bool]
|
return_image: Optional[bool]
|
||||||
Optionally return image directly
|
Optionally return image directly
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Asserts
|
||||||
if latents_for_injection is None:
|
if type(mixing_coeffs) == float:
|
||||||
do_inject_latents = False
|
list_mixing_coeffs = self.num_inference_steps*[mixing_coeffs]
|
||||||
do_mix_latents = False
|
elif type(mixing_coeffs) == list:
|
||||||
|
assert len(mixing_coeffs) == self.num_inference_steps
|
||||||
|
list_mixing_coeffs = mixing_coeffs
|
||||||
else:
|
else:
|
||||||
if mixing_coeff > 0.0:
|
raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps")
|
||||||
do_inject_latents = False
|
|
||||||
do_mix_latents = True
|
|
||||||
assert seed_mixing_target != -1, "Set to correct seed for mixing"
|
|
||||||
else:
|
|
||||||
do_inject_latents = True
|
|
||||||
do_mix_latents = False
|
|
||||||
|
|
||||||
|
if np.sum(list_mixing_coeffs) > 0:
|
||||||
|
assert len(list_latents_mixing) == self.num_inference_steps
|
||||||
|
|
||||||
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
precision_scope = autocast if self.precision == "autocast" else nullcontext
|
||||||
generator = torch.Generator(device=self.device).manual_seed(int(seed_source))
|
|
||||||
|
|
||||||
with precision_scope("cuda"):
|
with precision_scope("cuda"):
|
||||||
with self.model.ema_scope():
|
with self.model.ema_scope():
|
||||||
|
@ -332,14 +322,10 @@ class StableDiffusionHolder:
|
||||||
uc = self.model.get_learned_conditioning(self.negative_prompt)
|
uc = self.model.get_learned_conditioning(self.negative_prompt)
|
||||||
else:
|
else:
|
||||||
uc = None
|
uc = None
|
||||||
shape_latents = [self.C, self.height // self.f, self.width // self.f]
|
|
||||||
|
|
||||||
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
|
self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
|
||||||
C, H, W = shape_latents
|
|
||||||
size = (1, C, H, W)
|
|
||||||
b = size[0]
|
|
||||||
|
|
||||||
latents = torch.randn(size, generator=generator, device=self.device)
|
latents = latents_start.clone()
|
||||||
|
|
||||||
timesteps = self.sampler.ddim_timesteps
|
timesteps = self.sampler.ddim_timesteps
|
||||||
|
|
||||||
|
@ -349,29 +335,20 @@ class StableDiffusionHolder:
|
||||||
# collect latents
|
# collect latents
|
||||||
list_latents_out = []
|
list_latents_out = []
|
||||||
for i, step in enumerate(time_range):
|
for i, step in enumerate(time_range):
|
||||||
if do_inject_latents:
|
# Set the right starting latents
|
||||||
# Inject latent at right place
|
if i < idx_start:
|
||||||
if i < idx_start:
|
list_latents_out.append(None)
|
||||||
continue
|
continue
|
||||||
elif i == idx_start:
|
elif i == idx_start:
|
||||||
latents = latents_for_injection.clone()
|
latents = latents_start.clone()
|
||||||
if do_mix_latents:
|
|
||||||
if i == 0:
|
|
||||||
generator = torch.Generator(device=self.device).manual_seed(int(seed_mixing_target))
|
|
||||||
latents_mixtarget = torch.randn(size, generator=generator, device=self.device)
|
|
||||||
if i < idx_start:
|
|
||||||
latents_mixtarget = latents_for_injection[i-1].clone()
|
|
||||||
latents = interpolate_spherical(latents, latents_mixtarget, mixing_coeff)
|
|
||||||
|
|
||||||
if i == idx_start:
|
# Mix the latents.
|
||||||
do_mix_latents = False
|
if i > 0 and list_mixing_coeffs[i]>0:
|
||||||
|
latents_mixtarget = list_latents_mixing[i-1].clone()
|
||||||
|
latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
|
||||||
|
|
||||||
if i == idx_stop:
|
|
||||||
return list_latents_out
|
|
||||||
|
|
||||||
# print(f"diffusion iter {i}")
|
|
||||||
index = total_steps - i - 1
|
index = total_steps - i - 1
|
||||||
ts = torch.full((b,), step, device=self.device, dtype=torch.long)
|
ts = torch.full((1,), step, device=self.device, dtype=torch.long)
|
||||||
outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False,
|
outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False,
|
||||||
quantize_denoised=False, temperature=1.0,
|
quantize_denoised=False, temperature=1.0,
|
||||||
noise_dropout=0.0, score_corrector=None,
|
noise_dropout=0.0, score_corrector=None,
|
||||||
|
|
Loading…
Reference in New Issue