# Paragraph Separation Script

* Author: docai-incubator@google.com

## Disclaimer

This tool is not supported by the Google engineering team or product team. It is provided and supported on a best-effort basis by the **DocAI Incubator Team**. No guarantees of performance are implied.


## Objective

This document provides instructions for correcting merged paragraphs identified during the OCR process. The separation is achieved based on specific characters such as (i), (ii), (iii), (a), (b), and so on.


## Prerequisite
* Vertex AI Notebook
* Documents in GCS Folder
* Output folder to upload fixed documents


## Step by Step procedure

### 1.Importing Required Modules

In [None]:
!wget https://raw.githubusercontent.com/GoogleCloudPlatform/document-ai-samples/main/incubator-tools/best-practices/utilities/utilities.py

In [None]:
import json
import os
import re
import time
import warnings
import utilities
import io
import base64
import gcsfs
import numpy as np
import pandas as pd
import itertools

from itertools import cycle
from PIL import Image, ImageDraw, ImageFont
from PyPDF2 import PdfFileReader
from google.auth import credentials
from google.cloud import documentai_v1beta3 as documentai
from google.cloud import storage
from tqdm import tqdm
from io import BytesIO
from pathlib import Path
from pprint import pprint
from typing import (
    Container,
    Dict,
    Iterable,
    Iterator,
    List,
    Mapping,
    Optional,
    Sequence,
    Tuple,
    Union,
)

### 2.Setup the Inputs

* `input_uri`: This contains the storage bucket path of the input files.  
* `output_bucket_name`: Your output bucket name.
* `base_file_path`: Base path within the bucket for storing output.

In [None]:
# Input parameters:
input_uri = "gs://xxxxxxx/xxxxxxxxxx/xxxxxxxxx/xxxxxxxx/"
output_bucket_name = "xxxxxxxxxx"
base_file_path = "xxxxxx/xxxxxxxx/"  # Base path within the bucket

### 3.Run the below functions used in this tool

In [None]:
def convert_base64_to_image(base64_text: str):
    """
    Converts a base64 encoded text to an image.

    Args:
    base64_text (str): A string containing the base64 encoded data of an image.
                      It can optionally start with 'data:image/png;base64,'.

    Returns:
    Image: An image object created from the base64 encoded data.
    """
    try:
        image = Image.open(io.BytesIO(base64_text))
        return image
    except IOError:
        print("Error in loading the image. The image data might be corrupted.")
        return None


def highlight_text_in_images(json_data: object) -> None:
    """
    Process JSON data to extract images and highlight text segments.
    """
    image_pages = []
    for page in json_data.pages:
        tokens = page.paragraphs
        base64_text = page.image.content
        image = convert_base64_to_image(base64_text)
        draw = ImageDraw.Draw(image)
        border_width = 4
        text = json_data.text

        color_iterator = itertools.cycle(
            ["red", "green", "blue", "purple", "orange"]
        )  # Example colors

        for entity in tokens:
            try:
                # Initialize variables to store the minimum start index and maximum end index
                min_start_index = float("inf")
                max_end_index = -1

                # Iterate over all text segments to find the min start index and max end index
                for segment in entity.layout.text_anchor.text_segments:
                    start_index = int(segment.start_index)
                    end_index = int(segment.end_index)
                    min_start_index = min(min_start_index, start_index)
                    max_end_index = max(max_end_index, end_index)

                # Extract and clean the substring
                substring = text[min_start_index : max_end_index - 2]
                substring = "".join(
                    ch for ch in substring if ord(ch) < 128
                )  # Keep only ASCII characters

                vertices = [
                    (v.x * image.width, v.y * image.height)
                    for v in entity.layout.bounding_poly.normalized_vertices
                ]

                # Get the next color from the iterator
                border_color = next(color_iterator)

                # Draw a border with the selected color
                for i in range(border_width):
                    border_vertices = [(v[0] - i - 1, v[1] - i - 1) for v in vertices]
                    draw.polygon(border_vertices, outline=border_color)

            except KeyError:
                pass

        image_pages.append(image)

    # Display each image
    for img in image_pages:
        display(img)


pattern = r"""
(?<!\w)  # Negative lookbehind to ensure the start of a new line or a space before the bullet point
(        # Start capturing group for bullet points
    \(\d+\)|                 # Number in parentheses, e.g., (1)
    \([ivxlcIVXLC]+\)|       # Roman numeral in parentheses
    \d+\.\s*|                # Number followed by dot and optional space, e.g., 1.
    [ivxlcIVXLC]+\.\s*|      # Roman numeral followed by dot and optional space
    \([a-zA-Z]\)|            # Single letter in parentheses, e.g., (a)
    \(\d+\.[A-Z]+\)|         # Number dot and uppercase letter in parentheses, e.g., (1.A)
    [a-zA-Z]\.\s*|           # Single letter followed by dot and optional space, e.g., a.
    \d+\.\d+                 # Decimal numbering, e.g., 1.1, 2.3
)
(?!\w)   # Negative lookahead to ensure a non-word character follows the bullet point
"""


