From 1571d31d3187e2714fa17e0788bb0e092f1e9f22 Mon Sep 17 00:00:00 2001 From: Jimmy Date: Fri, 21 Oct 2022 17:42:59 +1300 Subject: [PATCH] Working with fastapi --- imageserver/main.py | 40 +++++++++++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/imageserver/main.py b/imageserver/main.py index 4becd40..52e58b4 100644 --- a/imageserver/main.py +++ b/imageserver/main.py @@ -1,19 +1,49 @@ +from multiprocessing import context +from httplib2 import Response import torch import uuid import os from diffusers import StableDiffusionPipeline from dotenv import load_dotenv from os import getenv +from fastapi import FastAPI, Response +from pydantic import BaseModel +import io load_dotenv() -# get your token at https://huggingface.co/settings/tokens + pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=getenv("TOKEN")) pipe.to("cuda") -prompt = "metal buttons are often soldiers who just got out of high school or a couple of years graduated from college easy as an air conditioned box about radar the patriot radar known as the a n n e e e pi this is an extremely powerful radar unit so powerful that they actually" +class Text(BaseModel): + text: str -for _ in range(10): - image = pipe(prompt)["sample"][0] +app = FastAPI() - image.save(f"{uuid.uuid4()}.png".replace(" ", "_")) \ No newline at end of file + +@app.get("/", + responses = { + 200: { + "content": {"image/png": {}} + } + }, + response_class=Response +) +async def root(text: Text): + # get your token at https://huggingface.co/settings/tokens + + prompt = text.text + print(prompt) + image = pipe(prompt).images[0] + +# print(image) + + # image = Image.new('RGB', (1000, 1000), (100,200,10)) + imgByteArr = io.BytesIO() + # image.save expects a file as a argument, passing a bytes io ins + image.save(imgByteArr, format="PNG") + # Turn the BytesIO object back into a bytes object + imgByteArr = imgByteArr.getvalue() + # media_type here sets the media type of the actual response sent to the client. + return Response(content=imgByteArr, media_type="image/png")