Compare commits

...

2 Commits

Author SHA1 Message Date
Jimmy 3a63a37777 Add httplib2 2022-10-24 18:25:57 +13:00
Jimmy f76f62e53d Add error handling 2022-10-24 18:25:33 +13:00
2 changed files with 12 additions and 14 deletions

View File

@ -1,7 +1,7 @@
FROM nvidia/cuda:11.6.0-base-ubuntu20.04 FROM nvidia/cuda:11.6.0-base-ubuntu20.04
RUN apt-get update && apt-get install python3 python3-pip -y RUN apt-get update && apt-get install python3 python3-pip -y
RUN pip3 install --upgrade diffusers transformers scipy python-dotenv cuda-python fastapi uvicorn && \ RUN pip3 install --upgrade diffusers transformers scipy python-dotenv cuda-python fastapi uvicorn httplib2 && \
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117 && \ pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117 && \
pip install torch==1.11.0+cu115 torchvision==0.12.0+cu115 torchaudio==0.11.0+cu115 -f https://download.pytorch.org/whl/torch_stable.html pip install torch==1.11.0+cu115 torchvision==0.12.0+cu115 torchaudio==0.11.0+cu115 -f https://download.pytorch.org/whl/torch_stable.html

View File

@ -6,22 +6,22 @@ import os
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline
from dotenv import load_dotenv from dotenv import load_dotenv
from os import getenv from os import getenv
from fastapi import FastAPI, Response from fastapi import FastAPI, Response, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
import io import io
load_dotenv() load_dotenv()
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=getenv("TOKEN")) pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=getenv("TOKEN"))
pipe.to("cuda") pipe.to("cuda")
class Text(BaseModel): class Text(BaseModel):
text: str text: str
app = FastAPI()
app = FastAPI()
@app.get("/", @app.get("/",
responses = { responses = {
200: { 200: {
@ -30,20 +30,18 @@ app = FastAPI()
}, },
response_class=Response response_class=Response
) )
async def root(text: Text): def root(text: Text):
# get your token at https://huggingface.co/settings/tokens
prompt = text.text prompt = text.text
print(prompt) print(prompt)
image = pipe(prompt).images[0] try:
image = pipe(prompt).images[0]
except RuntimeError as e:
raise HTTPException(status_code=202, detail="Busy")
except:
raise HTTPException(status_code=504)
# print(image)
# image = Image.new('RGB', (1000, 1000), (100,200,10))
imgByteArr = io.BytesIO() imgByteArr = io.BytesIO()
# image.save expects a file as a argument, passing a bytes io ins
image.save(imgByteArr, format="PNG") image.save(imgByteArr, format="PNG")
# Turn the BytesIO object back into a bytes object
imgByteArr = imgByteArr.getvalue() imgByteArr = imgByteArr.getvalue()
# media_type here sets the media type of the actual response sent to the client. running = False
return Response(content=imgByteArr, media_type="image/png") return Response(content=imgByteArr, media_type="image/png")