auto branching fix

This commit is contained in:
Johannes Stelzer 2022-11-24 11:24:23 +01:00
parent 78505be6da
commit 4ad098a20a
1 changed files with 17 additions and 29 deletions

View File

@ -1029,9 +1029,9 @@ def get_time(resolution=None):
def get_branching(
quality: str = 'medium',
depth: str = 'medium',
strength_injection_first: float = 0.65,
deepth_strength: float = 0.65,
nmb_frames: int = 360,
nmb_mindist: int = 3,
):
r"""
Helper function to set up the branching structure automatically.
@ -1040,51 +1040,38 @@ def get_branching(
quality: str
Determines how many diffusion steps are being made + how many branches in total.
Choose: fast, medium, high, ultra
quality: depth
deepth_strength: float = 0.65,
Determines how deep the first injection will happen.
Deeper injections will cause (unwanted) formation of new structures,
more shallow values will go into alpha-blendy land
Choose: verydeep, deep, medium, shallow, veryshallow
strength_injection_first: float = 0.65,
...
more shallow values will go into alpha-blendy land.
nmb_frames: int = 360,
total number of frames
nmb_mindist: int = 3
minimum distance in terms of diffusion iteratinos between subsequent injections
"""
nmb_mindist = 3 #minimum distance between injections
depth = 'override'
#FIXME: XXX nmb frames last has to be enforced. avoid weird cases where no injection...
if depth == 'verydeep':
strength_injection_first = 0.35
elif depth == 'deep':
strength_injection_first = 0.45
elif depth == 'medium':
strength_injection_first = 0.6
elif depth == 'shallow':
strength_injection_first = 0.8
elif depth == 'veryshallow':
strength_injection_first = 0.9
if quality == 'superfast':
num_inference_steps = 8
if quality == 'lowest':
num_inference_steps = 12
nmb_branches_final = 5
elif quality == 'fast':
elif quality == 'low':
num_inference_steps = 15
nmb_branches_final = nmb_frames//30
nmb_branches_final = nmb_frames//16
elif quality == 'medium':
num_inference_steps = 30
nmb_branches_final = nmb_frames//10
nmb_branches_final = nmb_frames//8
elif quality == 'high':
num_inference_steps = 60
nmb_branches_final = nmb_frames//3
nmb_branches_final = nmb_frames//4
elif quality == 'ultra':
num_inference_steps = 100
nmb_branches_final = nmb_frames
nmb_branches_final = nmb_frames//2
else:
raise ValueError("quality = '{quality}' not supported")
idx_injection_first = int(np.round(num_inference_steps*strength_injection_first))
idx_injection_first = int(np.round(num_inference_steps*deepth_strength))
idx_injection_last = num_inference_steps - 3
nmb_injections = int(np.floor(num_inference_steps/5)) - 1
@ -1110,6 +1097,7 @@ def get_branching(
return num_inference_steps, list_injection_idx_clean, list_nmb_branches_clean
#%% le main
if __name__ == "__main__":