packages/blueprints/gen-ai-chatbot/static-assets/chatbot-genai-components/backend/python/embedding/main.py (249 lines of code) (raw):
import argparse
import json
import logging
import multiprocessing
import os
from multiprocessing.managers import ListProxy
from typing import Any
import pg8000
import requests
from app.config import DEFAULT_EMBEDDING_CONFIG
from app.repositories.common import RecordNotFoundError, _get_table_client
from app.repositories.custom_bot import (
compose_bot_id,
decompose_bot_id,
find_private_bot_by_id,
)
from app.routes.schemas.bot import type_sync_status
from app.utils import compose_upload_document_s3_path
from aws_lambda_powertools.utilities import parameters
from embedding.loaders import UrlLoader
from embedding.loaders.base import BaseLoader
from embedding.loaders.s3 import S3FileLoader
from embedding.wrapper import DocumentSplitter, Embedder
from llama_index.core.node_parser import SentenceSplitter
from retry import retry
from ulid import ULID
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
BEDROCK_REGION = os.environ.get("BEDROCK_REGION", "us-east-1")
RETRIES_TO_INSERT_TO_POSTGRES = 4
RETRY_DELAY_TO_INSERT_TO_POSTGRES = 2
RETRIES_TO_UPDATE_SYNC_STATUS = 4
RETRY_DELAY_TO_UPDATE_SYNC_STATUS = 2
DB_SECRETS_ARN = os.environ.get("DB_SECRETS_ARN", "")
DOCUMENT_BUCKET = os.environ.get("DOCUMENT_BUCKET", "documents")
METADATA_URI = os.environ.get("ECS_CONTAINER_METADATA_URI_V4")
def get_exec_id() -> str:
# Get task id from ECS metadata
# Ref: https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-metadata-endpoint-v4.html#task-metadata-endpoint-v4-enable
response = requests.get(f"{METADATA_URI}/task")
data = response.json()
task_arn = data.get("TaskARN", "")
task_id = task_arn.split("/")[-1]
return task_id
@retry(tries=RETRIES_TO_INSERT_TO_POSTGRES, delay=RETRY_DELAY_TO_INSERT_TO_POSTGRES)
def insert_to_postgres(
bot_id: str, contents: ListProxy, sources: ListProxy, embeddings: ListProxy
):
secrets: Any = parameters.get_secret(DB_SECRETS_ARN) # type: ignore
db_info = json.loads(secrets)
conn = pg8000.connect(
database=db_info["dbname"],
host=db_info["host"],
port=db_info["port"],
user=db_info["username"],
password=db_info["password"],
)
try:
with conn.cursor() as cursor:
delete_query = "DELETE FROM items WHERE botid = %s"
cursor.execute(delete_query, (bot_id,))
insert_query = f"INSERT INTO items (id, botid, content, source, embedding) VALUES (%s, %s, %s, %s, %s)"
values_to_insert = []
for i, (source, content, embedding) in enumerate(
zip(sources, contents, embeddings)
):
id_ = str(ULID())
logger.info(f"Preview of content {i}: {content[:200]}")
values_to_insert.append(
(id_, bot_id, content, source, json.dumps(embedding))
)
cursor.executemany(insert_query, values_to_insert)
conn.commit()
logger.info(f"Successfully inserted {len(values_to_insert)} records.")
except Exception as e:
conn.rollback()
raise e
finally:
conn.close()
@retry(tries=RETRIES_TO_UPDATE_SYNC_STATUS, delay=RETRY_DELAY_TO_UPDATE_SYNC_STATUS)
def update_sync_status(
user_id: str,
bot_id: str,
sync_status: type_sync_status,
sync_status_reason: str,
last_exec_id: str,
):
table = _get_table_client(user_id)
table.update_item(
Key={"PK": user_id, "SK": compose_bot_id(user_id, bot_id)},
UpdateExpression="SET SyncStatus = :sync_status, SyncStatusReason = :sync_status_reason, LastExecId = :last_exec_id",
ExpressionAttributeValues={
":sync_status": sync_status,
":sync_status_reason": sync_status_reason,
":last_exec_id": last_exec_id,
},
)
def embed(
loader: BaseLoader,
contents: ListProxy,
sources: ListProxy,
embeddings: ListProxy,
chunk_size: int,
chunk_overlap: int,
):
splitter = DocumentSplitter(
splitter=SentenceSplitter(
paragraph_separator=r"\n\n\n",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
# Use length of text as token count for cohere-multilingual-v3
tokenizer=lambda text: [0] * len(text),
)
)
embedder = Embedder(verbose=True)
documents = loader.load()
splitted = splitter.split_documents(documents)
splitted_embeddings = embedder.embed_documents(splitted)
contents.extend([t.page_content for t in splitted])
sources.extend([t.metadata["source"] for t in splitted])
embeddings.extend(splitted_embeddings)
def main(
user_id: str,
bot_id: str,
sitemap_urls: list[str],
source_urls: list[str],
filenames: list[str],
chunk_size: int,
chunk_overlap: int,
enable_partition_pdf: bool,
):
exec_id = ""
try:
exec_id = get_exec_id()
except Exception as e:
logger.error(f"[ERROR] Failed to get exec_id: {e}")
exec_id = "FAILED_TO_GET_ECS_EXEC_ID"
update_sync_status(
user_id,
bot_id,
"RUNNING",
"",
exec_id,
)
status_reason = ""
try:
if len(sitemap_urls) + len(source_urls) + len(filenames) == 0:
logger.info("No contents to embed. Skipping.")
status_reason = "No contents to embed."
update_sync_status(
user_id,
bot_id,
"SUCCEEDED",
status_reason,
exec_id,
)
return
# Calculate embeddings using LangChain
with multiprocessing.Manager() as manager:
contents: ListProxy = manager.list()
sources: ListProxy = manager.list()
embeddings: ListProxy = manager.list()
if len(source_urls) > 0:
embed(
UrlLoader(source_urls),
contents,
sources,
embeddings,
chunk_size,
chunk_overlap,
)
if len(sitemap_urls) > 0:
for sitemap_url in sitemap_urls:
raise NotImplementedError()
if len(filenames) > 0:
with multiprocessing.Pool(processes=None) as pool:
futures = [
pool.apply_async(
embed,
args=(
S3FileLoader(
bucket=DOCUMENT_BUCKET,
key=compose_upload_document_s3_path(
user_id, bot_id, filename
),
enable_partition_pdf=enable_partition_pdf,
),
contents,
sources,
embeddings,
chunk_size,
chunk_overlap,
),
)
for filename in filenames
]
for future in futures:
future.get()
logger.info(f"Number of chunks: {len(contents)}")
# Insert records into postgres
insert_to_postgres(bot_id, contents, sources, embeddings)
status_reason = "Successfully inserted to vector store."
except Exception as e:
logger.error("[ERROR] Failed to embed.")
logger.error(e)
update_sync_status(
user_id,
bot_id,
"FAILED",
f"{e}",
exec_id,
)
return
update_sync_status(
user_id,
bot_id,
"SUCCEEDED",
status_reason,
exec_id,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("Keys", type=str)
args = parser.parse_args()
keys = json.loads(args.Keys)
sk = keys["SK"]["S"]
bot_id = decompose_bot_id(sk)
pk = keys["PK"]["S"]
user_id = pk
new_image = find_private_bot_by_id(user_id, bot_id)
embedding_params = new_image.embedding_params
chunk_size = embedding_params.chunk_size
chunk_overlap = embedding_params.chunk_overlap
enable_partition_pdf = embedding_params.enable_partition_pdf
knowledge = new_image.knowledge
sitemap_urls = knowledge.sitemap_urls
source_urls = knowledge.source_urls
filenames = knowledge.filenames
logger.info(f"source_urls to crawl: {source_urls}")
logger.info(f"sitemap_urls to crawl: {sitemap_urls}")
logger.info(f"filenames: {filenames}")
logger.info(f"chunk_size: {chunk_size}")
logger.info(f"chunk_overlap: {chunk_overlap}")
logger.info(f"enable_partition_pdf: {enable_partition_pdf}")
main(
user_id,
bot_id,
sitemap_urls,
source_urls,
filenames,
chunk_size,
chunk_overlap,
enable_partition_pdf,
)