In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Generate and store embeddings with batch processing

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/GoogleCloudPlatform/python-docs-samples/blob/main/alloydb/notebooks/generate_batch_embeddings.ipynb)

---
## Introduction

This notebook demonstrates an efficient way to generate and store vector embeddings in AlloyDB. You'll learn how to:

* **Optimize embedding generation**: Dynamically batch text chunks based on character length to generate more embeddings with each API call.
* **Streamline storage**: Use [Asyncio](https://docs.python.org/3/library/asyncio.html) to seamlessly update AlloyDB with the generated embeddings.

This approach significantly speeds up the process, especially for large datasets, making it ideal for efficiently handling large-scale embedding tasks.

## What you'll need

* A Google Cloud Account and Google Cloud Project

## Basic Setup
### Install dependencies

In [None]:
%pip install \
    google-cloud-alloydb-connector[asyncpg]==1.4.0 \
    sqlalchemy==2.0.36 \
    pandas==2.2.3 \
    vertexai==1.70.0 \
    asyncio==3.4.3 \
    greenlet==3.1.1 \
    --quiet

### Authenticate to Google Cloud within Colab
If you're running this on google colab notebook, you will need to Authenticate as an IAM user.

In [None]:
from google.colab import auth

auth.authenticate_user()

### Connect Your Google Cloud Project

In [None]:
# @markdown Please fill in the value below with your GCP project ID and then run the cell.

# Please fill in these values.
project_id = "my-project-id"  # @param {type:"string"}

# Quick input validations.
assert project_id, "⚠️ Please provide a Google Cloud project ID"

# Configure gcloud.
!gcloud config set project {project_id}

### Enable APIs for AlloyDB and Vertex AI

You will need to enable these APIs in order to create an AlloyDB database and utilize Vertex AI as an embeddings service!

In [None]:
!gcloud services enable alloydb.googleapis.com aiplatform.googleapis.com

## Set up AlloyDB
You will need a Postgres AlloyDB instance for the following stages of this notebook. Please set the following variables to connect to your instance or create a new instance

In [None]:
# @markdown Please fill in the both the Google Cloud region and name of your AlloyDB instance. Once filled in, run the cell.

# Please fill in these values.
region = "us-central1"  # @param {type:"string"}
cluster_name = "my-cluster"  # @param {type:"string"}
instance_name = "my-primary"  # @param {type:"string"}
database_name = "test_db"  # @param {type:"string"}
table_name = "investments"
password = input("Please provide a password to be used for 'postgres' database user: ")

In [None]:
# Quick input validations.
assert region, "⚠️ Please provide a Google Cloud region"
assert instance_name, "⚠️ Please provide the name of your instance"
assert database_name, "⚠️ Please provide the name of your database_name"

### Create an AlloyDB Instance
If you have already created an AlloyDB Cluster and Instance, you can skip these steps and skip to the `Connect to AlloyDB` section.

> ⏳ - Creating an AlloyDB cluster may take a few minutes.

In [None]:
!gcloud beta alloydb clusters create {cluster_name} --password={password} --region={region}

Create an instance attached to our cluster with the following command.
> ⏳ - Creating an AlloyDB instance may take a few minutes.

In [None]:
!gcloud beta alloydb instances create {instance_name} --instance-type=PRIMARY --cpu-count=2 --region={region} --cluster={cluster_name}

To connect to your AlloyDB instance from this notebook, you will need to enable public IP on your instance. Alternatively, you can follow [these instructions](https://cloud.google.com/alloydb/docs/connect-external) to connect to an AlloyDB for PostgreSQL instance with Private IP from outside your VPC.

In [None]:
!gcloud beta alloydb instances update {instance_name} --region={region} --cluster={cluster_name} --assign-inbound-public-ip=ASSIGN_IPV4 --database-flags="password.enforce_complexity=on" --no-async

### Connect to AlloyDB

This function will create a connection pool to your AlloyDB instance using the [AlloyDB Python connector](https://github.com/GoogleCloudPlatform/alloydb-python-connector). The AlloyDB Python connector will automatically create secure connections to your AlloyDB instance using mTLS.

In [None]:
import asyncpg

import sqlalchemy
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine

from google.cloud.alloydb.connector import AsyncConnector, IPTypes

async def init_connection_pool(connector: AsyncConnector, db_name: str, pool_size: int = 5) -> AsyncEngine:
    # initialize Connector object for connections to AlloyDB
    connection_string = f"projects/{project_id}/locations/{region}/clusters/{cluster_name}/instances/{instance_name}"

    async def getconn() -> asyncpg.Connection:
        conn: asyncpg.Connection = await connector.connect(
            connection_string,
            "asyncpg",
            user="postgres",
            password=password,
            db=db_name,
            ip_type=IPTypes.PUBLIC,
        )
        return conn

    pool = create_async_engine(
        "postgresql+asyncpg://",
        async_creator=getconn,
        pool_size=pool_size,
        max_overflow=0,
    )
    return pool

connector = AsyncConnector()

### Create a Database

Nex, you will create database to store the data using the connection pool. Enabling public IP takes a few minutes, you may get an error that there is no public IP address. Please wait and retry this step if you hit an error!

In [None]:
from sqlalchemy import text, exc

async def create_db(database_name, connector):    
    pool = await init_connection_pool(connector, "postgres")
    async with pool.connect() as conn:
        try:
          # End transaction. This ensures that a clean slate before creating a database.
          await conn.execute(text("COMMIT"))
          await conn.execute(text(f"CREATE DATABASE {database_name}"))
          print(f"Database '{database_name}' created successfully")
        except exc.ProgrammingError as e:
          print(e)

await create_db(database_name=database_name, connector=connector)

### Download data

The following code has been prepared code to help insert the CSV data into your AlloyDB for PostgreSQL database.

Download the CSV file:

In [None]:
!gcloud storage cp gs://cloud-samples-data/alloydb/investments_data ./investments.csv

The download can be verified by the following command or using the "Files" tab.

In [None]:
!ls

### Import data to your database


In this step you will:

1. Create the table into store data
2. And insert the data from the CSV into the database table

In [None]:
# Prepare data
import pandas as pd

data = "./investments.csv"

df = pd.read_csv(data)
df['etf'] = df['etf'].map({'t': True, 'f': False})
df['rating'] = df['rating'].astype(str).fillna('')

In [None]:
df.head()

The data consists of the following columns:

* **id**
* **ticker**: A string representing the stock symbol or ticker (e.g., "AAPL" for Apple, "GOOG" for Google).
* **etf**: A boolean value indicating whether the asset is an ETF (True) or not (False).
* **market**:  A string representing the stock exchange where the asset is traded.
* **rating**: Whether to hold, buy or sell a stock.
* **overview**: A text field for a general overview or description of the asset.
* **analysis**: A text field, for a more detailed analysis of the asset.
* **overview_embedding** (empty)
* **analysis_embedding** (empty)

In this dataset, we need to embed two columns `overview` and `analysis`. The embeddings corresponding to these columns will be added to the `overview_embedding` and `analysis_embedding` column respectively.

In [None]:
create_table_cmd = sqlalchemy.text(
    f'CREATE TABLE {table_name} ( \
        id SERIAL PRIMARY KEY, \
        ticker VARCHAR(255) NOT NULL UNIQUE, \
        etf BOOLEAN, \
        market VARCHAR(255), \
        rating TEXT,  \
        overview TEXT, \
        overview_embedding VECTOR (768), \
        analysis TEXT,  \
        analysis_embedding VECTOR (768) \
    )'
)


insert_data_cmd = sqlalchemy.text(
    f"INSERT INTO {table_name} (id, ticker, etf, market, rating, overview, analysis)\n"
    "VALUES (:id, :ticker, :etf, :market, :rating, :overview, :analysis)\n",
)

In [10]:
from google.cloud.alloydb.connector import AsyncConnector

# Create table and insert data
async def insert_data(pool):
  async with pool.connect() as db_conn:
    await db_conn.execute(sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS vector;"))
    await db_conn.execute(create_table_cmd)
    await db_conn.execute(
        insert_data_cmd,
        df.to_dict('records'),
    )
    await db_conn.commit()

pool = await init_connection_pool(connector, database_name)
await insert_data(pool)
await pool.dispose()

## Building an Embeddings Workflow

Now that we have created our database, we'll define the methods to carry out each step of the embeddings workflow.

The workflow contains four major steps:
1. **Read the Data:** Load the dataset into our program.
2. **Batch the Data:** Divide the data into smaller batches for efficient processing.
3. **Generate Embeddings:** Use an embedding model to create vector representations of the text. The text to be embed could be present in multiple columns in the table.
4. **Update Original Table:** Add the generated embeddings as new columns to our table.

#### Step 0:  Configure Logging

In [None]:
import logging
import sys

# Configure the root logger to output messages with INFO level or above
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='%(asctime)s[%(levelname)5s][%(name)14s] - %(message)s',  datefmt='%H:%M:%S', force=True)

#### Step 1: Read the data

This code reads data from a database and yields it for further processing.

In [None]:
from typing import AsyncIterator, List
from sqlalchemy import RowMapping
from sqlalchemy.ext.asyncio import AsyncEngine


async def get_source_data(
    pool: AsyncEngine, embed_cols: List[str]
) -> AsyncIterator[RowMapping]:
    """Retrieves data from the database for embedding, excluding already embedded data.

    Args:
      pool: The AsyncEngine pool corresponding to the AlloyDB database.
      embed_cols: A list of column names containing the data to be embedded.

    Yields:
      A single row of data, containing the 'id' and the specified `embed_cols`.
      For example: {'id': 'id1', 'col1': 'val1', 'col2': 'val2'}
    """
    logger = logging.getLogger("get_source_data")

    # Only embed columns which are not already embedded.
    where_clause = " OR ".join(f"{col}_embedding IS NULL" for col in embed_cols)
    sql = f"SELECT id, {', '.join(embed_cols)} FROM {table_name} WHERE {where_clause};"
    logger.info(f"Running SQL query: {sql}")

    async with pool.connect() as conn:
        async for row in await conn.stream(text(sql)):
            logger.debug(f"yielded row: {row._mapping['id']}")
            # Yield the row as a dictionary (RowMapping)
            yield row._mapping

#### Step 2: Batch the data

This code defines a function called `batch_source_data` that takes database rows and groups them into batches based on a character count limit (max_char_count). This batching process is crucial for efficient embedding generation for these reasons:

* **Resource Optimization:**  Instead of sending numerous small requests, batching allows us to send fewer, larger requests. This significantly optimizes resource usage and potentially reduces API costs.

* **Working Within API Limits:**  The max_char_count limit ensures each batch stays within the API's acceptable input size, preventing issues with exceeding the maximum character limit.


In [None]:
from typing import Any, List
import asyncio


async def batch_source_data(
    read_generator: AsyncIterator[RowMapping],
    embed_cols: List[str],
) -> AsyncIterator[List[dict[str, Any]]]:
    """
    Groups data into batches for efficient embedding processing.

    It is ensured that each batch adheres to predefined limits for character count
    (`max_char_count`) and the number of embeddable instances (`max_instances_per_prediction`).

    Args:
      read_generator: An asynchronous generator yielding individual data rows.
      embed_cols: A list of column names containing the data to be embedded.

    Yields:
      A list of rows, where each row contains data to be embedded.
      For example:
      [
        {'id' : 'id1', 'col1': 'val1', 'col2': 'val2'},
        ...
      ]
      where col1 and col2 are columns containing data to be embedded.
    """
    logger = logging.getLogger("batch_data")

    global max_char_count
    global max_instances_per_prediction

    batch = []
    batch_char_count = 0
    batch_num = 0
    batch_embed_cells = 0

    async for row in read_generator:
        # Char count in current row
        row_char_count = 0
        row_embed_cells = 0
        for col in embed_cols:
            if col in row and row[col] is not None:
                row_char_count += len(row[col])
                row_embed_cells += 1

        # Skip the row if all columns to embed are empty.
        if row_embed_cells == 0:
            continue

        # Ensure the batch doesn't exceed the maximum character count
        # or the maximum number of embedding instances.
        if (batch_char_count + row_char_count > max_char_count) or (
            batch_embed_cells + row_embed_cells > max_instances_per_prediction
        ):
            batch_num += 1
            logger.info(f"yielded batch number: {batch_num} with length: {len(batch)}")
            yield batch
            batch, batch_char_count, batch_embed_cells = [], 0, 0

        # Add the current row to the batch
        batch.append(row)
        batch_char_count += row_char_count
        batch_embed_cells += row_embed_cells

    if batch:
        batch_num += 1
        logger.info(f"Yielded batch number: {batch_num} with length: {len(batch)}")
        yield batch

#### Step 3: Generate embeddings

This step converts your text data into numerical representations called "embeddings." These embeddings capture the meaning and relationships between words, making them useful for various tasks like search, recommendations, and clustering.

The code uses two functions to efficiently generate embeddings:

**embed_text**

This function your text data and sends it to Vertex AI, transforming the text in specific columns into embeddings.

**embed_objects_concurrently**

This function is the orchestrator. It manages the embedding generation process for multiple batches of text concurrently. This function ensures that all batches are processed efficiently without overwhelming the system.

In [None]:
from google.api_core.exceptions import ResourceExhausted
from typing import Union
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel


async def embed_text(
    batch_data: List[dict[str, Any]],
    model: TextEmbeddingModel,
    cols_to_embed: List[str],
    task_type: str = "SEMANTIC_SIMILARITY",
    retries: int = 100,
    delay: int = 30,
) -> List[dict[str, Union[List[float], str]]]:
    """Embeds text data from a batch of records using a Vertex AI embedding model.

    Args:
      batch_data: A data batch containing records with text data to embed.
      model: The Vertex AI `TextEmbeddingModel` to use for generating embeddings.
      cols_to_embed: A list of column names containing the data to be embedded.
      task_type: The task type for the embedding model. Defaults to
        "SEMANTIC_SIMILARITY".
        Supported task types: https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/task-types
      retries: The maximum number of times to retry embedding generation in case
        of errors. Defaults to 100.
      delay: The delay in seconds between retries. Defaults to 30.

    Returns:
      A list of records containing ids and embeddings.
      Example:
        [
          {
            'id': 'id1',
            'col1_embedding': [1.0, 1.1, ...],
            'col2_embedding': [2.0, 2.1, ...],
            ...
          },
          ...
        ]
      where col1 and col2 are columns containing data to be embedded.
    Raises:
      Exception: Raises the encountered exception if all retries fail.
    """
    logger = logging.getLogger("embed_objects")
    global total_char_count

    # Place all of the embeddings into a single list
    inputs = []
    for row in batch_data:
        for col in cols_to_embed:
            if col in row and row[col]:
                inputs.append(TextEmbeddingInput(row[col], task_type))

    # Retry loop
    for attempt in range(retries):
        try:
            # Get embeddings for the text data
            embeddings = await model.get_embeddings_async(inputs)

            # Increase total char count
            total_char_count += sum([len(input.text) for input in inputs])

            # group the results together by id
            embedding_iter = iter(embeddings)
            results = []
            for row in batch_data:
                r = {"id": row["id"]}
                for col in cols_to_embed:
                    if col in row and row[col]:
                        r[f"{col}_embedding"] = str(next(embedding_iter).values)
                    else:
                        r[f"{col}_embedding"] = None
                results.append(r)
            return results

        except Exception as e:
            if attempt < retries - 1:  # Retry only if attempts are left
                logger.warning(f"Error: {e}. Retrying in {delay} seconds...")
                await asyncio.sleep(delay)  # Wait before retrying
            else:
                logger.error(f"Failed to get embeddings for data: {batch_data} after {retries} attempts.")
    return []


async def embed_objects_concurrently(
    cols_to_embed: List[str],
    batch_data: AsyncIterator[List[dict[str, Any]]],
    model: TextEmbeddingModel,
    task_type: str,
    max_concurrency: int = 5,
) -> AsyncIterator[List[dict[str, Union[str, List[float]]]]]:
    """Embeds text data concurrently from an asynchronous batch data generator.

    Args:
      cols_to_embed: A list of column names containing the data to be embedded.
      batch_data: A data batch containing records with text data to embed.
      model: The Vertex AI `TextEmbeddingModel` to use for generating embeddings.
      task_type: The task type for the embedding model.
        Supported task types: https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/task-types
      max_concurrency: The maximum number of embedding tasks to run concurrently.
        Defaults to 5.
    Yields:
      A list of records containing ids and embeddings.
    """
    logger = logging.getLogger("embed_objects")

    # Keep track of pending tasks
    pending: set[asyncio.Task] = set()
    has_next = True
    while pending or has_next:
        while len(pending) < max_concurrency and has_next:
            try:
                data = await batch_data.__anext__()
                coro = embed_text(data, model, cols_to_embed, task_type)
                pending.add(asyncio.ensure_future(coro))
            except StopAsyncIteration:
                has_next = False

        if pending:
            done, pending = await asyncio.wait(
                pending, return_when=asyncio.FIRST_COMPLETED
            )
            for task in done:
                result = task.result()
                logger.info(f"Embedding task completed: Processed {len(result)} rows.")
                yield result

#### Step 4: Update original table

After generating embeddings for your text data, you need to store them in your database. This step efficiently updates your original table with the newly created embeddings.

This process uses two functions to manage database updates:

**batch_update_rows**
1. This function takes a batch of data (including the embeddings) and updates the corresponding rows in your database table.
2. It constructs an SQL UPDATE query to modify specific columns with the embedding values.
3. It ensures that the updates are done efficiently and correctly within a database transaction.


**batch_update_rows_concurrently**

1. This function handles the concurrent updating of multiple batches of data.
2. It creates multiple "tasks" that each execute the batch_update_rows function on a separate batch.
3. It limits the number of concurrent tasks to avoid overloading your database and system resources.
4. It manages the execution of these tasks, ensuring that all batches are processed efficiently.

In [None]:
from sqlalchemy import text


async def batch_update_rows(
    pool: AsyncEngine, data: List[dict[str, Any]], cols_to_embed: List[str]
) -> None:
    """Updates rows in the database with embedding data.

    Args:
      pool: The AsyncEngine pool corresponding to the AlloyDB database.
      data: A data batch containing records with text embeddings.
      cols_to_embed: A list of column names containing the data to be embedded.
    """
    update_query = f"""
    UPDATE {table_name}
    SET {', '.join([f'{col}_embedding = :{col}_embedding' for col in cols_to_embed])}
    WHERE id = :id;
  """
    logger = logging.getLogger("update_rows")
    async with pool.connect() as conn:
        await conn.execute(
            text(update_query),
            # Create parameters for all rows in the data
            parameters=data,
        )
        await conn.commit()
    logger.info(f"Updated {len(data)} rows in database.")


async def batch_update_rows_concurrently(
    pool: AsyncEngine,
    embed_data: AsyncIterator[List[dict[str, Any]]],
    cols_to_embed: List[str],
    max_concurrency: int = 5,
):
    """Updates database rows concurrently with embedding data.

    Args:
      pool: The AsyncEngine pool corresponding to the AlloyDB database.
      embed_data: A data batch containing records with text embeddings.
      cols_to_embed: A list of column names containing the data to be embedded.
      max_concurrency: The maximum number of database update tasks to run concurrently.
        Defaults to 5.
    """
    logger = logging.getLogger("update_rows")
    # Keep track of pending tasks
    pending: set[asyncio.Task] = set()
    has_next = True
    while pending or has_next:
        while len(pending) < max_concurrency and has_next:
            try:
                data = await embed_data.__anext__()
                coro = batch_update_rows(pool, data, cols_to_embed)
                pending.add(asyncio.ensure_future(coro))
            except StopAsyncIteration:
                has_next = False
        if pending:
            done, pending = await asyncio.wait(
                pending, return_when=asyncio.FIRST_COMPLETED
            )

    logger.info("All database update tasks completed.")

## Run the embeddings workflow

This runs the complete embeddings workflow:

1. Gettting source data
2. Batching source data
3. Generating embeddings for batches
4. Updating data batches in the original table


In [None]:
import vertexai
import time
from vertexai.language_models import TextEmbeddingModel

### Define variables ###

# Max token count for the embeddings API
max_tokens = 20000

# For some tokenizers and text, there's a rough approximation that 1 token corresponds to about 3-4 characters.
# This is a very general guideline and can vary significantly.
max_char_count = max_tokens * 3
max_instances_per_prediction = 250

cols_to_embed = ["analysis", "overview"]

# Model to use for generating embeddings
model_name = "text-embedding-004"

# Generate optimised embeddings for a given task
# Ref: https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/task-types#supported_task_types
task = "SEMANTIC_SIMILARITY"

total_char_count = 0

### Embeddings workflow ###


async def run_embeddings_workflow(
    pool_size: int = 10,
    embed_data_concurrency: int = 20,
    batch_update_concurrency: int = 10,
):
    """Orchestrates the end-to-end workflow for generating and storing embeddings.

    The workflow includes the following major steps:

    1. Data Retrieval: Fetches data from the database that requires embedding.
    2. Batching: Divides the data into batches for optimized processing.
    3. Embedding Generation: Generates embeddings concurrently for the batched
        data using the Vertex AI model.
    4. Database Update: Updates the database concurrently with the generated
        embeddings.

    Args:
        pool_size: The size of the database connection pool. Defaults to 10.
        embed_data_concurrency: The maximum number of concurrent tasks for generating embeddings.
            Defaults to 20.
        batch_update_concurrency: The maximum number of concurrent tasks for updating the database.
            Defaults to 10.
    """
    # Set up connections to the database
    pool = await init_connection_pool(connector, database_name, pool_size=pool_size)

    # Initialise VertexAI and the model to be used to generate embeddings
    vertexai.init(project=project_id, location=region)
    model = TextEmbeddingModel.from_pretrained(model_name)

    start_time = time.monotonic()

    # Fetch source data from the database
    source_data = get_source_data(pool, cols_to_embed)

    # Divide the source data into batches for efficient processing
    batch_data = batch_source_data(source_data, cols_to_embed)

    # Generate embeddings for the batched data concurrently
    embeddings_data = embed_objects_concurrently(
        cols_to_embed, batch_data, model, task, max_concurrency=embed_data_concurrency
    )

    # Update the database with the generated embeddings concurrently
    await batch_update_rows_concurrently(
        pool, embeddings_data, cols_to_embed, max_concurrency=batch_update_concurrency
    )

    end_time = time.monotonic()
    elapsed_time = end_time - start_time

    # Release database connections and close the connector
    await pool.dispose()
    await connector.close()

    print(f"Job started at: {time.ctime(start_time)}")
    print(f"Job ended at: {time.ctime(end_time)}")
    print(f"Total run time: {elapsed_time:.2f} seconds")
    print(f"Total characters embedded: {total_char_count}")


await run_embeddings_workflow()