mid scaling
This commit is contained in:
parent
b5b703c516
commit
14bc3323b5
|
@ -39,15 +39,15 @@ sdh = StableDiffusionHolder(fp_ckpt, fp_config, device)
|
|||
|
||||
|
||||
#%% Next let's set up all parameters
|
||||
guidance_scale = 5
|
||||
quality = 'high'
|
||||
quality = 'medium'
|
||||
fixed_seeds = [69731932, 504430820]
|
||||
|
||||
lb = LatentBlending(sdh, guidance_scale)
|
||||
lb = LatentBlending(sdh)
|
||||
prompt1 = "photo of a beautiful forest covered in white flowers, ambient light, very detailed, magic"
|
||||
prompt2 = "photo of an golden statue with a funny hat, surrounded by ferns and vines, grainy analog photograph, mystical ambience, incredible detail"
|
||||
lb.set_prompt1(prompt1)
|
||||
lb.set_prompt2(prompt2)
|
||||
|
||||
lb.autosetup_branching(quality=quality)
|
||||
|
||||
imgs_transition = lb.run_transition(fixed_seeds=fixed_seeds)
|
||||
|
@ -58,7 +58,7 @@ fps = 60
|
|||
imgs_transition_ext = add_frames_linear_interp(imgs_transition, duration_transition, fps)
|
||||
|
||||
# movie saving
|
||||
fp_movie = f"movie_example1_{quality}.mp4"
|
||||
fp_movie = f"movie_example1.mp4"
|
||||
if os.path.isfile(fp_movie):
|
||||
os.remove(fp_movie)
|
||||
ms = MovieSaver(fp_movie, fps=fps, shape_hw=[sdh.height, sdh.width])
|
||||
|
|
|
@ -47,31 +47,36 @@ class LatentBlending():
|
|||
def __init__(
|
||||
self,
|
||||
sdh: None,
|
||||
guidance_scale: float = 7.5,
|
||||
guidance_scale: float = 4,
|
||||
guidance_scale_mid_damper: float = 0.5,
|
||||
mid_compression_scaler: float = 2.0,
|
||||
):
|
||||
r"""
|
||||
Initializes the latent blending class.
|
||||
Args:
|
||||
FIXME XXX
|
||||
height: int
|
||||
Height of the desired output image. The model was trained on 512.
|
||||
width: int
|
||||
Width of the desired output image. The model was trained on 512.
|
||||
guidance_scale: float
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
seed: int
|
||||
Random seed.
|
||||
guidance_scale_mid_damper: float = 0.5
|
||||
Reduces the guidance scale towards the middle of the transition.
|
||||
A value of 0.5 would decrease the guidance_scale towards the middle linearly by 0.5.
|
||||
mid_compression_scaler: float = 2.0
|
||||
Increases the sampling density in the middle (where most changes happen). Higher value
|
||||
imply more values in the middle. However the inflection point can occur outside the middle,
|
||||
thus high values can give rough transitions. Values around 2 should be fine.
|
||||
|
||||
"""
|
||||
self.sdh = sdh
|
||||
self.device = self.sdh.device
|
||||
self.width = self.sdh.width
|
||||
self.height = self.sdh.height
|
||||
self.seed = 420 #use self.set_seed or fixed_seeds argument in run_transition
|
||||
assert guidance_scale_mid_damper>0 and guidance_scale_mid_damper<=1.0, f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}"
|
||||
self.guidance_scale_mid_damper = guidance_scale_mid_damper
|
||||
self.mid_compression_scaler = mid_compression_scaler
|
||||
self.seed = 420 # Run self.set_seed or fixed_seeds argument in run_transition
|
||||
|
||||
# Initialize vars
|
||||
self.prompt1 = ""
|
||||
|
@ -109,9 +114,21 @@ class LatentBlending():
|
|||
r"""
|
||||
sets the guidance scale.
|
||||
"""
|
||||
self.guidance_scale_base = guidance_scale
|
||||
self.guidance_scale = guidance_scale
|
||||
self.sdh.guidance_scale = guidance_scale
|
||||
|
||||
def set_guidance_mid_dampening(self, fract_mixing):
|
||||
r"""
|
||||
Tunes the guidance scale down as a linear function of fract_mixing,
|
||||
towards 0.5 the minimum will be reached.
|
||||
"""
|
||||
mid_factor = 1 - np.abs(fract_mixing - 0.5)/ 0.5
|
||||
max_guidance_reduction = self.guidance_scale_base * (1-self.guidance_scale_mid_damper)
|
||||
guidance_scale_effective = self.guidance_scale_base - max_guidance_reduction*mid_factor
|
||||
self.guidance_scale = guidance_scale_effective
|
||||
self.sdh.guidance_scale = guidance_scale_effective
|
||||
|
||||
def set_prompt1(self, prompt: str):
|
||||
r"""
|
||||
Sets the first prompt (for the first keyframe) including text embeddings.
|
||||
|
@ -158,6 +175,7 @@ class LatentBlending():
|
|||
total number of frames
|
||||
nmb_mindist: int = 3
|
||||
minimum distance in terms of diffusion iteratinos between subsequent injections
|
||||
|
||||
"""
|
||||
|
||||
if quality == 'lowest':
|
||||
|
@ -201,13 +219,13 @@ class LatentBlending():
|
|||
list_injection_idx = list_injection_idx_clean
|
||||
list_nmb_branches = list_nmb_branches_clean
|
||||
|
||||
print(f"num_inference_steps: {num_inference_steps}")
|
||||
print(f"list_injection_idx: {list_injection_idx}")
|
||||
print(f"list_nmb_branches: {list_nmb_branches}")
|
||||
# print(f"num_inference_steps: {num_inference_steps}")
|
||||
# print(f"list_injection_idx: {list_injection_idx}")
|
||||
# print(f"list_nmb_branches: {list_nmb_branches}")
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.list_injection_idx = list_injection_idx
|
||||
self.list_nmb_branches = list_nmb_branches
|
||||
list_nmb_branches = list_nmb_branches
|
||||
list_injection_idx = list_injection_idx
|
||||
self.setup_branching(num_inference_steps, list_nmb_branches=list_nmb_branches, list_injection_idx=list_injection_idx)
|
||||
|
||||
|
||||
def setup_branching(self,
|
||||
|
@ -215,7 +233,7 @@ class LatentBlending():
|
|||
list_nmb_branches: List[int] = None,
|
||||
list_injection_strength: List[float] = None,
|
||||
list_injection_idx: List[int] = None,
|
||||
guidance_downscale: float = 1.0,
|
||||
|
||||
):
|
||||
r"""
|
||||
Sets the branching structure for making transitions.
|
||||
|
@ -229,13 +247,9 @@ class LatentBlending():
|
|||
list_injection_idx: List[int]:
|
||||
list of injection strengths within interval [0, 1), values need to be increasing.
|
||||
Alternatively you can specify the list_injection_strength.
|
||||
guidance_downscale: float = 1.0
|
||||
reduces the guidance scale towards the middle of the transition
|
||||
|
||||
|
||||
"""
|
||||
# Assert
|
||||
assert guidance_downscale>0 and guidance_downscale<=1.0, "guidance_downscale neees to be in interval (0,1]"
|
||||
assert not((list_injection_strength is not None) and (list_injection_idx is not None)), "suppyl either list_injection_strength or list_injection_idx"
|
||||
|
||||
if list_injection_strength is None:
|
||||
|
@ -262,6 +276,7 @@ class LatentBlending():
|
|||
self.sdh.num_inference_steps = num_inference_steps
|
||||
self.list_nmb_branches = list_nmb_branches
|
||||
self.list_injection_idx = list_injection_idx
|
||||
self.guidance_scale_mid_damper = guidance_scale_mid_damper
|
||||
|
||||
|
||||
|
||||
|
@ -341,7 +356,8 @@ class LatentBlending():
|
|||
nmb_blocks_time = len(list_injection_idx_ext)-1
|
||||
for t_block in range(nmb_blocks_time):
|
||||
nmb_branches = list_nmb_branches[t_block]
|
||||
list_fract_mixing_current = np.linspace(0, 1, nmb_branches)
|
||||
# list_fract_mixing_current = np.linspace(0, 1, nmb_branches)
|
||||
list_fract_mixing_current = get_spacing(nmb_branches, self.mid_compression_scaler)
|
||||
self.tree_fracts.append(list_fract_mixing_current)
|
||||
self.tree_latents.append([None]*nmb_branches)
|
||||
self.tree_status.append(['untouched']*nmb_branches)
|
||||
|
@ -403,6 +419,8 @@ class LatentBlending():
|
|||
idx_stop = list_injection_idx_ext[t_block+1]
|
||||
fract_mixing = self.tree_fracts[t_block][idx_branch]
|
||||
text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing)
|
||||
self.set_guidance_mid_dampening(fract_mixing)
|
||||
# print(f"fract_mixing {fract_mixing} guid {self.sdh.guidance_scale}")
|
||||
if t_block == 0:
|
||||
if fixed_seeds is not None:
|
||||
if idx_branch == 0:
|
||||
|
@ -787,6 +805,31 @@ def add_frames_linear_interp(
|
|||
return list_imgs_interp
|
||||
|
||||
|
||||
def get_spacing(nmb_points:int, scaling: float):
|
||||
"""
|
||||
Helper function for getting nonlinear spacing between 0 and 1, symmetric around 0.5
|
||||
Args:
|
||||
nmb_points: int
|
||||
Number of points between [0, 1]
|
||||
scaling: float
|
||||
Higher values will return higher sampling density around 0.5
|
||||
|
||||
"""
|
||||
if scaling < 1.7:
|
||||
return np.linspace(0, 1, nmb_points)
|
||||
nmb_points_per_side = nmb_points//2 + 1
|
||||
if np.mod(nmb_points, 2) != 0: # uneven case
|
||||
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)
|
||||
right_side = 1-left_side[::-1][1:]
|
||||
else:
|
||||
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)[0:-1]
|
||||
right_side = 1-left_side[::-1]
|
||||
|
||||
all_fracts = np.hstack([left_side, right_side])
|
||||
|
||||
return all_fracts
|
||||
|
||||
|
||||
def get_time(resolution=None):
|
||||
"""
|
||||
Helper function returning an nicely formatted time string, e.g. 221117_1620
|
||||
|
|
Loading…
Reference in New Issue