Compare commits

...

4 Commits

Author SHA1 Message Date
Jimmy 6a550d5847 Example client 2022-12-13 21:23:21 +13:00
Jimmy 7dba63c954 Working image server 2022-12-13 21:23:05 +13:00
Jimmy d261d8e46f Ignore png and pycache 2022-12-13 21:22:41 +13:00
Jimmy 59307bff41 Ignore images and models folders 2022-12-13 19:14:55 +13:00
3 changed files with 82 additions and 33 deletions

6
.gitignore vendored
View File

@ -1,2 +1,6 @@
*.jpeg
*.jpg
*.jpg
*.png
imageserver/images/
imageserver/models
__pycache__

40
imageserver/client.py Normal file
View File

@ -0,0 +1,40 @@
# import aiohttp
# import aiofiles
import asyncio
import requests
from io import BytesIO
from PIL import Image
import shutil
from random import randint
def main():
print("Starting")
img = Image.new('RGB', (25, 25), color = (randint(0, 255), randint(0, 255), randint(0, 255)))
img = Image.open("/home/jimmy/image.png")
byte_io = BytesIO()
img.save(byte_io, 'png')
byte_io.seek(0)
r = requests.post(url='http://localhost:8000?text=cartoon',
files={
'my_file': (
'1.png',
byte_io,
'image/png'
),
},
stream=True
)
print(r.status_code)
if r.status_code == 200:
byte_io = BytesIO(r.content)
img = Image.open(byte_io)
img.show()
if __name__ == '__main__':
main()

View File

@ -14,8 +14,8 @@ from PIL import Image
load_dotenv()
# pipe = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", num_inference_steps=100, revision="fp16", torch_dtype=torch.float16, use_auth_token=getenv("TOKEN"))
# pipe.to("cuda")
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", num_inference_steps=100, revision="fp16", torch_dtype=torch.float16, use_auth_token=getenv("TOKEN"))
pipe.to("cuda")
class Text(BaseModel):
text: str
@ -25,39 +25,44 @@ class Text(BaseModel):
app = FastAPI()
@app.post("/",
# responses = {
# 200: {
# "content": {"image/png": {}}
# }
# },
# response_class=Response
responses = {
200: {
"content": {"image/png": {}}
}
},
response_class=Response
)
def root(text: str):
async def root(text: str, my_file: UploadFile = File(...)):
prompt = text.replace('+', ' ')
print(prompt)
#request_object_content = file.read()
# img = Image.open(io.BytesIO(request_object_content))
request_object_content = await my_file.read()
img = Image.open(io.BytesIO(request_object_content))
# height_orig = img.height
# width_orig = img.width
# aspect_ratio = width_orig / height_orig
# width_new = 512
# height_new = int(width_new / aspect_ratio)
# img = img.resize((width_new, height_new), 0)
height_orig = img.height
width_orig = img.width
aspect_ratio = width_orig / height_orig
width_new = 512
height_new = int(width_new / aspect_ratio)
img = img.resize((width_new, height_new), 0)
try:
resp = pipe(prompt, image=img)
print(resp)
image = resp.images[0]
except RuntimeError as e:
print(e)
raise HTTPException(status_code=202, detail="Busy")
except Exception as e:
raise HTTPException(status_code=504, detail=str(e))
if resp["nsfw_content_detected"] == [True]:
raise HTTPException(status_code=418, detail="NSFW")
# try:
# resp = pipe(prompt, init_image=img)
# print(resp)
# image = resp.images[0]
# except RuntimeError as e:
# print(e)
# raise HTTPException(status_code=202, detail="Busy")
# except:
# raise HTTPException(status_code=504)
imgByteArr = io.BytesIO()
image.save(imgByteArr, format="PNG")
imgByteArr = imgByteArr.getvalue()
running = False
return Response(content=imgByteArr, media_type="image/png")
# imgByteArr = io.BytesIO()
# image.save(imgByteArr, format="PNG")
# imgByteArr = imgByteArr.getvalue()
# running = False
# return Response(content=imgByteArr, media_type="image/png")