def split_into_paragraphs(text: str) -> List:
    matches = list(re.finditer(pattern, text, re.VERBOSE))
    # Initialize a list to store the resulting paragraph indices
    paragraphs_list = []

    # Check for matches
    if matches:
        # The start of the first paragraph is the start of the text
        start = 0
        # Loop over the matches
        for match in matches:
            # The end of the current paragraph is the start of the next bullet point
            end = (
                match.start() + 1
            )  # we add 1 because we want to ignore the '\n' that's captured in the regex
            # Append the current start and end indices to the list if they are not the same
            if start != end:
                paragraphs_list.append((start, end))
            # The start of the next paragraph is the start of the current bullet point
            start = match.start() + 1  # again, we add 1 to ignore the '\n'

        # The end of the last paragraph is the end of the text
        paragraphs_list.append((start, len(text)))
        # print(paragraphs_list)
    else:
        # If no bullet points are found, the entire text is one paragraph
        # paragraphs_list.append((0, len(text)))
        pass

    return paragraphs_list


def get_token(
    doc: object, page: int, text_anchor: List
) -> Tuple[Dict[str, object], Dict[str, object]]:
    """
    Uses loaded JSON, page number, and text anchors as input and gives the text anchors and page anchors.

    Args:
    - json_dict (Any): Loaded JSON.
    - page (int): Page number.
    - text_anchors_check (List): List of text anchors.

    Returns:
    - Tuple[Dict[str, Any], Dict[str, Any]]: A tuple containing:
        - text_anchors (Dict[str, Any]): Text anchors.
        - page_anchors (Dict[str, Any]): Page anchors.
    """
    min_x_normalized = float("inf")
    min_x = float("inf")
    temp_ver_normalized = {"x": [], "y": []}
    temp_ver = {"x": [], "y": []}
    temp_text_anc = documentai.Document.TextAnchor()
    temp_confidence = []
    for token in doc.pages[page].tokens:
        if not token.layout.text_anchor.text_segments[0].start_index:
            token.layout.text_anchor.text_segments[0].start_index = 0
        token_anc = token.layout.text_anchor.text_segments[0]
        if token.layout.text_anchor.text_segments == text_anchor.text_segments:
            text_temp = doc.text[
                int(token.layout.text_anchor.text_segments[0].start_index) : int(
                    token.layout.text_anchor.text_segments[0].end_index
                )
            ]
            if len(text_temp) > 2 or ("\n" not in text_temp and len(text_temp) <= 2):
                vertices = token.layout.bounding_poly
                min_x_normalized = min(
                    vertex.x for vertex in vertices.normalized_vertices
                )
                min_y_normalized = min(
                    vertex.y for vertex in vertices.normalized_vertices
                )
                max_x_normalized = max(
                    vertex.x for vertex in vertices.normalized_vertices
                )
                max_y_normalized = max(
                    vertex.y for vertex in vertices.normalized_vertices
                )
                min_x = min(vertex.x for vertex in vertices.vertices)
                min_y = min(vertex.y for vertex in vertices.vertices)
                max_x = max(vertex.x for vertex in vertices.vertices)
                max_y = max(vertex.y for vertex in vertices.vertices)
                confidence = token.layout.confidence
                temp_text_anc.text_segments = token.layout.text_anchor.text_segments
        elif (
            int(token_anc.start_index)
            >= int(text_anchor.text_segments[0].start_index) - 2
            and int(token_anc.end_index)
            <= int(text_anchor.text_segments[0].end_index) + 2
        ):
            text_temp = doc.text[
                int(token.layout.text_anchor.text_segments[0].start_index) : int(
                    token.layout.text_anchor.text_segments[0].end_index
                )
            ]
            if len(text_temp) > 2 or ("\n" not in text_temp and len(text_temp) <= 2):
                vertices = token.layout.bounding_poly
                min_x_normalized = min(
                    vertex.x for vertex in vertices.normalized_vertices
                )
                min_y_normalized = min(
                    vertex.y for vertex in vertices.normalized_vertices
                )
                max_x_normalized = max(
                    vertex.x for vertex in vertices.normalized_vertices
                )
                max_y_normalized = max(
                    vertex.y for vertex in vertices.normalized_vertices
                )
                min_x = min(vertex.x for vertex in vertices.vertices)
                min_y = min(vertex.y for vertex in vertices.vertices)
                max_x = max(vertex.x for vertex in vertices.vertices)
                max_y = max(vertex.y for vertex in vertices.vertices)
                temp_ver_normalized["x"].extend([min_x_normalized, max_x_normalized])
                temp_ver_normalized["y"].extend([min_y_normalized, max_y_normalized])
                temp_ver["x"].extend([min_x, max_x])
                temp_ver["y"].extend([min_y, max_y])
                text_anc_token = token.layout.text_anchor.text_segments
                for an1 in text_anc_token:
                    temp_text_anc.text_segments.append(an1)
                confidence = token.layout.confidence
                temp_confidence.append(confidence)
    if min_x_normalized == float("inf") or min_x == float("inf"):
        for token in doc.pages[page].tokens:
            if not token.layout.text_anchor.text_segments[0].start_index:
                token.layout.text_anchor.text_segments[0].start_index = 0
            if (
                abs(
                    int(token.layout.text_anchor.text_segments[0].start_index)
                    - int(token.layout.text_anchor.text_segments[0].end_index)
                )
                <= 2
            ):
                text_temp = doc.text[
                    int(token.layout.text_anchor.text_segments[0].start_index) : int(
                        token.layout.text_anchor.text_segments[0].end_index
                    )
                ]
                vertices = token.layout.bounding_poly
                min_x_normalized = min(
                    vertex.x for vertex in vertices.normalized_vertices
                )
                min_y_normalized = min(
                    vertex.y for vertex in vertices.normalized_vertices
                )
                max_x_normalized = max(
                    vertex.x for vertex in vertices.normalized_vertices
                )
                max_y_normalized = max(
                    vertex.y for vertex in vertices.normalized_vertices
                )
                min_x = min(vertex.x for vertex in vertices.vertices)
                min_y = min(vertex.y for vertex in vertices.vertices)
                max_x = max(vertex.x for vertex in vertices.vertices)
                max_y = max(vertex.y for vertex in vertices.vertices)
                temp_text_anc.text_segments = token.layout.text_anchor.text_segments
                confidence = token.layout.confidence
    if len(temp_text_anc.text_segments) != 0:
        final_ver_normalized = {
            "min_x": min(temp_ver_normalized["x"]),
            "min_y": min(temp_ver_normalized["y"]),
            "max_x": max(temp_ver_normalized["x"]),
            "max_y": max(temp_ver_normalized["y"]),
        }
        final_ver = {
            "min_x": min(temp_ver["x"]),
            "min_y": min(temp_ver["y"]),
            "max_x": max(temp_ver["x"]),
            "max_y": max(temp_ver["y"]),
        }
        final_confidence = min(temp_confidence)
        final_text_anc = sorted(temp_text_anc.text_segments, key=lambda x: x.end_index)
        return final_ver, final_ver_normalized, final_text_anc, final_confidence
    else:
        return (
            {"min_x": min_x, "min_y": min_y, "max_x": max_x, "max_y": max_y},
            {
                "min_x": min_x_normalized,
                "min_y": min_y_normalized,
                "max_x": max_x_normalized,
                "max_y": max_y_normalized,
            },
            text_anc_token,
            confidence,
        )

