gke/load-embeddings/main.py (130 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
import time
import asyncpg
import google.auth
from google.auth.transport.requests import Request as GRequest
from google.cloud import aiplatform
from langchain_google_vertexai import VertexAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
import numpy as np
import pandas as pd
from pgvector.asyncpg import register_vector
DB_HOST = os.getenv("DB_HOST")
DB_USER = os.getenv("DB_USER")
DB_NAME = os.getenv("DB_NAME")
REGION = os.getenv("REGION")
PROJECT_ID = os.getenv("PROJECT_ID")
DATASET_FILE = "retail_toy_dataset.csv"
def load_dataset(location) -> pd.DataFrame:
"""Loads the dataset from the specified location"""
df = pd.read_csv(location)
df = df.loc[:, ["product_id", "product_name", "description", "list_price"]]
df = df.dropna()
return df
async def load_into_db(conn: asyncpg.Connection, df: pd.DataFrame):
"""Loads data into a Postgres database table.
This may take a few minutes to run."""
await conn.execute("DROP TABLE IF EXISTS products CASCADE")
await conn.execute(
"""
CREATE TABLE products(
product_id VARCHAR(1024) PRIMARY KEY,
product_name TEXT,
description TEXT,
list_price NUMERIC
)
"""
)
# Copy the dataframe to the `products` table.
tuples = list(df.itertuples(index=False))
await conn.copy_records_to_table(
"products", records=tuples, columns=list(df), timeout=10
)
def split_product_descriptions(df: pd.DataFrame):
"""Splits long product descriptions into smaller chunks"""
text_splitter = RecursiveCharacterTextSplitter(
separators=[".", "\n"],
chunk_size=500,
chunk_overlap=0,
length_function=len,
)
chunked = []
for _, row in df.iterrows():
product_id = row["product_id"]
desc = row["description"]
splits = text_splitter.create_documents([desc])
for s in splits:
r = {"product_id": product_id, "content": s.page_content}
chunked.append(r)
return chunked
def retry_with_backoff(func, *args, retry_delay=5, backoff_factor=2, **kwargs):
"""Helper function to retry failed API requests with exponential
backoff."""
max_attempts = 10
retries = 0
for i in range(max_attempts):
try:
return func(*args, **kwargs)
except Exception as e:
print(f"error: {e}")
retries += 1
wait = retry_delay * (backoff_factor**retries)
print(f"Retry after waiting for {wait} seconds...")
time.sleep(wait)
def generate_vector_embeddings(df: pd.DataFrame):
"""Generate the vector embeddings for each chunk of text.
Vertex AI text embedding model is used to generate vector embeddings,
which outputs a 768-dimensional vector for each chunk of text.
This may take a few minutes to run."""
aiplatform.init(project=f"{PROJECT_ID}", location=f"{REGION}")
embeddings_service = VertexAIEmbeddings(
model_name="textembedding-gecko@003",
)
chunked = split_product_descriptions(df)
batch_size = 5
for i in range(0, len(chunked), batch_size):
request = [x["content"] for x in chunked[i : i + batch_size]]
response = retry_with_backoff(embeddings_service.embed_documents, request)
# Store the retrieved vector embeddings for each chunk back.
for x, e in zip(chunked[i : i + batch_size], response):
x["embedding"] = e
# Store the generated embeddings in a pandas dataframe.
product_embeddings = pd.DataFrame(chunked)
print(product_embeddings.head())
return product_embeddings
async def store_embeddings_in_db(conn: asyncpg.Connection, product_embeddings):
"""Store the generated vector embeddings in a PostgreSQL table."""
await conn.execute("DROP TABLE IF EXISTS product_embeddings")
await conn.execute(
"""
CREATE TABLE product_embeddings(
product_id VARCHAR(1024) NOT NULL REFERENCES products(product_id),
content TEXT,
embedding vector(768)
)
"""
)
# Store all the generated embeddings back into the database.
for index, row in product_embeddings.iterrows():
await conn.execute(
"""
INSERT INTO product_embeddings
(product_id, content, embedding)
VALUES
($1, $2, $3)
""",
row["product_id"],
row["content"],
np.array(row["embedding"]),
)
async def create_embeddings_index(conn: asyncpg.Connection):
"""Create indexes for faster similarity search in pgvector"""
m = 24
ef_construction = 100
operator = "vector_cosine_ops"
lists = 100
# Create an HNSW index on the `product_embeddings` table.
await conn.execute(
f"""
CREATE INDEX ON product_embeddings
USING hnsw(embedding {operator})
WITH (m = {m}, ef_construction = {ef_construction})
"""
)
# Create an IVFFLAT index on the `product_embeddings` table.
await conn.execute(
f"""
CREATE INDEX ON product_embeddings
USING ivfflat(embedding {operator})
WITH (lists = {lists})
"""
)
creds, _ = google.auth.default(
scopes=["https://www.googleapis.com/auth/sqlservice.login"]
)
def get_password():
if not creds.valid:
request = GRequest()
creds.refresh(request)
return creds.token
async def main():
print("Starting load-embeddings job...")
df = load_dataset(DATASET_FILE)
print(df.head(10))
print("Creating connection pool...")
async with asyncpg.create_pool(
host=DB_HOST,
user=DB_USER,
password=get_password,
database=DB_NAME,
ssl="require",
) as pool:
async with pool.acquire() as conn:
print("Registering vector type...")
await register_vector(conn)
print("Loading dataset into db...")
await load_into_db(conn, df)
print("Generating embeddings...")
embeddings = generate_vector_embeddings(df)
print("Loading embeddings into db...")
await store_embeddings_in_db(conn, embeddings)
print("Creating embeddings index...")
await create_embeddings_index(conn)
print("Done")
if __name__ == "__main__":
asyncio.run(main())