Add image server
This commit is contained in:
parent
8ceb9fd08a
commit
40b657d45c
|
@ -0,0 +1 @@
|
|||
TOKEN="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"
|
|
@ -0,0 +1,15 @@
|
|||
FROM nvidia/cuda:11.6.0-base-ubuntu20.04
|
||||
RUN apt-get update && apt-get install python3 python3-pip -y
|
||||
|
||||
RUN pip3 install --upgrade diffusers transformers scipy python-dotenv cuda-python fastapi uvicorn httplib2 && \
|
||||
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117 && \
|
||||
pip3 install torch==1.11.0+cu115 torchvision==0.12.0+cu115 torchaudio==0.11.0+cu115 -f https://download.pytorch.org/whl/torch_stable.html && \
|
||||
pip3 install python-multipart accelerate
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY main.py /app/main.py
|
||||
VOLUME /root/.cache/huggingface/diffusers/
|
||||
|
||||
|
||||
CMD [ "uvicorn", "main:app", "--host", "0.0.0.0", "--reload" ]
|
|
@ -0,0 +1,16 @@
|
|||
version: '3.9'
|
||||
|
||||
services:
|
||||
imageserver:
|
||||
image: imageserver
|
||||
build: ./
|
||||
runtime: nvidia
|
||||
ports:
|
||||
- 8000:8000
|
||||
env_file:
|
||||
- .env
|
||||
volumes:
|
||||
- ./models:/root/.cache/huggingface/diffusers/
|
||||
- ./main.py:/app/main.py
|
||||
- ./images:/images
|
||||
restart: unless-stopped
|
|
@ -0,0 +1,63 @@
|
|||
from multiprocessing import context
|
||||
from httplib2 import Response
|
||||
import torch
|
||||
import uuid
|
||||
import os
|
||||
from diffusers import StableDiffusionImg2ImgPipeline
|
||||
from dotenv import load_dotenv
|
||||
from os import getenv
|
||||
from fastapi import FastAPI, Response, HTTPException, File, UploadFile
|
||||
from pydantic import BaseModel
|
||||
import io
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
from PIL import Image
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# pipe = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", num_inference_steps=100, revision="fp16", torch_dtype=torch.float16, use_auth_token=getenv("TOKEN"))
|
||||
# pipe.to("cuda")
|
||||
|
||||
class Text(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.post("/",
|
||||
# responses = {
|
||||
# 200: {
|
||||
# "content": {"image/png": {}}
|
||||
# }
|
||||
# },
|
||||
# response_class=Response
|
||||
)
|
||||
def root(text: str):
|
||||
prompt = text.replace('+', ' ')
|
||||
print(prompt)
|
||||
#request_object_content = file.read()
|
||||
# img = Image.open(io.BytesIO(request_object_content))
|
||||
|
||||
# height_orig = img.height
|
||||
# width_orig = img.width
|
||||
# aspect_ratio = width_orig / height_orig
|
||||
# width_new = 512
|
||||
# height_new = int(width_new / aspect_ratio)
|
||||
# img = img.resize((width_new, height_new), 0)
|
||||
|
||||
# try:
|
||||
|
||||
# resp = pipe(prompt, init_image=img)
|
||||
# print(resp)
|
||||
# image = resp.images[0]
|
||||
# except RuntimeError as e:
|
||||
# print(e)
|
||||
# raise HTTPException(status_code=202, detail="Busy")
|
||||
# except:
|
||||
# raise HTTPException(status_code=504)
|
||||
|
||||
# imgByteArr = io.BytesIO()
|
||||
# image.save(imgByteArr, format="PNG")
|
||||
# imgByteArr = imgByteArr.getvalue()
|
||||
# running = False
|
||||
# return Response(content=imgByteArr, media_type="image/png")
|
Loading…
Reference in New Issue