backend/app.py (351 lines of code) (raw):

import json import logging import os import time from urllib.parse import unquote import uuid import requests from azure.identity import ManagedIdentityCredential, AzureCliCredential, ChainedTokenCredential from azure.storage.blob import BlobServiceClient from dotenv import load_dotenv from flask import Flask, Response, jsonify, request, session, redirect, url_for from flask_cors import CORS import msal from flask_session import Session from werkzeug.middleware.proxy_fix import ProxyFix # Import the asynchronous secret retrieval function from keyvault import get_secret load_dotenv() # Helper functions for reading environment variables def read_env_variable(var_name, default=None): value = os.getenv(var_name, default) return value.strip() if value else default def read_env_list(var_name): value = os.getenv(var_name, "") return [item.strip() for item in value.split(",") if item.strip()] def read_env_boolean(var_name, default=False): value = os.getenv(var_name, str(default)).strip().lower() return value in ['true', '1', 'yes'] # Read Environment Variables SPEECH_REGION = read_env_variable('SPEECH_REGION') ORCHESTRATOR_ENDPOINT = read_env_variable('ORCHESTRATOR_ENDPOINT') STORAGE_ACCOUNT = read_env_variable('STORAGE_ACCOUNT') LOGLEVEL = read_env_variable('LOGLEVEL', 'INFO').upper() # MSAL / OIDC configuration for custom authentication ENABLE_AUTHENTICATION = read_env_boolean('ENABLE_AUTHENTICATION') FORWARD_ACCESS_TOKEN_TO_ORCHESTRATOR = read_env_boolean('FORWARD_ACCESS_TOKEN_TO_ORCHESTRATOR') OTHER_AUTH_SCOPES = read_env_list('OTHER_AUTH_SCOPES') CLIENT_ID = os.getenv("CLIENT_ID", "your_client_id") APP_SERVICE_CLIENT_SECRET_NAME = os.getenv("APP_SERVICE_CLIENT_SECRET_NAME", "appServiceClientSecretKey") FLASK_SECRET_KEY_NAME = os.getenv("FLASK_SECRET_KEY_NAME", "flaskSecretKey") AUTHORITY = os.getenv("AUTHORITY", "https://login.microsoftonline.com/your_tenant_id") REDIRECT_PATH = os.getenv("REDIRECT_PATH", "/getAToken") # Must match the Azure AD app registration redirect URI. SCOPE = [ "User.Read" ] # Authorization settings ALLOWED_GROUP_NAMES = read_env_list('ALLOWED_GROUP_NAMES') ALLOWED_USER_PRINCIPALS = read_env_list('ALLOWED_USER_PRINCIPALS') ALLOWED_USER_NAMES = read_env_list('ALLOWED_USER_NAMES') SPEECH_RECOGNITION_LANGUAGE = read_env_variable('SPEECH_RECOGNITION_LANGUAGE') SPEECH_SYNTHESIS_LANGUAGE = read_env_variable('SPEECH_SYNTHESIS_LANGUAGE') SPEECH_SYNTHESIS_VOICE_NAME = read_env_variable('SPEECH_SYNTHESIS_VOICE_NAME') # Set logging logging.basicConfig(level=LOGLEVEL) # ------------------------------------------------------------------------------ # Load secrets from Key Vault using the asynchronous function at startup. # This avoids having to call asyncio.run() repeatedly in your helper functions. # ------------------------------------------------------------------------------ FLASK_SECRET_KEY = get_secret(FLASK_SECRET_KEY_NAME) APP_SERVICE_CLIENT_SECRET = get_secret(APP_SERVICE_CLIENT_SECRET_NAME) # Obtain the token using Managed Identity def get_managed_identity_token(): credential = ChainedTokenCredential( ManagedIdentityCredential(), AzureCliCredential() ) token = credential.get_token("https://management.azure.com/.default").token return token def get_function_key(): subscription_id = os.getenv('AZURE_SUBSCRIPTION_ID') resource_group = os.getenv('AZURE_RESOURCE_GROUP_NAME') function_app_name = os.getenv('AZURE_ORCHESTRATOR_FUNC_NAME') token = get_managed_identity_token() logging.info("[webbackend] Obtaining function key.") # URL to get all function keys, including the default one requestUrl = f"https://management.azure.com/subscriptions/{subscription_id}/resourceGroups/{resource_group}/providers/Microsoft.Web/sites/{function_app_name}/functions/orc/listKeys?api-version=2022-03-01" requestHeaders = { "Authorization": f"Bearer {token}", "Content-Type": "application/json" } response = requests.post(requestUrl, headers=requestHeaders) response_json = json.loads(response.content.decode('utf-8')) try: # Assuming you want to get the 'default' key function_key = response_json['default'] except KeyError as e: function_key = None logging.error(f"[webbackend] Error when getting function key. Details: {str(e)}.") return function_key app = Flask(__name__) CORS(app) # Use the asynchronously retrieved Flask secret key app.secret_key = FLASK_SECRET_KEY # Configure server-side session storage app.wsgi_app = ProxyFix(app.wsgi_app, x_proto=1, x_host=1) app.config["SESSION_TYPE"] = "filesystem" app.config["SESSION_FILE_DIR"] = "./flask_session_files" app.config["SESSION_PERMANENT"] = False app.config["PREFERRED_URL_SCHEME"] = "https" Session(app) # --- Helper function to obtain a valid (refreshed) access token --- def get_valid_access_token(scopes): cache = _load_cache() msal_app = _build_msal_app(cache=cache) accounts = msal_app.get_accounts() account = accounts[0] if accounts else None result = msal_app.acquire_token_silent(scopes, account=account) if not result: raise Exception("Could not refresh token silently: no token found in cache.") if "error" in result: raise Exception(result.get("error_description", "Could not refresh token silently.")) _save_cache(cache) return result.get("access_token") # --- Authentication Endpoints --- @app.route("/login") def login(): if not ENABLE_AUTHENTICATION: return redirect(url_for("index")) session["state"] = str(uuid.uuid4()) auth_url = _build_auth_url(scopes=SCOPE, state=session["state"]) return redirect(auth_url) @app.route(REDIRECT_PATH) def authorized(): if not ENABLE_AUTHENTICATION: return redirect(url_for("index")) if request.args.get("state") != session.get("state"): return redirect(url_for("index")) if "error" in request.args: return f"Error: {request.args.get('error_description')}", 400 if request.args.get("code"): logging.info("[webbackend] Attempting to acquire token for user.") cache = _load_cache() result = _build_msal_app(cache=cache).acquire_token_by_authorization_code( request.args["code"], scopes=SCOPE, redirect_uri=url_for("authorized", _external=True) ) if "error" in result: logging.warning(f"Could not acquire token for user. Error: {result.get('error_description')}") return f"Login failure: {result.get('error_description')}", 400 session["user"] = result.get("id_token_claims") session["graph_access_token"] = result.get("access_token") session["refresh_token"] = result.get("refresh_token") _save_cache(cache) if OTHER_AUTH_SCOPES: logging.info("[webbackend] Attempting to acquire token for other scopes.") try: other_access_token = get_valid_access_token(OTHER_AUTH_SCOPES) session["other_access_token"] = other_access_token except Exception as ex: logging.warning(f"Could not acquire token for other scopes {OTHER_AUTH_SCOPES}. Error: {str(ex)}") return f"Other scopes token acquisition failure: {str(ex)}", 400 return redirect(url_for("index")) @app.route("/logout") def logout(): if ENABLE_AUTHENTICATION: session.clear() return redirect( AUTHORITY + "/oauth2/v2.0/logout" + "?post_logout_redirect_uri=" + url_for("index", _external=True) ) else: return redirect(url_for("index")) def _build_auth_url(scopes=None, state=None): return _build_msal_app().get_authorization_request_url( scopes or [], state=state or str(uuid.uuid4()), redirect_uri=url_for("authorized", _external=True) ) def _build_msal_app(cache=None): # Use the asynchronously retrieved client secret client_secret = APP_SERVICE_CLIENT_SECRET return msal.ConfidentialClientApplication( CLIENT_ID, authority=AUTHORITY, client_credential=client_secret, token_cache=cache ) def _load_cache(): cache = msal.SerializableTokenCache() if session.get("token_cache"): cache.deserialize(session["token_cache"]) return cache def _save_cache(cache): if cache.has_state_changed: session["token_cache"] = cache.serialize() # --- End Authentication Endpoints --- @app.route("/") def index(): if ENABLE_AUTHENTICATION and not session.get("user"): return redirect(url_for("login")) return app.send_static_file("index.html") @app.route("/<path:path>") def static_files(path): return app.send_static_file(path) def check_authorization(): if not ENABLE_AUTHENTICATION: return { 'authorized': True, 'client_principal_id': 'no-auth', 'client_principal_name': 'anonymous', 'client_group_names': [], 'access_token': None } user = session.get("user") if not user: logging.info("[webbackend] No user in session; user is not authenticated.") return { 'authorized': False, 'client_principal_id': None, 'client_principal_name': None, 'client_group_names': [], 'access_token': None } client_principal_id = user.get("oid") client_principal_name = user.get("preferred_username") or user.get("upn") try: graph_access_token = get_valid_access_token(SCOPE) session["graph_access_token"] = graph_access_token except Exception as ex: logging.error(f"[webbackend] Failed to refresh Graph token: {str(ex)}") graph_access_token = session.get("graph_access_token", None) other_access_token = None if OTHER_AUTH_SCOPES: try: other_access_token = get_valid_access_token(OTHER_AUTH_SCOPES) session["other_access_token"] = other_access_token except Exception as ex: logging.error(f"[webbackend] Failed to refresh other scopes token: {str(ex)}") other_access_token = session.get("other_access_token", None) access_token = other_access_token if other_access_token else graph_access_token groups = [] if graph_access_token: try: graph_headers = {'Authorization': f'Bearer {graph_access_token}'} graph_url = 'https://graph.microsoft.com/v1.0/me/memberOf' graph_response = requests.get(graph_url, headers=graph_headers) graph_response.raise_for_status() group_data = graph_response.json() groups = [group.get('displayName', 'missing-group-read-all-permission') for group in group_data.get('value', [])] logging.info(f"[webbackend] User groups from Graph API: {groups}") except Exception as e: logging.info(f"[webbackend] Failed to get user groups from Graph API: {e}") else: logging.info("[webbackend] No valid Graph access token available; cannot get user groups") authorized = True if ALLOWED_GROUP_NAMES or ALLOWED_USER_PRINCIPALS or ALLOWED_USER_NAMES: authorized = False if client_principal_name in ALLOWED_USER_NAMES: authorized = True elif client_principal_id in ALLOWED_USER_PRINCIPALS: authorized = True elif any(group in ALLOWED_GROUP_NAMES for group in groups): authorized = True if not authorized: logging.info("[webbackend] User is not in allowed groups or users.") return { 'authorized': authorized, 'client_principal_id': client_principal_id, 'client_principal_name': client_principal_name, 'client_group_names': groups, 'access_token': access_token } @app.route("/chatgpt", methods=["POST"]) def chatgpt(): start_time = time.time() conversation_id = request.json["conversation_id"] question = request.json["query"] logging.info("[webbackend] conversation_id: " + conversation_id) logging.info("[webbackend] question: " + question) auth_info = check_authorization() if not auth_info['authorized']: response = { "answer": "You are not authorized to access this service. Please contact your administrator.", "thoughts": "The user attempted to access the service but is not part of any authorized users or groups.", "conversation_id": conversation_id } return jsonify(response) client_principal_id = auth_info['client_principal_id'] client_principal_name = auth_info['client_principal_name'] client_group_names = auth_info['client_group_names'] access_token = auth_info['access_token'] function_key = get_function_key() try: url = ORCHESTRATOR_ENDPOINT payload = { "conversation_id": conversation_id, "question": question, "client_principal_id": client_principal_id, "client_principal_name": client_principal_name, "client_group_names": client_group_names } if FORWARD_ACCESS_TOKEN_TO_ORCHESTRATOR and access_token: logging.info("[webbackend] Forwarding access token to orchestrator.") payload['access_token'] = access_token headers = { 'Content-Type': 'application/json', 'x-functions-key': function_key } logging.info(f"[webbackend] calling orchestrator at: {ORCHESTRATOR_ENDPOINT}") response = requests.post(url, headers=headers, json=payload) logging.info(f"[webbackend] response: {response.text[:100]}...") return response.text except Exception as e: logging.error("[webbackend] exception in /chatgpt") logging.exception(e) response = { "answer": "Error in application backend.", "thoughts": "", "conversation_id": conversation_id } return jsonify(response) finally: end_time = time.time() duration = end_time - start_time logging.info(f"[webbackend] Finished processing in {duration:.2f} seconds") @app.route("/api/get-speech-token", methods=["GET"]) def getGptSpeechToken(): try: token = get_managed_identity_token() fetch_token_url = f"https://{SPEECH_REGION}.api.cognitive.microsoft.com/sts/v1.0/issueToken" headers = { 'Authorization': f'Bearer {token}', 'Content-Type': 'application/x-www-form-urlencoded' } response = requests.post(fetch_token_url, headers=headers) access_token = str(response.text) return json.dumps({ 'token': access_token, 'region': SPEECH_REGION, 'speechRecognitionLanguage': SPEECH_RECOGNITION_LANGUAGE, 'speechSynthesisLanguage': SPEECH_SYNTHESIS_LANGUAGE, 'speechSynthesisVoiceName': SPEECH_SYNTHESIS_VOICE_NAME }) except Exception as e: logging.exception("[webbackend] exception in /api/get-speech-token") return jsonify({"error": str(e)}), 500 @app.route("/api/get-storage-account", methods=["GET"]) def getStorageAccount(): if not STORAGE_ACCOUNT: return jsonify({"error": "Add STORAGE_ACCOUNT to frontend app settings"}), 500 try: return json.dumps({'storageaccount': STORAGE_ACCOUNT}) except Exception as e: logging.exception("[webbackend] exception in /api/get-storage-account") return jsonify({"error": str(e)}), 500 @app.route("/api/get-blob", methods=["POST"]) def getBlob(): blob_name = unquote(request.json["blob_name"]) logging.info(f"Starting getBlob function for blob: {blob_name}") try: client_credential = ChainedTokenCredential( ManagedIdentityCredential(), AzureCliCredential() ) blob_service_client = BlobServiceClient( f"https://{STORAGE_ACCOUNT}.blob.core.windows.net", client_credential ) blob_client = blob_service_client.get_blob_client(container='documents', blob=blob_name) blob_data = blob_client.download_blob() blob_text = blob_data.readall() logging.info(f"Successfully fetched blob: {blob_name}") return Response(blob_text, content_type='application/octet-stream') except Exception as e: logging.exception("[webbackend] exception in /api/get-blob") logging.exception(blob_name) return jsonify({"error": str(e)}), 500 if __name__ == "__main__": app.run(host='0.0.0.0', port=8000)