incubator-tools/ocr_upgradation_tool/ocr_upgradation_tool.ipynb (1,055 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "id": "151cc42c-8f04-4ba6-9bf3-9ed6735c4400", "metadata": {}, "source": [ "# Dataset OCR Upgrade Tool" ] }, { "cell_type": "markdown", "id": "b10775a0-36c7-42ea-98f4-7b16b19683da", "metadata": {}, "source": [ "* Author: docai-incubator@google.com" ] }, { "cell_type": "markdown", "id": "16002c68-6472-410c-8332-c07f8a0b1fd8", "metadata": {}, "source": [ "## Disclaimer\n", "\n", "This tool is not supported by the Google engineering team or product team. It is provided and supported on a best-effort basis by the **DocAI Incubator Team**. No guarantees of performance are implied.\n" ] }, { "cell_type": "markdown", "id": "5e46ea74-5631-48f7-838f-0e2249c63c85", "metadata": {}, "source": [ "## Background" ] }, { "cell_type": "markdown", "id": "26192f34-df7f-41b6-aa0e-46601cfbf3d8", "metadata": {}, "source": [ "[Effective April 9, 2025](https://cloud.google.com/document-ai/docs/release-notes#September_26_2024), the following Custom Extractor versions will no longer be accessible:\n", "\n", "* pretrained-foundation-model-v1.0-2023-08-22\n", "* pretrained-foundation-model-v1.1-2024-03-12\n", "\n", "You will need to migrate to a later version to avoid any service disruptions, such as : \n", "\n", "* pretrained-foundation-model-v1.2-2024-05-10 and \n", "* pretrained-foundation-model-v1.3-2024-08-31 \n", "\n", "for improved quality from the latest proprietary vision models and foundation models.\n", "\n", "**NOTE :** [Effective April 9 2025](https://cloud.google.com/document-ai/docs/release-notes#September_26_2024) OCR versions pretrained-ocr-v1.0-2020-09-23 and pretrained-ocr-v1.1-2022-09-12 will be discontinued in the US and EU regions. " ] }, { "cell_type": "markdown", "id": "f5abd429-ec04-4b79-bb7c-41bf9e91af82", "metadata": {}, "source": [ "## Need for Dataset Label OCR Upgrade\n", "\n", "Simply **using** the new Foundation Models (FM’s) : \n", "\n", "* **v1.2 (pretrained-foundation-model-v1.2-2024-05-10)** and \n", "* **v1.3 (pretrained-foundation-model-v1.3-2024-08-31)**\n", "\n", "For few-shot (5 labeled documents) or fine-tuning (1 or more labeled documents for both test and training) with an existing Custom Extractor processor (created before Sept. 29th, 2024) **is not recommended**. These new FM’s use OCR 2.1, where previous Foundation Models used OCR 2.0. The recommendation is to **UPGRADE the OCR version that the dataset is labeled with**.\n", "For a graphical explanation of the differences between the versions and why this is necessary, please refer to this graph:\n", "\n", "<img src=\"./Images/first_image.png\" width=800 height=800 alt=\"Graphical Explanation\">\n", "\n", "The reason for this is that there are TWO OCR engines involved with any processor:\n", "\n", "* one for the **dataset** (used for autolabeling and test document processing) and\n", "* one associated with the **specific Foundation Model** for prediction.\n", "\n", "These can become mismatched, which is the case with previously existing processors and these new Foundation Models.\n", "\n", "All custom processors before Sept. 29th, 2024 used **OCR 2.0 for the dataset** (regardless of the prediction OCR engine associated with the specific Foundation Model being used for prediction). Those created after Sept. 29th, 2024 use the **new OCR 2.1 for the dataset**. When there is a mismatch of the training dataset OCR and the prediction OCR, accuracy is often not ideal. Therefore the strong recommendation is that in order to use the new FM versions (listed above) any labeled documents in a dataset created before Sept 29th, 2024 **should be re-labeled** in order to have their OCR version match the FM version. Basically, the dataset documents need to be relabeled. Fortunately, **this tool makes that easy**." ] }, { "cell_type": "markdown", "id": "fc2f0355-7c7e-4072-a1c3-40677fc2bc24", "metadata": {}, "source": [ "## Objective\n", "\n", "This document is intended as a guide to help with the migration of datasets between older processors versions (before Sept. 29th, 2024) and new (as of Sept. 29th, 2024) created processors.\n", "\n", "To convert the OCR in a labeled dataset from an older version to an updated version in 3 steps: \n", "\n", "1. **Exporting** datasets from a processor\n", "2. **Reprocessing** exported OCR-labeled JSON data with an updated OCR engine which is used in the new processor.\n", "3. **Importing** the enhanced OCR JSON files into a new processor, while also migrating the **schema** from the original processor.\n", "\n", "<img src=\"./Images/second_image.png\" width=800 height=800 alt=\"OCR Upgrades\">" ] }, { "cell_type": "markdown", "id": "0e650899-dee4-45f7-a5aa-3504140428ed", "metadata": {}, "source": [ "## Prerequisites\n", "\n", "* Vertex AI Notebook.\n", "* Storage Bucket for storing exported json files and output JSON files.\n", "* Permission For Google DocAI Processors, Storage and Vertex AI Notebook.\n", "* A new DocAI Processor to receive the upgraded OCR Labeled documents in the dataset" ] }, { "cell_type": "markdown", "id": "22712441-3128-47e5-8331-43f87e9c7696", "metadata": {}, "source": [ "## Step by step procedure\n", "\n", "The procedure is basically to: \n", "* Load the following library and code into your Vertex AI notebook, \n", "* Edit to provide the required input parameters, then\n", "* Run the code\n", "\n", "After running this code, the provided processor will get the updated JSONs with the new OCR data, along with the schema." ] }, { "cell_type": "code", "execution_count": null, "id": "d8b481bc-e2cd-4f8d-9148-e8ed39ebaf19", "metadata": { "tags": [] }, "outputs": [], "source": [ "!wget https://raw.githubusercontent.com/GoogleCloudPlatform/document-ai-samples/main/incubator-tools/best-practices/utilities/utilities.py" ] }, { "cell_type": "markdown", "id": "7442803c-ecc0-4961-b82f-a50fff0afd8d", "metadata": {}, "source": [ "### 1. Install the required libraries \n", "\n", "Use the pip command to install the required libraries before executing the code." ] }, { "cell_type": "code", "execution_count": null, "id": "f10671ca-c77c-4360-bce8-3ed98bd6e84d", "metadata": { "tags": [] }, "outputs": [], "source": [ "!pip install google-cloud-documentai google-cloud-storage PyPDF2" ] }, { "cell_type": "markdown", "id": "26e1af2d-e7ae-4be4-874e-8fb35a6163d4", "metadata": {}, "source": [ "### 2. Import the libraries\n", "Import the necessary libraries utilized in the code. If you encounter any import errors, use the pip command to install the missing libraries.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "ce265651-a3bb-4d99-8f70-8176d49bcf4c", "metadata": { "tags": [] }, "outputs": [], "source": [ "from google.cloud import documentai_v1beta3 as documentai\n", "from google.api_core.client_options import ClientOptions\n", "import json\n", "from pathlib import Path\n", "from tqdm import tqdm\n", "from google.cloud import storage\n", "from typing import (\n", " Container,\n", " Iterable,\n", " Iterator,\n", " List,\n", " Mapping,\n", " Optional,\n", " Sequence,\n", " Tuple,\n", " Union,\n", ")\n", "from io import BytesIO\n", "from pprint import pprint\n", "import copy\n", "from PIL import Image\n", "from PyPDF2 import PdfFileReader\n", "import io\n", "from tqdm.notebook import tqdm\n", "import time\n", "import concurrent.futures\n", "import utilities" ] }, { "cell_type": "markdown", "id": "fa0f9d3d-8c4c-431d-b4ea-35828b1c6122", "metadata": {}, "source": [ "### 3. Set up the required inputs" ] }, { "cell_type": "code", "execution_count": null, "id": "71cb352b-b196-4112-9789-cc798a98b7ca", "metadata": { "tags": [] }, "outputs": [], "source": [ "project_id = \"xxxxxxxxx\"\n", "location = \"xxxxx\"\n", "processor_id = \"xxxxxxxxxxx\"\n", "export_dataset_path = \"gs://bucket/path/to/export_dataset/\"\n", "updated_ocr_files_path = \"gs://bucket/path/to/updated_dataset/\"\n", "location = \"xxxxx\"\n", "new_processor_id = \"xxxxxxxx\"\n", "new_version_id = \"pretrained-foundation-model-v1.3-2024-08-31\"\n", "offset = 0.005 # To Expand the Existing bounding box in order to get all the tokens corrosponding to the entities. Can adjust with optimal value." ] }, { "cell_type": "markdown", "id": "f383c955-e8cb-4922-a512-08a9e3f5df1c", "metadata": {}, "source": [ "* `project_id`: Provide project ID \n", "* `location`: Provide Processor Location (“us” or “eu”)\n", "* `processor_id`: Provide GCP DocumentAI Processor ID where you have labeled jsons with old OCR version\n", "* `export_dataset_path`: Provide GCS path to store Exported JSON files\n", "* `updated_ocr_files_path`: Provide GCS path to store OCR updated JSON files\n", "* `new_processor_id`: Provide GCP DocumentAI new Processor ID to import updated documents \n", "* `new_version_id`: Provide the version id of the processor to use the updated OCR" ] }, { "cell_type": "markdown", "id": "699e74b9-562e-42fd-a420-219e395fe96d", "metadata": {}, "source": [ "### 4. Run the code" ] }, { "cell_type": "code", "execution_count": null, "id": "8aeb215e-1f52-4506-90c8-6df981153401", "metadata": { "tags": [] }, "outputs": [], "source": [ "def create_pdf_bytes(json: str) -> bytes:\n", " \"\"\"\n", " Creates PDF bytes from image content in a JSON document (typically ground truth data),\n", " which is used for further processing of files. This function decodes image data and\n", " combines them into a single PDF.\n", "\n", " Args:\n", " json (str): The JSON string representing the ground truth data, typically retrieved\n", " from Google Cloud's Document AI output or other sources. The JSON should contain image data in\n", " its content field.\n", "\n", " Returns:\n", " bytes: A byte representation of the generated PDF containing all images.\n", "\n", " Raises:\n", " ValueError: If no images are found in the input JSON or an invalid image format is encountered.\n", "\n", " Example:\n", " json_str = '{\"pages\": [{\"image\": {\"content\": \"<image_bytes_in_base64>\"}}]}'\n", " pdf_bytes = create_pdf_bytes(json_str)\n", " \"\"\"\n", " from google.cloud import documentai_v1beta3\n", "\n", " def decode_image(image_bytes: bytes) -> Image.Image:\n", " \"\"\"Decodes image bytes into a PIL Image object.\"\"\"\n", " with io.BytesIO(image_bytes) as image_file:\n", " image = Image.open(image_file)\n", " image.load()\n", " return image\n", "\n", " def create_pdf_from_images(images: Sequence[Image.Image]) -> bytes:\n", " \"\"\"Creates a PDF from a sequence of images.\n", "\n", " Args:\n", " images: A sequence of images to be included in the PDF.\n", "\n", " Returns:\n", " bytes: The PDF bytes generated from the images.\n", "\n", " Raises:\n", " ValueError: If no images are provided.\n", " \"\"\"\n", " if not images:\n", " raise ValueError(\"At least one image is required to create a PDF\")\n", "\n", " # PIL PDF saver does not support RGBA images\n", " images = [\n", " image.convert(\"RGB\") if image.mode == \"RGBA\" else image for image in images\n", " ]\n", "\n", " with io.BytesIO() as pdf_file:\n", " images[0].save(\n", " pdf_file, save_all=True, append_images=images[1:], format=\"PDF\"\n", " )\n", " return pdf_file.getvalue()\n", "\n", " d = documentai_v1beta3.Document\n", " document = d.from_json(json)\n", " synthesized_images = []\n", " for i in range(len(document.pages)):\n", " synthesized_images.append(decode_image(document.pages[i].image.content))\n", " pdf_bytes = create_pdf_from_images(synthesized_images)\n", "\n", " return pdf_bytes\n", "\n", "\n", "def process_document_sample(\n", " project_id: str,\n", " location: str,\n", " processor_id: str,\n", " file_path: str,\n", " processor_version_id: Optional[str] = None,\n", " mime_type: Optional[str] = \"application/pdf\",\n", " field_mask: Optional[str] = None,\n", ") -> documentai.ProcessResponse:\n", " \"\"\"\n", " Processes a document using a specified Document AI processor in Google Cloud and\n", " returns the processed result. This function reads a file, processes it through a Document AI processor,\n", " and retrieves the result which may include text extraction, form parsing, etc.\n", "\n", " Args:\n", " project_id (str): The Google Cloud project ID where the Document AI processor is located.\n", " location (str): The location/region of the Document AI processor (e.g., 'us', 'eu').\n", " processor_id (str): The ID of the Document AI processor to use for processing.\n", " file_path (str): The local path or in-memory string content of the document to be processed.\n", " processor_version_id (Optional[str], optional): The specific processor version to use, if any.\n", " If not provided, the default processor version will be used. Defaults to None.\n", " mime_type (Optional[str], optional): The MIME type of the document. Defaults to 'application/pdf'.\n", " field_mask (Optional[str], optional): Field mask specifying the parts of the document to process.\n", " If not provided, the entire document will be processed. Defaults to None.\n", "\n", " Returns:\n", " documentai.ProcessResponse: The response object containing the processed document data from the processor.\n", " \"\"\"\n", " # You must set the `api_endpoint` if you use a location other than \"us\".\n", " opts = ClientOptions(api_endpoint=f\"{location}-documentai.googleapis.com\")\n", " client = documentai.DocumentProcessorServiceClient(client_options=opts)\n", " if processor_version_id:\n", " name = client.processor_version_path(\n", " project_id, location, processor_id, processor_version_id\n", " )\n", " else:\n", " name = client.processor_path(project_id, location, processor_id)\n", " # Read the file into memory\n", " image_content = file_path\n", " # Load binary data\n", " raw_document = documentai.RawDocument(content=image_content, mime_type=mime_type)\n", " request = documentai.ProcessRequest(\n", " name=name,\n", " raw_document=raw_document,\n", " field_mask=field_mask,\n", " # process_options=process_options,\n", " )\n", " result = client.process_document(request=request)\n", " # Read the text recognition output from the processor\n", " return result\n", "\n", "\n", "def get(entity: dict, arg: str) -> float:\n", " \"\"\"\n", " Extracts the specified bounding box coordinate (x_min, y_min, x_max, y_max) from the entity object.\n", " This function calculates the minimum or maximum x and y coordinates of the bounding box around the\n", " entity based on normalized vertices.\n", "\n", " Args:\n", " entity (dict): The entity dictionary that contains the bounding box information.\n", " It may have different key formats (`pageAnchor` or `page_anchor`) depending on the format.\n", " arg (str): The coordinate to extract. Valid values are 'x_min', 'y_min', 'x_max', 'y_max'.\n", "\n", " Returns:\n", " float: The value of the requested coordinate (minimum or maximum of x or y).\n", "\n", " Raises:\n", " ValueError: If an invalid argument is passed for `arg`.\n", "\n", " Example:\n", " entity = {\n", " \"pageAnchor\": {\n", " \"pageRefs\": [{\n", " \"boundingPoly\": {\n", " \"normalizedVertices\": [{\"x\": 0.1, \"y\": 0.2}, {\"x\": 0.4, \"y\": 0.5}]\n", " }\n", " }]\n", " }\n", " }\n", " x_min = get(entity, 'x_min')\n", " y_max = get(entity, 'y_max')\n", " \"\"\"\n", " x_list = []\n", " y_list = []\n", " if \"pageAnchor\" in entity.keys():\n", " for i in entity[\"pageAnchor\"][\"pageRefs\"]:\n", " for j in i[\"boundingPoly\"][\"normalizedVertices\"]:\n", " x_list.append(j[\"x\"])\n", " y_list.append(j[\"y\"])\n", "\n", " if arg == \"x_min\":\n", " return min(x_list)\n", " if arg == \"y_min\":\n", " return min(y_list)\n", " if arg == \"x_max\":\n", " return max(x_list)\n", " if arg == \"y_max\":\n", " return max(y_list)\n", " else:\n", " for i in entity[\"page_anchor\"][\"page_refs\"]:\n", " for j in i[\"bounding_poly\"][\"normalized_vertices\"]:\n", " x_list.append(j[\"x\"])\n", " y_list.append(j[\"y\"])\n", "\n", " if arg == \"x_min\":\n", " return min(x_list)\n", " if arg == \"y_min\":\n", " return min(y_list)\n", " if arg == \"x_max\":\n", " return max(x_list)\n", " if arg == \"y_max\":\n", " return max(y_list)\n", "\n", "\n", "def find_textSegment_list(\n", " x_min: float, y_min: float, x_max: float, y_max: float, js: dict, page: int\n", ") -> List[dict]:\n", " \"\"\"\n", " Finds and returns a list of text segments within a specified bounding box (defined by `x_min`, `y_min`,\n", " `x_max`, `y_max`) from the tokens on a given page of a Document AI JSON structure.\n", "\n", " Args:\n", " x_min (float): The minimum x-coordinate of the bounding box.\n", " y_min (float): The minimum y-coordinate of the bounding box.\n", " x_max (float): The maximum x-coordinate of the bounding box.\n", " y_max (float): The maximum y-coordinate of the bounding box.\n", " js (dict): The JSON data (in Document AI format) containing page and token information.\n", " page (int): The page number from which to extract the text segments.\n", "\n", " Returns:\n", " List[dict]: A list of text segments (from `text_anchor`) that fall within the specified bounding box.\n", "\n", " Example:\n", " text_segments = find_textSegment_list(\n", " x_min=0.1, y_min=0.2, x_max=0.4, y_max=0.5,\n", " js=document_json, page=0\n", " )\n", " \"\"\"\n", " textSegments_list = []\n", " for token in js[\"pages\"][page][\"tokens\"]:\n", " token_xMin = get_token(token, \"x_min\")\n", " token_xMax = get_token(token, \"x_max\")\n", " token_yMin = get_token(token, \"y_min\")\n", " token_yMax = get_token(token, \"y_max\")\n", " if (\n", " token_xMin >= x_min\n", " and token_xMax <= x_max\n", " and token_yMin >= y_min\n", " and token_yMax < y_max\n", " ):\n", " textSegments_list.extend(token[\"layout\"][\"text_anchor\"][\"text_segments\"])\n", " return textSegments_list\n", "\n", "\n", "def get_token(token: dict, param: str) -> float:\n", " \"\"\"\n", " Retrieves the specified bounding box coordinate (x_min, y_min, x_max, y_max) from a token's layout information.\n", "\n", " This function extracts the list of normalized vertices (x, y coordinates) from the bounding box of the token\n", " and returns the minimum or maximum value based on the requested parameter.\n", "\n", " Args:\n", " token (dict): The token dictionary that contains the bounding box (in the 'layout' field).\n", " param (str): The coordinate to extract. Valid values are 'x_min', 'x_max', 'y_min', 'y_max'.\n", "\n", " Returns:\n", " float: The value of the requested coordinate (minimum or maximum of x or y).\n", " \"\"\"\n", " x_list = []\n", " y_list = []\n", " for j in token[\"layout\"][\"bounding_poly\"][\"normalized_vertices\"]:\n", " x_list.append(j[\"x\"])\n", " y_list.append(j[\"y\"])\n", " if param == \"x_min\":\n", " return min(x_list)\n", " if param == \"x_max\":\n", " return max(x_list)\n", " if param == \"y_min\":\n", " return min(y_list)\n", " if param == \"y_max\":\n", " return max(y_list)\n", "\n", "\n", "def update_text_anchors_mention_text(entity: dict, js: dict, new_js: dict) -> dict:\n", " \"\"\"\n", " Updates the text anchor of an entity with the corresponding text segments from a new JSON structure.\n", "\n", " This function extracts text segments from the `new_js` based on the bounding box coordinates of the provided\n", " `entity`. It constructs a new entity that includes a text anchor and the mention text derived from the text segments.\n", "\n", " Args:\n", " entity (dict): The original entity containing the bounding box and page anchor information.\n", " js (dict): The original JSON structure containing page and token information.\n", " new_js (dict): The new JSON structure containing text and token information to extract text segments from.\n", " offset (float): An offset to be applied to the bounding box coordinates for expanding the search area.\n", "\n", " Returns:\n", " dict: A new entity with updated text anchor and mention text, including the corresponding page references.\n", " \"\"\"\n", "\n", " if \"pageAnchor\" not in entity.keys():\n", " # print(f\"We're skipping the {entity['type']} because there's no PageAnchor.\")\n", " return None\n", "\n", " new_entity = {}\n", " text_anchor = {}\n", " textAnchorList = []\n", " x_min = get(entity, \"x_min\")\n", " x_max = get(entity, \"x_max\")\n", " y_min = get(entity, \"y_min\")\n", " y_max = get(entity, \"y_max\")\n", " page = 0\n", " if \"page\" in entity[\"pageAnchor\"][\"pageRefs\"][0].keys():\n", " page = int(entity[\"pageAnchor\"][\"pageRefs\"][0][\"page\"])\n", " textSegmentList = find_textSegment_list(\n", " x_min - offset, y_min - offset, x_max + offset, y_max + offset, new_js, page\n", " )\n", " for j in textSegmentList:\n", " if \"start_index\" not in j.keys():\n", " j[\"start_index\"] = str(0)\n", " textSegmentList = sorted(textSegmentList, key=lambda x: int(x[\"start_index\"]))\n", " text_anchor[\"text_segments\"] = textSegmentList\n", " mentionText = \"\"\n", " listOfIndex = []\n", " for j in textSegmentList:\n", " mentionText += new_js[\"text\"][int(j[\"start_index\"]) : int(j[\"end_index\"])]\n", " text_anchor[\"content\"] = mentionText\n", " new_entity[\"text_anchor\"] = text_anchor\n", " new_entity[\"mention_text\"] = mentionText\n", " temp_page_anchor = {}\n", "\n", " list_of_page_refs = []\n", " for i in entity[\"pageAnchor\"][\"pageRefs\"]:\n", " temp = {}\n", " temp2 = {}\n", " temp3 = []\n", " for j in i[\"boundingPoly\"][\"normalizedVertices\"]:\n", " temp3.append(j)\n", " temp2[\"normalized_vertices\"] = temp3\n", " temp[\"bounding_poly\"] = temp2\n", " temp[\"layout_type\"] = i.get(\"layoutType\", \"LAYOUT_TYPE_UNSPECIFIED\")\n", " temp[\"page\"] = str(page)\n", " list_of_page_refs.append(temp)\n", " temp_page_anchor[\"page_refs\"] = list_of_page_refs\n", " new_entity[\"page_anchor\"] = temp_page_anchor\n", " new_entity[\"type\"] = entity[\"type\"]\n", " return new_entity\n", "\n", "\n", "def make_parent_from_child_entities(temp_child: list, new_js: dict) -> dict:\n", " \"\"\"\n", " Creates a parent entity from a list of child entities by combining their text anchors and bounding boxes.\n", "\n", " This function checks the number of child entities. If there is one child, it returns that child directly.\n", " If there are two or more children, it combines them into a single parent entity, merging their text segments,\n", " mention text, and bounding box coordinates.\n", "\n", " Args:\n", " temp_child (list): A list of child entity dictionaries to be combined.\n", " new_js (dict): The new JSON structure containing text data used for extracting text segments.\n", "\n", " Returns:\n", " dict: A parent entity that includes the merged text anchor, mention text, and bounding box.\n", "\n", " Example:\n", " parent_entity = make_parent_from_child_entities(child_entities, new_json)\n", " \"\"\"\n", "\n", " def combine_two_entities(entity1: dict, entity2: dict, js: dict) -> dict:\n", " \"\"\"Combines two entities into one by merging their text anchors and bounding boxes.\"\"\"\n", " new_entity = {}\n", " new_entity[\"type\"] = entity1[\n", " \"type\"\n", " ] # It's a temporary placeholder for the parent entity type, which it'll change later in the code.\n", " text_anchor = {}\n", " # print(\"Entity1 : \"+entity1['mentionText'])\n", " # print(\"Entity2 : \"+entity2['mentionText'])\n", " textAnchorList = []\n", "\n", " entity1[\"text_anchor\"][\"text_segments\"] = sorted(\n", " entity1[\"text_anchor\"][\"text_segments\"], key=lambda x: int(x[\"start_index\"])\n", " )\n", " entity2[\"text_anchor\"][\"text_segments\"] = sorted(\n", " entity2[\"text_anchor\"][\"text_segments\"], key=lambda x: int(x[\"start_index\"])\n", " )\n", " for j in entity1[\"text_anchor\"][\"text_segments\"]:\n", " textAnchorList.append(j)\n", " # print(js['text'][int(j['startIndex']):int(j['endIndex'])])\n", " for j in entity2[\"text_anchor\"][\"text_segments\"]:\n", " textAnchorList.append(j)\n", " textAnchorList = sorted(textAnchorList, key=lambda x: int(x[\"start_index\"]))\n", " mentionText = \"\"\n", " for j in textAnchorList:\n", " mentionText += js[\"text\"][int(j[\"start_index\"]) : int(j[\"end_index\"])]\n", " new_entity[\"mention_text\"] = mentionText\n", " text_anchor[\"content\"] = mentionText\n", " temp_text_anchor_list = []\n", " for i in range(len(entity1[\"text_anchor\"][\"text_segments\"])):\n", " temp_text_anchor_list.append(entity1[\"text_anchor\"][\"text_segments\"][i])\n", " for i in range(len(entity2[\"text_anchor\"][\"text_segments\"])):\n", " temp_text_anchor_list.append(entity2[\"text_anchor\"][\"text_segments\"][i])\n", " text_anchor[\"text_segments\"] = temp_text_anchor_list\n", " new_entity[\"text_anchor\"] = text_anchor\n", " min_x = min(get(entity1, \"x_min\"), get(entity2, \"x_min\"))\n", " min_y = min(get(entity1, \"y_min\"), get(entity2, \"y_min\"))\n", " max_x = max(get(entity1, \"x_max\"), get(entity2, \"x_max\"))\n", " max_y = max(get(entity1, \"y_max\"), get(entity2, \"y_max\"))\n", " A = {\"x\": min_x, \"y\": min_y}\n", " B = {\"x\": max_x, \"y\": min_y}\n", " C = {\"x\": max_x, \"y\": max_y}\n", " D = {\"x\": min_x, \"y\": max_y}\n", " new_entity[\"page_anchor\"] = entity1[\"page_anchor\"]\n", " new_entity[\"page_anchor\"][\"page_refs\"][0][\"bounding_poly\"][\n", " \"normalized_vertices\"\n", " ] = [A, B, C, D]\n", " return new_entity\n", "\n", " if len(temp_child) == 1:\n", " return temp_child[0]\n", " if len(temp_child) == 2:\n", " parent_entity = combine_two_entities(temp_child[0], temp_child[1], new_js)\n", " return parent_entity\n", " parent_entity = combine_two_entities(temp_child[0], temp_child[1], new_js)\n", " for i in range(2, len(temp_child)):\n", " parent_entity = combine_two_entities(parent_entity, temp_child[i], new_js)\n", " return parent_entity\n", "\n", "\n", "def list_documents(\n", " project_id: str,\n", " location: str,\n", " processor: str,\n", " page_size: int = 100,\n", " page_token: str = \"\",\n", ") -> documentai.types.ListDocumentsResponse:\n", " \"\"\"\n", " Lists documents in a specified Document AI processor.\n", "\n", " This function retrieves a list of documents from the specified Document AI processor dataset.\n", " It supports pagination through the `page_size` and `page_token` parameters.\n", "\n", " Args:\n", " project_id (str): The ID of the Google Cloud project.\n", " location (str): The location of the Document AI processor (e.g., 'us', 'eu').\n", " processor (str): The ID of the Document AI processor.\n", " page_size (int, optional): The maximum number of documents to return per page. Default is 100.\n", " page_token (str, optional): A token for pagination; it indicates the next page of results. Default is an empty string.\n", "\n", " Returns:\n", " documentai.types.ListDocumentsResponse: The response containing a list of documents and additional metadata.\n", "\n", " Example:\n", " response = list_documents('my-project-id', 'us', 'my-processor-id')\n", " for document in response.documents:\n", " print(document.name)\n", " \"\"\"\n", " client = documentai.DocumentServiceClient()\n", " dataset = (\n", " f\"projects/{project_id}/locations/{location}/processors/{processor}/dataset\"\n", " )\n", " request = documentai.types.ListDocumentsRequest(\n", " dataset=dataset,\n", " page_token=page_token,\n", " page_size=page_size,\n", " return_total_size=True,\n", " )\n", " operation = client.list_documents(request)\n", " return operation\n", "\n", "\n", "def get_document(\n", " project_id: str, location: str, processor: str, doc_id: str\n", ") -> documentai.types.Document:\n", " \"\"\"\n", " Retrieves a specific document from a Document AI processor by its document ID.\n", "\n", " This function fetches the details of a document stored in the specified Document AI processor\n", " dataset using the document's unique ID.\n", "\n", " Args:\n", " project_id (str): The ID of the Google Cloud project.\n", " location (str): The location of the Document AI processor (e.g., 'us', 'eu').\n", " processor (str): The ID of the Document AI processor.\n", " doc_id (str): The unique identifier of the document to retrieve.\n", "\n", " Returns:\n", " documentai.types.Document: The document object containing the requested document's details.\n", "\n", " Example:\n", " document = get_document('my-project-id', 'us', 'my-processor-id', 'my-document-id')\n", " print(document.name)\n", " \"\"\"\n", " client = documentai.DocumentServiceClient()\n", " dataset = (\n", " f\"projects/{project_id}/locations/{location}/processors/{processor}/dataset\"\n", " )\n", " request = documentai.types.GetDocumentRequest(dataset=dataset, document_id=doc_id)\n", " operation = client.get_document(request)\n", " return operation.document\n", "\n", "\n", "def get_dataset_schema(\n", " project_id: str, processor_id: str, location: str\n", ") -> documentai.types.DatasetSchema:\n", " \"\"\"\n", " Retrieves the dataset schema for a specified Document AI processor.\n", "\n", " This function fetches the schema of the dataset associated with a given Document AI processor,\n", " which describes the structure and organization of the dataset.\n", "\n", " Args:\n", " project_id (str): The ID of the Google Cloud project.\n", " processor_id (str): The ID of the Document AI processor.\n", " location (str): The location of the Document AI processor (e.g., 'us', 'eu').\n", "\n", " Returns:\n", " documentai.types.DatasetSchema: The schema of the dataset associated with the processor.\n", "\n", " Example:\n", " schema = get_dataset_schema('my-project-id', 'my-processor-id', 'us')\n", " print(schema)\n", " \"\"\"\n", " # Create a client\n", " processor_name = (\n", " f\"projects/{project_id}/locations/{location}/processors/{processor_id}\"\n", " )\n", " client = documentai.DocumentServiceClient()\n", " request = documentai.GetDatasetSchemaRequest(\n", " name=processor_name + \"/dataset/datasetSchema\"\n", " )\n", " # Make the request\n", " response = client.get_dataset_schema(request=request)\n", "\n", " return response\n", "\n", "\n", "def upload_dataset_schema(\n", " schema: documentai.types.DatasetSchema,\n", ") -> documentai.types.DatasetSchema:\n", " \"\"\"\n", " Uploads a new or updated dataset schema to a Document AI processor.\n", "\n", " This function sends the provided dataset schema to the Document AI processor, allowing\n", " for the schema to be updated or created as necessary.\n", "\n", " Args:\n", " schema (documentai.types.DatasetSchema): The dataset schema to be uploaded.\n", "\n", " Returns:\n", " documentai.types.DatasetSchema: The updated dataset schema returned by the service.\n", "\n", " Example:\n", " from google.cloud import documentai_v1beta3 as documentai\n", "\n", " schema = documentai.DatasetSchema(\n", " # populate schema fields as necessary\n", " )\n", " updated_schema = upload_dataset_schema(schema)\n", " print(updated_schema)\n", " \"\"\"\n", " client = documentai.DocumentServiceClient()\n", " request = documentai.UpdateDatasetSchemaRequest(dataset_schema=schema)\n", " res = client.update_dataset_schema(request=request)\n", " return res\n", "\n", "\n", "def import_documents(\n", " project_id: str, processor_id: str, location: str, gcs_path: str\n", ") -> documentai.types.ImportDocumentsResponse:\n", " \"\"\"\n", " Imports documents from Google Cloud Storage (GCS) into a Document AI processor's dataset.\n", "\n", " This function imports documents into the dataset associated with a specified Document AI\n", " processor, organizing them into training, testing, and unassigned splits based on the\n", " provided GCS path.\n", "\n", " Args:\n", " project_id (str): The ID of the Google Cloud project.\n", " processor_id (str): The ID of the Document AI processor.\n", " location (str): The location of the Document AI processor (e.g., 'us', 'eu').\n", " gcs_path (str): The GCS path prefix where the documents are stored. It should include\n", " the base path with trailing slash.\n", "\n", " Returns:\n", " documentai.types.ImportDocumentsResponse: The response from the import operation.\n", "\n", " Example:\n", " response = import_documents('my-project-id', 'my-processor-id', 'us', 'gs://my-bucket/documents/')\n", " print(response)\n", " \"\"\"\n", " client = documentai.DocumentServiceClient()\n", " dataset = (\n", " f\"projects/{project_id}/locations/{location}/processors/{processor_id}/dataset\"\n", " )\n", " request = documentai.ImportDocumentsRequest(\n", " dataset=dataset,\n", " batch_documents_import_configs=[\n", " {\n", " \"dataset_split\": \"DATASET_SPLIT_TRAIN\",\n", " \"batch_input_config\": {\n", " \"gcs_prefix\": {\"gcs_uri_prefix\": gcs_path + \"train/\"}\n", " },\n", " },\n", " {\n", " \"dataset_split\": \"DATASET_SPLIT_TEST\",\n", " \"batch_input_config\": {\n", " \"gcs_prefix\": {\"gcs_uri_prefix\": gcs_path + \"test/\"}\n", " },\n", " },\n", " {\n", " \"dataset_split\": \"DATASET_SPLIT_UNASSIGNED\",\n", " \"batch_input_config\": {\n", " \"gcs_prefix\": {\"gcs_uri_prefix\": gcs_path + \"unassigned/\"}\n", " },\n", " },\n", " ],\n", " )\n", " response = client.import_documents(request=request)\n", "\n", " return response\n", "\n", "\n", "def retry_function_with_internal_error_handling(\n", " func, max_retries=3, wait_time=10, *args, **kwargs\n", "):\n", " \"\"\"\n", " Runs a function with retry logic if a 500 Internal Server Error is encountered.\n", "\n", " Args:\n", " func (function): The function to be called.\n", " max_retries (int): Maximum number of retries (default is 3).\n", " wait_time (int): Time to wait between retries in seconds (default is 2 seconds).\n", " *args: Positional arguments to pass to the function.\n", " **kwargs: Keyword arguments to pass to the function.\n", "\n", " Returns:\n", " Result of the function if it succeeds within the retry attempts.\n", "\n", " Raises:\n", " Exception: If the function fails after max retries.\n", " \"\"\"\n", " attempt = 0\n", " while attempt < max_retries:\n", " try:\n", " # Call the function with passed arguments\n", " result = func(*args, **kwargs)\n", " return result # If successful, return the result immediately\n", " except Exception as e:\n", " if \"500\" in str(e): # Check if the error is a 500 error\n", " attempt += 1\n", " time.sleep(wait_time) # Wait for a short period before retrying\n", " else:\n", " raise # If it's not a 500 error, raise the error\n", " # If the function fails after all retries\n", " raise Exception(\n", " f\"Function failed after {max_retries} attempts due to repeated 500 errors.\"\n", " )" ] }, { "cell_type": "markdown", "id": "acb5c059-19f9-4582-b7de-08fe3174abf8", "metadata": {}, "source": [ "### Main Function" ] }, { "cell_type": "code", "execution_count": null, "id": "4753ddea-fca4-466f-8e02-eeba9b11dd6d", "metadata": { "tags": [] }, "outputs": [], "source": [ "def process_file(document_path):\n", " storage_client = storage.Client()\n", " source_bucket = storage_client.bucket(export_dataset_path.split(\"/\")[2])\n", " try:\n", " file_name = (\"/\").join(document_path.split(\"/\")[-2:])\n", " print(document_path)\n", " js_json = source_bucket.blob(document_path).download_as_string().decode(\"utf-8\")\n", " merged_pdf = create_pdf_bytes(js_json)\n", " js = json.loads(js_json)\n", " res = retry_function_with_internal_error_handling(\n", " process_document_sample,\n", " project_id=project_id,\n", " location=location,\n", " processor_id=processor_id,\n", " file_path=merged_pdf,\n", " processor_version_id=new_version_id,\n", " )\n", " if res.document.entities:\n", " del res.document.entities\n", " new_js = documentai.Document.to_dict(res.document)\n", " updated_entities = []\n", " for entity in js[\"entities\"]:\n", " # print(entity)\n", " temp_child = []\n", " ent = {}\n", " if \"properties\" in entity.keys() and len(entity[\"properties\"]) != 0:\n", " for child_item in entity[\"properties\"]:\n", " ent_ch = update_text_anchors_mention_text(child_item, js, new_js)\n", " if ent_ch is not None:\n", " temp_child.append(ent_ch)\n", " ent = make_parent_from_child_entities(copy.deepcopy(temp_child), new_js)\n", " ent[\"type\"] = entity[\"type\"]\n", " ent[\"properties\"] = temp_child\n", " else:\n", " ent = update_text_anchors_mention_text(entity, js, new_js)\n", " # pprint(ent)\n", " if ent is not None:\n", " updated_entities.append(ent)\n", " # pprint(updated_entities)\n", " new_js[\"entities\"] = updated_entities\n", " d = documentai.Document.from_json(json.dumps(new_js))\n", " output_bucket_path_prefix = \"/\".join(updated_ocr_files_path.split(\"/\")[3:])\n", " output_file_name = f\"{output_bucket_path_prefix}{file_name}\"\n", " # print(output_file_name)\n", " utilities.store_document_as_json(\n", " json.dumps(new_js), updated_ocr_files_path.split(\"/\")[2], output_file_name\n", " )\n", " except Exception as e:\n", " print(\n", " \">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\"\n", " + document_path\n", " + \" was not processed successfully!!!\"\n", " )\n", " print(e)\n", "\n", "\n", "def process_files_concurrently(file_list, max_workers=5):\n", " # The current limit of 5 parallel processes can be increased by adjusting the max_workers parameter\n", " results = []\n", "\n", " # Use ProcessPoolExecutor for CPU-bound tasks\n", " with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:\n", " for file in file_list:\n", " executor.submit(process_file, file)\n", "\n", "\n", "def main():\n", " results = list_documents(project_id, location, processor_id)\n", " document_list = results.document_metadata\n", " while len(document_list) != results.total_size:\n", " page_token = results.next_page_token\n", " results = list_documents(\n", " project_id, location, processor_id, page_token=page_token\n", " )\n", " document_list.extend(results.document_metadata)\n", " print(\"Exporting Dataset...\")\n", " for doc in tqdm(document_list):\n", " doc_id = doc.document_id\n", " split_type = doc.dataset_type\n", " if split_type == 3:\n", " split = \"unassigned\"\n", " elif split_type == 2:\n", " split = \"test\"\n", " elif split_type == 1:\n", " split = \"train\"\n", " else:\n", " split = \"unknown\"\n", " file_name = doc.display_name\n", " res = get_document(project_id, location, processor_id, doc_id)\n", " exported_path = (\"/\").join(export_dataset_path.split(\"/\")[3:])\n", " output_file_name = f\"{exported_path}/{split}/{file_name}.json\"\n", " json_data = documentai.Document.to_json(res)\n", " utilities.store_document_as_json(\n", " json_data, export_dataset_path.split(\"/\")[2], output_file_name\n", " )\n", "\n", " print(\"Exporting Dataset is completed...\")\n", " exported_schema = get_dataset_schema(project_id, processor_id, location)\n", " exported_schema.name = f\"projects/{project_id}/locations/{location}/processors/{new_processor_id}/dataset/datasetSchema\"\n", " import_schema = upload_dataset_schema(exported_schema)\n", "\n", " document_paths = list(utilities.file_names(export_dataset_path)[1].values())\n", " process_files_concurrently(document_paths)\n", " print(f\"imporing updated OCR documents to {new_processor_id}\")\n", " res = import_documents(\n", " project_id, new_processor_id, location, updated_ocr_files_path\n", " )\n", " print(f\"Waiting for {len(document_paths)*1} seconds to import all documents\")\n", " time.sleep(len(document_paths) * 1)\n", " print(\"All documents have been imported.\")\n", "\n", "\n", "main()" ] } ], "metadata": { "environment": { "kernel": "conda-base-py", "name": "workbench-notebooks.m125", "type": "gcloud", "uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/workbench-notebooks:m125" }, "kernelspec": { "display_name": "Python 3 (ipykernel) (Local)", "language": "python", "name": "conda-base-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.15" } }, "nbformat": 4, "nbformat_minor": 5 }