speechtoimage/imageserver/main.py

56 lines
1.4 KiB
Python
Raw Normal View History

2022-10-21 04:42:59 +00:00
from multiprocessing import context
from httplib2 import Response
2022-10-19 11:25:02 +00:00
import torch
import uuid
import os
from diffusers import StableDiffusionPipeline
from dotenv import load_dotenv
from os import getenv
2022-10-24 05:25:33 +00:00
from fastapi import FastAPI, Response, HTTPException
2022-10-21 04:42:59 +00:00
from pydantic import BaseModel
import io
2022-11-01 01:00:35 +00:00
from PIL.PngImagePlugin import PngInfo
2022-10-19 11:25:02 +00:00
load_dotenv()
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=getenv("TOKEN"))
pipe.to("cuda")
2022-10-21 04:42:59 +00:00
class Text(BaseModel):
text: str
2022-10-19 11:25:02 +00:00
2022-10-24 05:25:33 +00:00
app = FastAPI()
2022-10-21 04:42:59 +00:00
@app.get("/",
responses = {
200: {
"content": {"image/png": {}}
}
},
response_class=Response
)
2022-11-01 01:00:35 +00:00
def root(text: str):
prompt = text.replace('+', ' ')
2022-10-21 04:42:59 +00:00
print(prompt)
2022-10-24 05:25:33 +00:00
try:
2022-11-01 01:00:35 +00:00
generator = torch.Generator("cuda").manual_seed(1024)
resp = pipe(prompt)
print(resp)
image = resp.images[0]
2022-10-24 05:25:33 +00:00
except RuntimeError as e:
raise HTTPException(status_code=202, detail="Busy")
2022-11-01 01:00:35 +00:00
except:
2022-10-24 05:25:33 +00:00
raise HTTPException(status_code=504)
2022-10-19 11:25:02 +00:00
2022-11-01 01:00:35 +00:00
metadata = PngInfo()
metadata.add_text("text", prompt)
image.save(f'/images/{str(uuid.uuid4())}.png', pnginfo=metadata)
2022-10-21 04:42:59 +00:00
imgByteArr = io.BytesIO()
image.save(imgByteArr, format="PNG")
imgByteArr = imgByteArr.getvalue()
2022-10-24 05:25:33 +00:00
running = False
2022-10-21 04:42:59 +00:00
return Response(content=imgByteArr, media_type="image/png")