from multiprocessing import context
from httplib2 import Response
import torch
import uuid
import os
from diffusers import StableDiffusionPipeline
from dotenv import load_dotenv
from os import getenv
from fastapi import FastAPI, Response, HTTPException
from pydantic import BaseModel
import io
from PIL.PngImagePlugin import PngInfo

load_dotenv()

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=getenv("TOKEN"))
pipe.to("cuda")

class Text(BaseModel):
    text: str



app = FastAPI()

@app.get("/",
    responses = {
        200: {
            "content": {"image/png": {}}
        }
    },
    response_class=Response
)
def root(text: str):
    prompt = text.replace('+', ' ')
    print(prompt)
    try:
        generator = torch.Generator("cuda").manual_seed(1024)
        resp = pipe(prompt)
        print(resp)
        image = resp.images[0]
    except RuntimeError as e:
        raise HTTPException(status_code=202, detail="Busy")
    except:
        raise HTTPException(status_code=504)

    metadata = PngInfo()
    metadata.add_text("text", prompt)
    image.save(f'/images/{str(uuid.uuid4())}.png', pnginfo=metadata)    
    
    imgByteArr = io.BytesIO()
    image.save(imgByteArr, format="PNG")
    imgByteArr = imgByteArr.getvalue()
    running = False
    return Response(content=imgByteArr, media_type="image/png")