assert for branch1_influence

This commit is contained in:
Johannes Stelzer 2023-01-11 11:39:45 +01:00
parent 7ebe6aaa66
commit d87cefc75d
1 changed files with 1 additions and 1 deletions

View File

@ -419,6 +419,7 @@ class LatentBlending():
# Split the first block if there is branch1 crossfeeding # Split the first block if there is branch1 crossfeeding
if self.branch1_influence > 0.0 and not self.branch1_insertion_completed: if self.branch1_influence > 0.0 and not self.branch1_insertion_completed:
assert self.list_nmb_branches[0]==2, 'branch1 influence currently requires the self.list_nmb_branches[0] = 0'
self.list_nmb_branches.insert(1, 2) self.list_nmb_branches.insert(1, 2)
idx_crossfeed = int(round(self.list_injection_idx[1]*self.branch1_fract_crossfeed)) idx_crossfeed = int(round(self.list_injection_idx[1]*self.branch1_fract_crossfeed))
self.list_injection_idx_ext.insert(1, idx_crossfeed) self.list_injection_idx_ext.insert(1, idx_crossfeed)
@ -507,7 +508,6 @@ class LatentBlending():
list_latents = self.run_diffusion(list_conditionings, idx_stop=idx_stop) list_latents = self.run_diffusion(list_conditionings, idx_stop=idx_stop)
# Inject latents from first branch for very first block # Inject latents from first branch for very first block
# FIXME: if more than 2 base branches?
if idx_branch==1 and self.branch1_influence > 0: if idx_branch==1 and self.branch1_influence > 0:
fract_base_influence = np.clip(self.branch1_influence, 0, 1) fract_base_influence = np.clip(self.branch1_influence, 0, 1)
for i in range(len(list_latents)): for i in range(len(list_latents)):