mid scaling
This commit is contained in:
parent
b5b703c516
commit
14bc3323b5
|
@ -39,15 +39,15 @@ sdh = StableDiffusionHolder(fp_ckpt, fp_config, device)
|
||||||
|
|
||||||
|
|
||||||
#%% Next let's set up all parameters
|
#%% Next let's set up all parameters
|
||||||
guidance_scale = 5
|
quality = 'medium'
|
||||||
quality = 'high'
|
|
||||||
fixed_seeds = [69731932, 504430820]
|
fixed_seeds = [69731932, 504430820]
|
||||||
|
|
||||||
lb = LatentBlending(sdh, guidance_scale)
|
lb = LatentBlending(sdh)
|
||||||
prompt1 = "photo of a beautiful forest covered in white flowers, ambient light, very detailed, magic"
|
prompt1 = "photo of a beautiful forest covered in white flowers, ambient light, very detailed, magic"
|
||||||
prompt2 = "photo of an golden statue with a funny hat, surrounded by ferns and vines, grainy analog photograph, mystical ambience, incredible detail"
|
prompt2 = "photo of an golden statue with a funny hat, surrounded by ferns and vines, grainy analog photograph, mystical ambience, incredible detail"
|
||||||
lb.set_prompt1(prompt1)
|
lb.set_prompt1(prompt1)
|
||||||
lb.set_prompt2(prompt2)
|
lb.set_prompt2(prompt2)
|
||||||
|
|
||||||
lb.autosetup_branching(quality=quality)
|
lb.autosetup_branching(quality=quality)
|
||||||
|
|
||||||
imgs_transition = lb.run_transition(fixed_seeds=fixed_seeds)
|
imgs_transition = lb.run_transition(fixed_seeds=fixed_seeds)
|
||||||
|
@ -58,7 +58,7 @@ fps = 60
|
||||||
imgs_transition_ext = add_frames_linear_interp(imgs_transition, duration_transition, fps)
|
imgs_transition_ext = add_frames_linear_interp(imgs_transition, duration_transition, fps)
|
||||||
|
|
||||||
# movie saving
|
# movie saving
|
||||||
fp_movie = f"movie_example1_{quality}.mp4"
|
fp_movie = f"movie_example1.mp4"
|
||||||
if os.path.isfile(fp_movie):
|
if os.path.isfile(fp_movie):
|
||||||
os.remove(fp_movie)
|
os.remove(fp_movie)
|
||||||
ms = MovieSaver(fp_movie, fps=fps, shape_hw=[sdh.height, sdh.width])
|
ms = MovieSaver(fp_movie, fps=fps, shape_hw=[sdh.height, sdh.width])
|
||||||
|
|
|
@ -47,31 +47,36 @@ class LatentBlending():
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
sdh: None,
|
sdh: None,
|
||||||
guidance_scale: float = 7.5,
|
guidance_scale: float = 4,
|
||||||
|
guidance_scale_mid_damper: float = 0.5,
|
||||||
|
mid_compression_scaler: float = 2.0,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Initializes the latent blending class.
|
Initializes the latent blending class.
|
||||||
Args:
|
Args:
|
||||||
FIXME XXX
|
|
||||||
height: int
|
|
||||||
Height of the desired output image. The model was trained on 512.
|
|
||||||
width: int
|
|
||||||
Width of the desired output image. The model was trained on 512.
|
|
||||||
guidance_scale: float
|
guidance_scale: float
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||||
usually at the expense of lower image quality.
|
usually at the expense of lower image quality.
|
||||||
seed: int
|
guidance_scale_mid_damper: float = 0.5
|
||||||
Random seed.
|
Reduces the guidance scale towards the middle of the transition.
|
||||||
|
A value of 0.5 would decrease the guidance_scale towards the middle linearly by 0.5.
|
||||||
|
mid_compression_scaler: float = 2.0
|
||||||
|
Increases the sampling density in the middle (where most changes happen). Higher value
|
||||||
|
imply more values in the middle. However the inflection point can occur outside the middle,
|
||||||
|
thus high values can give rough transitions. Values around 2 should be fine.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.sdh = sdh
|
self.sdh = sdh
|
||||||
self.device = self.sdh.device
|
self.device = self.sdh.device
|
||||||
self.width = self.sdh.width
|
self.width = self.sdh.width
|
||||||
self.height = self.sdh.height
|
self.height = self.sdh.height
|
||||||
self.seed = 420 #use self.set_seed or fixed_seeds argument in run_transition
|
assert guidance_scale_mid_damper>0 and guidance_scale_mid_damper<=1.0, f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}"
|
||||||
|
self.guidance_scale_mid_damper = guidance_scale_mid_damper
|
||||||
|
self.mid_compression_scaler = mid_compression_scaler
|
||||||
|
self.seed = 420 # Run self.set_seed or fixed_seeds argument in run_transition
|
||||||
|
|
||||||
# Initialize vars
|
# Initialize vars
|
||||||
self.prompt1 = ""
|
self.prompt1 = ""
|
||||||
|
@ -109,8 +114,20 @@ class LatentBlending():
|
||||||
r"""
|
r"""
|
||||||
sets the guidance scale.
|
sets the guidance scale.
|
||||||
"""
|
"""
|
||||||
|
self.guidance_scale_base = guidance_scale
|
||||||
self.guidance_scale = guidance_scale
|
self.guidance_scale = guidance_scale
|
||||||
self.sdh.guidance_scale = guidance_scale
|
self.sdh.guidance_scale = guidance_scale
|
||||||
|
|
||||||
|
def set_guidance_mid_dampening(self, fract_mixing):
|
||||||
|
r"""
|
||||||
|
Tunes the guidance scale down as a linear function of fract_mixing,
|
||||||
|
towards 0.5 the minimum will be reached.
|
||||||
|
"""
|
||||||
|
mid_factor = 1 - np.abs(fract_mixing - 0.5)/ 0.5
|
||||||
|
max_guidance_reduction = self.guidance_scale_base * (1-self.guidance_scale_mid_damper)
|
||||||
|
guidance_scale_effective = self.guidance_scale_base - max_guidance_reduction*mid_factor
|
||||||
|
self.guidance_scale = guidance_scale_effective
|
||||||
|
self.sdh.guidance_scale = guidance_scale_effective
|
||||||
|
|
||||||
def set_prompt1(self, prompt: str):
|
def set_prompt1(self, prompt: str):
|
||||||
r"""
|
r"""
|
||||||
|
@ -158,6 +175,7 @@ class LatentBlending():
|
||||||
total number of frames
|
total number of frames
|
||||||
nmb_mindist: int = 3
|
nmb_mindist: int = 3
|
||||||
minimum distance in terms of diffusion iteratinos between subsequent injections
|
minimum distance in terms of diffusion iteratinos between subsequent injections
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if quality == 'lowest':
|
if quality == 'lowest':
|
||||||
|
@ -201,13 +219,13 @@ class LatentBlending():
|
||||||
list_injection_idx = list_injection_idx_clean
|
list_injection_idx = list_injection_idx_clean
|
||||||
list_nmb_branches = list_nmb_branches_clean
|
list_nmb_branches = list_nmb_branches_clean
|
||||||
|
|
||||||
print(f"num_inference_steps: {num_inference_steps}")
|
# print(f"num_inference_steps: {num_inference_steps}")
|
||||||
print(f"list_injection_idx: {list_injection_idx}")
|
# print(f"list_injection_idx: {list_injection_idx}")
|
||||||
print(f"list_nmb_branches: {list_nmb_branches}")
|
# print(f"list_nmb_branches: {list_nmb_branches}")
|
||||||
|
|
||||||
self.num_inference_steps = num_inference_steps
|
list_nmb_branches = list_nmb_branches
|
||||||
self.list_injection_idx = list_injection_idx
|
list_injection_idx = list_injection_idx
|
||||||
self.list_nmb_branches = list_nmb_branches
|
self.setup_branching(num_inference_steps, list_nmb_branches=list_nmb_branches, list_injection_idx=list_injection_idx)
|
||||||
|
|
||||||
|
|
||||||
def setup_branching(self,
|
def setup_branching(self,
|
||||||
|
@ -215,7 +233,7 @@ class LatentBlending():
|
||||||
list_nmb_branches: List[int] = None,
|
list_nmb_branches: List[int] = None,
|
||||||
list_injection_strength: List[float] = None,
|
list_injection_strength: List[float] = None,
|
||||||
list_injection_idx: List[int] = None,
|
list_injection_idx: List[int] = None,
|
||||||
guidance_downscale: float = 1.0,
|
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Sets the branching structure for making transitions.
|
Sets the branching structure for making transitions.
|
||||||
|
@ -229,13 +247,9 @@ class LatentBlending():
|
||||||
list_injection_idx: List[int]:
|
list_injection_idx: List[int]:
|
||||||
list of injection strengths within interval [0, 1), values need to be increasing.
|
list of injection strengths within interval [0, 1), values need to be increasing.
|
||||||
Alternatively you can specify the list_injection_strength.
|
Alternatively you can specify the list_injection_strength.
|
||||||
guidance_downscale: float = 1.0
|
|
||||||
reduces the guidance scale towards the middle of the transition
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Assert
|
# Assert
|
||||||
assert guidance_downscale>0 and guidance_downscale<=1.0, "guidance_downscale neees to be in interval (0,1]"
|
|
||||||
assert not((list_injection_strength is not None) and (list_injection_idx is not None)), "suppyl either list_injection_strength or list_injection_idx"
|
assert not((list_injection_strength is not None) and (list_injection_idx is not None)), "suppyl either list_injection_strength or list_injection_idx"
|
||||||
|
|
||||||
if list_injection_strength is None:
|
if list_injection_strength is None:
|
||||||
|
@ -262,6 +276,7 @@ class LatentBlending():
|
||||||
self.sdh.num_inference_steps = num_inference_steps
|
self.sdh.num_inference_steps = num_inference_steps
|
||||||
self.list_nmb_branches = list_nmb_branches
|
self.list_nmb_branches = list_nmb_branches
|
||||||
self.list_injection_idx = list_injection_idx
|
self.list_injection_idx = list_injection_idx
|
||||||
|
self.guidance_scale_mid_damper = guidance_scale_mid_damper
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -341,7 +356,8 @@ class LatentBlending():
|
||||||
nmb_blocks_time = len(list_injection_idx_ext)-1
|
nmb_blocks_time = len(list_injection_idx_ext)-1
|
||||||
for t_block in range(nmb_blocks_time):
|
for t_block in range(nmb_blocks_time):
|
||||||
nmb_branches = list_nmb_branches[t_block]
|
nmb_branches = list_nmb_branches[t_block]
|
||||||
list_fract_mixing_current = np.linspace(0, 1, nmb_branches)
|
# list_fract_mixing_current = np.linspace(0, 1, nmb_branches)
|
||||||
|
list_fract_mixing_current = get_spacing(nmb_branches, self.mid_compression_scaler)
|
||||||
self.tree_fracts.append(list_fract_mixing_current)
|
self.tree_fracts.append(list_fract_mixing_current)
|
||||||
self.tree_latents.append([None]*nmb_branches)
|
self.tree_latents.append([None]*nmb_branches)
|
||||||
self.tree_status.append(['untouched']*nmb_branches)
|
self.tree_status.append(['untouched']*nmb_branches)
|
||||||
|
@ -403,6 +419,8 @@ class LatentBlending():
|
||||||
idx_stop = list_injection_idx_ext[t_block+1]
|
idx_stop = list_injection_idx_ext[t_block+1]
|
||||||
fract_mixing = self.tree_fracts[t_block][idx_branch]
|
fract_mixing = self.tree_fracts[t_block][idx_branch]
|
||||||
text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing)
|
text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing)
|
||||||
|
self.set_guidance_mid_dampening(fract_mixing)
|
||||||
|
# print(f"fract_mixing {fract_mixing} guid {self.sdh.guidance_scale}")
|
||||||
if t_block == 0:
|
if t_block == 0:
|
||||||
if fixed_seeds is not None:
|
if fixed_seeds is not None:
|
||||||
if idx_branch == 0:
|
if idx_branch == 0:
|
||||||
|
@ -787,6 +805,31 @@ def add_frames_linear_interp(
|
||||||
return list_imgs_interp
|
return list_imgs_interp
|
||||||
|
|
||||||
|
|
||||||
|
def get_spacing(nmb_points:int, scaling: float):
|
||||||
|
"""
|
||||||
|
Helper function for getting nonlinear spacing between 0 and 1, symmetric around 0.5
|
||||||
|
Args:
|
||||||
|
nmb_points: int
|
||||||
|
Number of points between [0, 1]
|
||||||
|
scaling: float
|
||||||
|
Higher values will return higher sampling density around 0.5
|
||||||
|
|
||||||
|
"""
|
||||||
|
if scaling < 1.7:
|
||||||
|
return np.linspace(0, 1, nmb_points)
|
||||||
|
nmb_points_per_side = nmb_points//2 + 1
|
||||||
|
if np.mod(nmb_points, 2) != 0: # uneven case
|
||||||
|
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)
|
||||||
|
right_side = 1-left_side[::-1][1:]
|
||||||
|
else:
|
||||||
|
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)[0:-1]
|
||||||
|
right_side = 1-left_side[::-1]
|
||||||
|
|
||||||
|
all_fracts = np.hstack([left_side, right_side])
|
||||||
|
|
||||||
|
return all_fracts
|
||||||
|
|
||||||
|
|
||||||
def get_time(resolution=None):
|
def get_time(resolution=None):
|
||||||
"""
|
"""
|
||||||
Helper function returning an nicely formatted time string, e.g. 221117_1620
|
Helper function returning an nicely formatted time string, e.g. 221117_1620
|
||||||
|
|
Loading…
Reference in New Issue