gemini/sample-apps/llamaindex-rag/backend/indexing/docai_parser.py (218 lines of code) (raw):
import json
import logging
import time
import traceback
from google.api_core.client_options import ClientOptions
from google.cloud import documentai, storage
from google.cloud.storage import Blob
from llama_index.core import Document
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DocAIParser:
"""
Class for interfacing with DocAIParser
"""
def __init__(
self,
project_id: str,
location: str,
processor_name: str,
gcs_output_path: str,
):
self.project_id = project_id
self.location = location
self.processor_name = processor_name
self.gcs_output_path = gcs_output_path
self._client = self._initialize_client()
def _initialize_client(self):
options = ClientOptions(
api_endpoint=f"{self.location}-documentai.googleapis.com"
)
return documentai.DocumentProcessorServiceClient(client_options=options)
def batch_parse(
self,
blobs: list[Blob],
chunk_size: int = 500,
include_ancestor_headings: bool = True,
timeout_sec: int = 3600,
check_in_interval_sec: int = 60,
) -> tuple[list[Document], list["DocAIParsingResults"]]: # noqa: F821
"""
Parses a list of blobs using Document AI.
Args:
blobs: List of GCS Blobs to parse.
chunk_size: Chunk size for Document AI processing.
include_ancestor_headings: Whether to include ancestor headings.
timeout_sec: Timeout in seconds for the operation.
check_in_interval_sec: Check-in interval in seconds.
Returns:
A tuple containing a list of parsed documents and a list of
DocAIParsingResults.
"""
try:
operations = self._start_batch_process(
blobs, chunk_size, include_ancestor_headings
)
print(f"Number of operations started: {len(operations)}")
self._wait_for_operations(operations, timeout_sec, check_in_interval_sec)
print("Operations completed successfully")
for i, operation in enumerate(operations):
print(f"Operation {i + 1} metadata: {operation.metadata}")
results = self._get_results(operations)
print(f"Number of results: {len(results)}")
parsed_docs = self._parse_from_results(results)
print(f"Number of parsed documents: {len(parsed_docs)}")
return parsed_docs, results
except Exception as e:
print(f"Error in batch_parse: {str(e)}")
traceback.print_exc()
# Return any successfully parsed documents
# instead of raising an exception
return [], []
def _start_batch_process(
self, blobs: list[Blob], chunk_size: int, include_ancestor_headings: bool
):
input_config = documentai.BatchDocumentsInputConfig(
gcs_documents=documentai.GcsDocuments(
documents=[
documentai.GcsDocument(
gcs_uri=blob.path,
mime_type=blob.mimetype or "application/pdf",
)
for blob in blobs
]
)
)
output_config = documentai.DocumentOutputConfig(
gcs_output_config=documentai.DocumentOutputConfig.GcsOutputConfig(
gcs_uri=self.gcs_output_path
)
)
layout_config = documentai.ProcessOptions.LayoutConfig(
chunking_config=documentai.ProcessOptions.LayoutConfig.ChunkingConfig(
chunk_size=chunk_size,
include_ancestor_headings=include_ancestor_headings,
)
)
process_options = documentai.ProcessOptions(layout_config=layout_config)
request = documentai.BatchProcessRequest(
name=self.processor_name,
input_documents=input_config,
document_output_config=output_config,
process_options=process_options,
skip_human_review=True,
)
try:
operation = self._client.batch_process_documents(request)
print(f"Batch process started. Operation: {operation}")
return [operation]
except Exception as e:
print(f"Error starting batch process: {str(e)}")
raise
def _wait_for_operations(self, operations, timeout_sec, check_in_interval_sec):
time_elapsed = 0
while any(not operation.done() for operation in operations):
time.sleep(check_in_interval_sec)
time_elapsed += check_in_interval_sec
if time_elapsed > timeout_sec:
raise TimeoutError("Timeout exceeded!")
# Check for errors in completed operations
for operation in operations:
if operation.exception():
raise KeyError(f"Operation failed: {operation.exception()}")
def _get_results(self, operations) -> list["DocAIParsingResults"]: # noqa: F821
results = []
for operation in operations:
metadata = operation.metadata
if hasattr(metadata, "individual_process_statuses"):
for status in metadata.individual_process_statuses:
results.append(
DocAIParsingResults(
source_path=status.input_gcs_source,
parsed_path=status.output_gcs_destination,
)
)
else:
print(f"Warning: Unexpected metadata structure: {metadata}")
return results
def _parse_from_results(self, results: list["DocAIParsingResults"]): # noqa: F821
documents = []
storage_client = storage.Client()
for result in results:
print(
f"Processing result: source_path={result.source_path}, "
f"parsed_path={result.parsed_path}"
)
if not result.parsed_path:
print(
"Warning: Empty parsed_path for source "
f"{result.source_path}. Skipping."
)
continue
try:
bucket_name, prefix = result.parsed_path.replace("gs://", "").split(
"/", 1
)
except ValueError:
print(
f"Error: Invalid parsed_path format for {result.source_path}. Skipping."
)
continue
bucket = storage_client.bucket(bucket_name)
blobs = list(bucket.list_blobs(prefix=prefix))
print(f"Found {len(blobs)} blobs in {result.parsed_path}")
for blob in blobs:
if blob.name.endswith(".json"):
print(f"Processing JSON blob: {blob.name}")
try:
content = blob.download_as_text()
doc_data = json.loads(content)
if (
"chunkedDocument" in doc_data
and "chunks" in doc_data["chunkedDocument"]
):
for chunk in doc_data["chunkedDocument"]["chunks"]:
doc = Document(
text=chunk["content"],
metadata={
"chunk_id": chunk["chunkId"],
"source": result.source_path,
},
)
documents.append(doc)
else:
print(
"Warning: Expected 'chunkedDocument' "
f"structure not found in {blob.name}"
)
except Exception as e:
print(f"Error processing blob {blob.name}: {str(e)}")
print(f"Total documents created: {len(documents)}")
return documents
class DocAIParsingResults:
"""
Document AI Parsing Results
"""
def __init__(self, source_path: str, parsed_path: str):
self.source_path = source_path
self.parsed_path = parsed_path
def get_or_create_docai_processor(
project_id: str,
location: str,
processor_display_name: str,
processor_id: str | None = None,
create_new: bool = False,
processor_type: str = "LAYOUT_PARSER_PROCESSOR",
) -> documentai.Processor:
client_options = ClientOptions(
api_endpoint=f"{location}-documentai.googleapis.com",
quota_project_id=project_id,
)
client = documentai.DocumentProcessorServiceClient(client_options=client_options)
if not create_new:
if processor_id:
# Try to get the existing processor by ID
name = client.processor_path(project_id, location, processor_id)
try:
return client.get_processor(name=name)
except Exception as e:
print(f"Error getting processor by ID: {e}")
print("Falling back to searching by display name...")
# Search for the processor by display name
parent = client.common_location_path(project_id, location)
processors = [
p
for p in client.list_processors(parent=parent)
if p.display_name == processor_display_name
]
if processors:
return processors[0]
elif not create_new:
raise ValueError(
f"No processor found with display name "
f"'{processor_display_name}' and create_new is False"
)
# If we reach here, we need to create a new processor
parent = client.common_location_path(project_id, location)
return client.create_processor(
parent=parent,
processor=documentai.Processor(
display_name=processor_display_name, type_=processor_type
),
)