run/chatbot-api/app/main.py (113 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from contextlib import asynccontextmanager
import os
from typing import Union
import asyncpg
from fastapi import FastAPI, Request
import google.auth
from google.auth.transport.requests import Request as GRequest
from google.cloud import aiplatform
from langchain.chains.summarize import load_summarize_chain
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate
from langchain_google_vertexai import VertexAI, VertexAIEmbeddings
from pgvector.asyncpg import register_vector
REGION = os.getenv("REGION")
PROJECT_ID = os.getenv("PROJECT_ID")
DB_HOST = os.getenv("DB_HOST")
DB_USER = os.getenv("DB_USER")
DB_NAME = os.getenv("DB_NAME")
aiplatform.init(project=f"{PROJECT_ID}", location=f"{REGION}")
llm = VertexAI()
embeddings_service = VertexAIEmbeddings(
model_name="textembedding-gecko@003",
)
async def find_by_query(pool, q):
"""
Finding similar toy products using pgvector cosine search operator
"""
min_price = 25
max_price = 100
similarity_threshold = 0.1
num_matches = 25
qe = embeddings_service.embed_query(q)
async with pool.acquire() as conn:
await register_vector(conn)
# Find similar products to the query using cosine similarity search
# over all vector embeddings.
# This new feature is provided by `pgvector`.
results = await conn.fetch(
"""
WITH vector_matches AS (
SELECT product_id, 1 - (embedding <=> $1) AS similarity
FROM product_embeddings
WHERE 1 - (embedding <=> $1) > $2
ORDER BY similarity DESC
LIMIT $3
)
SELECT product_name, list_price, description FROM products
WHERE product_id IN (SELECT product_id FROM vector_matches)
AND list_price >= $4 AND list_price <= $5
""",
qe,
similarity_threshold,
num_matches,
min_price,
max_price,
)
if len(results) == 0:
raise Exception("Did not find any results. Adjust the query parameters.")
matches = []
for r in results:
# Collect the description for all the matched similar toy products.
matches.append(
{
"product_name": r["product_name"],
"description": r["description"],
"list_price": round(r["list_price"], 2),
}
)
return matches
map_prompt_template = """
You will be given a detailed description of a toy product.
This description is enclosed in triple backticks (```).
Using this description only, extract the name of the toy,
the price of the toy and its features.
```{text}```
SUMMARY:
"""
combine_prompt_template = """
You will be given a detailed description different toy products
enclosed in triple backticks (```) and a question enclosed in
double backticks(``).
Select one toy that is most relevant to answer the question.
Using that selected toy description, answer the following
question in as much detail as possible.
You should only use the information in the description.
Your answer should include the name of the toy, the price of the toy
and its features. Your answer should be less than 200 words.
Your answer should be in Markdown in a numbered list format.
Description:
```{text}```
Question:
``{user_query}``
Answer:
"""
async def find_by_chatbot(pool, q):
matches = await find_by_query(pool, q)
map_prompt = PromptTemplate(
template=map_prompt_template,
input_variables=["text"],
)
combine_prompt = PromptTemplate(
template=combine_prompt_template,
input_variables=["text", "user_query"],
)
matches = [
f"""
The name of the toy is {r["product_name"]}.
The price of the toy is ${round(r["list_price"], 2)}.
Its description is below:
{r["description"]}.
"""
for r in matches
]
docs = [Document(page_content=t) for t in matches]
chain = load_summarize_chain(
llm,
chain_type="map_reduce",
map_prompt=map_prompt,
combine_prompt=combine_prompt,
)
answer = chain.invoke(
{
"input_documents": docs,
"user_query": q,
}
)
return {"answer": answer["output_text"]}
creds, _ = google.auth.default(
scopes=["https://www.googleapis.com/auth/sqlservice.login"]
)
def get_password():
if not creds.valid:
request = GRequest()
creds.refresh(request)
return creds.token
@asynccontextmanager
async def lifespan(app: FastAPI):
app.state.pool = await asyncpg.create_pool(
host=DB_HOST,
user=DB_USER,
password=get_password,
database=DB_NAME,
ssl="require",
)
yield
await asyncio.wait_for(app.state.pool.close(), 10)
app = FastAPI(lifespan=lifespan)
@app.get("/search")
async def do_search(request: Request, q: Union[str, None] = None):
return await find_by_query(request.app.state.pool, q)
@app.get("/chatbot")
async def ask_chatbot(request: Request, q: Union[str, None] = None):
return await find_by_chatbot(request.app.state.pool, q)
@app.get("/")
async def root(request: Request):
async with request.app.state.pool.acquire() as conn:
version = await conn.fetch("select version()")
return version[0]