compile flag with sfast
This commit is contained in:
parent
1775c9a90a
commit
4b235b874e
|
@ -19,6 +19,7 @@ class BlendingEngine():
|
|||
def __init__(
|
||||
self,
|
||||
dh: None,
|
||||
do_compile: bool = False,
|
||||
guidance_scale_mid_damper: float = 0.5,
|
||||
mid_compression_scaler: float = 1.2):
|
||||
r"""
|
||||
|
@ -80,6 +81,16 @@ class BlendingEngine():
|
|||
self.benchmark_speed()
|
||||
self.set_branching()
|
||||
|
||||
if do_compile:
|
||||
print("starting compilation")
|
||||
from sfast.compilers.diffusion_pipeline_compiler import (compile, CompilationConfig)
|
||||
self.dh.pipe.enable_xformers_memory_efficient_attention()
|
||||
config = CompilationConfig.Default()
|
||||
config.enable_xformers = True
|
||||
config.enable_triton = True
|
||||
config.enable_cuda_graph = True
|
||||
self.dh.pipe = compile(self.dh.pipe, config)
|
||||
|
||||
|
||||
|
||||
def benchmark_speed(self):
|
||||
|
@ -833,19 +844,20 @@ if __name__ == "__main__":
|
|||
duration_transition = 12 # In seconds
|
||||
|
||||
# Spawn latent blending
|
||||
lb = LatentBlending(dh)
|
||||
lb.set_prompt1(prompt1)
|
||||
lb.set_prompt2(prompt2)
|
||||
lb.set_negative_prompt(negative_prompt)
|
||||
be = BlendingEngine(dh)
|
||||
be.set_prompt1(prompt1)
|
||||
be.set_prompt2(prompt2)
|
||||
be.set_negative_prompt(negative_prompt)
|
||||
|
||||
# Run latent blending
|
||||
t0 = time.time()
|
||||
lb.run_transition(fixed_seeds=[420, 421])
|
||||
be.run_transition(fixed_seeds=[420, 421])
|
||||
dt = time.time() - t0
|
||||
print(f"dt = {dt}")
|
||||
|
||||
# Save movie
|
||||
fp_movie = f'test.mp4'
|
||||
lb.write_movie_transition(fp_movie, duration_transition)
|
||||
be.write_movie_transition(fp_movie, duration_transition)
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue