Compare commits

...

4 Commits

Author SHA1 Message Date
Jimmy 018163fae8 Save images locally. Add manual seed 2022-11-01 14:00:35 +13:00
Jimmy 2f91297ff9 Expose ports, Add volumes 2022-11-01 13:59:54 +13:00
Jimmy 80097d995d Don't use gpui 2022-11-01 13:59:10 +13:00
Jimmy 76813d86ec Ignore images folder 2022-11-01 13:55:21 +13:00
3 changed files with 21 additions and 8 deletions

3
.gitignore vendored
View File

@ -1,2 +1,3 @@
.env .env
image.png image.png
images/

View File

@ -10,10 +10,14 @@ services:
env_file: env_file:
- .env - .env
volumes: volumes:
- ~/Downloads/sd-v1-4-full-ema.ckpt:/app/model.ckpt - ./images:/images
- ./notebooks:/notebooks
- ./models:/root/.cache/huggingface/diffusers/
ports:
- 8888:8888
- 8000:8000
vosk: vosk:
image: alphacep/kaldi-en-gpu image: alphacep/kaldi-en
ports: ports:
- 2700:2700 - 2700:2700
runtime: nvidia

View File

@ -9,6 +9,7 @@ from os import getenv
from fastapi import FastAPI, Response, HTTPException from fastapi import FastAPI, Response, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
import io import io
from PIL.PngImagePlugin import PngInfo
load_dotenv() load_dotenv()
@ -30,16 +31,23 @@ app = FastAPI()
}, },
response_class=Response response_class=Response
) )
def root(text: Text): def root(text: str):
prompt = text.text prompt = text.replace('+', ' ')
print(prompt) print(prompt)
try: try:
image = pipe(prompt).images[0] generator = torch.Generator("cuda").manual_seed(1024)
resp = pipe(prompt)
print(resp)
image = resp.images[0]
except RuntimeError as e: except RuntimeError as e:
raise HTTPException(status_code=202, detail="Busy") raise HTTPException(status_code=202, detail="Busy")
except: except:
raise HTTPException(status_code=504) raise HTTPException(status_code=504)
metadata = PngInfo()
metadata.add_text("text", prompt)
image.save(f'/images/{str(uuid.uuid4())}.png', pnginfo=metadata)
imgByteArr = io.BytesIO() imgByteArr = io.BytesIO()
image.save(imgByteArr, format="PNG") image.save(imgByteArr, format="PNG")
imgByteArr = imgByteArr.getvalue() imgByteArr = imgByteArr.getvalue()