Add error handling

This commit is contained in:
Jimmy 2022-10-24 18:25:33 +13:00
parent 39f104288b
commit f76f62e53d
1 changed files with 11 additions and 13 deletions

View File

@ -6,22 +6,22 @@ import os
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline
from dotenv import load_dotenv from dotenv import load_dotenv
from os import getenv from os import getenv
from fastapi import FastAPI, Response from fastapi import FastAPI, Response, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
import io import io
load_dotenv() load_dotenv()
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=getenv("TOKEN")) pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=getenv("TOKEN"))
pipe.to("cuda") pipe.to("cuda")
class Text(BaseModel): class Text(BaseModel):
text: str text: str
app = FastAPI()
app = FastAPI()
@app.get("/", @app.get("/",
responses = { responses = {
200: { 200: {
@ -30,20 +30,18 @@ app = FastAPI()
}, },
response_class=Response response_class=Response
) )
async def root(text: Text): def root(text: Text):
# get your token at https://huggingface.co/settings/tokens
prompt = text.text prompt = text.text
print(prompt) print(prompt)
image = pipe(prompt).images[0] try:
image = pipe(prompt).images[0]
except RuntimeError as e:
raise HTTPException(status_code=202, detail="Busy")
except:
raise HTTPException(status_code=504)
# print(image)
# image = Image.new('RGB', (1000, 1000), (100,200,10))
imgByteArr = io.BytesIO() imgByteArr = io.BytesIO()
# image.save expects a file as a argument, passing a bytes io ins
image.save(imgByteArr, format="PNG") image.save(imgByteArr, format="PNG")
# Turn the BytesIO object back into a bytes object
imgByteArr = imgByteArr.getvalue() imgByteArr = imgByteArr.getvalue()
# media_type here sets the media type of the actual response sent to the client. running = False
return Response(content=imgByteArr, media_type="image/png") return Response(content=imgByteArr, media_type="image/png")