Working with fastapi

This commit is contained in:
Jimmy 2022-10-21 17:42:59 +13:00
parent f03d90f378
commit 1571d31d31
1 changed files with 35 additions and 5 deletions

View File

@ -1,19 +1,49 @@
from multiprocessing import context
from httplib2 import Response
import torch import torch
import uuid import uuid
import os import os
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline
from dotenv import load_dotenv from dotenv import load_dotenv
from os import getenv from os import getenv
from fastapi import FastAPI, Response
from pydantic import BaseModel
import io
load_dotenv() load_dotenv()
# get your token at https://huggingface.co/settings/tokens
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=getenv("TOKEN")) pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=getenv("TOKEN"))
pipe.to("cuda") pipe.to("cuda")
prompt = "metal buttons are often soldiers who just got out of high school or a couple of years graduated from college easy as an air conditioned box about radar the patriot radar known as the a n n e e e pi this is an extremely powerful radar unit so powerful that they actually" class Text(BaseModel):
text: str
for _ in range(10): app = FastAPI()
image = pipe(prompt)["sample"][0]
image.save(f"{uuid.uuid4()}.png".replace(" ", "_"))
@app.get("/",
responses = {
200: {
"content": {"image/png": {}}
}
},
response_class=Response
)
async def root(text: Text):
# get your token at https://huggingface.co/settings/tokens
prompt = text.text
print(prompt)
image = pipe(prompt).images[0]
# print(image)
# image = Image.new('RGB', (1000, 1000), (100,200,10))
imgByteArr = io.BytesIO()
# image.save expects a file as a argument, passing a bytes io ins
image.save(imgByteArr, format="PNG")
# Turn the BytesIO object back into a bytes object
imgByteArr = imgByteArr.getvalue()
# media_type here sets the media type of the actual response sent to the client.
return Response(content=imgByteArr, media_type="image/png")