auto branching function

This commit is contained in:
Johannes Stelzer 2022-11-23 20:51:19 +01:00
parent 021e34df60
commit 78505be6da
1 changed files with 31 additions and 19 deletions

View File

@ -1030,6 +1030,8 @@ def get_time(resolution=None):
def get_branching( def get_branching(
quality: str = 'medium', quality: str = 'medium',
depth: str = 'medium', depth: str = 'medium',
strength_injection_first: float = 0.65,
nmb_frames: int = 360,
): ):
r""" r"""
Helper function to set up the branching structure automatically. Helper function to set up the branching structure automatically.
@ -1043,9 +1045,15 @@ def get_branching(
Deeper injections will cause (unwanted) formation of new structures, Deeper injections will cause (unwanted) formation of new structures,
more shallow values will go into alpha-blendy land more shallow values will go into alpha-blendy land
Choose: verydeep, deep, medium, shallow, veryshallow Choose: verydeep, deep, medium, shallow, veryshallow
strength_injection_first: float = 0.65,
...
nmb_frames: int = 360,
""" """
nmb_mindist = 3 #minimum distance between injections nmb_mindist = 3 #minimum distance between injections
depth = 'override'
#FIXME: XXX nmb frames last has to be enforced. avoid weird cases where no injection...
if depth == 'verydeep': if depth == 'verydeep':
strength_injection_first = 0.35 strength_injection_first = 0.35
elif depth == 'deep': elif depth == 'deep':
@ -1056,27 +1064,29 @@ def get_branching(
strength_injection_first = 0.8 strength_injection_first = 0.8
elif depth == 'veryshallow': elif depth == 'veryshallow':
strength_injection_first = 0.9 strength_injection_first = 0.9
else:
raise ValueError("depth = '{depth}' not supported")
if quality == 'superfast':
if quality == 'fast': num_inference_steps = 8
num_iterations = 15 nmb_branches_final = 5
nmb_branches_final = 6 elif quality == 'fast':
num_inference_steps = 15
nmb_branches_final = nmb_frames//30
elif quality == 'medium': elif quality == 'medium':
num_iterations = 30 num_inference_steps = 30
nmb_branches_final = 30 nmb_branches_final = nmb_frames//10
elif quality == 'high': elif quality == 'high':
num_iterations = 60 num_inference_steps = 60
nmb_branches_final = 150 nmb_branches_final = nmb_frames//3
elif quality == 'ultra': elif quality == 'ultra':
num_iterations = 100 num_inference_steps = 100
nmb_branches_final = 300 nmb_branches_final = nmb_frames
else: else:
raise ValueError("quality = '{quality}' not supported") raise ValueError("quality = '{quality}' not supported")
idx_injection_first = int(np.round(num_iterations*strength_injection_first)) idx_injection_first = int(np.round(num_inference_steps*strength_injection_first))
idx_injection_last = num_iterations - 3 idx_injection_last = num_inference_steps - 3
nmb_injections = int(np.floor(num_iterations/5)) - 1 nmb_injections = int(np.floor(num_inference_steps/5)) - 1
list_injection_idx = [0] list_injection_idx = [0]
list_injection_idx.extend(np.linspace(idx_injection_first, idx_injection_last, nmb_injections).astype(int)) list_injection_idx.extend(np.linspace(idx_injection_first, idx_injection_last, nmb_injections).astype(int))
@ -1091,12 +1101,14 @@ def get_branching(
list_injection_idx_clean.append(list_injection_idx[i+1]) list_injection_idx_clean.append(list_injection_idx[i+1])
list_nmb_branches_clean.append(list_nmb_branches[i+1]) list_nmb_branches_clean.append(list_nmb_branches[i+1])
idx_last_check +=1 idx_last_check +=1
list_injection_idx_clean = [int(l) for l in list_injection_idx_clean]
print(f"num_iterations: {num_iterations}") list_nmb_branches_clean = [int(l) for l in list_nmb_branches_clean]
print(f"num_inference_steps: {num_inference_steps}")
print(f"list_injection_idx: {list_injection_idx_clean}") print(f"list_injection_idx: {list_injection_idx_clean}")
print(f"list_nmb_branches: {list_nmb_branches_clean}") print(f"list_nmb_branches: {list_nmb_branches_clean}")
return list_injection_idx_clean, list_nmb_branches_clean return num_inference_steps, list_injection_idx_clean, list_nmb_branches_clean
#%% le main #%% le main
if __name__ == "__main__": if __name__ == "__main__":