mid scaling

This commit is contained in:
lugo 2022-11-28 15:34:18 +01:00
parent b5b703c516
commit 14bc3323b5
2 changed files with 68 additions and 25 deletions

View File

@ -39,15 +39,15 @@ sdh = StableDiffusionHolder(fp_ckpt, fp_config, device)
#%% Next let's set up all parameters #%% Next let's set up all parameters
guidance_scale = 5 quality = 'medium'
quality = 'high'
fixed_seeds = [69731932, 504430820] fixed_seeds = [69731932, 504430820]
lb = LatentBlending(sdh, guidance_scale) lb = LatentBlending(sdh)
prompt1 = "photo of a beautiful forest covered in white flowers, ambient light, very detailed, magic" prompt1 = "photo of a beautiful forest covered in white flowers, ambient light, very detailed, magic"
prompt2 = "photo of an golden statue with a funny hat, surrounded by ferns and vines, grainy analog photograph, mystical ambience, incredible detail" prompt2 = "photo of an golden statue with a funny hat, surrounded by ferns and vines, grainy analog photograph, mystical ambience, incredible detail"
lb.set_prompt1(prompt1) lb.set_prompt1(prompt1)
lb.set_prompt2(prompt2) lb.set_prompt2(prompt2)
lb.autosetup_branching(quality=quality) lb.autosetup_branching(quality=quality)
imgs_transition = lb.run_transition(fixed_seeds=fixed_seeds) imgs_transition = lb.run_transition(fixed_seeds=fixed_seeds)
@ -58,7 +58,7 @@ fps = 60
imgs_transition_ext = add_frames_linear_interp(imgs_transition, duration_transition, fps) imgs_transition_ext = add_frames_linear_interp(imgs_transition, duration_transition, fps)
# movie saving # movie saving
fp_movie = f"movie_example1_{quality}.mp4" fp_movie = f"movie_example1.mp4"
if os.path.isfile(fp_movie): if os.path.isfile(fp_movie):
os.remove(fp_movie) os.remove(fp_movie)
ms = MovieSaver(fp_movie, fps=fps, shape_hw=[sdh.height, sdh.width]) ms = MovieSaver(fp_movie, fps=fps, shape_hw=[sdh.height, sdh.width])

View File

