auth.py (81 lines of code) (raw):

import logging import os from typing import Dict, List import httpx import msal import chainlit as cl def read_env_list(var_name: str) -> List[str]: """Reads a comma-separated list from the environment variable.""" value = os.getenv(var_name, "") return [item.strip() for item in value.split(",") if item.strip()] def get_env_var(name: str, fallback: str = None) -> str: """Helper to fetch and log missing environment variables.""" value = os.getenv(name, fallback) if value is None: logging.warning(f"[auth] Environment variable '{name}' is not set.") return value async def get_user_groups(access_token: str) -> List[str]: """Fetch user group names from Microsoft Graph API.""" graph_url = "https://graph.microsoft.com/v1.0/me/memberOf" headers = {"Authorization": f"Bearer {access_token}"} try: async with httpx.AsyncClient() as client: response = await client.get(graph_url, headers=headers) response.raise_for_status() group_data = response.json() groups = [g.get("displayName", "unknown-group") for g in group_data.get("value", [])] logging.info(f"[auth] User groups: {groups}") return groups except Exception as e: logging.warning(f"[auth] Failed to retrieve groups: {e}") return [] def is_user_authorized(name: str, principal_id: str, groups: List[str]) -> bool: """Check if user is authorized based on group or user criteria.""" allowed_names = read_env_list("ALLOWED_USER_NAMES") allowed_ids = read_env_list("ALLOWED_USER_PRINCIPALS") allowed_groups = read_env_list("ALLOWED_GROUP_NAMES") if not (allowed_names or allowed_ids or allowed_groups): return True if name in allowed_names or principal_id in allowed_ids: return True if any(group in allowed_groups for group in groups): return True logging.info(f"[auth] Access denied for user {name}. Not in allowed users or groups.") return False @cl.oauth_callback async def oauth_callback( provider_id: str, code: str, raw_user_data: Dict[str, str], default_user: cl.User ) -> cl.User: """Handles the OAuth callback and returns a validated Chainlit User.""" client_id = get_env_var("OAUTH_AZURE_AD_CLIENT_ID", get_env_var("CLIENT_ID")) client_secret = get_env_var("OAUTH_AZURE_AD_CLIENT_SECRET") tenant_id = get_env_var("OAUTH_AZURE_AD_TENANT_ID") authority = f"https://login.microsoftonline.com/{tenant_id}" scopes = read_env_list("OAUTH_AZURE_AD_SCOPES") or ["User.Read"] # Build MSAL confidential client msal_app = msal.ConfidentialClientApplication( client_id, authority=authority, client_credential=client_secret ) result = msal_app.acquire_token_by_refresh_token( refresh_token=default_user.metadata.get("refresh_token"), scopes=scopes ) if "error" in result: error_desc = result.get("error_description", "Unknown error") raise Exception(f"Token acquisition failed: {error_desc}") access_token = result.get("access_token") refresh_token = result.get("refresh_token") id_token = result.get("id_token_claims", {}) user_id = id_token.get("oid", "00000000-0000-0000-0000-000000000000") user_name = id_token.get("name", "anonymous") principal_name = id_token.get("preferred_username", "") # Fetch user groups groups = await get_user_groups(access_token) if access_token else [] authorized = is_user_authorized(principal_name, user_id, groups) return cl.User( identifier=user_name, metadata={ "access_token": access_token, "refresh_token": refresh_token, "authorized": authorized, "user_name": user_name, "client_principal_id": user_id, "client_principal_name": principal_name, "client_group_names": groups } )