retrieval_service/datastore/providers/alloydb.py (156 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Literal, Optional
import asyncpg
from google.cloud.alloydb.connector import AsyncConnector, RefreshStrategy
from pgvector.asyncpg import register_vector
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
import models
from .. import datastore
from .postgres import Client as PostgresClient
ALLOYDB_PG_IDENTIFIER = "alloydb-postgres"
class Config(BaseModel, datastore.AbstractConfig):
kind: Literal["alloydb-postgres"]
project: str
region: str
cluster: str
instance: str
user: str
password: str
database: str
class Client(datastore.Client[Config]):
__connector: Optional[AsyncConnector] = None
__pg_client: PostgresClient
@datastore.classproperty
def kind(cls):
return ALLOYDB_PG_IDENTIFIER
def __init__(self, async_engine: AsyncEngine):
self.__pg_client = PostgresClient(async_engine)
@classmethod
async def create(cls, config: Config) -> "Client":
async def getconn() -> asyncpg.Connection:
if cls.__connector is None:
cls.__connector = AsyncConnector(refresh_strategy=RefreshStrategy.LAZY)
conn: asyncpg.Connection = await cls.__connector.connect(
# Alloydb instance connection name
f"projects/{config.project}/locations/{config.region}/clusters/{config.cluster}/instances/{config.instance}",
"asyncpg",
user=f"{config.user}",
password=f"{config.password}",
db=f"{config.database}",
ip_type="PUBLIC",
)
await register_vector(conn)
return conn
async_engine = create_async_engine(
"postgresql+asyncpg://",
async_creator=getconn,
)
if async_engine is None:
raise TypeError("async_engine not instantiated")
return cls(async_engine)
async def initialize_data(
self,
airports: list[models.Airport],
amenities: list[models.Amenity],
flights: list[models.Flight],
policies: list[models.Policy],
) -> None:
await self.__pg_client.initialize_data(airports, amenities, flights, policies)
async def export_data(
self,
) -> tuple[
list[models.Airport],
list[models.Amenity],
list[models.Flight],
list[models.Policy],
]:
return await self.__pg_client.export_data()
async def get_airport_by_id(
self, id: int
) -> tuple[Optional[models.Airport], Optional[str]]:
return await self.__pg_client.get_airport_by_id(id)
async def get_airport_by_iata(
self, iata: str
) -> tuple[Optional[models.Airport], Optional[str]]:
return await self.__pg_client.get_airport_by_iata(iata)
async def search_airports(
self,
country: Optional[str] = None,
city: Optional[str] = None,
name: Optional[str] = None,
) -> tuple[list[models.Airport], Optional[str]]:
return await self.__pg_client.search_airports(country, city, name)
async def get_amenity(
self, id: int
) -> tuple[Optional[models.Amenity], Optional[str]]:
return await self.__pg_client.get_amenity(id)
async def amenities_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> tuple[list[Any], Optional[str]]:
return await self.__pg_client.amenities_search(
query_embedding, similarity_threshold, top_k
)
async def get_flight(
self, flight_id: int
) -> tuple[Optional[models.Flight], Optional[str]]:
return await self.__pg_client.get_flight(flight_id)
async def search_flights_by_number(
self,
airline: str,
number: str,
) -> tuple[list[models.Flight], Optional[str]]:
return await self.__pg_client.search_flights_by_number(airline, number)
async def search_flights_by_airports(
self,
date: str,
departure_airport: Optional[str] = None,
arrival_airport: Optional[str] = None,
) -> tuple[list[models.Flight], Optional[str]]:
return await self.__pg_client.search_flights_by_airports(
date, departure_airport, arrival_airport
)
async def validate_ticket(
self,
airline: str,
flight_number: str,
departure_airport: str,
departure_time: str,
) -> tuple[Optional[models.Flight], Optional[str]]:
return await self.__pg_client.validate_ticket(
airline, flight_number, departure_airport, departure_time
)
async def insert_ticket(
self,
user_id: str,
user_name: str,
user_email: str,
airline: str,
flight_number: str,
departure_airport: str,
arrival_airport: str,
departure_time: str,
arrival_time: str,
):
await self.__pg_client.insert_ticket(
user_id,
user_name,
user_email,
airline,
flight_number,
departure_airport,
arrival_airport,
departure_time,
arrival_time,
)
async def list_tickets(
self,
user_id: str,
) -> tuple[list[Any], Optional[str]]:
return await self.__pg_client.list_tickets(user_id)
async def policies_search(
self, query_embedding: list[float], similarity_threshold: float, top_k: int
) -> tuple[list[str], Optional[str]]:
return await self.__pg_client.policies_search(
query_embedding, similarity_threshold, top_k
)
async def close(self):
await self.__pg_client.close()