# Dataset OCR Upgrade Tool

* Author: docai-incubator@google.com

## Disclaimer

This tool is not supported by the Google engineering team or product team. It is provided and supported on a best-effort basis by the **DocAI Incubator Team**. No guarantees of performance are implied.


## Background

[Effective April 9, 2025](https://cloud.google.com/document-ai/docs/release-notes#September_26_2024), the following Custom Extractor versions will no longer be accessible:

* pretrained-foundation-model-v1.0-2023-08-22
* pretrained-foundation-model-v1.1-2024-03-12

You will need to migrate to a later version to avoid any service disruptions, such as : 

* pretrained-foundation-model-v1.2-2024-05-10 and 
* pretrained-foundation-model-v1.3-2024-08-31 

for improved quality from the latest proprietary vision models and foundation models.

**NOTE :**  [Effective April 9 2025](https://cloud.google.com/document-ai/docs/release-notes#September_26_2024) OCR versions pretrained-ocr-v1.0-2020-09-23 and pretrained-ocr-v1.1-2022-09-12 will be discontinued in the US and EU regions. 

## Need for Dataset Label OCR Upgrade

Simply **using** the new Foundation Models (FM’s) : 

* **v1.2 (pretrained-foundation-model-v1.2-2024-05-10)** and 
* **v1.3 (pretrained-foundation-model-v1.3-2024-08-31)**

For few-shot (5 labeled documents) or fine-tuning (1 or more labeled documents for both test and training) with an existing Custom Extractor processor (created before Sept. 29th, 2024) **is not recommended**. These new FM’s use OCR 2.1, where previous Foundation Models used OCR 2.0. The recommendation is to **UPGRADE the OCR version that the dataset is labeled with**.
For a graphical explanation of the differences between the versions and why this is necessary, please refer to this graph:

<img src="./Images/first_image.png" width=800 height=800 alt="Graphical Explanation">

The reason for this is that there are TWO OCR engines involved with any processor:

* one for the **dataset** (used for autolabeling and test document processing) and
* one associated with the **specific Foundation Model** for prediction.

These can become mismatched, which is the case with previously existing processors and these new Foundation Models.

All custom processors before Sept. 29th, 2024 used **OCR 2.0 for the dataset** (regardless of the prediction OCR engine associated with the specific Foundation Model being used for prediction).  Those created after  Sept. 29th, 2024 use the **new OCR 2.1 for the dataset**. When there is a mismatch of the training dataset OCR and the prediction OCR, accuracy is often not ideal. Therefore the strong recommendation is that in order to use the new FM versions (listed above) any labeled documents in a dataset created before Sept 29th, 2024 **should be re-labeled** in order to  have their OCR version match the FM version. Basically, the dataset documents need to be relabeled. Fortunately, **this tool makes that easy**.

## Objective

This document is intended as a guide to help with the migration of datasets between older processors versions (before  Sept. 29th, 2024) and new (as of Sept. 29th, 2024) created processors.

To convert the OCR in a labeled dataset from an older version to an updated version in 3 steps: 

1. **Exporting** datasets from a processor
2. **Reprocessing** exported OCR-labeled JSON data with an updated OCR engine which is used in the new processor.
3. **Importing** the enhanced OCR JSON files into a new processor, while also migrating the **schema** from the original processor.

<img src="./Images/second_image.png" width=800 height=800 alt="OCR Upgrades">

## Prerequisites

* Vertex AI Notebook.
* Storage Bucket for storing exported json files and output JSON files.
* Permission For Google DocAI Processors, Storage and Vertex AI Notebook.
* A new DocAI Processor to receive the upgraded OCR Labeled documents in the dataset

## Step by step procedure

The procedure is basically to: 
* Load the following library and code into your Vertex AI notebook, 
* Edit to provide the required input parameters, then
* Run the code

After running this code, the provided processor will get the updated JSONs with the new OCR data, along with the schema.

In [None]:
!wget https://raw.githubusercontent.com/GoogleCloudPlatform/document-ai-samples/main/incubator-tools/best-practices/utilities/utilities.py

### 1. Install the required libraries 

Use the pip command to install the required libraries before executing the code.

In [None]:
!pip install google-cloud-documentai google-cloud-storage PyPDF2

### 2. Import the libraries
Import the necessary libraries utilized in the code. If you encounter any import errors, use the pip command to install the missing libraries.


In [None]:
from google.cloud import documentai_v1beta3 as documentai
from google.api_core.client_options import ClientOptions
import json
from pathlib import Path
from tqdm import tqdm
from google.cloud import storage
from typing import (
    Container,
    Iterable,
    Iterator,
    List,
    Mapping,
    Optional,
    Sequence,
    Tuple,
    Union,
)
from io import BytesIO
from pprint import pprint
import copy
from PIL import Image
from PyPDF2 import PdfFileReader
import io
from tqdm.notebook import tqdm
import time
import concurrent.futures
import utilities

### 3. Set up the required inputs

In [None]:
project_id = "xxxxxxxxx"
location = "xxxxx"
processor_id = "xxxxxxxxxxx"
export_dataset_path = "gs://bucket/path/to/export_dataset/"
updated_ocr_files_path = "gs://bucket/path/to/updated_dataset/"
location = "xxxxx"
new_processor_id = "xxxxxxxx"
new_version_id = "pretrained-foundation-model-v1.3-2024-08-31"
offset = 0.005  # To Expand the Existing bounding box in order to get all the tokens corrosponding to the entities. Can adjust with optimal value.

* `project_id`: Provide project ID 
* `location`: Provide Processor Location (“us” or “eu”)
* `processor_id`: Provide GCP DocumentAI Processor ID where you have labeled jsons with old OCR version
* `export_dataset_path`: Provide GCS path to store Exported JSON files
* `updated_ocr_files_path`: Provide GCS path to store OCR updated JSON files
* `new_processor_id`: Provide GCP DocumentAI new Processor ID to import updated documents 
* `new_version_id`: Provide the version id of the processor to use the updated OCR

### 4. Run the code

In [None]:
def create_pdf_bytes(json: str) -> bytes:
    """
    Creates PDF bytes from image content in a JSON document (typically ground truth data),
    which is used for further processing of files. This function decodes image data and
    combines them into a single PDF.

    Args:
        json (str): The JSON string representing the ground truth data, typically retrieved
        from Google Cloud's Document AI output or other sources. The JSON should contain image data in
        its content field.

    Returns:
        bytes: A byte representation of the generated PDF containing all images.

    Raises:
        ValueError: If no images are found in the input JSON or an invalid image format is encountered.

    Example:
        json_str = '{"pages": [{"image": {"content": "<image_bytes_in_base64>"}}]}'
        pdf_bytes = create_pdf_bytes(json_str)
    """
    from google.cloud import documentai_v1beta3

    def decode_image(image_bytes: bytes) -> Image.Image:
        """Decodes image bytes into a PIL Image object."""
        with io.BytesIO(image_bytes) as image_file:
            image = Image.open(image_file)
            image.load()
        return image

    def create_pdf_from_images(images: Sequence[Image.Image]) -> bytes:
        """Creates a PDF from a sequence of images.

        Args:
            images: A sequence of images to be included in the PDF.

        Returns:
            bytes: The PDF bytes generated from the images.

        Raises:
            ValueError: If no images are provided.
        """
        if not images:
            raise ValueError("At least one image is required to create a PDF")

        # PIL PDF saver does not support RGBA images
        images = [
            image.convert("RGB") if image.mode == "RGBA" else image for image in images
        ]

        with io.BytesIO() as pdf_file:
            images[0].save(
                pdf_file, save_all=True, append_images=images[1:], format="PDF"
            )
            return pdf_file.getvalue()

    d = documentai_v1beta3.Document
    document = d.from_json(json)
    synthesized_images = []
    for i in range(len(document.pages)):
        synthesized_images.append(decode_image(document.pages[i].image.content))
    pdf_bytes = create_pdf_from_images(synthesized_images)

    return pdf_bytes


def process_document_sample(
    project_id: str,
    location: str,
    processor_id: str,
    file_path: str,
    processor_version_id: Optional[str] = None,
    mime_type: Optional[str] = "application/pdf",
    field_mask: Optional[str] = None,
) -> documentai.ProcessResponse:
    """
    Processes a document using a specified Document AI processor in Google Cloud and
    returns the processed result. This function reads a file, processes it through a Document AI processor,
    and retrieves the result which may include text extraction, form parsing, etc.

    Args:
        project_id (str): The Google Cloud project ID where the Document AI processor is located.
        location (str): The location/region of the Document AI processor (e.g., 'us', 'eu').
        processor_id (str): The ID of the Document AI processor to use for processing.
        file_path (str): The local path or in-memory string content of the document to be processed.
        processor_version_id (Optional[str], optional): The specific processor version to use, if any.
            If not provided, the default processor version will be used. Defaults to None.
        mime_type (Optional[str], optional): The MIME type of the document. Defaults to 'application/pdf'.
        field_mask (Optional[str], optional): Field mask specifying the parts of the document to process.
            If not provided, the entire document will be processed. Defaults to None.

    Returns:
        documentai.ProcessResponse: The response object containing the processed document data from the processor.
    """
    # You must set the `api_endpoint` if you use a location other than "us".
    opts = ClientOptions(api_endpoint=f"{location}-documentai.googleapis.com")
    client = documentai.DocumentProcessorServiceClient(client_options=opts)
    if processor_version_id:
        name = client.processor_version_path(
            project_id, location, processor_id, processor_version_id
        )
    else:
        name = client.processor_path(project_id, location, processor_id)
    # Read the file into memory
    image_content = file_path
    # Load binary data
    raw_document = documentai.RawDocument(content=image_content, mime_type=mime_type)
    request = documentai.ProcessRequest(
        name=name,
        raw_document=raw_document,
        field_mask=field_mask,
        # process_options=process_options,
    )
    result = client.process_document(request=request)
    # Read the text recognition output from the processor
    return result


def get(entity: dict, arg: str) -> float:
    """
    Extracts the specified bounding box coordinate (x_min, y_min, x_max, y_max) from the entity object.
    This function calculates the minimum or maximum x and y coordinates of the bounding box around the
    entity based on normalized vertices.

    Args:
        entity (dict): The entity dictionary that contains the bounding box information.
            It may have different key formats (`pageAnchor` or `page_anchor`) depending on the format.
        arg (str): The coordinate to extract. Valid values are 'x_min', 'y_min', 'x_max', 'y_max'.

    Returns:
        float: The value of the requested coordinate (minimum or maximum of x or y).

    Raises:
        ValueError: If an invalid argument is passed for `arg`.

    Example:
        entity = {
            "pageAnchor": {
                "pageRefs": [{
                    "boundingPoly": {
                        "normalizedVertices": [{"x": 0.1, "y": 0.2}, {"x": 0.4, "y": 0.5}]
                    }
                }]
            }
        }
        x_min = get(entity, 'x_min')
        y_max = get(entity, 'y_max')
    """
    x_list = []
    y_list = []
    if "pageAnchor" in entity.keys():
        for i in entity["pageAnchor"]["pageRefs"]:
            for j in i["boundingPoly"]["normalizedVertices"]:
                x_list.append(j["x"])
                y_list.append(j["y"])

        if arg == "x_min":
            return min(x_list)
        if arg == "y_min":
            return min(y_list)
        if arg == "x_max":
            return max(x_list)
        if arg == "y_max":
            return max(y_list)
    else:
        for i in entity["page_anchor"]["page_refs"]:
            for j in i["bounding_poly"]["normalized_vertices"]:
                x_list.append(j["x"])
                y_list.append(j["y"])

        if arg == "x_min":
            return min(x_list)
        if arg == "y_min":
            return min(y_list)
        if arg == "x_max":
            return max(x_list)
        if arg == "y_max":
            return max(y_list)


def find_textSegment_list(
    x_min: float, y_min: float, x_max: float, y_max: float, js: dict, page: int
) -> List[dict]:
    """
    Finds and returns a list of text segments within a specified bounding box (defined by `x_min`, `y_min`,
    `x_max`, `y_max`) from the tokens on a given page of a Document AI JSON structure.

    Args:
        x_min (float): The minimum x-coordinate of the bounding box.
        y_min (float): The minimum y-coordinate of the bounding box.
        x_max (float): The maximum x-coordinate of the bounding box.
        y_max (float): The maximum y-coordinate of the bounding box.
        js (dict): The JSON data (in Document AI format) containing page and token information.
        page (int): The page number from which to extract the text segments.

    Returns:
        List[dict]: A list of text segments (from `text_anchor`) that fall within the specified bounding box.

    Example:
        text_segments = find_textSegment_list(
            x_min=0.1, y_min=0.2, x_max=0.4, y_max=0.5,
            js=document_json, page=0
        )
    """
    textSegments_list = []
    for token in js["pages"][page]["tokens"]:
        token_xMin = get_token(token, "x_min")
        token_xMax = get_token(token, "x_max")
        token_yMin = get_token(token, "y_min")
        token_yMax = get_token(token, "y_max")
        if (
            token_xMin >= x_min
            and token_xMax <= x_max
            and token_yMin >= y_min
            and token_yMax < y_max
        ):
            textSegments_list.extend(token["layout"]["text_anchor"]["text_segments"])
    return textSegments_list


def get_token(token: dict, param: str) -> float:
    """
    Retrieves the specified bounding box coordinate (x_min, y_min, x_max, y_max) from a token's layout information.

    This function extracts the list of normalized vertices (x, y coordinates) from the bounding box of the token
    and returns the minimum or maximum value based on the requested parameter.

    Args:
        token (dict): The token dictionary that contains the bounding box (in the 'layout' field).
        param (str): The coordinate to extract. Valid values are 'x_min', 'x_max', 'y_min', 'y_max'.

    Returns:
        float: The value of the requested coordinate (minimum or maximum of x or y).
    """
    x_list = []
    y_list = []
    for j in token["layout"]["bounding_poly"]["normalized_vertices"]:
        x_list.append(j["x"])
        y_list.append(j["y"])
    if param == "x_min":
        return min(x_list)
    if param == "x_max":
        return max(x_list)
    if param == "y_min":
        return min(y_list)
    if param == "y_max":
        return max(y_list)


def update_text_anchors_mention_text(entity: dict, js: dict, new_js: dict) -> dict:
    """
    Updates the text anchor of an entity with the corresponding text segments from a new JSON structure.

    This function extracts text segments from the `new_js` based on the bounding box coordinates of the provided
    `entity`. It constructs a new entity that includes a text anchor and the mention text derived from the text segments.

    Args:
        entity (dict): The original entity containing the bounding box and page anchor information.
        js (dict): The original JSON structure containing page and token information.
        new_js (dict): The new JSON structure containing text and token information to extract text segments from.
        offset (float): An offset to be applied to the bounding box coordinates for expanding the search area.

    Returns:
        dict: A new entity with updated text anchor and mention text, including the corresponding page references.
    """

    if "pageAnchor" not in entity.keys():
        # print(f"We're skipping the {entity['type']} because there's no PageAnchor.")
        return None

    new_entity = {}
    text_anchor = {}
    textAnchorList = []
    x_min = get(entity, "x_min")
    x_max = get(entity, "x_max")
    y_min = get(entity, "y_min")
    y_max = get(entity, "y_max")
    page = 0
    if "page" in entity["pageAnchor"]["pageRefs"][0].keys():
        page = int(entity["pageAnchor"]["pageRefs"][0]["page"])
    textSegmentList = find_textSegment_list(
        x_min - offset, y_min - offset, x_max + offset, y_max + offset, new_js, page
    )
    for j in textSegmentList:
        if "start_index" not in j.keys():
            j["start_index"] = str(0)
    textSegmentList = sorted(textSegmentList, key=lambda x: int(x["start_index"]))
    text_anchor["text_segments"] = textSegmentList
    mentionText = ""
    listOfIndex = []
    for j in textSegmentList:
        mentionText += new_js["text"][int(j["start_index"]) : int(j["end_index"])]
    text_anchor["content"] = mentionText
    new_entity["text_anchor"] = text_anchor
    new_entity["mention_text"] = mentionText
    temp_page_anchor = {}

    list_of_page_refs = []
    for i in entity["pageAnchor"]["pageRefs"]:
        temp = {}
        temp2 = {}
        temp3 = []
        for j in i["boundingPoly"]["normalizedVertices"]:
            temp3.append(j)
        temp2["normalized_vertices"] = temp3
        temp["bounding_poly"] = temp2
        temp["layout_type"] = i.get("layoutType", "LAYOUT_TYPE_UNSPECIFIED")
        temp["page"] = str(page)
        list_of_page_refs.append(temp)
    temp_page_anchor["page_refs"] = list_of_page_refs
    new_entity["page_anchor"] = temp_page_anchor
    new_entity["type"] = entity["type"]
    return new_entity


def make_parent_from_child_entities(temp_child: list, new_js: dict) -> dict:
    """
    Creates a parent entity from a list of child entities by combining their text anchors and bounding boxes.

    This function checks the number of child entities. If there is one child, it returns that child directly.
    If there are two or more children, it combines them into a single parent entity, merging their text segments,
    mention text, and bounding box coordinates.

    Args:
        temp_child (list): A list of child entity dictionaries to be combined.
        new_js (dict): The new JSON structure containing text data used for extracting text segments.

    Returns:
        dict: A parent entity that includes the merged text anchor, mention text, and bounding box.

    Example:
        parent_entity = make_parent_from_child_entities(child_entities, new_json)
    """

    def combine_two_entities(entity1: dict, entity2: dict, js: dict) -> dict:
        """Combines two entities into one by merging their text anchors and bounding boxes."""
        new_entity = {}
        new_entity["type"] = entity1[
            "type"
        ]  # It's a temporary placeholder for the parent entity type, which it'll change later in the code.
        text_anchor = {}
        # print("Entity1 : "+entity1['mentionText'])
        # print("Entity2 : "+entity2['mentionText'])
        textAnchorList = []

        entity1["text_anchor"]["text_segments"] = sorted(
            entity1["text_anchor"]["text_segments"], key=lambda x: int(x["start_index"])
        )
        entity2["text_anchor"]["text_segments"] = sorted(
            entity2["text_anchor"]["text_segments"], key=lambda x: int(x["start_index"])
        )
        for j in entity1["text_anchor"]["text_segments"]:
            textAnchorList.append(j)
            # print(js['text'][int(j['startIndex']):int(j['endIndex'])])
        for j in entity2["text_anchor"]["text_segments"]:
            textAnchorList.append(j)
        textAnchorList = sorted(textAnchorList, key=lambda x: int(x["start_index"]))
        mentionText = ""
        for j in textAnchorList:
            mentionText += js["text"][int(j["start_index"]) : int(j["end_index"])]
        new_entity["mention_text"] = mentionText
        text_anchor["content"] = mentionText
        temp_text_anchor_list = []
        for i in range(len(entity1["text_anchor"]["text_segments"])):
            temp_text_anchor_list.append(entity1["text_anchor"]["text_segments"][i])
        for i in range(len(entity2["text_anchor"]["text_segments"])):
            temp_text_anchor_list.append(entity2["text_anchor"]["text_segments"][i])
        text_anchor["text_segments"] = temp_text_anchor_list
        new_entity["text_anchor"] = text_anchor
        min_x = min(get(entity1, "x_min"), get(entity2, "x_min"))
        min_y = min(get(entity1, "y_min"), get(entity2, "y_min"))
        max_x = max(get(entity1, "x_max"), get(entity2, "x_max"))
        max_y = max(get(entity1, "y_max"), get(entity2, "y_max"))
        A = {"x": min_x, "y": min_y}
        B = {"x": max_x, "y": min_y}
        C = {"x": max_x, "y": max_y}
        D = {"x": min_x, "y": max_y}
        new_entity["page_anchor"] = entity1["page_anchor"]
        new_entity["page_anchor"]["page_refs"][0]["bounding_poly"][
            "normalized_vertices"
        ] = [A, B, C, D]
        return new_entity

    if len(temp_child) == 1:
        return temp_child[0]
    if len(temp_child) == 2:
        parent_entity = combine_two_entities(temp_child[0], temp_child[1], new_js)
        return parent_entity
    parent_entity = combine_two_entities(temp_child[0], temp_child[1], new_js)
    for i in range(2, len(temp_child)):
        parent_entity = combine_two_entities(parent_entity, temp_child[i], new_js)
    return parent_entity


def list_documents(
    project_id: str,
    location: str,
    processor: str,
    page_size: int = 100,
    page_token: str = "",
) -> documentai.types.ListDocumentsResponse:
    """
    Lists documents in a specified Document AI processor.

    This function retrieves a list of documents from the specified Document AI processor dataset.
    It supports pagination through the `page_size` and `page_token` parameters.

    Args:
        project_id (str): The ID of the Google Cloud project.
        location (str): The location of the Document AI processor (e.g., 'us', 'eu').
        processor (str): The ID of the Document AI processor.
        page_size (int, optional): The maximum number of documents to return per page. Default is 100.
        page_token (str, optional): A token for pagination; it indicates the next page of results. Default is an empty string.

    Returns:
        documentai.types.ListDocumentsResponse: The response containing a list of documents and additional metadata.

    Example:
        response = list_documents('my-project-id', 'us', 'my-processor-id')
        for document in response.documents:
            print(document.name)
    """
    client = documentai.DocumentServiceClient()
    dataset = (
        f"projects/{project_id}/locations/{location}/processors/{processor}/dataset"
    )
    request = documentai.types.ListDocumentsRequest(
        dataset=dataset,
        page_token=page_token,
        page_size=page_size,
        return_total_size=True,
    )
    operation = client.list_documents(request)
    return operation


def get_document(
    project_id: str, location: str, processor: str, doc_id: str
) -> documentai.types.Document:
    """
    Retrieves a specific document from a Document AI processor by its document ID.

    This function fetches the details of a document stored in the specified Document AI processor
    dataset using the document's unique ID.

    Args:
        project_id (str): The ID of the Google Cloud project.
        location (str): The location of the Document AI processor (e.g., 'us', 'eu').
        processor (str): The ID of the Document AI processor.
        doc_id (str): The unique identifier of the document to retrieve.

    Returns:
        documentai.types.Document: The document object containing the requested document's details.

    Example:
        document = get_document('my-project-id', 'us', 'my-processor-id', 'my-document-id')
        print(document.name)
    """
    client = documentai.DocumentServiceClient()
    dataset = (
        f"projects/{project_id}/locations/{location}/processors/{processor}/dataset"
    )
    request = documentai.types.GetDocumentRequest(dataset=dataset, document_id=doc_id)
    operation = client.get_document(request)
    return operation.document


def get_dataset_schema(
    project_id: str, processor_id: str, location: str
) -> documentai.types.DatasetSchema:
    """
    Retrieves the dataset schema for a specified Document AI processor.

    This function fetches the schema of the dataset associated with a given Document AI processor,
    which describes the structure and organization of the dataset.

    Args:
        project_id (str): The ID of the Google Cloud project.
        processor_id (str): The ID of the Document AI processor.
        location (str): The location of the Document AI processor (e.g., 'us', 'eu').

    Returns:
        documentai.types.DatasetSchema: The schema of the dataset associated with the processor.

    Example:
        schema = get_dataset_schema('my-project-id', 'my-processor-id', 'us')
        print(schema)
    """
    # Create a client
    processor_name = (
        f"projects/{project_id}/locations/{location}/processors/{processor_id}"
    )
    client = documentai.DocumentServiceClient()
    request = documentai.GetDatasetSchemaRequest(
        name=processor_name + "/dataset/datasetSchema"
    )
    # Make the request
    response = client.get_dataset_schema(request=request)

    return response


def upload_dataset_schema(
    schema: documentai.types.DatasetSchema,
) -> documentai.types.DatasetSchema:
    """
    Uploads a new or updated dataset schema to a Document AI processor.

    This function sends the provided dataset schema to the Document AI processor, allowing
    for the schema to be updated or created as necessary.

    Args:
        schema (documentai.types.DatasetSchema): The dataset schema to be uploaded.

    Returns:
        documentai.types.DatasetSchema: The updated dataset schema returned by the service.

    Example:
        from google.cloud import documentai_v1beta3 as documentai

        schema = documentai.DatasetSchema(
            # populate schema fields as necessary
        )
        updated_schema = upload_dataset_schema(schema)
        print(updated_schema)
    """
    client = documentai.DocumentServiceClient()
    request = documentai.UpdateDatasetSchemaRequest(dataset_schema=schema)
    res = client.update_dataset_schema(request=request)
    return res


def import_documents(
    project_id: str, processor_id: str, location: str, gcs_path: str
) -> documentai.types.ImportDocumentsResponse:
    """
    Imports documents from Google Cloud Storage (GCS) into a Document AI processor's dataset.

    This function imports documents into the dataset associated with a specified Document AI
    processor, organizing them into training, testing, and unassigned splits based on the
    provided GCS path.

    Args:
        project_id (str): The ID of the Google Cloud project.
        processor_id (str): The ID of the Document AI processor.
        location (str): The location of the Document AI processor (e.g., 'us', 'eu').
        gcs_path (str): The GCS path prefix where the documents are stored. It should include
                        the base path with trailing slash.

    Returns:
        documentai.types.ImportDocumentsResponse: The response from the import operation.

    Example:
        response = import_documents('my-project-id', 'my-processor-id', 'us', 'gs://my-bucket/documents/')
        print(response)
    """
    client = documentai.DocumentServiceClient()
    dataset = (
        f"projects/{project_id}/locations/{location}/processors/{processor_id}/dataset"
    )
    request = documentai.ImportDocumentsRequest(
        dataset=dataset,
        batch_documents_import_configs=[
            {
                "dataset_split": "DATASET_SPLIT_TRAIN",
                "batch_input_config": {
                    "gcs_prefix": {"gcs_uri_prefix": gcs_path + "train/"}
                },
            },
            {
                "dataset_split": "DATASET_SPLIT_TEST",
                "batch_input_config": {
                    "gcs_prefix": {"gcs_uri_prefix": gcs_path + "test/"}
                },
            },
            {
                "dataset_split": "DATASET_SPLIT_UNASSIGNED",
                "batch_input_config": {
                    "gcs_prefix": {"gcs_uri_prefix": gcs_path + "unassigned/"}
                },
            },
        ],
    )
    response = client.import_documents(request=request)

    return response


def retry_function_with_internal_error_handling(
    func, max_retries=3, wait_time=10, *args, **kwargs
):
    """
    Runs a function with retry logic if a 500 Internal Server Error is encountered.

    Args:
        func (function): The function to be called.
        max_retries (int): Maximum number of retries (default is 3).
        wait_time (int): Time to wait between retries in seconds (default is 2 seconds).
        *args: Positional arguments to pass to the function.
        **kwargs: Keyword arguments to pass to the function.

    Returns:
        Result of the function if it succeeds within the retry attempts.

    Raises:
        Exception: If the function fails after max retries.
    """
    attempt = 0
    while attempt < max_retries:
        try:
            # Call the function with passed arguments
            result = func(*args, **kwargs)
            return result  # If successful, return the result immediately
        except Exception as e:
            if "500" in str(e):  # Check if the error is a 500 error
                attempt += 1
                time.sleep(wait_time)  # Wait for a short period before retrying
            else:
                raise  # If it's not a 500 error, raise the error
    # If the function fails after all retries
    raise Exception(
        f"Function failed after {max_retries} attempts due to repeated 500 errors."
    )

### Main Function

In [None]:
def process_file(document_path):
    storage_client = storage.Client()
    source_bucket = storage_client.bucket(export_dataset_path.split("/")[2])
    try:
        file_name = ("/").join(document_path.split("/")[-2:])
        print(document_path)
        js_json = source_bucket.blob(document_path).download_as_string().decode("utf-8")
        merged_pdf = create_pdf_bytes(js_json)
        js = json.loads(js_json)
        res = retry_function_with_internal_error_handling(
            process_document_sample,
            project_id=project_id,
            location=location,
            processor_id=processor_id,
            file_path=merged_pdf,
            processor_version_id=new_version_id,
        )
        if res.document.entities:
            del res.document.entities
        new_js = documentai.Document.to_dict(res.document)
        updated_entities = []
        for entity in js["entities"]:
            # print(entity)
            temp_child = []
            ent = {}
            if "properties" in entity.keys() and len(entity["properties"]) != 0:
                for child_item in entity["properties"]:
                    ent_ch = update_text_anchors_mention_text(child_item, js, new_js)
                    if ent_ch is not None:
                        temp_child.append(ent_ch)
                ent = make_parent_from_child_entities(copy.deepcopy(temp_child), new_js)
                ent["type"] = entity["type"]
                ent["properties"] = temp_child
            else:
                ent = update_text_anchors_mention_text(entity, js, new_js)
            # pprint(ent)
            if ent is not None:
                updated_entities.append(ent)
        # pprint(updated_entities)
        new_js["entities"] = updated_entities
        d = documentai.Document.from_json(json.dumps(new_js))
        output_bucket_path_prefix = "/".join(updated_ocr_files_path.split("/")[3:])
        output_file_name = f"{output_bucket_path_prefix}{file_name}"
        # print(output_file_name)
        utilities.store_document_as_json(
            json.dumps(new_js), updated_ocr_files_path.split("/")[2], output_file_name
        )
    except Exception as e:
        print(
            ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
            + document_path
            + " was not processed successfully!!!"
        )
        print(e)


def process_files_concurrently(file_list, max_workers=5):
    # The current limit of 5 parallel processes can be increased by adjusting the max_workers parameter
    results = []

    # Use ProcessPoolExecutor for CPU-bound tasks
    with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
        for file in file_list:
            executor.submit(process_file, file)


def main():
    results = list_documents(project_id, location, processor_id)
    document_list = results.document_metadata
    while len(document_list) != results.total_size:
        page_token = results.next_page_token
        results = list_documents(
            project_id, location, processor_id, page_token=page_token
        )
        document_list.extend(results.document_metadata)
    print("Exporting Dataset...")
    for doc in tqdm(document_list):
        doc_id = doc.document_id
        split_type = doc.dataset_type
        if split_type == 3:
            split = "unassigned"
        elif split_type == 2:
            split = "test"
        elif split_type == 1:
            split = "train"
        else:
            split = "unknown"
        file_name = doc.display_name
        res = get_document(project_id, location, processor_id, doc_id)
        exported_path = ("/").join(export_dataset_path.split("/")[3:])
        output_file_name = f"{exported_path}/{split}/{file_name}.json"
        json_data = documentai.Document.to_json(res)
        utilities.store_document_as_json(
            json_data, export_dataset_path.split("/")[2], output_file_name
        )

    print("Exporting Dataset is completed...")
    exported_schema = get_dataset_schema(project_id, processor_id, location)
    exported_schema.name = f"projects/{project_id}/locations/{location}/processors/{new_processor_id}/dataset/datasetSchema"
    import_schema = upload_dataset_schema(exported_schema)

    document_paths = list(utilities.file_names(export_dataset_path)[1].values())
    process_files_concurrently(document_paths)
    print(f"imporing updated OCR documents to {new_processor_id}")
    res = import_documents(
        project_id, new_processor_id, location, updated_ocr_files_path
    )
    print(f"Waiting for {len(document_paths)*1} seconds to import all documents")
    time.sleep(len(document_paths) * 1)
    print("All documents have been imported.")


main()