30 lines
757 B
Python
30 lines
757 B
Python
|
import torch
|
||
|
from diffusers import StableDiffusionPipeline
|
||
|
from dotenv import load_dotenv
|
||
|
from os import getenv
|
||
|
from PIL.PngImagePlugin import PngInfo
|
||
|
import uuid
|
||
|
from random import randint
|
||
|
|
||
|
load_dotenv()
|
||
|
|
||
|
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=getenv("TOKEN"))
|
||
|
pipe.to("cuda")
|
||
|
|
||
|
|
||
|
prompt = "red horse jumping over a dog"
|
||
|
|
||
|
seed = randint(0, 2**64)
|
||
|
generator = torch.Generator("cuda").manual_seed(seed)
|
||
|
resp = pipe(prompt, generator=generator)
|
||
|
print(resp)
|
||
|
image = resp.images[0]
|
||
|
|
||
|
metadata = PngInfo()
|
||
|
metadata.add_text("text", prompt)
|
||
|
metadata.add_text("seed", str(seed))
|
||
|
|
||
|
id = str(uuid.uuid4())
|
||
|
print(id)
|
||
|
|
||
|
image.save(f'/images/{id}.png', pnginfo=metadata)
|