from multiprocessing import context from httplib2 import Response import torch import uuid import os from diffusers import StableDiffusionImg2ImgPipeline from dotenv import load_dotenv from os import getenv from fastapi import FastAPI, Response, HTTPException, File, UploadFile from pydantic import BaseModel import io from PIL.PngImagePlugin import PngInfo from PIL import Image load_dotenv() pipe = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", num_inference_steps=100, revision="fp16", torch_dtype=torch.float16, use_auth_token=getenv("TOKEN")) pipe.to("cuda") class Text(BaseModel): text: str app = FastAPI() @app.post("/", responses = { 200: { "content": {"image/png": {}} } }, response_class=Response ) async def root(text: str, my_file: UploadFile = File(...)): prompt = text.replace('+', ' ') print(prompt) request_object_content = await my_file.read() img = Image.open(io.BytesIO(request_object_content)) height_orig = img.height width_orig = img.width aspect_ratio = width_orig / height_orig width_new = 512 height_new = int(width_new / aspect_ratio) img = img.resize((width_new, height_new), 0) try: resp = pipe(prompt, image=img) print(resp) image = resp.images[0] except RuntimeError as e: print(e) raise HTTPException(status_code=202, detail="Busy") except Exception as e: raise HTTPException(status_code=504, detail=str(e)) if resp["nsfw_content_detected"] == [True]: raise HTTPException(status_code=418, detail="NSFW") imgByteArr = io.BytesIO() image.save(imgByteArr, format="PNG") imgByteArr = imgByteArr.getvalue() running = False return Response(content=imgByteArr, media_type="image/png")