aa-integration-backend/ui-connector/main.py (197 lines of code) (raw):
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import json
import random
from datetime import datetime
import gzip
import hashlib
from flask import Flask, request, make_response, jsonify, render_template
from flask_cors import CORS
from flask_socketio import SocketIO, emit, join_room, leave_room, rooms
from socketio.exceptions import ConnectionRefusedError
import redis
import time
import config
import dialogflow
from auth import check_auth, generate_jwt, token_required, check_jwt, load_jwt_secret_key, check_app_auth
app = Flask(__name__)
CORS(app, origins=config.CORS_ALLOWED_ORIGINS)
socketio = SocketIO(app, cors_allowed_origins=config.CORS_ALLOWED_ORIGINS)
load_jwt_secret_key()
def redis_pubsub_handler(message):
"""Handles messages from Redis Pub/Sub."""
logging.info('Redis Pub/Sub Received data: {}'.format(message))
msg_object = json.loads(message['data'])
socketio.emit(msg_object['data_type'], msg_object,
to=msg_object['conversation_name'])
logging.info('Redis Subscribe: {0},{1},{2},{3}; conversation_name: {4}, data_type: {5}.'.format(
message['type'],
message['pattern'],
message['channel'],
message['data'],
msg_object['conversation_name'],
msg_object['data_type']))
def psubscribe_exception_handler(ex, pubsub, thread):
logging.exception('An error occurred while getting pubsub messages: {}'.format(ex))
time.sleep(2)
SERVER_ID = '{}-{}'.format(random.uniform(0, 322321),
datetime.now().timestamp())
logging.info('--------- SERVER_ID: {} ---------'.format(SERVER_ID))
redis_client = redis.StrictRedis(
host=config.REDIS_HOST, port=config.REDIS_PORT,
health_check_interval=10,
socket_connect_timeout=15,
retry_on_timeout=True,
socket_keepalive=True,
retry=redis.retry.Retry(redis.backoff.ExponentialBackoff(cap=5, base=1), 5),
retry_on_error=[redis.exceptions.ConnectionError, redis.exceptions.TimeoutError, redis.exceptions.ResponseError])
p = redis_client.pubsub(ignore_subscribe_messages=True)
p.psubscribe(**{'{}:*'.format(SERVER_ID): redis_pubsub_handler})
thread = p.run_in_thread(sleep_time=0.001, exception_handler=psubscribe_exception_handler)
def get_conversation_name_without_location(conversation_name):
"""Returns a conversation name without its location id."""
conversation_name_without_location = conversation_name
if '/locations/' in conversation_name:
name_array = conversation_name.split('/')
conversation_name_without_location = '/'.join(
name_array[i] for i in [0, 1, -2, -1])
return conversation_name_without_location
@app.route('/')
def test():
"""Shows a test page for conversation runtime handling.
"""
return render_template('index.html')
@app.route('/status')
def check_status():
"""Tests whether the service is available for a domain.
Remove this function if it's not necessary for you.
"""
return 'Hello, cross-origin-world!'
@app.route('/register', methods=['POST'])
def register_token():
"""Registers a JWT token after checking authorization header."""
auth = request.headers.get('Authorization', '')
if not check_auth(auth):
return make_response('Could not authenticate user', 401, {'Authentication': 'valid token required'})
token = generate_jwt(request.get_json(force=True, silent=True))
return jsonify({'token': token})
@app.route('/register-app', methods=['POST'])
def register_app_token():
"""Registers a JWT token after checking application-level auth."""
data = request.get_json()
if not check_app_auth(data):
return make_response('Could not authenticate user', 401, {'Authentication': 'valid application level auth required'})
token = generate_jwt(request.get_json(force=True, silent=True))
return jsonify({'token': token})
# Note: Dialogflow methods projects.locations.conversations.list and projects.locations.answerRecords.list are not supported.
# projects.locations.answerRecords.patch
@app.route('/<version>/projects/<project>/locations/<location>/answerRecords/<path:path>', methods=['PATCH'])
# projects.locations.conversations.participants.patch
@app.route('/<version>/projects/<project>/locations/<location>/conversations/<path:path>', methods=['PATCH'])
# projects.locations.conversations.create
@app.route('/<version>/projects/<project>/locations/<location>/conversations', defaults={'path': ''}, methods=['POST'])
# GET:
# projects.locations.conversations.get
# projects.locations.conversations.messages.list
# projects.locations.conversations.participants.get
# projects.locations.conversations.participants.list
# POST:
# projects.locations.conversations.complete
# projects.locations.conversations.create
# projects.locations.conversations.messages.batchCreate
# projects.locations.conversations.participants.analyzeContent
# projects.locations.conversations.participants.create
# projects.locations.conversations.participants.suggestions.suggestArticles
# projects.locations.conversations.participants.suggestions.suggestFaqAnswers
# projects.locations.conversations.participants.suggestions.suggestSmartReplies
@app.route('/<version>/projects/<project>/locations/<location>/conversations/<path:path>', methods=['GET', 'POST'])
# projects.locations.conversationProfiles.get
@app.route('/<version>/projects/<project>/locations/<location>/conversationProfiles/<path:path>', methods=['GET'])
# projects.locations.conversationModels.get
@app.route('/<version>/projects/<project>/locations/<location>/conversationModels/<path:path>', methods=['GET'])
# projects.locations.suggestions.searchKnowledge
@app.route('/<version>/projects/<project>/locations/<location>/suggestions:searchKnowledge', defaults={'path': None}, methods=['POST'])
# projects.locations.conversations.generateStatelessSuggestion
@app.route('/<version>/projects/<project>/locations/<location>/statelessSuggestion:generate', defaults={'path': None}, methods=['POST'])
# projects.locations.generators.get
@app.route('/<version>/projects/<project>/locations/<location>/generators/<path:path>', methods=['GET'])
@token_required
def call_dialogflow(version, project, location, path):
"""Forwards valid request to dialogflow and return its responese."""
logging.info(
'Called Dialogflow for request path: {}'.format(request.full_path))
if request.method == 'GET':
response = dialogflow.get_dialogflow(location, request.full_path)
logging.info('get_dialogflow response: {0}, {1}, {2}'.format(
gzip.decompress(response.raw.data), response.status_code, response.headers))
return response.raw.data, response.status_code, response.headers.items()
elif request.method == 'POST':
# Handles projects.conversations.complete, whose request body should be empty.
response = None
if request.path.endswith(':complete'):
response = dialogflow.post_dialogflow(location, request.full_path)
else:
response = dialogflow.post_dialogflow(
location, request.full_path, request.get_json())
logging.info('post_dialogflow response: {0}, {1}, {2}'.format(
response.raw.data, response.status_code, response.headers))
return response.raw.data, response.status_code, response.headers.items()
else:
response = dialogflow.patch_dialogflow(
location, request.full_path, request.get_json())
logging.info('patch_dialogflow response: {0}, {1}, {2}'.format(
response.raw.data, response.status_code, response.headers))
return response.raw.data, response.status_code, response.headers.items()
@app.route('/conversation-name', methods=['POST'])
@token_required
def set_conversation_name():
"""Allows setting a conversationIntegrationKey:conversationName key/value pair in Redis.
This is useful in cases where it's not possible to send the DialogFlow
conversation name to the agent desktop directly. A good example of a
conversationIntegrationKey is a phone number.
"""
conversation_integration_key = request.json.get('conversationIntegrationKey', '')
hashed_key = hashlib.sha256(conversation_integration_key.encode('utf-8')).hexdigest()
conversation_name = request.json.get('conversationName', '')
logging.info(
'/conversation-name - redis: SET %s %s', conversation_integration_key, conversation_name)
result = redis_client.set(hashed_key, conversation_name)
if not (conversation_integration_key and conversation_name and result):
return make_response('Bad request', 400)
else:
return jsonify({conversation_integration_key: conversation_name})
@app.route('/conversation-name', methods=['GET'])
@token_required
def get_conversation_name():
"""Allows agent desktops to get a DialogFlow conversation name from Redis
using a conversationIntegrationKey.
"""
conversation_integration_key = str(request.args.get('conversationIntegrationKey'))
hashed_key = hashlib.sha256(conversation_integration_key.encode('utf-8')).hexdigest()
conversation_name = redis_client.get(hashed_key)
logging.info(
'/conversation-name - redis: GET %s -> %s', conversation_integration_key, conversation_name)
if not conversation_integration_key:
return make_response('Bad request', 400)
else:
return jsonify({'conversationName': str(conversation_name, encoding='utf-8') if conversation_name else ''})
@app.route('/conversation-name', methods=['DELETE'])
@token_required
def del_conversation_name():
"""Allows agent desktops to delete a DialogFlow conversation name from Redis
using a conversationIntegrationKey.
"""
conversation_integration_key = str(request.args.get('conversationIntegrationKey'))
hashed_key = hashlib.sha256(conversation_integration_key.encode('utf-8')).hexdigest()
result = redis_client.delete(hashed_key)
logging.info(
'/conversation-name - redis: DEL %s, result %s', conversation_integration_key, result)
if conversation_integration_key == 'None':
return make_response('Bad request', 400)
elif not result:
return make_response('Not found', 404)
else:
return make_response('Success', 200)
@socketio.on('connect')
def connect(auth={}):
logging.info(
'Receives connection request with sid: {0}.'.format(request.sid))
if isinstance(auth, dict) and 'token' in auth:
is_valid, log_info = check_jwt(auth['token'])
logging.info(log_info)
if is_valid:
return True
socketio.emit('unauthenticated')
raise ConnectionRefusedError('authentication failed')
@socketio.on('disconnect')
def disconnect(reason):
logging.info('Client disconnected, reason: {}, request.sid: {}'.format(reason, request.sid))
room_list = rooms()
# Delete mapping for conversation_name and SERVER_ID.
if len(room_list) > 1:
room_list.pop(0) # the first one in room list is request.sid
redis_client.delete(*room_list)
@app.errorhandler(500)
def server_error(e):
"""Handles Flask HTTP errors."""
logging.exception('An error occurred during a request.')
return """
An internal error occurred: <pre>{}</pre>
See logs for full stacktrace.
""".format(e), 500
@socketio.on('join-conversation')
def on_join(message):
"""Joins a room specified by its conversation name."""
logging.info('Received event: join-conversation: {}'.format(message))
# Remove location id from the conversation name.
conversation_name = get_conversation_name_without_location(message)
join_room(conversation_name)
# Update mapping for conversation_name and SERVER_ID.
redis_client.set(conversation_name, SERVER_ID)
logging.info(
'join-conversation for: {}'.format(conversation_name))
return True, conversation_name
@socketio.on('leave-conversation')
def on_leave(message):
"""Leaves a room specified by its conversation name."""
logging.info('Received event: leave-conversation: {}'.format(message))
# Remove location id from the conversation name.
conversation_name = get_conversation_name_without_location(message)
leave_room(conversation_name)
# Delete mapping for conversation_name and SERVER_ID.
redis_client.delete(conversation_name)
logging.info(
'leave-conversation for: {}'.format(conversation_name))
return True, conversation_name
@socketio.on_error_default
def default_error_handler(e):
"""Handles SocketIO event errors."""
logging.exception('error from {0} event: {1}'.format(
request.event['message'], e))
if __name__ == '__main__':
# This is used when running locally. Gunicorn is used to run the application
# on Google App Engine and Cloud Run. See entrypoint in Dockerfile.
port = int(os.environ.get('PORT', 8080))
socketio.run(app, host='127.0.0.1', port=port, debug=True)