gemini/use-cases/applying-llms-to-data/gemini-and-documentai-for-entity-extraction/extractor.py (119 lines of code) (raw):
import re
from typing import Any, Optional
from google.api_core.client_options import ClientOptions
from google.api_core.exceptions import InternalServerError, RetryError
from google.cloud import documentai, storage
from temp_file_uploader import TempFileUploader
class DocumentExtractor:
"""Abstract base class for document extraction."""
def __init__(
self,
project_id: str,
location: str,
processor_id: str,
processor_version_id: Optional[str] = None,
):
self.project_id = project_id
self.location = location
self.processor_id = processor_id
self.processor_version_id = processor_version_id
self.client = documentai.DocumentProcessorServiceClient(
client_options=ClientOptions(
api_endpoint=f"{location}-documentai.googleapis.com"
)
)
self.processor_name = self._get_proccessor_name()
def _get_proccessor_name(self) -> Any:
if self.processor_version_id:
return self.client.processor_version_path(
self.project_id,
self.location,
self.processor_id,
self.processor_version_id,
)
return self.client.processor_path(
self.project_id, self.location, self.processor_id
)
def process_document(self, file_path: str, mime_type: str) -> documentai.Document:
"""abstract function for document processing"""
raise NotImplementedError
class OnlineDocumentExtractor(DocumentExtractor):
"""
Processes documents using the online Document AI API.
"""
def process_document(
self, file_path: str, mime_type: str = "application/pdf"
) -> documentai.Document:
with open(file_path, "rb") as image:
image_content = image.read()
request = documentai.ProcessRequest(
name=self.processor_name,
raw_document=documentai.RawDocument(
content=image_content, mime_type=mime_type
),
)
result = self.client.process_document(request=request)
return result.document
class BatchDocumentExtractor(DocumentExtractor):
"""
Processes documents using the batch Document AI API.
"""
# pylint: disable=too-many-arguments
def __init__(
self,
project_id: str,
location: str,
processor_id: str,
gcs_output_uri: str,
gcs_temp_uri: str,
processor_version_id: str,
timeout: int = 400,
):
super().__init__(project_id, location, processor_id, processor_version_id)
self.gcs_output_uri = gcs_output_uri
self.timeout = timeout
self.storage_client = storage.Client()
self.temp_file_uploader = TempFileUploader(gcs_temp_uri)
def process_document(self, file_path: str, mime_type: str) -> documentai.Document:
gcs_input_uri = self.temp_file_uploader.upload_file(file_path)
document = self._process_document_batch(gcs_input_uri, mime_type)
self.temp_file_uploader.delete_file()
return document
# pylint: disable=too-many-locals
def _process_document_batch(
self, gcs_input_uri: str, mime_type: str
) -> documentai.Document:
gcs_document = documentai.GcsDocument(
gcs_uri=gcs_input_uri, mime_type=mime_type
)
gcs_documents = documentai.GcsDocuments(documents=[gcs_document])
input_config = documentai.BatchDocumentsInputConfig(gcs_documents=gcs_documents)
gcs_output_config = documentai.DocumentOutputConfig.GcsOutputConfig(
gcs_uri=self.gcs_output_uri
)
output_config = documentai.DocumentOutputConfig(
gcs_output_config=gcs_output_config
)
request = documentai.BatchProcessRequest(
name=self.processor_name,
input_documents=input_config,
document_output_config=output_config,
)
operation = self.client.batch_process_documents(request)
try:
print(f"Waiting for operation ({operation.operation.name}) to complete...")
operation.result(timeout=self.timeout)
except (RetryError, InternalServerError) as e:
print(e.message)
metadata = documentai.BatchProcessMetadata(operation.metadata)
if metadata.state != documentai.BatchProcessMetadata.State.SUCCEEDED:
raise ValueError(f"Batch Process Failed: {metadata.state_message}")
# Retrieve the processed document from GCS
for process in list(metadata.individual_process_statuses):
matches = re.match(r"gs://(.*?)/(.*)", process.output_gcs_destination)
if not matches:
print(
"Could not parse output GCS destination:",
process.output_gcs_destination,
)
continue
output_bucket, output_prefix = matches.groups()
output_blobs = self.storage_client.list_blobs(
output_bucket, prefix=output_prefix
)
for blob in output_blobs:
if blob.content_type == "application/json":
print(f"Fetching {blob.name}")
return documentai.Document.from_json(
blob.download_as_bytes(), ignore_unknown_fields=True
)
raise FileNotFoundError("Processed document not found in GCS.")