notebooks/stable_diffusion/stable_diffusion_space/server.py (52 lines of code) (raw):
import base64
import io
import os
import time
from fastapi import Depends, FastAPI, HTTPException, Security
from fastapi.security.api_key import APIKey, APIKeyHeader
from pydantic import BaseModel
from starlette.status import HTTP_403_FORBIDDEN
_IS_DEBUG = os.getenv("DEBUG", False)
API_KEY = os.getenv("API_KEY", None)
API_KEY_NAME = "access_token"
if _IS_DEBUG:
from PIL import Image
else:
import torch
from ipu_models import IPUStableDiffusionPipeline
pipe = IPUStableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
revision="fp16",
torch_dtype=torch.float16,
)
pipe.enable_attention_slicing()
pipe("Pipeline warmup...")
api_key = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
app = FastAPI()
async def get_api_key(api_key: str = Security(api_key)):
if api_key == API_KEY:
return api_key
else:
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials")
class StableDiffusionInputs(BaseModel):
prompt: str
guidance_scale: float = 7.5
@app.post("/inference/")
async def stable_diffusion(inputs: StableDiffusionInputs, _: APIKey = Depends(get_api_key)):
start = time.time()
if _IS_DEBUG:
images = [Image.new("RGB", (512, 512), "blue")]
else:
images = pipe(inputs.prompt, guidance_scale=inputs.guidance_scale).images
latency = time.time() - start
images_b64 = []
for image in images:
image_byte_arr = io.BytesIO()
image.save(image_byte_arr, format="PNG")
image_byte_arr = image_byte_arr.getvalue()
images_b64.append(base64.b64encode(image_byte_arr))
content = {"images": images_b64, "latency": latency}
return content
@app.get("/")
async def root():
return {"message": "This is the server running Stable Diffusion v1.5 on Graphcore IPUs!"}