In [None]:
list_of_files, file_name_dict = utilities.file_names(input_uri)
input_bucket_name = input_uri.split("/")[2]
for i in list_of_files:
    doc = utilities.documentai_json_proto_downloader(
        input_bucket_name, file_name_dict[i]
    )
    text = doc.text
    for page_number, page in enumerate(doc.pages):
        new_paragraphs = []
        paragraph_indices = split_into_paragraphs(text)

        if len(paragraph_indices) > 1:
            for index in paragraph_indices:
                try:
                    start_index = index[0]
                    end_index = index[1] - 3
                    new_paragraph = documentai.Document.Page.Paragraph()
                    text_segment = documentai.Document.TextAnchor.TextSegment()
                    text_segment.start_index = start_index
                    text_segment.end_index = end_index
                    new_paragraph.layout.text_anchor.text_segments = [text_segment]
                    (
                        vertices,
                        normalized_vertices,
                        text_segments,
                        confidence,
                    ) = get_token(doc, page_number, new_paragraph.layout.text_anchor)
                    new_paragraph.layout.text_anchor.text_segments = text_segments
                    new_paragraph.layout.bounding_poly.vertices = [
                        {"x": vertices["min_x"], "y": vertices["min_y"]},
                        {"x": vertices["max_x"], "y": vertices["min_y"]},
                        {"x": vertices["max_x"], "y": vertices["max_y"]},
                        {"x": vertices["min_x"], "y": vertices["max_y"]},
                    ]
                    new_paragraph.layout.bounding_poly.normalized_vertices = [
                        {
                            "x": normalized_vertices["min_x"],
                            "y": normalized_vertices["min_y"],
                        },
                        {
                            "x": normalized_vertices["max_x"],
                            "y": normalized_vertices["min_y"],
                        },
                        {
                            "x": normalized_vertices["max_x"],
                            "y": normalized_vertices["max_y"],
                        },
                        {
                            "x": normalized_vertices["min_x"],
                            "y": normalized_vertices["max_y"],
                        },
                    ]
                    new_paragraphs.append(new_paragraph)
                except:
                    pass

            page.paragraphs.clear()
            page.paragraphs.extend(new_paragraphs)

    highlight_text_in_images(doc)
    file_name_only = file_name_dict[i].split("/")[-1]
    full_file_path = base_file_path + file_name_only
    utilities.store_document_as_json(
        documentai.Document.to_json(doc), output_bucket_name, full_file_path
    )

## Results

The fixed documents are saved in the output bucket which you have provided in the script with the same folder structure in input URI.

<img src="./Images/paragraph_1.png" width=800 height=400></img>
<img src="./Images/paragraph_2.png" width=800 height=400></img>