from multiprocessing import context from httplib2 import Response import torch import uuid import os from diffusers import StableDiffusionPipeline from dotenv import load_dotenv from os import getenv from fastapi import FastAPI, Response, HTTPException from pydantic import BaseModel import io from PIL.PngImagePlugin import PngInfo load_dotenv() pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=getenv("TOKEN")) pipe.to("cuda") class Text(BaseModel): text: str app = FastAPI() @app.get("/", responses = { 200: { "content": {"image/png": {}} } }, response_class=Response ) def root(text: str): prompt = text.replace('+', ' ') print(prompt) try: generator = torch.Generator("cuda").manual_seed(1024) resp = pipe(prompt) print(resp) image = resp.images[0] except RuntimeError as e: raise HTTPException(status_code=202, detail="Busy") except: raise HTTPException(status_code=504) metadata = PngInfo() metadata.add_text("text", prompt) image.save(f'/images/{str(uuid.uuid4())}.png', pnginfo=metadata) imgByteArr = io.BytesIO() image.save(imgByteArr, format="PNG") imgByteArr = imgByteArr.getvalue() running = False return Response(content=imgByteArr, media_type="image/png")