classify-split-extract-workflow/classify-job/split_and_classify.py (289 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=logging-fstring-interpolation,import-error,too-many-locals
"""
Module for splitting and classifying documents using Google Cloud Document AI.
This module includes functions for batch classification of documents, processing
classification results, splitting PDF files based on classification, and handling
metadata and callbacks.
"""
import json
import os
import re
from typing import Dict, List, Optional, Tuple
import bq_mlops
import config
from config import DOCAI_OUTPUT_BUCKET
from config import METADATA_CONFIDENCE
from config import METADATA_DOCUMENT_TYPE
from config import NO_CLASSIFIER_LABEL
from config import SPLITTER_OUTPUT_DIR
import docai_helper
import gcs_helper
from google.cloud import documentai_v1 as documentai
from google.cloud import storage
from google.cloud.documentai_toolbox import gcs_utilities
from google.cloud.documentai_v1.types.document import Document
from google.cloud.documentai_v1.types.document_processor_service import (
BatchProcessMetadata,
)
from logging_handler import Logger
from pikepdf import Pdf
import utils
storage_client = storage.Client()
logger = Logger.get_logger(__file__)
def batch_classification(
processor: documentai.types.processor.Processor,
dai_client: documentai.DocumentProcessorServiceClient,
input_uris: List[str],
) -> Optional[Dict]:
"""Performs batch classification on a list of documents using Document AI."""
logger.info(f"input_uris = {input_uris}")
if not input_uris:
return None
input_docs = [
documentai.GcsDocument(gcs_uri=doc_uri, mime_type=config.PDF_MIME_TYPE)
for doc_uri in input_uris
]
gcs_documents = documentai.GcsDocuments(documents=input_docs)
input_config = documentai.BatchDocumentsInputConfig(gcs_documents=gcs_documents)
gcs_output_uri = f"gs://{DOCAI_OUTPUT_BUCKET}"
timestamp = utils.get_utc_timestamp()
gcs_output_uri_prefix = "classifier_out_" + timestamp
destination_uri = f"{gcs_output_uri}/{gcs_output_uri_prefix}/"
output_config = documentai.DocumentOutputConfig(
gcs_output_config={"gcs_uri": destination_uri}
)
logger.info(f"input_config = {input_config}, output_config = {output_config}")
logger.info(
f"Calling DocAI API for {len(input_uris)} document(s) "
f"using {processor.display_name} processor "
f"type={processor.type_}, path={processor.name}"
)
request = documentai.types.document_processor_service.BatchProcessRequest(
name=processor.name,
input_documents=input_config,
document_output_config=output_config,
)
operation = dai_client.batch_process_documents(request)
logger.info(f"Waiting for operation {operation.operation.name} to complete...")
operation.result()
metadata: BatchProcessMetadata = documentai.BatchProcessMetadata(operation.metadata)
return process_classify_results(metadata)
def process_classify_results(metadata: BatchProcessMetadata) -> Optional[Dict]:
"""Processes the results of a classification operation."""
logger.info(f"handling classification results - operation.metadata={metadata}")
documents = {}
if metadata.state != documentai.BatchProcessMetadata.State.SUCCEEDED:
raise ValueError(f"Batch Process Failed: {metadata.state_message}")
for process in metadata.individual_process_statuses:
matches = re.match(r"gs://(.*?)/(.*)", process.output_gcs_destination)
if matches:
output_bucket, output_prefix = matches.groups()
else:
logger.error(
f"Invalid GCS destination format: {process.output_gcs_destination}"
)
continue
input_gcs_source = process.input_gcs_source
logger.info(
f"output_bucket = {output_bucket}, "
f"output_prefix={output_prefix}, "
f"input_gcs_source = {input_gcs_source}, "
f"output_gcs_destination = {process.output_gcs_destination}"
)
# Adding support for shards using toolbox after the issue is addressed
# https://github.com/googleapis/python-documentai-toolbox/issues/332
output_blob = list(
storage_client.list_blobs(output_bucket, prefix=output_prefix + "/")
)[0]
if ".json" not in output_blob.name:
logger.info(
f"Skipping non-supported file: {output_blob.name} - Mimetype: "
f"{output_blob.content_type}"
)
continue
document_out = documentai.Document.from_json(
output_blob.download_as_bytes(), ignore_unknown_fields=True
)
blob_entities = document_out.entities
if not blob_entities:
logger.info(f"No entities found for {input_gcs_source}")
continue
if is_splitting_required(blob_entities):
documents = split_pdf(input_gcs_source, blob_entities)
else:
max_confidence_entity = max(blob_entities, key=lambda item: item.confidence)
metadata = get_metadata(max_confidence_entity)
gcs_helper.add_metadata(input_gcs_source, metadata)
add_predicted_document_type(
metadata, input_gcs_source=input_gcs_source, documents=documents
)
return documents
def get_metadata(entity: Optional[Document.Entity] = None) -> Dict:
"""Get metadata from a Document AI entity."""
if not entity:
confidence = -1
document_type = NO_CLASSIFIER_LABEL
else:
confidence = round(entity.confidence, 3)
document_type = entity.type_
return {
METADATA_CONFIDENCE: confidence,
METADATA_DOCUMENT_TYPE: document_type,
}
def add_predicted_document_type(
metadata: dict, input_gcs_source: str, documents: Dict
) -> None:
"""Add predicted document type to the documents dictionary."""
classification_default_class = config.get_classification_default_class()
predicted_confidence = metadata[METADATA_CONFIDENCE]
predicted_label = metadata[METADATA_DOCUMENT_TYPE]
if check_confidence_threshold_passed(predicted_confidence):
predicted_class = config.get_document_class_by_classifier_label(predicted_label)
else:
logger.warning(
f"Using default document type={classification_default_class} for {input_gcs_source},"
f" due to low confidence={predicted_confidence}"
)
predicted_class = classification_default_class
if not predicted_class:
logger.warning(
f"No document type found for {predicted_label} and no default one defined, "
f"using the default class = {classification_default_class}"
)
predicted_class = classification_default_class
if predicted_class not in documents:
documents[predicted_class] = []
documents[predicted_class].append(input_gcs_source)
def handle_no_classifier(f_uris: List[str]) -> Dict:
"""Handles cases where no classifier is used."""
documents: Dict[str, List[str]] = {}
for uri in f_uris:
add_predicted_document_type(
get_metadata(), input_gcs_source=uri, documents=documents
)
return documents
def stream_classification_results(
call_back_url: str, bucket_name: Optional[str], file_name: Optional[str]
):
"""Streams classification results to a specified callback URL."""
logger.info(f"bucket={bucket_name}, blob_object={file_name}")
success = bool(bucket_name and file_name)
result = (
"Classification Job completed successfully, proceed with extraction"
if success
else "Classification Job failed"
)
payload = {
"result": result,
"success": success,
"bucket": bucket_name,
"object": file_name,
}
utils.send_callback_request(call_back_url, payload)
def save_classification_results(
classified_items: Dict,
) -> Tuple[Optional[str], Optional[str]]:
"""Saves classification results to Google Cloud Storage."""
payload_data = []
try:
for document_type, f_uris in classified_items.items():
model_name, out_table_name = config.get_model_name_table_name(document_type)
processor_name = config.get_parser_name_by_doc_type(document_type)
if processor_name:
processor, _ = docai_helper.get_processor_and_client(processor_name)
else:
logger.error(f"No processor found for document type: {document_type}")
continue
object_table_name = bq_mlops.object_table_create(
f_uris=f_uris, document_type=document_type
)
bq_mlops.remote_model_create(processor=processor, model_name=model_name)
payload_data.append(
{
"object_table_name": object_table_name,
"model_name": model_name,
"out_table_name": out_table_name,
}
)
if not payload_data:
logger.warning("Payload data is empty, skipping")
return None, None
prefix = utils.get_utc_timestamp()
bucket, blob_object = gcs_helper.write_data_to_gcs(
bucket_name=config.CLASSIFY_OUTPUT_BUCKET,
blob_name=f"{prefix}_{config.OUTPUT_FILE_JSON}",
content=json.dumps(payload_data, indent=4),
mime_type="application/json",
)
except (json.JSONDecodeError, OSError, ValueError) as e:
logger.error(f"Exception while saving classification results: {e}")
return None, None
logger.info(
f"Saved classification results to, bucket={bucket}, file_name={blob_object}"
)
return bucket, blob_object
def is_splitting_required(entities: List[Document.Entity]) -> bool:
"""Check if splitting is required based on entities."""
try:
return not all(
len(entity.page_anchor.page_refs) == 0
or all(not ref for ref in entity.page_anchor.page_refs)
for entity in entities
)
except AttributeError:
return False
def check_confidence_threshold_passed(predicted_confidence: float) -> bool:
"""Check if the confidence threshold is passed."""
confidence_threshold = config.get_classification_confidence_threshold()
if predicted_confidence < confidence_threshold:
logger.warning(
f"Confidence threshold not passed for "
f"{predicted_confidence} < {confidence_threshold}"
)
return False
return True
def split_pdf(gcs_uri: str, entities: List[Document.Entity]) -> Dict:
"""Splits local PDF file into multiple PDF files based on output from a
Splitter/Classifier processor.
Args:
gcs_uri (str):
Required. The path to the PDF file.
entities (List[Document.Entity]):
Required. The list of entities to be split.
Returns:
List[str]:
A list of output pdf files.
"""
documents: Dict = {}
if len(entities) == 1:
metadata = get_metadata(entities[0])
metadata.update({"original": gcs_uri})
gcs_helper.add_metadata(gcs_uri=gcs_uri, metadata=metadata)
add_predicted_document_type(
metadata=metadata, input_gcs_source=gcs_uri, documents=documents
)
else:
temp_local_dir = os.path.join(
os.path.dirname(__file__), "temp_files", utils.get_utc_timestamp()
)
if not os.path.exists(temp_local_dir):
os.makedirs(temp_local_dir)
pdf_path = os.path.join(temp_local_dir, os.path.basename(gcs_uri))
gcs_helper.download_file(gcs_uri=gcs_uri, output_filename=pdf_path)
input_filename, input_extension = os.path.splitext(os.path.basename(pdf_path))
bucket_name, _ = gcs_utilities.split_gcs_uri(gcs_uri)
with Pdf.open(pdf_path) as pdf:
for entity in entities:
subdoc_type = entity.type_ or "subdoc"
page_refs = entity.page_anchor.page_refs
if page_refs:
start_page = int(page_refs[0].page)
end_page = int(page_refs[-1].page)
else:
logger.warning(
f"Skipping {pdf_path} entity due to no page refs, no splitting"
)
continue
page_range = (
f"pg{start_page + 1}"
if start_page == end_page
else f"pg{start_page + 1}-{end_page + 1}"
)
output_filename = (
f"{input_filename}_{page_range}_{subdoc_type}{input_extension}"
)
metadata = get_metadata(entity)
metadata.update({"original": gcs_uri})
gcs_path = gcs_utilities.split_gcs_uri(os.path.dirname(gcs_uri))[1]
destination_blob_name = os.path.join(
gcs_path, SPLITTER_OUTPUT_DIR, output_filename
)
destination_blob_uri = f"gs://{bucket_name}/{destination_blob_name}"
local_out_file = os.path.join(temp_local_dir, output_filename)
subdoc = Pdf.new()
subdoc.pages.extend(pdf.pages[start_page : end_page + 1])
subdoc.save(local_out_file, min_version=pdf.pdf_version)
gcs_helper.upload_file(
bucket_name=bucket_name,
source_file_name=local_out_file,
destination_blob_name=destination_blob_name,
)
gcs_helper.add_metadata(destination_blob_uri, metadata)
add_predicted_document_type(
metadata=metadata,
input_gcs_source=destination_blob_uri,
documents=documents,
)
utils.delete_directory(temp_local_dir)
return documents