photobooth/imageserver/main.py

69 lines
1.8 KiB
Python

from multiprocessing import context
from httplib2 import Response
import torch
import uuid
import os
from diffusers import StableDiffusionImg2ImgPipeline
from dotenv import load_dotenv
from os import getenv
from fastapi import FastAPI, Response, HTTPException, File, UploadFile
from pydantic import BaseModel
import io
from PIL.PngImagePlugin import PngInfo
from PIL import Image
load_dotenv()
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", num_inference_steps=100, revision="fp16", torch_dtype=torch.float16, use_auth_token=getenv("TOKEN"))
pipe.to("cuda")
class Text(BaseModel):
text: str
app = FastAPI()
@app.post("/",
responses = {
200: {
"content": {"image/png": {}}
}
},
response_class=Response
)
async def root(text: str, my_file: UploadFile = File(...)):
prompt = text.replace('+', ' ')
print(prompt)
request_object_content = await my_file.read()
img = Image.open(io.BytesIO(request_object_content))
height_orig = img.height
width_orig = img.width
aspect_ratio = width_orig / height_orig
width_new = 512
height_new = int(width_new / aspect_ratio)
img = img.resize((width_new, height_new), 0)
try:
resp = pipe(prompt, image=img)
print(resp)
image = resp.images[0]
except RuntimeError as e:
print(e)
raise HTTPException(status_code=202, detail="Busy")
except Exception as e:
raise HTTPException(status_code=504, detail=str(e))
if resp["nsfw_content_detected"] == [True]:
raise HTTPException(status_code=418, detail="NSFW")
imgByteArr = io.BytesIO()
image.save(imgByteArr, format="PNG")
imgByteArr = imgByteArr.getvalue()
running = False
return Response(content=imgByteArr, media_type="image/png")