main.py (328 lines of code) (raw):
import asyncio
import json
import logging
import os
import secrets
import uuid
import aiofiles
import httpx
import msal
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import (FileResponse, JSONResponse, RedirectResponse,
StreamingResponse)
from fastapi.staticfiles import StaticFiles
from starlette.middleware.base import BaseHTTPMiddleware
from keyvault import get_secret
load_dotenv()
app = FastAPI()
# Logging configuration
logging.getLogger('azure').setLevel(logging.WARNING)
logging.basicConfig(level=os.environ.get('LOGLEVEL', 'DEBUG').upper(), force=True)
logging.getLogger("uvicorn.error").propagate = True
logging.getLogger("uvicorn.access").propagate = True
# -------------------------------
# File-based Session Middleware using request.state
# -------------------------------
class FileSessionMiddleware(BaseHTTPMiddleware):
def __init__(self, app, session_dir: str = "sessions", cookie_name: str = "session_id", max_age: int = 86400):
super().__init__(app)
self.session_dir = session_dir
self.cookie_name = cookie_name
self.max_age = max_age
os.makedirs(session_dir, exist_ok=True)
async def dispatch(self, request: Request, call_next):
session_id = request.cookies.get(self.cookie_name)
if session_id:
session_file = os.path.join(self.session_dir, f"{session_id}.json")
if os.path.exists(session_file):
async with aiofiles.open(session_file, mode="r") as f:
content = await f.read()
try:
session_data = json.loads(content)
except Exception as e:
logging.error(f"Error decoding session file: {e}")
session_data = {}
else:
session_data = {}
else:
session_data = {}
# Instead of assigning to request.session, we assign to request.state.session.
request.state.session = session_data
response = await call_next(request)
# If no session_id exists, create one and set it as a cookie.
if not session_id:
session_id = secrets.token_hex(16)
response.set_cookie(
key=self.cookie_name,
value=session_id,
max_age=self.max_age,
httponly=True,
samesite="lax"
)
session_file = os.path.join(self.session_dir, f"{session_id}.json")
async with aiofiles.open(session_file, mode="w") as f:
await f.write(json.dumps(request.state.session))
return response
app.add_middleware(FileSessionMiddleware)
# -------------------------------
# Authentication Configuration
# -------------------------------
ENABLE_AUTHENTICATION = os.getenv("ENABLE_AUTHENTICATION", "false").lower() == "true"
CLIENT_ID = os.getenv("CLIENT_ID", "your_client_id")
AUTHORITY = os.getenv("AUTHORITY", "https://login.microsoftonline.com/your_tenant_id")
REDIRECT_PATH = os.getenv("REDIRECT_PATH", "/getAToken")
REDIRECT_URI = os.getenv("REDIRECT_URI", f"http://localhost:8000{REDIRECT_PATH}")
BASIC_SCOPE = ["User.Read"]
OTHER_AUTH_SCOPES = os.getenv("OTHER_AUTH_SCOPES", "")
# -------------------------------
# Authentication Secrets
# -------------------------------
MSAL_CLIENT_SECRET = get_secret(os.getenv("APP_SERVICE_CLIENT_SECRET_NAME", "appServiceClientSecretKey"))
SESSION_SECRET_KEY = get_secret("avatarSessionSecretKey")
FUNCTION_KEY = get_secret("avatarOrchestratorFunctionKey")
AZURE_SPEECH_API_KEY = get_secret("avatarSpeechApiKey")
if not FUNCTION_KEY:
raise Exception("FUNCTION_KEY not found in KeyVault.")
app.mount("/static", StaticFiles(directory="static"), name="static")
# -------------------------------
# MSAL Helper Functions (using request.state.session)
# -------------------------------
def _build_msal_app(cache=None):
return msal.ConfidentialClientApplication(
CLIENT_ID,
authority=AUTHORITY,
client_credential=MSAL_CLIENT_SECRET,
token_cache=cache
)
def _build_auth_url(state: str):
msal_app = _build_msal_app()
return msal_app.get_authorization_request_url(
scopes=BASIC_SCOPE,
state=state,
redirect_uri=REDIRECT_URI
)
def _load_cache(request: Request):
cache = msal.SerializableTokenCache()
if "msal_cache" in request.state.session:
try:
cache.deserialize(request.state.session["msal_cache"])
except Exception as e:
logging.error(f"Error deserializing cache: {e}")
return cache
def _save_cache(request: Request, cache):
if cache.has_state_changed:
request.state.session["msal_cache"] = cache.serialize()
async def get_valid_access_token(request: Request, scopes: list):
cache = _load_cache(request)
msal_app = _build_msal_app(cache=cache)
accounts = msal_app.get_accounts()
account = accounts[0] if accounts else None
result = msal_app.acquire_token_silent(scopes, account=account)
if not result or "access_token" not in result:
raise Exception("Could not refresh token silently: no token found in cache.")
if "error" in result:
raise Exception(result.get("error_description", "Could not refresh token silently."))
_save_cache(request, cache)
return result.get("access_token")
# -------------------------------
# Authentication Endpoints
# -------------------------------
@app.get("/login")
async def login(request: Request):
if not ENABLE_AUTHENTICATION:
return RedirectResponse(url="/", status_code=303)
state = str(uuid.uuid4())
request.state.session["state"] = state
auth_url = _build_auth_url(state)
return RedirectResponse(url=auth_url, status_code=303)
@app.get(REDIRECT_PATH)
async def authorized(request: Request):
if not ENABLE_AUTHENTICATION:
return RedirectResponse(url="/", status_code=303)
if request.state.session.get("user"):
return RedirectResponse(url="/", status_code=303)
if request.query_params.get("state") != request.state.session.get("state"):
return JSONResponse(content={"error": "State mismatch"}, status_code=400)
if "error" in request.query_params:
error_desc = request.query_params.get("error_description", "Unknown error")
return JSONResponse(content={"error": error_desc}, status_code=400)
code = request.query_params.get("code")
if not code:
return JSONResponse(content={"error": "Authorization code not found"}, status_code=400)
cache = _load_cache(request)
msal_app = _build_msal_app(cache=cache)
result = msal_app.acquire_token_by_authorization_code(
code,
scopes=BASIC_SCOPE,
redirect_uri=REDIRECT_URI
)
if "error" in result:
return JSONResponse(
content={"error": result.get("error_description", "Could not acquire token")},
status_code=400
)
user_claims = result.get("id_token_claims", {})
minimal_user = {
"oid": user_claims.get("oid"),
"preferred_username": user_claims.get("preferred_username") or user_claims.get("upn")
}
request.state.session["user"] = minimal_user
request.state.session["graph_access_token"] = result.get("access_token")
request.state.session["refresh_token"] = result.get("refresh_token")
_save_cache(request, cache)
request.state.session.pop("state", None)
return RedirectResponse(url="/", status_code=303)
@app.get("/logout")
async def logout(request: Request):
request.state.session.clear()
logout_url = f"{AUTHORITY}/oauth2/v2.0/logout?post_logout_redirect_uri={REDIRECT_URI}"
return RedirectResponse(url=logout_url, status_code=303)
# -------------------------------
# Authorization Check (Extra Token Handling)
# -------------------------------
async def check_authorization(request: Request):
if not ENABLE_AUTHENTICATION:
return {
"authorized": True,
"client_principal_id": "no-auth",
"client_principal_name": "anonymous",
"client_group_names": [],
"access_token": None
}
user = request.state.session.get("user")
if not user:
logging.info("No user in session; user not authenticated.")
return {
"authorized": False,
"client_principal_id": None,
"client_principal_name": None,
"client_group_names": [],
"access_token": None
}
client_principal_id = user.get("oid")
client_principal_name = user.get("preferred_username") or user.get("upn")
try:
graph_access_token = await get_valid_access_token(request, BASIC_SCOPE)
request.state.session["graph_access_token"] = graph_access_token
except Exception as ex:
logging.error(f"Failed to refresh Graph token: {str(ex)}")
graph_access_token = request.state.session.get("graph_access_token", None)
other_access_token = None
if OTHER_AUTH_SCOPES:
try:
scopes = [s.strip() for s in OTHER_AUTH_SCOPES.split(",") if s.strip()]
other_access_token = await get_valid_access_token(request, scopes)
request.state.session["other_access_token"] = other_access_token
except Exception as ex:
logging.error(f"Failed to refresh token for other scopes: {str(ex)}")
other_access_token = request.state.session.get("other_access_token", None)
access_token = other_access_token if other_access_token else graph_access_token
groups = []
if graph_access_token:
try:
graph_headers = {"Authorization": f"Bearer {graph_access_token}"}
graph_url = "https://graph.microsoft.com/v1.0/me/memberOf"
async with httpx.AsyncClient() as client:
response = await client.get(graph_url, headers=graph_headers)
response.raise_for_status()
group_data = response.json()
groups = [g.get("displayName", "missing-group") for g in group_data.get("value", [])]
except Exception as e:
logging.info(f"Failed to get user groups from Graph API: {e}")
return {
"authorized": True,
"client_principal_id": client_principal_id,
"client_principal_name": client_principal_name,
"client_group_names": groups,
"access_token": access_token
}
# -------------------------------
# Protected Routes and Endpoints (using request.state.session)
# -------------------------------
@app.get("/")
async def serve_index(request: Request):
if ENABLE_AUTHENTICATION and not request.state.session.get("user"):
return RedirectResponse(url="/login", status_code=303)
return FileResponse("static/index.html")
@app.get("/favicon.ico")
async def serve_favicon(request: Request):
if ENABLE_AUTHENTICATION and not request.state.session.get("user"):
return RedirectResponse(url="/login", status_code=303)
return FileResponse("static/image/favicon.ico")
@app.post("/speak")
async def speak(request: Request):
body = await request.json()
question = body.get("spokenText")
conversation_id = body.get("conversation_id", "")
if not question:
raise HTTPException(status_code=400, detail="Missing spokenText in request.")
auth_info = await check_authorization(request)
if not auth_info.get("authorized"):
return JSONResponse(
content={
"answer": "You are not authorized to access this service. Please contact your administrator.",
"thoughts": "User not authorized.",
"conversation_id": conversation_id
},
status_code=401
)
access_token = auth_info.get("access_token")
payload = {
"conversation_id": conversation_id,
"question": question,
"text_only": True,
"client_principal_id": auth_info.get("client_principal_id"),
"client_principal_name": auth_info.get("client_principal_name"),
"client_group_names": auth_info.get("client_group_names")
}
if access_token:
payload["access_token"] = access_token
headers = {
"x-functions-key": FUNCTION_KEY,
"Content-Type": "application/json"
}
async def stream_generator():
logging.info("Sending request to streaming endpoint with payload: %s", payload)
async with httpx.AsyncClient(timeout=None) as client:
async with client.stream(
"POST",
os.getenv("STREAMING_ENDPOINT", "http://localhost:7071/api/orcstream"),
json=payload,
headers=headers
) as resp:
logging.info("Received response with status code: %s", resp.status_code)
if resp.status_code != 200:
yield f"Error: {resp.status_code}"
return
last_yield = asyncio.get_event_loop().time()
hearbeatcount = 0
async for line in resp.aiter_lines():
now = asyncio.get_event_loop().time()
if now - last_yield > 15:
# Yield heartbeat to keep connection alive
hearbeatcount += 1
logging.info("Yelding heartbeat ", hearbeatcount)
yield ":\n\n" # SSE comment heartbeat
last_yield = now
if line:
logging.info("Received line from stream: %s", line)
yield line
last_yield = now
return StreamingResponse(stream_generator(), media_type="text/event-stream")
@app.get("/get-speech-token")
async def get_speech_token():
speech_region = os.getenv("AZURE_SPEECH_REGION", "eastus2")
subscription_key = AZURE_SPEECH_API_KEY
if not subscription_key:
raise HTTPException(status_code=400, detail="Missing Azure Speech subscription key.")
token_url = f"https://{speech_region}.api.cognitive.microsoft.com/sts/v1.0/issueToken"
async with httpx.AsyncClient() as client:
headers = {"Ocp-Apim-Subscription-Key": subscription_key}
response = await client.post(token_url, headers=headers)
if response.status_code == 200:
return JSONResponse(content={"token": response.text})
else:
raise HTTPException(status_code=response.status_code, detail="Failed to get speech token.")
@app.get("/get-ice-server-token")
async def get_ice_server_token():
speech_region = os.getenv("AZURE_SPEECH_REGION", "eastus2")
subscription_key = AZURE_SPEECH_API_KEY
if not subscription_key:
raise HTTPException(status_code=400, detail="Missing Azure Speech subscription key.")
token_url = f"https://{speech_region}.tts.speech.microsoft.com/cognitiveservices/avatar/relay/token/v1"
async with httpx.AsyncClient() as client:
headers = {"Ocp-Apim-Subscription-Key": subscription_key}
response = await client.get(token_url, headers=headers)
if response.status_code == 200:
return JSONResponse(content=response.json())
else:
raise HTTPException(status_code=response.status_code, detail="Failed to get ICE server token.")
@app.get("/get-speech-region")
async def get_speech_region():
speech_region = os.getenv("AZURE_SPEECH_REGION", "eastus2")
return JSONResponse(content={"speech_region": speech_region})
@app.get("/get-supported-languages")
async def get_supported_languages():
supported_languages = os.getenv("SUPPORTED_LANGUAGES", "en-US,de-DE,zh-CN,nl-NL")
languages_list = [lang.strip() for lang in supported_languages.split(",")]
return JSONResponse(content={"supported_languages": languages_list})
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 8000))
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True)