diff --git a/requirements.txt b/requirements.txt index b8b178a..8d812e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ lpips==0.1.4 opencv-python==4.7.0.68 ffmpeg-python diffusers["torch"]==0.23.0 -transformers==4.35.2 \ No newline at end of file +transformers==4.35.2 +pytest \ No newline at end of file diff --git a/test_latentblending.py b/test_latentblending.py new file mode 100644 index 0000000..78f1688 --- /dev/null +++ b/test_latentblending.py @@ -0,0 +1,54 @@ +import unittest +from latent_blending import LatentBlending +from diffusers_holder import DiffusersHolder +from diffusers import DiffusionPipeline +import torch + +default_pipe = "stabilityai/stable-diffusion-xl-base-1.0" + + +class TestDiffusersHolder(unittest.TestCase): + + def test_load_diffusers_holder(self): + pipe = DiffusionPipeline.from_pretrained(default_pipe, torch_dtype=torch.float16).to('cuda') + dh = DiffusersHolder(pipe) + self.assertIsNotNone(dh, "Failed to load DiffusersHolder") + + +class TestSingleImageGeneration(unittest.TestCase): + + def test_single_image_generation(self): + pipe = DiffusionPipeline.from_pretrained(default_pipe, torch_dtype=torch.float16).to('cuda') + dh = DiffusersHolder(pipe) + dh.set_dimensions((1024, 704)) + dh.set_num_inference_steps(40) + prompt = "Your prompt here" + text_embeddings = dh.get_text_embedding(prompt) + generator = torch.Generator(device=dh.device).manual_seed(int(420)) + latents_start = dh.get_noise() + list_latents_1 = dh.run_diffusion(text_embeddings, latents_start) + img_orig = dh.latent2image(list_latents_1[-1]) + self.assertIsNotNone(img_orig, "Failed to generate an image") + + +class TestImageTransition(unittest.TestCase): + + def test_image_transition(self): + pipe = DiffusionPipeline.from_pretrained(default_pipe, torch_dtype=torch.float16).to('cuda') + dh = DiffusersHolder(pipe) + lb = LatentBlending(dh) + + lb.set_prompt1('photo of my first prompt1') + lb.set_prompt2('photo of my second prompt') + depth_strength = 0.6 + t_compute_max_allowed = 10 + num_inference_steps = 30 + imgs_transition = lb.run_transition( + depth_strength=depth_strength, + num_inference_steps=num_inference_steps, + t_compute_max_allowed=t_compute_max_allowed) + + self.assertTrue(len(imgs_transition) > 0, "No transition images generated") + +if __name__ == '__main__': + unittest.main()