incubator-tools/line_item_comparision/line_item_comparision.ipynb (714 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "id": "f3e25c70-df2b-40c3-83ed-90344f0915c0", "metadata": {}, "source": [ "# Line-Item Comparison Notebook" ] }, { "cell_type": "markdown", "id": "0c1b18bd-63ad-40dd-80e3-bf10fc5358f4", "metadata": {}, "source": [ "* Author: docai-incubator@google.com" ] }, { "cell_type": "markdown", "id": "493f2803-002d-440a-bb2b-f5b3a486634b", "metadata": {}, "source": [ "# Disclaimer\n", "This tool is not supported by the Google engineering team or product team. It is provided and supported on a best-effort basis by the DocAI Incubator Team. No guarantees of performance are implied." ] }, { "cell_type": "markdown", "id": "d7ab4feb-31da-4aaf-aa3b-964d5f26951d", "metadata": { "tags": [] }, "source": [ "# Objective\n", "This notebook is designed to compare JSON schemas using Google's Document AI and other processing tools. It includes functionality for fuzzy matching and schema comparison." ] }, { "cell_type": "markdown", "id": "741e11a1-a5d1-4d10-bf6d-409fc7c3ff35", "metadata": {}, "source": [ "# Prerequisite\n", "* Vertex AI Notebook\n", "* Parsed json files in a GCS Folder\n", "* GCS folders with Ground truth, parsed jsons and post processed jsons " ] }, { "cell_type": "markdown", "id": "daa663b0-0d5c-4550-b9b1-4a217636f1b0", "metadata": {}, "source": [ "# Step by Step procedure" ] }, { "cell_type": "markdown", "id": "383e3490-db7d-408d-9a9d-6fb3a8adc2ce", "metadata": {}, "source": [ "# 1. Imports\n", "\n", "Import necessary libraries for processing." ] }, { "cell_type": "code", "execution_count": null, "id": "ca64330b-923b-4227-a08c-d7da9fc603d3", "metadata": {}, "outputs": [], "source": [ "# Download incubator-tools utilities module to present-working-directory\n", "!wget https://raw.githubusercontent.com/GoogleCloudPlatform/document-ai-samples/main/incubator-tools/best-practices/utilities/utilities.py" ] }, { "cell_type": "code", "execution_count": null, "id": "53021ecc-8ebd-4e29-a32a-aa8db5c5b81d", "metadata": {}, "outputs": [], "source": [ "!pip install google-cloud-storage fuzzywuzzy pandas google-cloud-documentai -q" ] }, { "cell_type": "code", "execution_count": null, "id": "68c969cc-99f0-4a99-8098-ea12f0d8795c", "metadata": {}, "outputs": [], "source": [ "from google.cloud import storage\n", "from fuzzywuzzy import fuzz\n", "import pandas as pd\n", "from pprint import pprint\n", "import utilities\n", "from google.cloud import documentai_v1beta3 as documentai" ] }, { "cell_type": "markdown", "id": "56c2c7dc-d756-4471-b284-a795700e9c85", "metadata": {}, "source": [ "# 2. Input Details" ] }, { "cell_type": "markdown", "id": "177bec11-7c43-49a3-a7ff-d8384bfd662f", "metadata": {}, "source": [ "* **project_id** : Give your GCP Project ID\n", "* **gt_jsons_uri** : It is GCS path which contains ground-truth JSON files\n", "* **parsed_jsons_uri** : It is GCS path which contains document-processed JSON results\n", "* **post_processed_jsons_uri** : It is GCS path which contains document-processed JSON results\n", "\n", "**NOTE**:\n", "* Here all GCS paths should ends-with trailing-slash(`/`)\n", "* The file names have to be same in all the folders which contains Ground truth, parsed and post processed jsons" ] }, { "cell_type": "code", "execution_count": null, "id": "976aca5f-178e-4727-a82f-5ddb48c129a6", "metadata": {}, "outputs": [], "source": [ "project_id = \"xxxx-xxxx-xxxx\"\n", "gt_jsons_uri = \"gs://xx/xxx/xxxx/\"\n", "parsed_jsons_uri = \"gs://xx/xxxx/xxxx/xx/\"\n", "post_processed_jsons_uri = \"gs://xx/xxx/xxxx/xx/\"" ] }, { "cell_type": "markdown", "id": "5cb7cd21-98d4-4eea-91c2-da092097df7d", "metadata": {}, "source": [ "# 3. Script Execution" ] }, { "cell_type": "markdown", "id": "aba01db2-3ca6-4a5f-a21b-4601bccb9c6a", "metadata": {}, "source": [ "### Main Comparison Function\n", "\n", "This function compares two document line_items and returns a DataFrame with the comparison results.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b3635f82-bdbd-4054-98eb-e84cbb95b202", "metadata": {}, "outputs": [], "source": [ "def get_comparision_dataframe(doc_gt, doc_pp):\n", " \"\"\"\n", " Compares two document schemas and returns a DataFrame with the comparison results.\n", "\n", " Parameters:\n", " doc_gt (Document): Ground Truth document schema.\n", " doc_pp (Document): Post-processed document schema.\n", "\n", " Returns:\n", " DataFrame: A DataFrame containing the comparison results.\n", " \"\"\"\n", "\n", " def get_line_items(doc):\n", " line_items = []\n", " line_dict = {}\n", " sub_items = []\n", " import pandas as pd\n", "\n", " df = pd.DataFrame(columns=[\"line\", \"id\", \"type\", \"mentionText\", \"ver\"])\n", "\n", " for entity1 in doc.entities:\n", " if entity1.properties:\n", " if entity1.type == \"line_item\":\n", " line_items.append(entity1)\n", " for subitem in entity1.properties:\n", " sub_items.append(subitem)\n", "\n", " for i in range(len(line_items)):\n", " for ent in line_items[i].properties:\n", " if i in line_dict.keys():\n", " line_dict[i].append(ent)\n", " else:\n", " line_dict[i] = [ent]\n", "\n", " def get_min_max_ver(ent1):\n", " x1 = []\n", " y1 = []\n", " p = 0\n", " try:\n", " if ent1.page_anchor.page_refs[0].page:\n", " p = ent1.page_anchor.page_refs[0].page\n", "\n", " except:\n", " p = 0\n", " try:\n", " for ver in ent1.page_anchor.page_refs[\n", " 0\n", " ].bounding_poly.normalized_vertices:\n", " x1.append(ver.x)\n", " y1.append(ver.y)\n", " a = [\n", " {\"page\": p},\n", " {\"x\": min(x1), \"y\": min(y1)},\n", " {\"x\": max(x1), \"y\": max(y1)},\n", " ]\n", " except:\n", " pass\n", " return a\n", "\n", " dict_line_1 = {}\n", "\n", " for line_1, entities_1 in line_dict.items():\n", " for entity in entities_1:\n", " try:\n", " entity_id = entity.id\n", " except:\n", " entity_id = \"\"\n", "\n", " if line_1 in dict_line_1.keys():\n", " temp_df = pd.DataFrame(\n", " [\n", " [\n", " line_1,\n", " entity_id,\n", " entity.type,\n", " entity.mention_text,\n", " get_min_max_ver(entity),\n", " ]\n", " ],\n", " columns=df.columns,\n", " )\n", " df = pd.concat([df, temp_df])\n", " dict_line_1[line_1].append(\n", " {\n", " entity_id: [\n", " {\"type\": entity.type},\n", " {\"mentionText\": entity.mention_text},\n", " {\"ver\": get_min_max_ver(entity)},\n", " ]\n", " }\n", " )\n", " else:\n", " dict_line_1[line_1] = [\n", " {\n", " entity_id: [\n", " {\"type\": entity.type},\n", " {\"mentionText\": entity.mention_text},\n", " {\"ver\": get_min_max_ver(entity)},\n", " ]\n", " }\n", " ]\n", " temp_df = pd.DataFrame(\n", " [\n", " [\n", " line_1,\n", " entity_id,\n", " entity.type,\n", " entity.mention_text,\n", " get_min_max_ver(entity),\n", " ]\n", " ],\n", " columns=df.columns,\n", " )\n", " df = pd.concat([df, temp_df])\n", "\n", " return dict_line_1, df, line_dict, sub_items\n", "\n", " def BBoxOverlap(entity1, entity2):\n", " def valid_bbox_iou(gt_bbox, pred_bbox) -> bool:\n", " \"\"\"Returns true if two bbox overlap less than minimal_iou.\"\"\"\n", " if len(gt_bbox.normalized_vertices) != 4:\n", " return False\n", " if len(pred_bbox.normalized_vertices) != 4:\n", " return True\n", " # bbox represent as [x_min, x_max, y_min, y_max]\n", " bbox1 = get_bounding_bbox(gt_bbox)\n", " bbox2 = get_bounding_bbox(pred_bbox)\n", " xmin = max(bbox1[0], bbox2[0])\n", " xmax = min(bbox1[1], bbox2[1])\n", " ymin = max(bbox1[2], bbox2[2])\n", " ymax = min(bbox1[3], bbox2[3])\n", " intersection_area = max(xmax - xmin, 0.0) * max(ymax - ymin, 0.0)\n", " union_area = (\n", " (bbox1[1] - bbox1[0]) * (bbox1[3] - bbox1[2])\n", " + (bbox2[1] - bbox2[0]) * (bbox2[3] - bbox2[2])\n", " - intersection_area\n", " )\n", " if union_area < 1e-10:\n", " return True\n", " iou = intersection_area / union_area\n", " return xmax > xmin and ymax > ymin and iou >= 0.2\n", "\n", " def get_bounding_bbox(bbox):\n", " \"\"\"Returns the list representation for the bounding box.\"\"\"\n", " x_coordinates = get_bounding_poly_x(bbox)\n", " y_coordinates = get_bounding_poly_y(bbox)\n", " # bbox represent as [x_min, x_max, y_min, y_max]\n", " return [\n", " min(x_coordinates),\n", " max(x_coordinates),\n", " min(y_coordinates),\n", " max(y_coordinates),\n", " ]\n", "\n", " def get_bounding_poly_x(bounding_poly):\n", " \"\"\"Returns the list for x coordinates for the bounding poly.\"\"\"\n", " return [\n", " normalized_vertices.x\n", " for normalized_vertices in bounding_poly.normalized_vertices\n", " ]\n", "\n", " def get_bounding_poly_y(bounding_poly):\n", " \"\"\"Returns the list for y coordinates for the bounding poly.\"\"\"\n", " return [\n", " normalized_vertices.y\n", " for normalized_vertices in bounding_poly.normalized_vertices\n", " ]\n", "\n", " gt_bbox = entity1.page_anchor.page_refs[0].bounding_poly\n", " pred_bbox = entity2.page_anchor.page_refs[0].bounding_poly\n", " return valid_bbox_iou(gt_bbox, pred_bbox)\n", "\n", " dict_line_GT, df1, line_dict_GT, sub_items_GT = get_line_items(doc_gt)\n", " dict_line_PP, df2, line_dict_PP, sub_items_pp = get_line_items(doc_pp)\n", " df1.to_csv(\"GT_line.csv\")\n", " df2.to_csv(\"post_line.csv\")\n", "\n", " def check_page_match(ent1, ent2):\n", " p1 = \"\"\n", " p2 = \"\"\n", " try:\n", " if ent1.page_anchor.page_refs[0].page:\n", " p1 = ent1.page_anchor.page_refs[0].page\n", "\n", " except:\n", " p1 = 0\n", "\n", " try:\n", " if ent2.page_anchor.page_refs[0].page:\n", " p2 = ent2.page_anchor.page_refs[0].page\n", " except:\n", " p2 = 0\n", "\n", " if p1 == p2:\n", " return True\n", " elif p1 != p2:\n", " return False\n", "\n", " entities_match = []\n", " entities_nomatch = {}\n", " ent_1 = []\n", " ent_gt_matched = []\n", " ent_pp_matched = []\n", " for line_gt, ent_gt in line_dict_GT.items():\n", " for ent1 in ent_gt:\n", " for line_pp, ent_pp in line_dict_PP.items():\n", " for ent2 in ent_pp:\n", " # print(len(ent1),len(ent2))\n", " if (\n", " check_page_match(ent1, ent2) == True\n", " and BBoxOverlap(ent1, ent2) == True\n", " and ((fuzz.ratio(ent1.mention_text, ent2.mention_text)) / 100)\n", " > 0.8\n", " ):\n", " gt = {str(line_gt) + \"_GT\": [ent1]}\n", " pp = {str(line_pp) + \"_PP\": [ent2]}\n", " entities_match.append([gt, pp])\n", " ent_gt_matched.append(ent1)\n", " ent_pp_matched.append(ent2)\n", " # print(ent1['id'],ent2['id'])\n", " # ent_1.append({line_gt:ent1['id'],line_pp:ent2['id']})\n", " df_merge = pd.DataFrame(\n", " columns=[\n", " \"line_GT\",\n", " \"line_PP\",\n", " \"id_GT\",\n", " \"id_PP\",\n", " \"type_GT\",\n", " \"type_PP\",\n", " \"mentionText_GT\",\n", " \"mentionText_PP\",\n", " ]\n", " )\n", " for item in entities_match:\n", " for entity in item:\n", " for line, ent in entity.items():\n", " if \"_GT\" in line:\n", " line_GT = line\n", " id_GT = ent[0].id\n", " type_GT = ent[0].type\n", " mentionText_GT = ent[0].mention_text\n", " elif \"_PP\" in line:\n", " line_PP = line\n", " try:\n", " e_id = ent[0].id\n", " except:\n", " e_id = \"\"\n", " id_PP = e_id\n", " type_PP = ent[0].type\n", " mentionText_PP = ent[0].mention_text\n", " # print(line_GT)\n", " temp_df = pd.DataFrame(\n", " [\n", " [\n", " line_GT,\n", " line_PP,\n", " id_GT,\n", " id_PP,\n", " type_GT,\n", " type_PP,\n", " mentionText_GT,\n", " mentionText_PP,\n", " ]\n", " ],\n", " columns=df_merge.columns,\n", " )\n", " df_merge = pd.concat([df_merge, temp_df])\n", " left_over_GT = []\n", " for ent11 in sub_items_GT:\n", " if ent11 not in ent_gt_matched:\n", " left_over_GT.append(ent11)\n", "\n", " for line_gt, ent_gt in line_dict_GT.items():\n", " for ent1 in ent_gt:\n", " for ent2 in left_over_GT:\n", " if ent1 == ent2:\n", " line_GT = str(line_gt) + \"_GT\"\n", " line_PP = \"_____\"\n", " id_GT = ent1.id\n", " id_PP = \"_____\"\n", " type_GT = ent1.type\n", " type_PP = \"_____\"\n", " mentionText_GT = ent1.mention_text\n", " mentionText_PP = \"_____\"\n", " temp_df = pd.DataFrame(\n", " [\n", " [\n", " line_GT,\n", " line_PP,\n", " id_GT,\n", " id_PP,\n", " type_GT,\n", " type_PP,\n", " mentionText_GT,\n", " mentionText_PP,\n", " ]\n", " ],\n", " columns=df_merge.columns,\n", " )\n", " df_merge = pd.concat([df_merge, temp_df])\n", " left_over_pp = []\n", " for ent11 in sub_items_pp:\n", " if ent11 not in ent_pp_matched:\n", " left_over_pp.append(ent11)\n", "\n", " for line_pp, ent_pp in line_dict_PP.items():\n", " for ent1 in ent_pp:\n", " for ent2 in left_over_pp:\n", " if ent1 == ent2:\n", " line_GT = \"_____\"\n", " line_PP = str(line_pp) + \"_PP\"\n", " id_GT = \"_____\"\n", " try:\n", " en_id = ent1.id\n", " except:\n", " en_id = \"\"\n", " id_PP = en_id\n", " type_GT = \"_____\"\n", " type_PP = ent1.type\n", " mentionText_GT = \"_____\"\n", " mentionText_PP = ent1.mention_text\n", " temp_df = pd.DataFrame(\n", " [\n", " [\n", " line_GT,\n", " line_PP,\n", " id_GT,\n", " id_PP,\n", " type_GT,\n", " type_PP,\n", " mentionText_GT,\n", " mentionText_PP,\n", " ]\n", " ],\n", " columns=df_merge.columns,\n", " )\n", " df_merge = pd.concat([df_merge, temp_df])\n", "\n", " match = []\n", " for l1 in entities_match:\n", " k = []\n", " for item1 in l1:\n", " for lin1, en1 in item1.items():\n", " k.append(lin1)\n", " match.append(k)\n", "\n", " counts = {}\n", " for item in match:\n", " key = tuple(item)\n", " counts[key] = counts.get(key, 0) + 1\n", "\n", " line_change = {}\n", " for line_match, count in counts.items():\n", " l1 = int(line_match[0].split(\"_\")[0])\n", " l2 = int(line_match[1].split(\"_\")[0])\n", " # print(l1,l2)\n", "\n", " if line_match[0] in line_change.keys():\n", " if line_change[line_match[0]][\"count\"] >= count:\n", " pass\n", " else:\n", " line_change[line_match[0]] = {\"pp\": line_match[1], \"count\": count}\n", " else:\n", " line_change[line_match[0]] = {\"pp\": line_match[1], \"count\": count}\n", "\n", " for gt_line, value in line_change.items():\n", " df_merge[\"line_PP\"] = df_merge[\"line_PP\"].replace(value[\"pp\"], gt_line)\n", "\n", " def check_match(row):\n", " if (\n", " row[\"line_GT\"] == row[\"line_PP\"]\n", " and row[\"line_GT\"] != \"_____\"\n", " and row[\"line_PP\"] != \"_____\"\n", " ):\n", " return \"TP\"\n", " elif (\n", " row[\"line_GT\"] != row[\"line_PP\"]\n", " and row[\"line_GT\"] != \"_____\"\n", " and row[\"line_PP\"] != \"_____\"\n", " ):\n", " return \"FP\"\n", " elif row[\"line_GT\"] == \"_____\":\n", " return \"FN\"\n", " elif row[\"line_PP\"] == \"_____\":\n", " return \"FN\"\n", "\n", " # Add a new column 'Match' indicating the match\n", " df_merge[\"Match\"] = df_merge.apply(check_match, axis=1)\n", "\n", " return df_merge" ] }, { "cell_type": "markdown", "id": "0f9359e6-c57c-4f14-b4f0-002317e0f1be", "metadata": {}, "source": [ "Running the script to perform the Line-Item comparison using the defined functions.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "6c18a905-f083-40cc-9b7d-8bec73783dc1", "metadata": {}, "outputs": [], "source": [ "GT_list, GT_file_dict = utilities.file_names(gt_jsons_uri)\n", "parsed_list, parsed_file_dict = utilities.file_names(parsed_jsons_uri)\n", "post_processed_list, post_processed_file_dict = utilities.file_names(\n", " post_processed_jsons_uri\n", ")\n", "\n", "GT_bucket = gt_jsons_uri.split(\"/\")[2]\n", "parsed_bucket = parsed_jsons_uri.split(\"/\")[2]\n", "post_processed_json_bucket = post_processed_jsons_uri.split(\"/\")[2]\n", "\n", "from fuzzywuzzy import fuzz\n", "\n", "df_compare_all_files = pd.DataFrame()\n", "df_compare_accuracy = pd.DataFrame()\n", "for GT_file, GT_file_path in GT_file_dict.items():\n", " # print(GT_file,\" : \",GT_file_path)\n", " # print(GT_bucket)\n", " doc_gt = utilities.documentai_json_proto_downloader(\n", " GT_bucket, GT_file_dict[GT_file]\n", " )\n", " doc_parser = utilities.documentai_json_proto_downloader(\n", " parsed_bucket, GT_file_dict[GT_file]\n", " )\n", " doc_pp = utilities.documentai_json_proto_downloader(\n", " post_processed_json_bucket, GT_file_dict[GT_file]\n", " )\n", "\n", " # break\n", " df_compare_gt_pp = get_comparision_dataframe(doc_gt, doc_pp)\n", " df_compare_gt_parser = get_comparision_dataframe(doc_gt, doc_parser)\n", " file_accuracy_pp = (df_compare_gt_pp[\"Match\"].value_counts().get(\"TP\", 0)) / (\n", " (df_compare_gt_pp[\"Match\"].value_counts().get(\"TP\", 0))\n", " + (df_compare_gt_pp[\"Match\"].value_counts().get(\"FP\", 0))\n", " + (df_compare_gt_pp[\"Match\"].value_counts().get(\"FN\", 0))\n", " )\n", " temp_df_compare_gt_pp = pd.DataFrame(\n", " [\n", " [\n", " GT_file,\n", " \"-\",\n", " \"Accuracy\",\n", " \"-\",\n", " \"GT\",\n", " \"/\",\n", " \"post-processed\",\n", " \"-\",\n", " round(file_accuracy_pp, 3),\n", " ]\n", " ],\n", " columns=df_compare_gt_pp.columns,\n", " )\n", " df_compare_gt_pp = pd.concat([df_compare_gt_pp, temp_df_compare_gt_pp])\n", " file_accuracy_p = (df_compare_gt_parser[\"Match\"].value_counts().get(\"TP\", 0)) / (\n", " (df_compare_gt_parser[\"Match\"].value_counts().get(\"TP\", 0))\n", " + (df_compare_gt_parser[\"Match\"].value_counts().get(\"FP\", 0))\n", " + (df_compare_gt_parser[\"Match\"].value_counts().get(\"FN\", 0))\n", " )\n", " temp_df_compare_gt_parser = pd.DataFrame(\n", " [\n", " [\n", " GT_file,\n", " \"-\",\n", " \"Accuracy\",\n", " \"-\",\n", " \"GT\",\n", " \"/\",\n", " \"parsed\",\n", " \"-\",\n", " round(file_accuracy_p, 3),\n", " ]\n", " ],\n", " columns=df_compare_gt_parser.columns,\n", " )\n", " df_compare_gt_parser = pd.concat([df_compare_gt_parser, temp_df_compare_gt_parser])\n", " frames = [df_compare_all_files, df_compare_gt_pp, df_compare_gt_parser]\n", " df_compare_all_files = pd.concat(frames)\n", " # break\n", "df_compare_all_files.to_csv(\"compare_all.csv\")" ] }, { "cell_type": "code", "execution_count": null, "id": "d45a3328-3008-4291-a5fb-50d4a1173db1", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "a09d566b-7c76-4f97-a881-d07e8fff4595", "metadata": {}, "source": [ "# 4. Output" ] }, { "cell_type": "markdown", "id": "0b81dc64-dcd1-44d4-9c05-e5eae7419675", "metadata": {}, "source": [ "This gives the CSV file which shows the difference between Ground truth, parsed files and post processed files\n", "<img src=\"./images/sample_output.png\" width=800 height=400></img>\n", "\n", "\n", "- **GT (Ground Truth):** Represents the actual, original data or information.\n", "- **PP (Post Processed/Processed):** Refers to the data after it has undergone processing or post-processing.\n", "- **line-GT and line-PP:** These are specific line item numbers used to compare whether they are assigned to the same line items.\n", "- **Match:** Indicates whether the line items are correctly assigned.\n", " - **TP (True Positive):** When a match is found, indicating correct assignment.\n", " - **FP (False Positive):** When there is no match, indicating incorrect assignment.\n", " - **FN (False Negative):** Considered when a Ground Truth or Processed/Post Processed child item is missing, indicating a missing or overlooked item." ] } ], "metadata": { "environment": { "kernel": "python3", "name": "common-cpu.m104", "type": "gcloud", "uri": "gcr.io/deeplearning-platform-release/base-cpu:m104" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.12" } }, "nbformat": 4, "nbformat_minor": 5 }