compile flag with sfast

This commit is contained in:
DGX 2024-01-10 08:58:30 +00:00
parent 1775c9a90a
commit 4b235b874e
1 changed files with 18 additions and 6 deletions

View File

@ -19,6 +19,7 @@ class BlendingEngine():
def __init__(
self,
dh: None,
do_compile: bool = False,
guidance_scale_mid_damper: float = 0.5,
mid_compression_scaler: float = 1.2):
r"""
@ -80,6 +81,16 @@ class BlendingEngine():
self.benchmark_speed()
self.set_branching()
if do_compile:
print("starting compilation")
from sfast.compilers.diffusion_pipeline_compiler import (compile, CompilationConfig)
self.dh.pipe.enable_xformers_memory_efficient_attention()
config = CompilationConfig.Default()
config.enable_xformers = True
config.enable_triton = True
config.enable_cuda_graph = True
self.dh.pipe = compile(self.dh.pipe, config)
def benchmark_speed(self):
@ -833,19 +844,20 @@ if __name__ == "__main__":
duration_transition = 12 # In seconds
# Spawn latent blending
lb = LatentBlending(dh)
lb.set_prompt1(prompt1)
lb.set_prompt2(prompt2)
lb.set_negative_prompt(negative_prompt)
be = BlendingEngine(dh)
be.set_prompt1(prompt1)
be.set_prompt2(prompt2)
be.set_negative_prompt(negative_prompt)
# Run latent blending
t0 = time.time()
lb.run_transition(fixed_seeds=[420, 421])
be.run_transition(fixed_seeds=[420, 421])
dt = time.time() - t0
print(f"dt = {dt}")
# Save movie
fp_movie = f'test.mp4'
lb.write_movie_transition(fp_movie, duration_transition)
be.write_movie_transition(fp_movie, duration_transition)