gemini/sample-apps/llamaindex-rag/backend/indexing/run_parse_embed_index.py (229 lines of code) (raw):

"""Master script for parsing, embedding and indexing data living in a GCS bucket""" import asyncio import logging import os from backend.indexing.docai_parser import DocAIParser from backend.indexing.prompts import QA_EXTRACTION_PROMPT, QA_PARSER_PROMPT from backend.indexing.vector_search_utils import ( get_or_create_existing_index, ) # noqa: E501 from common.utils import ( create_pdf_blob_list, download_bucket_with_transfer_manager, link_nodes, ) from google.cloud import aiplatform from llama_index.core import Document, Settings, StorageContext, VectorStoreIndex from llama_index.core.extractors import QuestionsAnsweredExtractor from llama_index.core.node_parser import HierarchicalNodeParser, SentenceSplitter from llama_index.core.program import LLMTextCompletionProgram from llama_index.core.schema import NodeRelationship, RelatedNodeInfo, TextNode from llama_index.embeddings.vertex import VertexTextEmbedding from llama_index.llms.vertex import Vertex from llama_index.storage.docstore.firestore import FirestoreDocumentStore from llama_index.vector_stores.vertexaivectorsearch import VertexAIVectorStore from pydantic import BaseModel from tqdm.asyncio import tqdm_asyncio import yaml logging.basicConfig(level=logging.INFO) # Set the desired logging level logger = logging.getLogger(__name__) # Load configuration from config.yaml def load_config(): config_path = os.path.join( os.path.dirname(__file__), "..", "..", "common", "config.yaml" ) with open(config_path) as config_file: return yaml.safe_load(config_file) # Load configuration config = load_config() # Initialize parameters PROJECT_ID = config["project_id"] LOCATION = config["location"] INPUT_BUCKET_NAME = config["input_bucket_name"] DOCSTORE_BUCKET_NAME = config["docstore_bucket_name"] INDEX_ID = config["index_id"] VECTOR_INDEX_NAME = config["vector_index_name"] INDEX_ENDPOINT_NAME = config["index_endpoint_name"] INDEXING_METHOD = config["indexing_method"] CHUNK_SIZES = config["chunk_sizes"] EMBEDDINGS_MODEL_NAME = config["embeddings_model_name"] APPROXIMATE_NEIGHBORS_COUNT = config["approximate_neighbors_count"] BUCKET_PREFIX = config["bucket_prefix"] VECTOR_DATA_PREFIX = config["vector_data_prefix"] CHUNK_SIZE = config.get("chunk_size", 512) CHUNK_OVERLAP = config.get("chunk_overlap", 50) DOCAI_LOCATION = config["docai_location"] DOCAI_PROCESSOR_DISPLAY_NAME = config["document_ai_processor_display_name"] DOCAI_PROCESSOR_ID = config.get("docai_processor_id") CREATE_DOCAI_PROCESSOR = config.get("create_docai_processor", False) FIRESTORE_DB_NAME = config.get("firestore_db_name") FIRESTORE_NAMESPACE = config.get("firestore_namespace") QA_INDEX_NAME = config.get("qa_index_name") QA_ENDPOINT_NAME = config.get("qa_endpoint_name") class QuesionsAnswered(BaseModel): """List of Questions Answered by Document""" questions_list: list[str] def create_qa_index(li_docs, docstore, embed_model, llm): """creates index of hypothetical questions""" qa_index, qa_endpoint = get_or_create_existing_index( QA_INDEX_NAME, QA_ENDPOINT_NAME, APPROXIMATE_NEIGHBORS_COUNT ) qa_vector_store = VertexAIVectorStore( project_id=PROJECT_ID, region=LOCATION, index_id=qa_index.name, # Use .name instead of .resource_name endpoint_id=qa_endpoint.name, gcs_bucket_name=DOCSTORE_BUCKET_NAME, ) qa_extractor = QuestionsAnsweredExtractor( llm, questions=5, prompt_template=QA_EXTRACTION_PROMPT ) async def extract_batch(li_docs): return await tqdm_asyncio.gather( *[qa_extractor._aextract_questions_from_node(doc) for doc in li_docs] ) loop = asyncio.get_event_loop() metadata_list = loop.run_until_complete(extract_batch(li_docs)) program = LLMTextCompletionProgram.from_defaults( output_cls=QuesionsAnswered, prompt_template_str=QA_PARSER_PROMPT, verbose=True, ) async def parse_batch(metadata_list): return await asyncio.gather( *[program.acall(questions_list=x) for x in metadata_list], return_exceptions=True, ) parsed_questions = loop.run_until_complete(parse_batch(metadata_list)) loop.close() q_docs = [] for doc, questions in zip(li_docs, parsed_questions): if isinstance(questions, Exception): logger.info(f"Unparsable questions exception {questions}") continue else: for q in questions.questions_list: logger.info(f"Question extracted: {q}") q_doc = Document(text=q) q_doc.relationships[NodeRelationship.SOURCE] = RelatedNodeInfo( node_id=doc.doc_id ) q_docs.append(q_doc) docstore.add_documents(li_docs) storage_context = StorageContext.from_defaults( docstore=docstore, vector_store=qa_vector_store ) VectorStoreIndex( nodes=q_docs, storage_context=storage_context, embed_model=embed_model, llm=llm, ) def create_hierarchical_index(li_docs, docstore, vector_store, embed_model, llm): # Let hierarchical node parser take care of granular chunking node_parser = HierarchicalNodeParser.from_defaults(chunk_sizes=CHUNK_SIZES) nodes = node_parser.get_nodes_from_documents(li_docs) leaf_nodes = node_parser.get_leaf_nodes(nodes) num_leaf_nodes = len(leaf_nodes) num_nodes = len(nodes) logger.info(f"There are {num_leaf_nodes} leaf_nodes and {num_nodes} total nodes") docstore.add_documents(nodes) storage_context = StorageContext.from_defaults( docstore=docstore, vector_store=vector_store ) VectorStoreIndex( nodes=leaf_nodes, storage_context=storage_context, embed_model=embed_model, llm=llm, ) def create_flat_index(li_docs, docstore, vector_store, embed_model, llm): sentence_splitter = SentenceSplitter(chunk_size=CHUNK_OVERLAP) # Chunk into granular chunks manually node_chunk_list = [] for doc in li_docs: doc_dict = doc.to_dict() metadata = doc_dict.pop("metadata") doc_dict.update(metadata) chunks = sentence_splitter.get_nodes_from_documents([doc]) # Create nodes with relationships and flatten nodes = [] for chunk in chunks: text = chunk.pop("text") doc_source_id = doc.doc_id node = TextNode(text=text, metadata=chunk) node.relationships[NodeRelationship.SOURCE] = RelatedNodeInfo( node_id=doc_source_id ) nodes.append(node) nodes = link_nodes(nodes) node_chunk_list.extend(nodes) nodes = node_chunk_list logger.info("embedding...") docstore.add_documents(li_docs) storage_context = StorageContext.from_defaults( docstore=docstore, vector_store=vector_store ) for node in nodes: node.metadata.pop("excluded_embed_metadata_keys", None) node.metadata.pop("excluded_llm_metadata_keys", None) # Creating an index automatically embeds and creates the # vector db collection VectorStoreIndex( nodes=nodes, storage_context=storage_context, embed_model=embed_model, llm=llm, ) def main(): """Main parsing, embedding and indexing logic for data living in GCS""" # Initialize Vertex AI and create index and endpoint aiplatform.init(project=PROJECT_ID, location=LOCATION) # Creating Vector Search Index vs_index, vs_endpoint = get_or_create_existing_index( VECTOR_INDEX_NAME, INDEX_ENDPOINT_NAME, APPROXIMATE_NEIGHBORS_COUNT ) # Vertex AI Vector Search Vector DB and Firestore Docstore vector_store = VertexAIVectorStore( project_id=PROJECT_ID, region=LOCATION, index_id=vs_index.name, # Use .name instead of .resource_name endpoint_id=vs_endpoint.name, # Use .name instead of .resource_name gcs_bucket_name=DOCSTORE_BUCKET_NAME, ) docstore = FirestoreDocumentStore.from_database( project=PROJECT_ID, database=FIRESTORE_DB_NAME, namespace=FIRESTORE_NAMESPACE ) # Setup embedding model and LLM embed_model = VertexTextEmbedding( model_name=EMBEDDINGS_MODEL_NAME, project=PROJECT_ID, location=LOCATION ) llm = Vertex(model="gemini-2.0-flash", temperature=0.0) Settings.llm = llm Settings.embed_model = embed_model # Initialize Document AI parser GCS_OUTPUT_PATH = f"gs://{DOCSTORE_BUCKET_NAME}/{VECTOR_DATA_PREFIX}/docai_output/" parser = DocAIParser( project_id=PROJECT_ID, location=DOCAI_LOCATION, processor_name=f"projects/{PROJECT_ID}/locations/{DOCAI_LOCATION}/processors/{DOCAI_PROCESSOR_ID}", # noqa: E501 gcs_output_path=GCS_OUTPUT_PATH, ) # Download data from specified bucket and parse local_data_path = os.path.join("/tmp", BUCKET_PREFIX) os.makedirs(local_data_path, exist_ok=True) blobs = create_pdf_blob_list(INPUT_BUCKET_NAME, BUCKET_PREFIX) logger.info("downloading data") download_bucket_with_transfer_manager( INPUT_BUCKET_NAME, prefix=BUCKET_PREFIX, destination_directory=local_data_path ) # Parse documents using Document AI try: parsed_docs, raw_results = parser.batch_parse( blobs, chunk_size=CHUNK_SIZE, include_ancestor_headings=True ) print(f"Number of documents parsed by Document AI: {len(parsed_docs)}") if parsed_docs: print( f"First parsed document text (first 100 chars): {parsed_docs[0].text[:100]}..." # noqa: E501 ) else: print("No documents were parsed by Document AI.") # Print raw results for debugging print("Raw results:") for result in raw_results: print(f" Source: {result.source_path}") print(f" Parsed: {result.parsed_path}") except Exception as e: print(f"Error processing single document: {str(e)}") parsed_docs = [] raw_results = [] # Turn each parsed document into a llamaindex Document li_docs = [Document(text=doc.text, metadata=doc.metadata) for doc in parsed_docs] if QA_INDEX_NAME or QA_ENDPOINT_NAME: create_qa_index(li_docs, docstore, embed_model, llm) if INDEXING_METHOD == "hierarchical": create_hierarchical_index(li_docs, docstore, vector_store, embed_model, llm) elif INDEXING_METHOD == "flat": create_flat_index(li_docs, docstore, vector_store, embed_model, llm) if __name__ == "__main__": main()