@ -47,31 +47,36 @@ class LatentBlending():
def __init__( def __init__(
self, self,
sdh: None, sdh: None,
guidance_scale: float = 7.5, guidance_scale: float = 4,
guidance_scale_mid_damper: float = 0.5,
mid_compression_scaler: float = 2.0,
): ):
r""" r"""
Initializes the latent blending class. Initializes the latent blending class.
Args: Args:
FIXME XXX
height: int
Height of the desired output image. The model was trained on 512.
width: int
Width of the desired output image. The model was trained on 512.
guidance_scale: float guidance_scale: float
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen `guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality. usually at the expense of lower image quality.
seed: int guidance_scale_mid_damper: float = 0.5
Random seed. Reduces the guidance scale towards the middle of the transition.
A value of 0.5 would decrease the guidance_scale towards the middle linearly by 0.5.
mid_compression_scaler: float = 2.0
Increases the sampling density in the middle (where most changes happen). Higher value
imply more values in the middle. However the inflection point can occur outside the middle,
thus high values can give rough transitions. Values around 2 should be fine.
""" """
self.sdh = sdh self.sdh = sdh
self.device = self.sdh.device self.device = self.sdh.device
self.width = self.sdh.width self.width = self.sdh.width
self.height = self.sdh.height self.height = self.sdh.height
self.seed = 420 #use self.set_seed or fixed_seeds argument in run_transition assert guidance_scale_mid_damper>0 and guidance_scale_mid_damper<=1.0, f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}"
self.guidance_scale_mid_damper = guidance_scale_mid_damper
self.mid_compression_scaler = mid_compression_scaler
self.seed = 420 # Run self.set_seed or fixed_seeds argument in run_transition
# Initialize vars # Initialize vars
self.prompt1 = "" self.prompt1 = ""
@ -109,8 +114,20 @@ class LatentBlending():
r""" r"""
sets the guidance scale. sets the guidance scale.
""" """
self.guidance_scale_base = guidance_scale
self.guidance_scale = guidance_scale self.guidance_scale = guidance_scale
self.sdh.guidance_scale = guidance_scale self.sdh.guidance_scale = guidance_scale
def set_guidance_mid_dampening(self, fract_mixing):
r"""
Tunes the guidance scale down as a linear function of fract_mixing,
towards 0.5 the minimum will be reached.
"""
mid_factor = 1 - np.abs(fract_mixing - 0.5)/ 0.5
max_guidance_reduction = self.guidance_scale_base * (1-self.guidance_scale_mid_damper)
guidance_scale_effective = self.guidance_scale_base - max_guidance_reduction*mid_factor
self.guidance_scale = guidance_scale_effective
self.sdh.guidance_scale = guidance_scale_effective
def set_prompt1(self, prompt: str): def set_prompt1(self, prompt: str):
r""" r"""
@ -158,6 +175,7 @@ class LatentBlending():
total number of frames total number of frames
nmb_mindist: int = 3 nmb_mindist: int = 3
minimum distance in terms of diffusion iteratinos between subsequent injections minimum distance in terms of diffusion iteratinos between subsequent injections
""" """
if quality == 'lowest': if quality == 'lowest':
@ -201,13 +219,13 @@ class LatentBlending():
list_injection_idx = list_injection_idx_clean list_injection_idx = list_injection_idx_clean
list_nmb_branches = list_nmb_branches_clean list_nmb_branches = list_nmb_branches_clean
print(f"num_inference_steps: {num_inference_steps}") # print(f"num_inference_steps: {num_inference_steps}")
print(f"list_injection_idx: {list_injection_idx}") # print(f"list_injection_idx: {list_injection_idx}")
print(f"list_nmb_branches: {list_nmb_branches}") # print(f"list_nmb_branches: {list_nmb_branches}")
self.num_inference_steps = num_inference_steps list_nmb_branches = list_nmb_branches
self.list_injection_idx = list_injection_idx list_injection_idx = list_injection_idx
self.list_nmb_branches = list_nmb_branches self.setup_branching(num_inference_steps, list_nmb_branches=list_nmb_branches, list_injection_idx=list_injection_idx)
def setup_branching(self, def setup_branching(self,
@ -215,7 +233,7 @@ class LatentBlending():
list_nmb_branches: List[int] = None, list_nmb_branches: List[int] = None,
list_injection_strength: List[float] = None, list_injection_strength: List[float] = None,
list_injection_idx: List[int] = None, list_injection_idx: List[int] = None,
guidance_downscale: float = 1.0,
): ):
r""" r"""
Sets the branching structure for making transitions. Sets the branching structure for making transitions.
@ -229,13 +247,9 @@ class LatentBlending():
list_injection_idx: List[int]: list_injection_idx: List[int]:
list of injection strengths within interval [0, 1), values need to be increasing. list of injection strengths within interval [0, 1), values need to be increasing.
Alternatively you can specify the list_injection_strength. Alternatively you can specify the list_injection_strength.
guidance_downscale: float = 1.0
reduces the guidance scale towards the middle of the transition
""" """
# Assert # Assert
assert guidance_downscale>0 and guidance_downscale<=1.0, "guidance_downscale neees to be in interval (0,1]"
assert not((list_injection_strength is not None) and (list_injection_idx is not None)), "suppyl either list_injection_strength or list_injection_idx" assert not((list_injection_strength is not None) and (list_injection_idx is not None)), "suppyl either list_injection_strength or list_injection_idx"
if list_injection_strength is None: if list_injection_strength is None:
@ -262,6 +276,7 @@ class LatentBlending():
self.sdh.num_inference_steps = num_inference_steps self.sdh.num_inference_steps = num_inference_steps
self.list_nmb_branches = list_nmb_branches self.list_nmb_branches = list_nmb_branches
self.list_injection_idx = list_injection_idx self.list_injection_idx = list_injection_idx
self.guidance_scale_mid_damper = guidance_scale_mid_damper
@ -341,7 +356,8 @@ class LatentBlending():
nmb_blocks_time = len(list_injection_idx_ext)-1 nmb_blocks_time = len(list_injection_idx_ext)-1
for t_block in range(nmb_blocks_time): for t_block in range(nmb_blocks_time):
nmb_branches = list_nmb_branches[t_block] nmb_branches = list_nmb_branches[t_block]
list_fract_mixing_current = np.linspace(0, 1, nmb_branches) # list_fract_mixing_current = np.linspace(0, 1, nmb_branches)
list_fract_mixing_current = get_spacing(nmb_branches, self.mid_compression_scaler)
self.tree_fracts.append(list_fract_mixing_current) self.tree_fracts.append(list_fract_mixing_current)
self.tree_latents.append([None]*nmb_branches) self.tree_latents.append([None]*nmb_branches)
self.tree_status.append(['untouched']*nmb_branches) self.tree_status.append(['untouched']*nmb_branches)
@ -403,6 +419,8 @@ class LatentBlending():
idx_stop = list_injection_idx_ext[t_block+1] idx_stop = list_injection_idx_ext[t_block+1]
fract_mixing = self.tree_fracts[t_block][idx_branch] fract_mixing = self.tree_fracts[t_block][idx_branch]
text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing) text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing)
self.set_guidance_mid_dampening(fract_mixing)
# print(f"fract_mixing {fract_mixing} guid {self.sdh.guidance_scale}")
if t_block == 0: if t_block == 0:
if fixed_seeds is not None: if fixed_seeds is not None:
if idx_branch == 0: if idx_branch == 0:
@ -787,6 +805,31 @@ def add_frames_linear_interp(
return list_imgs_interp return list_imgs_interp
def get_spacing(nmb_points:int, scaling: float):
"""
Helper function for getting nonlinear spacing between 0 and 1, symmetric around 0.5
Args:
nmb_points: int
Number of points between [0, 1]
scaling: float
Higher values will return higher sampling density around 0.5
"""
if scaling < 1.7:
return np.linspace(0, 1, nmb_points)
nmb_points_per_side = nmb_points//2 + 1
if np.mod(nmb_points, 2) != 0: # uneven case
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)
right_side = 1-left_side[::-1][1:]
else:
left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)[0:-1]
right_side = 1-left_side[::-1]
all_fracts = np.hstack([left_side, right_side])
return all_fracts
def get_time(resolution=None): def get_time(resolution=None):
""" """
Helper function returning an nicely formatted time string, e.g. 221117_1620 Helper function returning an nicely formatted time string, e.g. 221117_1620