Working with fastapi
This commit is contained in:
parent
f03d90f378
commit
1571d31d31
|
@ -1,19 +1,49 @@
|
||||||
|
from multiprocessing import context
|
||||||
|
from httplib2 import Response
|
||||||
import torch
|
import torch
|
||||||
import uuid
|
import uuid
|
||||||
import os
|
import os
|
||||||
from diffusers import StableDiffusionPipeline
|
from diffusers import StableDiffusionPipeline
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from os import getenv
|
from os import getenv
|
||||||
|
from fastapi import FastAPI, Response
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import io
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# get your token at https://huggingface.co/settings/tokens
|
|
||||||
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=getenv("TOKEN"))
|
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=getenv("TOKEN"))
|
||||||
pipe.to("cuda")
|
pipe.to("cuda")
|
||||||
|
|
||||||
prompt = "metal buttons are often soldiers who just got out of high school or a couple of years graduated from college easy as an air conditioned box about radar the patriot radar known as the a n n e e e pi this is an extremely powerful radar unit so powerful that they actually"
|
class Text(BaseModel):
|
||||||
|
text: str
|
||||||
|
|
||||||
for _ in range(10):
|
app = FastAPI()
|
||||||
image = pipe(prompt)["sample"][0]
|
|
||||||
|
|
||||||
image.save(f"{uuid.uuid4()}.png".replace(" ", "_"))
|
|
||||||
|
@app.get("/",
|
||||||
|
responses = {
|
||||||
|
200: {
|
||||||
|
"content": {"image/png": {}}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
response_class=Response
|
||||||
|
)
|
||||||
|
async def root(text: Text):
|
||||||
|
# get your token at https://huggingface.co/settings/tokens
|
||||||
|
|
||||||
|
prompt = text.text
|
||||||
|
print(prompt)
|
||||||
|
image = pipe(prompt).images[0]
|
||||||
|
|
||||||
|
# print(image)
|
||||||
|
|
||||||
|
# image = Image.new('RGB', (1000, 1000), (100,200,10))
|
||||||
|
imgByteArr = io.BytesIO()
|
||||||
|
# image.save expects a file as a argument, passing a bytes io ins
|
||||||
|
image.save(imgByteArr, format="PNG")
|
||||||
|
# Turn the BytesIO object back into a bytes object
|
||||||
|
imgByteArr = imgByteArr.getvalue()
|
||||||
|
# media_type here sets the media type of the actual response sent to the client.
|
||||||
|
return Response(content=imgByteArr, media_type="image/png")
|
||||||
|
|
Loading…
Reference in New Issue