incubator-tools/paragraph_separation/paragraph_separation.ipynb (578 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"id": "d02cac17-a1c9-4856-85cb-0c5e0b59d4b6",
"metadata": {},
"source": [
"# Paragraph Separation Script"
]
},
{
"cell_type": "markdown",
"id": "cbe6a452-6756-44c0-a139-85a9148ba21b",
"metadata": {},
"source": [
"* Author: docai-incubator@google.com"
]
},
{
"cell_type": "markdown",
"id": "05bfea53-ddc7-40c4-b6c8-9261e20b12a2",
"metadata": {},
"source": [
"## Disclaimer\n",
"\n",
"This tool is not supported by the Google engineering team or product team. It is provided and supported on a best-effort basis by the **DocAI Incubator Team**. No guarantees of performance are implied.\n"
]
},
{
"cell_type": "markdown",
"id": "7143432a-f3d6-4b0d-bd76-ca2092b9eb55",
"metadata": {},
"source": [
"## Objective\n",
"\n",
"This document provides instructions for correcting merged paragraphs identified during the OCR process. The separation is achieved based on specific characters such as (i), (ii), (iii), (a), (b), and so on.\n"
]
},
{
"cell_type": "markdown",
"id": "319ce84a-f609-42ed-8487-259bc5ddbfd7",
"metadata": {},
"source": [
"## Prerequisite\n",
"* Vertex AI Notebook\n",
"* Documents in GCS Folder\n",
"* Output folder to upload fixed documents\n"
]
},
{
"cell_type": "markdown",
"id": "0197c79c-60b8-4d28-b8a8-93388f628729",
"metadata": {},
"source": [
"## Step by Step procedure"
]
},
{
"cell_type": "markdown",
"id": "219968c8-43ad-4514-889c-081c5639e69c",
"metadata": {},
"source": [
"### 1.Importing Required Modules"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6980aec4-321f-4f61-a149-e2dc449ae28f",
"metadata": {},
"outputs": [],
"source": [
"!wget https://raw.githubusercontent.com/GoogleCloudPlatform/document-ai-samples/main/incubator-tools/best-practices/utilities/utilities.py"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5518f964-bac3-4b5b-809f-fda815d97826",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import json\n",
"import os\n",
"import re\n",
"import time\n",
"import warnings\n",
"import utilities\n",
"import io\n",
"import base64\n",
"import gcsfs\n",
"import numpy as np\n",
"import pandas as pd\n",
"import itertools\n",
"\n",
"from itertools import cycle\n",
"from PIL import Image, ImageDraw, ImageFont\n",
"from PyPDF2 import PdfFileReader\n",
"from google.auth import credentials\n",
"from google.cloud import documentai_v1beta3 as documentai\n",
"from google.cloud import storage\n",
"from tqdm import tqdm\n",
"from io import BytesIO\n",
"from pathlib import Path\n",
"from pprint import pprint\n",
"from typing import (\n",
" Container,\n",
" Dict,\n",
" Iterable,\n",
" Iterator,\n",
" List,\n",
" Mapping,\n",
" Optional,\n",
" Sequence,\n",
" Tuple,\n",
" Union,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "238958bf-a224-4e83-a300-03851100b909",
"metadata": {},
"source": [
"### 2.Setup the Inputs\n",
"\n",
"* `input_uri`: This contains the storage bucket path of the input files. \n",
"* `output_bucket_name`: Your output bucket name.\n",
"* `base_file_path`: Base path within the bucket for storing output."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "08906484-2fb8-4c13-a34e-eb9729bd5a81",
"metadata": {},
"outputs": [],
"source": [
"# Input parameters:\n",
"input_uri = \"gs://xxxxxxx/xxxxxxxxxx/xxxxxxxxx/xxxxxxxx/\"\n",
"output_bucket_name = \"xxxxxxxxxx\"\n",
"base_file_path = \"xxxxxx/xxxxxxxx/\" # Base path within the bucket"
]
},
{
"cell_type": "markdown",
"id": "d76fc7b4-2146-4280-85f5-80fd8a5a0fb0",
"metadata": {},
"source": [
"### 3.Run the below functions used in this tool"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0506b57-6f71-482f-adcb-1e3bd75c2e8c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def convert_base64_to_image(base64_text: str):\n",
" \"\"\"\n",
" Converts a base64 encoded text to an image.\n",
"\n",
" Args:\n",
" base64_text (str): A string containing the base64 encoded data of an image.\n",
" It can optionally start with 'data:image/png;base64,'.\n",
"\n",
" Returns:\n",
" Image: An image object created from the base64 encoded data.\n",
" \"\"\"\n",
" try:\n",
" image = Image.open(io.BytesIO(base64_text))\n",
" return image\n",
" except IOError:\n",
" print(\"Error in loading the image. The image data might be corrupted.\")\n",
" return None\n",
"\n",
"\n",
"def highlight_text_in_images(json_data: object) -> None:\n",
" \"\"\"\n",
" Process JSON data to extract images and highlight text segments.\n",
" \"\"\"\n",
" image_pages = []\n",
" for page in json_data.pages:\n",
" tokens = page.paragraphs\n",
" base64_text = page.image.content\n",
" image = convert_base64_to_image(base64_text)\n",
" draw = ImageDraw.Draw(image)\n",
" border_width = 4\n",
" text = json_data.text\n",
"\n",
" color_iterator = itertools.cycle(\n",
" [\"red\", \"green\", \"blue\", \"purple\", \"orange\"]\n",
" ) # Example colors\n",
"\n",
" for entity in tokens:\n",
" try:\n",
" # Initialize variables to store the minimum start index and maximum end index\n",
" min_start_index = float(\"inf\")\n",
" max_end_index = -1\n",
"\n",
" # Iterate over all text segments to find the min start index and max end index\n",
" for segment in entity.layout.text_anchor.text_segments:\n",
" start_index = int(segment.start_index)\n",
" end_index = int(segment.end_index)\n",
" min_start_index = min(min_start_index, start_index)\n",
" max_end_index = max(max_end_index, end_index)\n",
"\n",
" # Extract and clean the substring\n",
" substring = text[min_start_index : max_end_index - 2]\n",
" substring = \"\".join(\n",
" ch for ch in substring if ord(ch) < 128\n",
" ) # Keep only ASCII characters\n",
"\n",
" vertices = [\n",
" (v.x * image.width, v.y * image.height)\n",
" for v in entity.layout.bounding_poly.normalized_vertices\n",
" ]\n",
"\n",
" # Get the next color from the iterator\n",
" border_color = next(color_iterator)\n",
"\n",
" # Draw a border with the selected color\n",
" for i in range(border_width):\n",
" border_vertices = [(v[0] - i - 1, v[1] - i - 1) for v in vertices]\n",
" draw.polygon(border_vertices, outline=border_color)\n",
"\n",
" except KeyError:\n",
" pass\n",
"\n",
" image_pages.append(image)\n",
"\n",
" # Display each image\n",
" for img in image_pages:\n",
" display(img)\n",
"\n",
"\n",
"pattern = r\"\"\"\n",
"(?<!\\w) # Negative lookbehind to ensure the start of a new line or a space before the bullet point\n",
"( # Start capturing group for bullet points\n",
" \\(\\d+\\)| # Number in parentheses, e.g., (1)\n",
" \\([ivxlcIVXLC]+\\)| # Roman numeral in parentheses\n",
" \\d+\\.\\s*| # Number followed by dot and optional space, e.g., 1.\n",
" [ivxlcIVXLC]+\\.\\s*| # Roman numeral followed by dot and optional space\n",
" \\([a-zA-Z]\\)| # Single letter in parentheses, e.g., (a)\n",
" \\(\\d+\\.[A-Z]+\\)| # Number dot and uppercase letter in parentheses, e.g., (1.A)\n",
" [a-zA-Z]\\.\\s*| # Single letter followed by dot and optional space, e.g., a.\n",
" \\d+\\.\\d+ # Decimal numbering, e.g., 1.1, 2.3\n",
")\n",
"(?!\\w) # Negative lookahead to ensure a non-word character follows the bullet point\n",
"\"\"\"\n",
"\n",
"\n",
"def split_into_paragraphs(text: str) -> List:\n",
" matches = list(re.finditer(pattern, text, re.VERBOSE))\n",
" # Initialize a list to store the resulting paragraph indices\n",
" paragraphs_list = []\n",
"\n",
" # Check for matches\n",
" if matches:\n",
" # The start of the first paragraph is the start of the text\n",
" start = 0\n",
" # Loop over the matches\n",
" for match in matches:\n",
" # The end of the current paragraph is the start of the next bullet point\n",
" end = (\n",
" match.start() + 1\n",
" ) # we add 1 because we want to ignore the '\\n' that's captured in the regex\n",
" # Append the current start and end indices to the list if they are not the same\n",
" if start != end:\n",
" paragraphs_list.append((start, end))\n",
" # The start of the next paragraph is the start of the current bullet point\n",
" start = match.start() + 1 # again, we add 1 to ignore the '\\n'\n",
"\n",
" # The end of the last paragraph is the end of the text\n",
" paragraphs_list.append((start, len(text)))\n",
" # print(paragraphs_list)\n",
" else:\n",
" # If no bullet points are found, the entire text is one paragraph\n",
" # paragraphs_list.append((0, len(text)))\n",
" pass\n",
"\n",
" return paragraphs_list\n",
"\n",
"\n",
"def get_token(\n",
" doc: object, page: int, text_anchor: List\n",
") -> Tuple[Dict[str, object], Dict[str, object]]:\n",
" \"\"\"\n",
" Uses loaded JSON, page number, and text anchors as input and gives the text anchors and page anchors.\n",
"\n",
" Args:\n",
" - json_dict (Any): Loaded JSON.\n",
" - page (int): Page number.\n",
" - text_anchors_check (List): List of text anchors.\n",
"\n",
" Returns:\n",
" - Tuple[Dict[str, Any], Dict[str, Any]]: A tuple containing:\n",
" - text_anchors (Dict[str, Any]): Text anchors.\n",
" - page_anchors (Dict[str, Any]): Page anchors.\n",
" \"\"\"\n",
" min_x_normalized = float(\"inf\")\n",
" min_x = float(\"inf\")\n",
" temp_ver_normalized = {\"x\": [], \"y\": []}\n",
" temp_ver = {\"x\": [], \"y\": []}\n",
" temp_text_anc = documentai.Document.TextAnchor()\n",
" temp_confidence = []\n",
" for token in doc.pages[page].tokens:\n",
" if not token.layout.text_anchor.text_segments[0].start_index:\n",
" token.layout.text_anchor.text_segments[0].start_index = 0\n",
" token_anc = token.layout.text_anchor.text_segments[0]\n",
" if token.layout.text_anchor.text_segments == text_anchor.text_segments:\n",
" text_temp = doc.text[\n",
" int(token.layout.text_anchor.text_segments[0].start_index) : int(\n",
" token.layout.text_anchor.text_segments[0].end_index\n",
" )\n",
" ]\n",
" if len(text_temp) > 2 or (\"\\n\" not in text_temp and len(text_temp) <= 2):\n",
" vertices = token.layout.bounding_poly\n",
" min_x_normalized = min(\n",
" vertex.x for vertex in vertices.normalized_vertices\n",
" )\n",
" min_y_normalized = min(\n",
" vertex.y for vertex in vertices.normalized_vertices\n",
" )\n",
" max_x_normalized = max(\n",
" vertex.x for vertex in vertices.normalized_vertices\n",
" )\n",
" max_y_normalized = max(\n",
" vertex.y for vertex in vertices.normalized_vertices\n",
" )\n",
" min_x = min(vertex.x for vertex in vertices.vertices)\n",
" min_y = min(vertex.y for vertex in vertices.vertices)\n",
" max_x = max(vertex.x for vertex in vertices.vertices)\n",
" max_y = max(vertex.y for vertex in vertices.vertices)\n",
" confidence = token.layout.confidence\n",
" temp_text_anc.text_segments = token.layout.text_anchor.text_segments\n",
" elif (\n",
" int(token_anc.start_index)\n",
" >= int(text_anchor.text_segments[0].start_index) - 2\n",
" and int(token_anc.end_index)\n",
" <= int(text_anchor.text_segments[0].end_index) + 2\n",
" ):\n",
" text_temp = doc.text[\n",
" int(token.layout.text_anchor.text_segments[0].start_index) : int(\n",
" token.layout.text_anchor.text_segments[0].end_index\n",
" )\n",
" ]\n",
" if len(text_temp) > 2 or (\"\\n\" not in text_temp and len(text_temp) <= 2):\n",
" vertices = token.layout.bounding_poly\n",
" min_x_normalized = min(\n",
" vertex.x for vertex in vertices.normalized_vertices\n",
" )\n",
" min_y_normalized = min(\n",
" vertex.y for vertex in vertices.normalized_vertices\n",
" )\n",
" max_x_normalized = max(\n",
" vertex.x for vertex in vertices.normalized_vertices\n",
" )\n",
" max_y_normalized = max(\n",
" vertex.y for vertex in vertices.normalized_vertices\n",
" )\n",
" min_x = min(vertex.x for vertex in vertices.vertices)\n",
" min_y = min(vertex.y for vertex in vertices.vertices)\n",
" max_x = max(vertex.x for vertex in vertices.vertices)\n",
" max_y = max(vertex.y for vertex in vertices.vertices)\n",
" temp_ver_normalized[\"x\"].extend([min_x_normalized, max_x_normalized])\n",
" temp_ver_normalized[\"y\"].extend([min_y_normalized, max_y_normalized])\n",
" temp_ver[\"x\"].extend([min_x, max_x])\n",
" temp_ver[\"y\"].extend([min_y, max_y])\n",
" text_anc_token = token.layout.text_anchor.text_segments\n",
" for an1 in text_anc_token:\n",
" temp_text_anc.text_segments.append(an1)\n",
" confidence = token.layout.confidence\n",
" temp_confidence.append(confidence)\n",
" if min_x_normalized == float(\"inf\") or min_x == float(\"inf\"):\n",
" for token in doc.pages[page].tokens:\n",
" if not token.layout.text_anchor.text_segments[0].start_index:\n",
" token.layout.text_anchor.text_segments[0].start_index = 0\n",
" if (\n",
" abs(\n",
" int(token.layout.text_anchor.text_segments[0].start_index)\n",
" - int(token.layout.text_anchor.text_segments[0].end_index)\n",
" )\n",
" <= 2\n",
" ):\n",
" text_temp = doc.text[\n",
" int(token.layout.text_anchor.text_segments[0].start_index) : int(\n",
" token.layout.text_anchor.text_segments[0].end_index\n",
" )\n",
" ]\n",
" vertices = token.layout.bounding_poly\n",
" min_x_normalized = min(\n",
" vertex.x for vertex in vertices.normalized_vertices\n",
" )\n",
" min_y_normalized = min(\n",
" vertex.y for vertex in vertices.normalized_vertices\n",
" )\n",
" max_x_normalized = max(\n",
" vertex.x for vertex in vertices.normalized_vertices\n",
" )\n",
" max_y_normalized = max(\n",
" vertex.y for vertex in vertices.normalized_vertices\n",
" )\n",
" min_x = min(vertex.x for vertex in vertices.vertices)\n",
" min_y = min(vertex.y for vertex in vertices.vertices)\n",
" max_x = max(vertex.x for vertex in vertices.vertices)\n",
" max_y = max(vertex.y for vertex in vertices.vertices)\n",
" temp_text_anc.text_segments = token.layout.text_anchor.text_segments\n",
" confidence = token.layout.confidence\n",
" if len(temp_text_anc.text_segments) != 0:\n",
" final_ver_normalized = {\n",
" \"min_x\": min(temp_ver_normalized[\"x\"]),\n",
" \"min_y\": min(temp_ver_normalized[\"y\"]),\n",
" \"max_x\": max(temp_ver_normalized[\"x\"]),\n",
" \"max_y\": max(temp_ver_normalized[\"y\"]),\n",
" }\n",
" final_ver = {\n",
" \"min_x\": min(temp_ver[\"x\"]),\n",
" \"min_y\": min(temp_ver[\"y\"]),\n",
" \"max_x\": max(temp_ver[\"x\"]),\n",
" \"max_y\": max(temp_ver[\"y\"]),\n",
" }\n",
" final_confidence = min(temp_confidence)\n",
" final_text_anc = sorted(temp_text_anc.text_segments, key=lambda x: x.end_index)\n",
" return final_ver, final_ver_normalized, final_text_anc, final_confidence\n",
" else:\n",
" return (\n",
" {\"min_x\": min_x, \"min_y\": min_y, \"max_x\": max_x, \"max_y\": max_y},\n",
" {\n",
" \"min_x\": min_x_normalized,\n",
" \"min_y\": min_y_normalized,\n",
" \"max_x\": max_x_normalized,\n",
" \"max_y\": max_y_normalized,\n",
" },\n",
" text_anc_token,\n",
" confidence,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e5c44925-6156-4a09-8f7a-14fe569eb8cc",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"list_of_files, file_name_dict = utilities.file_names(input_uri)\n",
"input_bucket_name = input_uri.split(\"/\")[2]\n",
"for i in list_of_files:\n",
" doc = utilities.documentai_json_proto_downloader(\n",
" input_bucket_name, file_name_dict[i]\n",
" )\n",
" text = doc.text\n",
" for page_number, page in enumerate(doc.pages):\n",
" new_paragraphs = []\n",
" paragraph_indices = split_into_paragraphs(text)\n",
"\n",
" if len(paragraph_indices) > 1:\n",
" for index in paragraph_indices:\n",
" try:\n",
" start_index = index[0]\n",
" end_index = index[1] - 3\n",
" new_paragraph = documentai.Document.Page.Paragraph()\n",
" text_segment = documentai.Document.TextAnchor.TextSegment()\n",
" text_segment.start_index = start_index\n",
" text_segment.end_index = end_index\n",
" new_paragraph.layout.text_anchor.text_segments = [text_segment]\n",
" (\n",
" vertices,\n",
" normalized_vertices,\n",
" text_segments,\n",
" confidence,\n",
" ) = get_token(doc, page_number, new_paragraph.layout.text_anchor)\n",
" new_paragraph.layout.text_anchor.text_segments = text_segments\n",
" new_paragraph.layout.bounding_poly.vertices = [\n",
" {\"x\": vertices[\"min_x\"], \"y\": vertices[\"min_y\"]},\n",
" {\"x\": vertices[\"max_x\"], \"y\": vertices[\"min_y\"]},\n",
" {\"x\": vertices[\"max_x\"], \"y\": vertices[\"max_y\"]},\n",
" {\"x\": vertices[\"min_x\"], \"y\": vertices[\"max_y\"]},\n",
" ]\n",
" new_paragraph.layout.bounding_poly.normalized_vertices = [\n",
" {\n",
" \"x\": normalized_vertices[\"min_x\"],\n",
" \"y\": normalized_vertices[\"min_y\"],\n",
" },\n",
" {\n",
" \"x\": normalized_vertices[\"max_x\"],\n",
" \"y\": normalized_vertices[\"min_y\"],\n",
" },\n",
" {\n",
" \"x\": normalized_vertices[\"max_x\"],\n",
" \"y\": normalized_vertices[\"max_y\"],\n",
" },\n",
" {\n",
" \"x\": normalized_vertices[\"min_x\"],\n",
" \"y\": normalized_vertices[\"max_y\"],\n",
" },\n",
" ]\n",
" new_paragraphs.append(new_paragraph)\n",
" except:\n",
" pass\n",
"\n",
" page.paragraphs.clear()\n",
" page.paragraphs.extend(new_paragraphs)\n",
"\n",
" highlight_text_in_images(doc)\n",
" file_name_only = file_name_dict[i].split(\"/\")[-1]\n",
" full_file_path = base_file_path + file_name_only\n",
" utilities.store_document_as_json(\n",
" documentai.Document.to_json(doc), output_bucket_name, full_file_path\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "1b5df15c-9839-4fd7-9ee1-6282da97b12d",
"metadata": {},
"source": [
"## Results\n",
"\n",
"The fixed documents are saved in the output bucket which you have provided in the script with the same folder structure in input URI."
]
},
{
"cell_type": "markdown",
"id": "8609a1df-16e6-45d3-984e-3f5754c49f07",
"metadata": {
"tags": []
},
"source": [
"<img src=\"./Images/paragraph_1.png\" width=800 height=400></img>\n",
"<img src=\"./Images/paragraph_2.png\" width=800 height=400></img>"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12bbd529-2a33-430f-994e-b137c69c9916",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"environment": {
"kernel": "conda-root-py",
"name": "workbench-notebooks.m113",
"type": "gcloud",
"uri": "gcr.io/deeplearning-platform-release/workbench-notebooks:m113"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel) (Local)",
"language": "python",
"name": "conda-root-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}