gemini/agents/genai-experience-concierge/scripts/langgraph_demo/dataset.py (251 lines of code) (raw):
# Copyright 2025 Google. This software is provided as-is, without warranty or
# representation for any use or purpose. Your use of it is subject to your
# agreement with Google.
"""Tools for generating a mock Cymbal Retail dataset."""
# pylint: disable=too-many-arguments,too-many-positional-arguments
import json
import subprocess
from typing import Callable, TypedDict, TypeVar
from google.api_core import exceptions, retry
from google.cloud import bigquery
from scripts.langgraph_demo import defaults
connection_permission_retry_config = retry.Retry(
predicate=lambda exc: isinstance(exc, subprocess.CalledProcessError),
initial=1,
maximum=60,
multiplier=2,
timeout=120,
on_error=lambda exc: print(f"API Error: {str(exc)}"),
)
embedding_model_retry_config = retry.Retry(
predicate=lambda exc: isinstance(exc, exceptions.BadRequest),
initial=1,
maximum=60,
multiplier=2,
timeout=120,
on_error=lambda exc: print(f"API Error: {str(exc)}"),
)
_T = TypeVar("_T")
TEXT_EMBEDDING_MODEL = "text-embedding-004"
CREATE_EMBEDDING_MODEL_QUERY = """
CREATE OR REPLACE MODEL `{embedding_model_uri}`
REMOTE WITH CONNECTION `{connection_uri}`
OPTIONS (ENDPOINT = '{endpoint}');
""".strip()
COPY_TABLE_QUERY = """
CREATE OR REPLACE TABLE `{dest_table}`
AS (SELECT * FROM `{source_table}`)
""".strip()
CREATE_PRODUCTS_WITH_EMBEDDINGS_QUERY = """
CREATE OR REPLACE TABLE `{product_with_embedding_table_uri}` AS
SELECT * FROM ML.GENERATE_TEXT_EMBEDDING(
MODEL `{embedding_model_uri}`,
(
SELECT *, CONCAT(product_name, " ", product_description) AS content
FROM `{product_table_uri}`
)
)
WHERE ARRAY_LENGTH(text_embedding) > 0;
""".strip()
class GeneratedDataset(TypedDict):
"""Represents the generated dataset with URIs for tables and models."""
dataset_id: str
products_table_uri: str
stores_table_uri: str
inventory_table_uri: str
embedding_model_uri: str
connection_uri: str
def create(
project: str,
location: str = "US",
dataset_id: str = "cymbal_retail",
connection_id: str = "cymbal_connection",
product_path: str = str(defaults.PRODUCT_GCS_DATASET_PATH),
store_path: str = str(defaults.STORE_GCS_DATASET_PATH),
inventory_path: str = str(defaults.INVENTORY_GCS_DATASET_PATH),
) -> GeneratedDataset:
"""
Create the required Cymbal dataset models and tables.
Only the project is required to exist before calling this function.
This function sets up the BigQuery resources required for the Cymbal retail
application. It creates an embedding model, loads product, store, and
inventory data from Parquet files into BigQuery tables, and generates a
product table with embeddings.
Args:
project (str): Project of the Cymbal dataset and connection.
location (str): Location for the Cymbal dataset and connection.
dataset_id (str): Dataset name for the generated Cymbal retail tables.
connection_id (str): Connection ID to use for creating a BQ resource connection and embedding model.
product_path (str): Path to a Parquet file containing product data.
store_path (str): Path to a Parquet file containing store data.
inventory_path (str): Path to a Parquet file containing inventory data.
Returns:
GeneratedDataset: A dictionary containing URIs for the created tables and model.
Raises:
Exception: If any BigQuery operation fails.
""" # pylint: disable=line-too-long
bq_client = bigquery.Client(project=project, location=location)
setup_dataset(
client=bq_client,
project=project,
location=location,
dataset_id=dataset_id,
connection_id=connection_id,
)
connection_uri = (
f"projects/{project}/locations/{location}/connections/{connection_id}"
)
embedding_model_uri = create_embedding_model(
client=bq_client,
project=project,
dataset_id=dataset_id,
connection_uri=connection_uri,
embedding_model_name=defaults.EMBEDDING_MODEL_NAME,
)
product_only_table_uri = f"{project}.{dataset_id}.cymbal_product_only"
store_table_uri = f"{project}.{dataset_id}.{defaults.STORE_TABLE_NAME}"
inventory_table_uri = f"{project}.{dataset_id}.{defaults.INVENTORY_TABLE_NAME}"
products_table_uri = f"{project}.{dataset_id}.{defaults.PRODUCT_TABLE_NAME}"
load_table_from_parquet(
client=bq_client,
table_uri=store_table_uri,
source_path=store_path,
)
load_table_from_parquet(
client=bq_client,
table_uri=inventory_table_uri,
source_path=inventory_path,
)
load_table_from_parquet(
client=bq_client,
table_uri=product_only_table_uri,
source_path=product_path,
)
create_product_table_with_embeddings(
client=bq_client,
source_table_uri=product_only_table_uri,
products_table_uri=products_table_uri,
embedding_model_uri=embedding_model_uri,
)
return GeneratedDataset(
dataset_id=dataset_id,
products_table_uri=products_table_uri,
stores_table_uri=store_table_uri,
inventory_table_uri=inventory_table_uri,
embedding_model_uri=embedding_model_uri,
connection_uri=connection_uri,
)
def load_table_from_parquet(
client: bigquery.Client,
table_uri: str,
source_path: str,
) -> None:
"""Load a Parquet file into a BigQuery table."""
job_config = bigquery.LoadJobConfig()
job_config.write_disposition = bigquery.WriteDisposition.WRITE_TRUNCATE
job_config.source_format = bigquery.SourceFormat.PARQUET
with_check(
f"Creating table: `{table_uri}`",
lambda: client.load_table_from_uri(
source_uris=source_path,
destination=table_uri,
job_config=job_config,
).result(),
)
def create_product_table_with_embeddings(
client: bigquery.Client,
source_table_uri: str,
products_table_uri: str,
embedding_model_uri: str,
) -> None:
"""Create a table with embeddings for product semantic search."""
product_with_embedding_query = CREATE_PRODUCTS_WITH_EMBEDDINGS_QUERY.format(
product_with_embedding_table_uri=products_table_uri,
embedding_model_uri=embedding_model_uri,
product_table_uri=source_table_uri,
)
with_check(
f"Creating table: `{products_table_uri}`",
lambda: client.query_and_wait(product_with_embedding_query),
)
def create_embedding_model(
client: bigquery.Client,
project: str,
dataset_id: str,
connection_uri: str,
embedding_model_name: str = defaults.EMBEDDING_MODEL_NAME,
) -> str:
"""Create a BigQuery embedding model in the dataset using the provided connection."""
embedding_endpoint = TEXT_EMBEDDING_MODEL
embedding_model_uri = f"{project}.{dataset_id}.{embedding_model_name}"
embedding_query = CREATE_EMBEDDING_MODEL_QUERY.format(
embedding_model_uri=embedding_model_uri,
connection_uri=connection_uri,
endpoint=embedding_endpoint,
)
try:
embedding_model_retry_config(
lambda: with_check(
f"Creating embedding model: `{embedding_model_uri}`",
lambda: client.query_and_wait(embedding_query),
)
)()
except exceptions.RetryError as e:
e.add_note(
"Please wait and try again if the error is permission-related."
" It is safe to re-run this command with the same inputs."
)
raise
return embedding_model_uri
def setup_dataset(
client: bigquery.Client,
project: str,
location: str,
dataset_id: str,
connection_id: str,
) -> None:
"""Ensure a BigQuery dataset with a Cloud Resource connection is correctly configured."""
dataset_uri = f"{project}.{dataset_id}"
dataset = bigquery.Dataset(dataset_uri)
dataset.location = location
with_check(
"Creating dataset (if not exists)...",
lambda: client.create_dataset(dataset, exists_ok=True),
)
connection_service_account: str | None = None
try:
connection_service_account = get_connection_service_account(
project=project,
location=location,
connection_id=connection_id,
)
except subprocess.CalledProcessError:
with_check(
"Connection not found, attempting to create connection",
lambda: subprocess.run(
[
"bq",
"mk",
"--connection",
"--location",
location,
"--project_id",
project,
"--connection_type",
"CLOUD_RESOURCE",
connection_id,
],
check=True,
),
)
# try to get service account again...
connection_service_account = get_connection_service_account(
project=project,
location=location,
connection_id=connection_id,
)
assert connection_service_account is not None, "Connection service account not set."
connection_permission_retry_config(
lambda: with_check(
"Granting BQ connection the Vertex AI User role",
lambda: subprocess.run(
[
"gcloud",
"projects",
"add-iam-policy-binding",
project,
"--member",
f"serviceAccount:{connection_service_account}",
"--role",
"roles/aiplatform.user",
],
check=True,
),
)
)()
def get_connection_service_account(
project: str,
location: str,
connection_id: str,
) -> str:
"""Retrieve the service account associated with a BigQuery connection."""
completed_process = subprocess.run(
[
"bq",
"show",
"--format",
"json",
"--connection",
f"{project}.{location}.{connection_id}",
],
check=True,
capture_output=True,
)
connection_details = json.loads(completed_process.stdout)
connection_service_account = str(
connection_details["cloudResource"]["serviceAccountId"]
)
return connection_service_account
def with_check(start_message: str, fn: Callable[[], _T]) -> _T:
"""
Executes a function and prints a success or failure message.
Args:
start_message (str): The message to print before executing the function.
fn (Callable[[], _T]): The function to execute.
Returns:
_T: The result of the executed function.
Raises:
Exception: If the function execution fails.
"""
print(f"{start_message}... ", end="")
try:
res = fn()
print("SUCCESS")
return res
except Exception:
print("FAILURE")
raise