Compare commits
4 Commits
40b657d45c
...
6a550d5847
Author | SHA1 | Date |
---|---|---|
Jimmy | 6a550d5847 | |
Jimmy | 7dba63c954 | |
Jimmy | d261d8e46f | |
Jimmy | 59307bff41 |
|
@ -1,2 +1,6 @@
|
||||||
*.jpeg
|
*.jpeg
|
||||||
*.jpg
|
*.jpg
|
||||||
|
*.png
|
||||||
|
imageserver/images/
|
||||||
|
imageserver/models
|
||||||
|
__pycache__
|
|
@ -0,0 +1,40 @@
|
||||||
|
# import aiohttp
|
||||||
|
# import aiofiles
|
||||||
|
import asyncio
|
||||||
|
import requests
|
||||||
|
from io import BytesIO
|
||||||
|
from PIL import Image
|
||||||
|
import shutil
|
||||||
|
from random import randint
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("Starting")
|
||||||
|
|
||||||
|
img = Image.new('RGB', (25, 25), color = (randint(0, 255), randint(0, 255), randint(0, 255)))
|
||||||
|
img = Image.open("/home/jimmy/image.png")
|
||||||
|
byte_io = BytesIO()
|
||||||
|
img.save(byte_io, 'png')
|
||||||
|
byte_io.seek(0)
|
||||||
|
|
||||||
|
r = requests.post(url='http://localhost:8000?text=cartoon',
|
||||||
|
files={
|
||||||
|
'my_file': (
|
||||||
|
'1.png',
|
||||||
|
byte_io,
|
||||||
|
'image/png'
|
||||||
|
),
|
||||||
|
},
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
print(r.status_code)
|
||||||
|
|
||||||
|
if r.status_code == 200:
|
||||||
|
byte_io = BytesIO(r.content)
|
||||||
|
img = Image.open(byte_io)
|
||||||
|
img.show()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,8 +14,8 @@ from PIL import Image
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# pipe = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", num_inference_steps=100, revision="fp16", torch_dtype=torch.float16, use_auth_token=getenv("TOKEN"))
|
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", num_inference_steps=100, revision="fp16", torch_dtype=torch.float16, use_auth_token=getenv("TOKEN"))
|
||||||
# pipe.to("cuda")
|
pipe.to("cuda")
|
||||||
|
|
||||||
class Text(BaseModel):
|
class Text(BaseModel):
|
||||||
text: str
|
text: str
|
||||||
|
@ -25,39 +25,44 @@ class Text(BaseModel):
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
@app.post("/",
|
@app.post("/",
|
||||||
# responses = {
|
responses = {
|
||||||
# 200: {
|
200: {
|
||||||
# "content": {"image/png": {}}
|
"content": {"image/png": {}}
|
||||||
# }
|
}
|
||||||
# },
|
},
|
||||||
# response_class=Response
|
response_class=Response
|
||||||
)
|
)
|
||||||
def root(text: str):
|
async def root(text: str, my_file: UploadFile = File(...)):
|
||||||
prompt = text.replace('+', ' ')
|
prompt = text.replace('+', ' ')
|
||||||
print(prompt)
|
print(prompt)
|
||||||
#request_object_content = file.read()
|
request_object_content = await my_file.read()
|
||||||
# img = Image.open(io.BytesIO(request_object_content))
|
img = Image.open(io.BytesIO(request_object_content))
|
||||||
|
|
||||||
# height_orig = img.height
|
height_orig = img.height
|
||||||
# width_orig = img.width
|
width_orig = img.width
|
||||||
# aspect_ratio = width_orig / height_orig
|
aspect_ratio = width_orig / height_orig
|
||||||
# width_new = 512
|
width_new = 512
|
||||||
# height_new = int(width_new / aspect_ratio)
|
height_new = int(width_new / aspect_ratio)
|
||||||
# img = img.resize((width_new, height_new), 0)
|
img = img.resize((width_new, height_new), 0)
|
||||||
|
try:
|
||||||
|
resp = pipe(prompt, image=img)
|
||||||
|
print(resp)
|
||||||
|
image = resp.images[0]
|
||||||
|
except RuntimeError as e:
|
||||||
|
print(e)
|
||||||
|
raise HTTPException(status_code=202, detail="Busy")
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=504, detail=str(e))
|
||||||
|
if resp["nsfw_content_detected"] == [True]:
|
||||||
|
raise HTTPException(status_code=418, detail="NSFW")
|
||||||
|
|
||||||
# try:
|
imgByteArr = io.BytesIO()
|
||||||
|
image.save(imgByteArr, format="PNG")
|
||||||
# resp = pipe(prompt, init_image=img)
|
imgByteArr = imgByteArr.getvalue()
|
||||||
# print(resp)
|
running = False
|
||||||
# image = resp.images[0]
|
return Response(content=imgByteArr, media_type="image/png")
|
||||||
# except RuntimeError as e:
|
|
||||||
# print(e)
|
|
||||||
# raise HTTPException(status_code=202, detail="Busy")
|
|
||||||
# except:
|
|
||||||
# raise HTTPException(status_code=504)
|
|
||||||
|
|
||||||
# imgByteArr = io.BytesIO()
|
|
||||||
# image.save(imgByteArr, format="PNG")
|
|
||||||
# imgByteArr = imgByteArr.getvalue()
|
|
||||||
# running = False
|
|
||||||
# return Response(content=imgByteArr, media_type="image/png")
|
|
||||||
|
|
Loading…
Reference in New Issue