2022-10-21 04:42:59 +00:00
|
|
|
from multiprocessing import context
|
|
|
|
from httplib2 import Response
|
2022-10-19 11:25:02 +00:00
|
|
|
import torch
|
|
|
|
import uuid
|
|
|
|
import os
|
|
|
|
from diffusers import StableDiffusionPipeline
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
from os import getenv
|
2022-10-24 05:25:33 +00:00
|
|
|
from fastapi import FastAPI, Response, HTTPException
|
2022-10-21 04:42:59 +00:00
|
|
|
from pydantic import BaseModel
|
|
|
|
import io
|
2022-11-01 01:00:35 +00:00
|
|
|
from PIL.PngImagePlugin import PngInfo
|
2022-10-19 11:25:02 +00:00
|
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
|
|
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=getenv("TOKEN"))
|
|
|
|
pipe.to("cuda")
|
|
|
|
|
2022-10-21 04:42:59 +00:00
|
|
|
class Text(BaseModel):
|
|
|
|
text: str
|
|
|
|
|
|
|
|
|
2022-10-19 11:25:02 +00:00
|
|
|
|
2022-10-24 05:25:33 +00:00
|
|
|
app = FastAPI()
|
|
|
|
|
2022-10-21 04:42:59 +00:00
|
|
|
@app.get("/",
|
|
|
|
responses = {
|
|
|
|
200: {
|
|
|
|
"content": {"image/png": {}}
|
|
|
|
}
|
|
|
|
},
|
|
|
|
response_class=Response
|
|
|
|
)
|
2022-11-01 01:00:35 +00:00
|
|
|
def root(text: str):
|
|
|
|
prompt = text.replace('+', ' ')
|
2022-10-21 04:42:59 +00:00
|
|
|
print(prompt)
|
2022-10-24 05:25:33 +00:00
|
|
|
try:
|
2022-11-01 01:00:35 +00:00
|
|
|
generator = torch.Generator("cuda").manual_seed(1024)
|
|
|
|
resp = pipe(prompt)
|
|
|
|
print(resp)
|
|
|
|
image = resp.images[0]
|
2022-10-24 05:25:33 +00:00
|
|
|
except RuntimeError as e:
|
|
|
|
raise HTTPException(status_code=202, detail="Busy")
|
2022-11-01 01:00:35 +00:00
|
|
|
except:
|
2022-10-24 05:25:33 +00:00
|
|
|
raise HTTPException(status_code=504)
|
2022-10-19 11:25:02 +00:00
|
|
|
|
2022-11-01 01:00:35 +00:00
|
|
|
metadata = PngInfo()
|
|
|
|
metadata.add_text("text", prompt)
|
|
|
|
image.save(f'/images/{str(uuid.uuid4())}.png', pnginfo=metadata)
|
|
|
|
|
2022-10-21 04:42:59 +00:00
|
|
|
imgByteArr = io.BytesIO()
|
|
|
|
image.save(imgByteArr, format="PNG")
|
|
|
|
imgByteArr = imgByteArr.getvalue()
|
2022-10-24 05:25:33 +00:00
|
|
|
running = False
|
2022-10-21 04:42:59 +00:00
|
|
|
return Response(content=imgByteArr, media_type="image/png")
|