retrieval_service/app/routes.py (168 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Mapping, Optional
from fastapi import APIRouter, HTTPException, Request
from google.auth.transport import requests # type:ignore
from google.oauth2 import id_token # type:ignore
from langchain_core.embeddings import Embeddings
import datastore
routes = APIRouter()
def _ParseUserIdToken(headers: Mapping[str, Any]) -> Optional[str]:
"""Parses the bearer token out of the request headers."""
# authorization_header = headers.lower()
user_id_token_header = headers.get("User-Id-Token")
if not user_id_token_header:
raise Exception("no user authorization header")
parts = str(user_id_token_header).split(" ")
if len(parts) != 2 or parts[0] != "Bearer":
raise Exception("Invalid ID token")
return parts[1]
async def get_user_info(request):
headers = request.headers
token = _ParseUserIdToken(headers)
try:
id_info = id_token.verify_oauth2_token(
token, requests.Request(), audience=request.app.state.client_id
)
return {
"user_id": id_info.get("sub"),
"user_name": id_info.get("name"),
"user_email": id_info.get("email"),
}
except Exception as e: # pylint: disable=broad-except
print(e)
@routes.get("/")
async def root():
return {"message": "Hello World"}
@routes.get("/airports")
async def get_airport(
request: Request,
id: Optional[int] = None,
iata: Optional[str] = None,
):
ds: datastore.Client = request.app.state.datastore
if id:
results, sql = await ds.get_airport_by_id(id)
elif iata:
results, sql = await ds.get_airport_by_iata(iata)
else:
raise HTTPException(
status_code=422,
detail="Request requires query params: airport id or iata",
)
return {"results": results, "sql": sql}
@routes.get("/airports/search")
async def search_airports(
request: Request,
country: Optional[str] = None,
city: Optional[str] = None,
name: Optional[str] = None,
):
if country is None and city is None and name is None:
raise HTTPException(
status_code=422,
detail="Request requires at least one query params: country, city, or airport name",
)
ds: datastore.Client = request.app.state.datastore
results, sql = await ds.search_airports(country, city, name)
return {"results": results, "sql": sql}
@routes.get("/amenities")
async def get_amenity(id: int, request: Request):
ds: datastore.Client = request.app.state.datastore
results, sql = await ds.get_amenity(id)
return {"results": results, "sql": sql}
@routes.get("/amenities/search")
async def amenities_search(query: str, top_k: int, request: Request):
ds: datastore.Client = request.app.state.datastore
embed_service: Embeddings = request.app.state.embed_service
query_embedding = embed_service.embed_query(query)
results, sql = await ds.amenities_search(query_embedding, 0.5, top_k)
return {"results": results, "sql": sql}
@routes.get("/flights")
async def get_flight(flight_id: int, request: Request):
ds: datastore.Client = request.app.state.datastore
results, sql = await ds.get_flight(flight_id)
return {"results": results, "sql": sql}
@routes.get("/flights/search")
async def search_flights(
request: Request,
departure_airport: Optional[str] = None,
arrival_airport: Optional[str] = None,
date: Optional[str] = None,
airline: Optional[str] = None,
flight_number: Optional[str] = None,
):
ds: datastore.Client = request.app.state.datastore
if date and (arrival_airport or departure_airport):
results, sql = await ds.search_flights_by_airports(
date, departure_airport, arrival_airport
)
elif airline and flight_number:
results, sql = await ds.search_flights_by_number(airline, flight_number)
else:
raise HTTPException(
status_code=422,
detail="Request requires query params: arrival_airport, departure_airport, date, or both airline and flight_number",
)
return {"results": results, "sql": sql}
@routes.post("/tickets/insert")
async def insert_ticket(
request: Request,
airline: str,
flight_number: str,
departure_airport: str,
arrival_airport: str,
departure_time: str,
arrival_time: str,
):
user_info = await get_user_info(request)
if user_info is None:
raise HTTPException(
status_code=401,
detail="User login required for data insertion",
)
ds: datastore.Client = request.app.state.datastore
results = await ds.insert_ticket(
user_info["user_id"],
user_info["user_name"],
user_info["user_email"],
airline,
flight_number,
departure_airport,
arrival_airport,
departure_time,
arrival_time,
)
return results
@routes.get("/tickets/validate")
async def validate_ticket(
request: Request,
airline: str,
flight_number: str,
departure_airport: str,
departure_time: str,
):
ds: datastore.Client = request.app.state.datastore
results, sql = await ds.validate_ticket(
airline,
flight_number,
departure_airport,
departure_time,
)
return {"results": results, "sql": sql}
@routes.get("/tickets/list")
async def list_tickets(
request: Request,
):
user_info = await get_user_info(request)
if user_info is None:
raise HTTPException(
status_code=401,
detail="User login required for data insertion",
)
ds: datastore.Client = request.app.state.datastore
results, sql = await ds.list_tickets(user_info["user_id"])
return {"results": results, "sql": sql}
@routes.get("/policies/search")
async def policies_search(query: str, top_k: int, request: Request):
ds: datastore.Client = request.app.state.datastore
embed_service: Embeddings = request.app.state.embed_service
query_embedding = embed_service.embed_query(query)
results, sql = await ds.policies_search(query_embedding, 0.5, top_k)
return {"results": results, "sql": sql}