uptraining_docai_processor_using_python/docai_uptraining.ipynb (1,009 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"id": "c2c9060d",
"metadata": {},
"source": [
"# Document AI Processor Uptraining using Python\n",
"\n",
"<table align=\"left\">\n",
"\n",
" <td>\n",
" <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/document-ai-samples/blob/main/uptraining_docai_processor_using_python/docai-uptraining.ipynb\">\n",
" <img src=\"https://cloud.google.com/ml-engine/images/colab-logo-32px.png\" alt=\"Colab logo\"> Run in Colab\n",
" </a>\n",
" </td>\n",
"\n",
"</table>\n"
]
},
{
"cell_type": "markdown",
"id": "KQafCrzUnS0s",
"metadata": {
"id": "KQafCrzUnS0s"
},
"source": [
"# Overview\n",
"\n",
"[Document AI](https://cloud.google.com/document-ai/docs) is a document understanding solution that takes unstructured data (e.g. documents, emails, invoices, forms, etc.) and makes the data easier to understand, analyze, and consume. The API provides structure through content classification, entity extraction, advanced searching, and more. With [Uptraining](https://cloud.google.com/document-ai/docs/workbench/uptrain-processor), you can achieve higher document processing accuracy by providing additional labeled examples for Specialized Document Types and creating a new model version.\n",
"\n",
"In this notebook, you will create an Invoice Parser processor, configure the processor for uptraining, label example documents(optional), and uptrain the processor.\n",
"\n",
"The document dataset used in this lab consists of randomly-generated invoices for a fictional piping company.\n",
"\n",
"Note: This notebook is a python version of the exisiting [Qwiklab](https://www.cloudskillsboost.google/focuses/67858?parent=catalog).\n",
"\n",
"<hr/>"
]
},
{
"cell_type": "markdown",
"id": "ctREoch1N05a",
"metadata": {
"id": "ctREoch1N05a"
},
"source": [
"## User Authentication"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "A1imEFqlNzxj",
"metadata": {
"id": "A1imEFqlNzxj"
},
"outputs": [],
"source": [
"from google.colab import auth as google_auth\n",
"\n",
"google_auth.authenticate_user()"
]
},
{
"cell_type": "markdown",
"id": "PQxzX50EN_VS",
"metadata": {
"id": "PQxzX50EN_VS"
},
"source": [
"## Install Dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "25T_i43nN-T8",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "25T_i43nN-T8",
"outputId": "cf8b11bd-42d0-419d-d4fe-149b151c1fb2"
},
"outputs": [],
"source": [
"!pip install google-cloud-documentai google-cloud-storage -q"
]
},
{
"cell_type": "markdown",
"id": "759N6wBgOCTg",
"metadata": {
"id": "759N6wBgOCTg"
},
"source": [
"## Restart the runtime"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "THlm8i-mN-a9",
"metadata": {
"id": "THlm8i-mN-a9"
},
"outputs": [],
"source": [
"%%capture\n",
"\n",
"import IPython\n",
"\n",
"app = IPython.Application.instance()\n",
"app.kernel.do_shutdown(True)"
]
},
{
"cell_type": "markdown",
"id": "e35c6d50-691f-4a47-a698-15270b3d75d1",
"metadata": {
"id": "e35c6d50-691f-4a47-a698-15270b3d75d1"
},
"source": [
"## Import libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "219793b9-6e0a-4956-9d92-8965c8079374",
"metadata": {
"id": "219793b9-6e0a-4956-9d92-8965c8079374"
},
"outputs": [],
"source": [
"from google.cloud import documentai_v1beta3 as documentai\n",
"from google.longrunning.operations_pb2 import GetOperationRequest\n",
"from google.api_core.client_options import ClientOptions\n",
"import google.auth.transport.requests\n",
"from google import auth\n",
"from google.cloud import storage\n",
"\n",
"import requests\n",
"import re\n",
"import time\n",
"from time import sleep\n",
"import json\n",
"from tqdm.auto import tqdm"
]
},
{
"cell_type": "markdown",
"id": "8df9586d-f656-437d-bd70-7b1168ee9363",
"metadata": {
"id": "8df9586d-f656-437d-bd70-7b1168ee9363"
},
"source": [
"## Initialize variables"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "094a148d-fb5b-4b0b-ba10-b400d89e8e63",
"metadata": {
"id": "094a148d-fb5b-4b0b-ba10-b400d89e8e63"
},
"outputs": [],
"source": [
"project_id = \"<YOUR PROJECT ID>\" # @param {type:\"string\"}\n",
"location = \"us\" # @param {type:\"string\"}\n",
"processor_type = \"INVOICE_PROCESSOR\" # @param {type:\"string\"}\n",
"\n",
"# Processor display name\n",
"processor_display_name = \"<DISPLAY NAME eg. invoice-test>\" # @param {type:\"string\"}\n",
"\n",
"# GCS bucket path, to store the data\n",
"dataset_gcs_uri = \"<GCS BUCKET URI eg. gs://invoice-api-test>\" # @param {type:\"string\"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "tcax_uwpfgBI",
"metadata": {
"id": "tcax_uwpfgBI"
},
"outputs": [],
"source": [
"!gcloud auth application-default login"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "vWsy-W1GbgOm",
"metadata": {
"id": "vWsy-W1GbgOm"
},
"outputs": [],
"source": [
"!gcloud config set project $project_id"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ntQPSoNgV0cb",
"metadata": {
"id": "ntQPSoNgV0cb"
},
"outputs": [],
"source": [
"!gcloud auth application-default set-quota-project $project_id"
]
},
{
"cell_type": "markdown",
"id": "691a3a32-a0d9-464e-9fe5-3a4d40696288",
"metadata": {
"id": "691a3a32-a0d9-464e-9fe5-3a4d40696288"
},
"source": [
"## Create Processor"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b8860f51-a168-4567-9754-f3c45b34df75",
"metadata": {
"id": "b8860f51-a168-4567-9754-f3c45b34df75"
},
"outputs": [],
"source": [
"def create_processor(project_id, location, processor_type, processor_display_name):\n",
" \"\"\"\n",
" Function creates a Document AI processor,\n",
" based on the provided processor type.\n",
" \"\"\"\n",
" # Create a client\n",
" client = documentai.DocumentProcessorServiceClient()\n",
"\n",
" parent = client.common_location_path(project_id, location)\n",
"\n",
" processor = documentai.Processor(\n",
" type_=processor_type, display_name=processor_display_name\n",
" )\n",
" # Initialize request argument(s)\n",
" request = documentai.CreateProcessorRequest(parent=parent, processor=processor)\n",
"\n",
" # Make the request\n",
" response = client.create_processor(request=request)\n",
"\n",
" # Handle the response\n",
" # print(response)\n",
"\n",
" return response"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "TRPOmjRJWdb5",
"metadata": {
"id": "TRPOmjRJWdb5"
},
"outputs": [],
"source": [
"response = create_processor(\n",
" project_id, location, processor_type, processor_display_name\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4142d47d-1977-4176-8b90-64b4dc6de220",
"metadata": {
"id": "4142d47d-1977-4176-8b90-64b4dc6de220"
},
"outputs": [],
"source": [
"# Get processor_resource_name\n",
"processor_name = response.name\n",
"processor_name"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "pZGzaUPTYAPP",
"metadata": {
"id": "pZGzaUPTYAPP"
},
"outputs": [],
"source": [
"# Get default processor version, it'll be used as a base for uptraining.\n",
"base_version = response.default_processor_version\n",
"base_version"
]
},
{
"cell_type": "markdown",
"id": "ba9692c4-c0d3-457b-b54d-ac6c360ec596",
"metadata": {
"id": "ba9692c4-c0d3-457b-b54d-ac6c360ec596"
},
"source": [
"## Create Dataset\n",
"\n",
"In order to train your processor, you will have to create a dataset with training and testing data to help the processor identify the entities you want to extract."
]
},
{
"cell_type": "markdown",
"id": "W43BbahYrCWQ",
"metadata": {
"id": "W43BbahYrCWQ"
},
"source": [
"You will need to create a new bucket in Cloud Storage to store the dataset.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "25b1c890-d3e8-47ed-a1ad-47e3536f5c39",
"metadata": {
"id": "25b1c890-d3e8-47ed-a1ad-47e3536f5c39"
},
"outputs": [],
"source": [
"def create_dataset_bucket(project_id, dataset_gcs_uri):\n",
" \"\"\"\n",
" Function to create a GCS bucket,\n",
" if it does not exist.\n",
" \"\"\"\n",
" client = storage.Client(project=project_id)\n",
" bucket = client.bucket(dataset_gcs_uri.split(\"//\")[1])\n",
" if not bucket.exists():\n",
" tqdm.write(f\"Creating bucket {bucket.name}\")\n",
" client.create_bucket(bucket)\n",
" tqdm.write(f\"Bucket {bucket.name} created\")\n",
" else:\n",
" tqdm.write(f\"Bucket {bucket.name} already exists\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3koHSBqugEN9",
"metadata": {
"id": "3koHSBqugEN9"
},
"outputs": [],
"source": [
"# create dataset_gcs_uri bucket if not exists\n",
"create_dataset_bucket(project_id, dataset_gcs_uri)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "mHfLbY0TZLaK",
"metadata": {
"id": "mHfLbY0TZLaK"
},
"outputs": [],
"source": [
"def poll_operation(operation_name, location):\n",
" \"\"\"\n",
" Function to check status of long running operations.\n",
" \"\"\"\n",
" # You must set the `api_endpoint` if you use a location other than \"us\".\n",
" opts = ClientOptions(api_endpoint=f\"{location}-documentai.googleapis.com\")\n",
" client = documentai.DocumentProcessorServiceClient(client_options=opts)\n",
"\n",
" request = GetOperationRequest(name=operation_name)\n",
"\n",
" while True:\n",
" # Make GetOperation request\n",
" operation = client.get_operation(request=request)\n",
"\n",
" # Stop polling when Operation is no longer running\n",
" if operation.done:\n",
" break\n",
"\n",
" tqdm.write(\".\", end=\"\")\n",
" # Wait 10 seconds before polling again\n",
" sleep(10)\n",
"\n",
" tqdm.write(\"\")\n",
" return operation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b7362243-06a5-4032-9da2-c5f5be751010",
"metadata": {
"id": "b7362243-06a5-4032-9da2-c5f5be751010"
},
"outputs": [],
"source": [
"def add_processor_dataset(processor_name, dataset_gcs_uri, project_id, location):\n",
" \"\"\"\n",
" Function to add dataset information to a processor.\n",
" \"\"\"\n",
" # Create a client\n",
" client = documentai.DocumentServiceClient()\n",
"\n",
" # User managed dataset, for Document AI service manage dataset refer\n",
" # https://cloud.google.com/python/docs/reference/documentai/latest/google.cloud.documentai_v1beta3.types.Dataset\n",
" gcs_managed_config = documentai.Dataset.GCSManagedConfig(\n",
" gcs_prefix=documentai.GcsPrefix(gcs_uri_prefix=dataset_gcs_uri)\n",
" )\n",
"\n",
" spanner_indexing_config = documentai.Dataset.SpannerIndexingConfig()\n",
"\n",
" # Initialize request argument(s)\n",
" dataset = documentai.Dataset(\n",
" name=client.dataset_path(project_id, location, processor_name.split(\"/\")[-1]),\n",
" gcs_managed_config=gcs_managed_config,\n",
" spanner_indexing_config=spanner_indexing_config,\n",
" )\n",
"\n",
" request = documentai.UpdateDatasetRequest(dataset=dataset)\n",
"\n",
" # Make the request\n",
" operation = client.update_dataset(request=request)\n",
"\n",
" response = operation.result()\n",
"\n",
" # Handle the response\n",
" # print(response)\n",
"\n",
" return response"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "oNzer_Wmf7oz",
"metadata": {
"id": "oNzer_Wmf7oz"
},
"outputs": [],
"source": [
"response = add_processor_dataset(processor_name, dataset_gcs_uri, project_id, location)"
]
},
{
"cell_type": "markdown",
"id": "029498b2-38e0-4d8a-b366-8710a5c762eb",
"metadata": {
"id": "029498b2-38e0-4d8a-b366-8710a5c762eb"
},
"source": [
"## Import Documents"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c3d2218b-20e2-4a84-9d99-df00bb6e8c4f",
"metadata": {
"id": "c3d2218b-20e2-4a84-9d99-df00bb6e8c4f"
},
"outputs": [],
"source": [
"def import_documents(processor_name, location, gcs_uri_prefix, train_split):\n",
" \"\"\"\n",
" Function to import documents to a processor,\n",
" provided dataset uri and train_split ratio to split\n",
" the data into train and test.\n",
" \"\"\"\n",
" # Create a client\n",
" client = documentai.DocumentServiceClient()\n",
"\n",
" # Initialize request argument(s)\n",
" batch_input_config = documentai.BatchDocumentsInputConfig(\n",
" gcs_prefix=documentai.GcsPrefix(gcs_uri_prefix=gcs_uri_prefix)\n",
" )\n",
"\n",
" batch_documents_import_configs = (\n",
" documentai.ImportDocumentsRequest.BatchDocumentsImportConfig(\n",
" batch_input_config=batch_input_config\n",
" )\n",
" )\n",
"\n",
" if not isinstance(train_split, float):\n",
" batch_documents_import_configs.dataset_split = train_split\n",
" else:\n",
" batch_documents_import_configs.auto_split_config.training_split_ratio = (\n",
" train_split\n",
" )\n",
"\n",
" dataset = client.dataset_path(project_id, location, processor_name.split(\"/\")[-1])\n",
" request = documentai.ImportDocumentsRequest(\n",
" dataset=dataset,\n",
" batch_documents_import_configs=[batch_documents_import_configs],\n",
" )\n",
"\n",
" # Make the request\n",
" operation = client.import_documents(request=request)\n",
"\n",
" print(\"Waiting for operation to complete...\")\n",
" operation = poll_operation(operation.operation.name, location)\n",
"\n",
" print(\"Documents are imported successfully\")\n",
" return operation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e4e8a751-3ace-4369-8f9a-e8185ad71a30",
"metadata": {
"id": "e4e8a751-3ace-4369-8f9a-e8185ad71a30"
},
"outputs": [],
"source": [
"## Import a sample document\n",
"sample_doc_gcs_uri_prefix = (\n",
" \"gs://cloud-samples-data/documentai/codelabs/uptraining/pdfs\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "VtoxQYHg0J-A",
"metadata": {
"id": "VtoxQYHg0J-A"
},
"outputs": [],
"source": [
"operation = import_documents(\n",
" processor_name,\n",
" location,\n",
" sample_doc_gcs_uri_prefix,\n",
" train_split=documentai.DatasetSplitType.DATASET_SPLIT_UNASSIGNED,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "sopOwbPrrj8p",
"metadata": {
"id": "sopOwbPrrj8p"
},
"source": [
"#### [OPTIONAL] Label the test document\n",
"\n",
"Follow the [instructions](https://www.cloudskillsboost.google/focuses/67858?parent=catalog#step9) to label the sample document in the Docuemnt AI console"
]
},
{
"cell_type": "markdown",
"id": "U9i5TihYsV-k",
"metadata": {
"id": "U9i5TihYsV-k"
},
"source": [
"### Import Pre-Labeled Data\n",
"\n",
"Document AI Uptraining requires a minimum of 10 documents in both the training and test sets, along with 10 instances of each label in each set. It's recommended to have at least 50 documents in each set with 50 instances of each label for best performance. More training data generally equates to higher accuracy.\n",
"\n",
"It will take a long time to manually label 100 documents, so we have some pre-labeled documents that you can import for this lab. You can import pre-labeled document files in the [Document.json](https://cloud.google.com/document-ai/docs/reference/rest/v1/Document) format. These can be results from calling a processor and verifying the accuracy using [Human in the Loop (HITL)](https://cloud.google.com/document-ai/hitl)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "47c2ffee-94e4-48ee-aaaf-e5aca0da470b",
"metadata": {
"id": "47c2ffee-94e4-48ee-aaaf-e5aca0da470b"
},
"outputs": [],
"source": [
"## Import all documents\n",
"docs_gcs_uri_prefix = \"gs://cloud-samples-data/documentai/Custom/Invoices/JSON\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "izA4zqQ31l7L",
"metadata": {
"id": "izA4zqQ31l7L"
},
"outputs": [],
"source": [
"batch_operation = import_documents(\n",
" processor_name, location, docs_gcs_uri_prefix, train_split=0.8\n",
")"
]
},
{
"cell_type": "markdown",
"id": "e0c4148f-d3f0-4d71-b15b-35b39f5eb85d",
"metadata": {
"id": "e0c4148f-d3f0-4d71-b15b-35b39f5eb85d"
},
"source": [
"## Update schema with required fields\n",
"\n",
"The sample documents we are using for this example do not contain every label supported by the Invoice Parser. We will need to mark the labels we are not using as inactive before training. You can also follow similar steps to add a custom label before Uptraining."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8643150f-6e3b-451f-8fcf-1a9e56fa583a",
"metadata": {
"id": "8643150f-6e3b-451f-8fcf-1a9e56fa583a"
},
"outputs": [],
"source": [
"def get_dataset_schema(processor_name):\n",
" \"\"\"\n",
" Function to get the existing processor schema.\n",
" \"\"\"\n",
" # Create a client\n",
" client = documentai.DocumentServiceClient()\n",
"\n",
" # Initialize request argument(s)\n",
" request = documentai.GetDatasetSchemaRequest(\n",
" name=client.dataset_schema_path(\n",
" project_id, location, processor_name.split(\"/\")[-1]\n",
" ),\n",
" )\n",
"\n",
" # Make the request\n",
" response = client.get_dataset_schema(request=request)\n",
"\n",
" return response"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5m6S-GBy2tXW",
"metadata": {
"id": "5m6S-GBy2tXW"
},
"outputs": [],
"source": [
"schema = get_dataset_schema(processor_name)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82bfc33c-9801-406d-9e42-41b065139a7d",
"metadata": {
"id": "82bfc33c-9801-406d-9e42-41b065139a7d"
},
"outputs": [],
"source": [
"## Fields which needs to be enable based on the imported dataset\n",
"enable_fields = [\n",
" \"invoice_date\",\n",
" \"line_item/amount\",\n",
" \"line_item/description\",\n",
" \"line_item/quantity\",\n",
" \"amount\",\n",
" \"description\",\n",
" \"receiver_address\",\n",
" \"receiver_name\",\n",
" \"supplier_address\",\n",
" \"supplier_name\",\n",
" \"total_amount\",\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1ad1b451-d5f6-4487-819c-ed10efbf284a",
"metadata": {
"id": "1ad1b451-d5f6-4487-819c-ed10efbf284a",
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"def update_schema_fields(schema, enable_fields):\n",
" \"\"\"\n",
" Function to update the schema with required fields.\n",
" \"\"\"\n",
" for entity in schema.document_schema.entity_types:\n",
" for prop in entity.properties:\n",
" if prop.name not in enable_fields:\n",
" prop.property_metadata = {\"inactive\": True}\n",
" return schema"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c7d5ff97-7c08-44cc-aae2-5a36bb29b38c",
"metadata": {
"id": "c7d5ff97-7c08-44cc-aae2-5a36bb29b38c"
},
"outputs": [],
"source": [
"def update_dataset_schema(schema):\n",
" \"\"\"\n",
" Function to update the dataset schema,\n",
" with the updated schema fields.\n",
" \"\"\"\n",
" # Create a client\n",
" client = documentai.DocumentServiceClient()\n",
"\n",
" # Initialize request argument(s)\n",
" dataset_schema = documentai.DatasetSchema(\n",
" name=schema.name,\n",
" document_schema=schema.document_schema,\n",
" )\n",
"\n",
" request = documentai.UpdateDatasetSchemaRequest(dataset_schema=dataset_schema)\n",
"\n",
" # Make the request\n",
" response = client.update_dataset_schema(request=request)\n",
"\n",
" # Handle the response\n",
" # print(response)\n",
"\n",
" return response"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d2f989c6-2aaf-4063-a31f-9579a2c6a7e9",
"metadata": {
"id": "d2f989c6-2aaf-4063-a31f-9579a2c6a7e9",
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"schema = update_schema_fields(schema, enable_fields)\n",
"response = update_dataset_schema(schema)"
]
},
{
"cell_type": "markdown",
"id": "wEHkZ78Ks9Hj",
"metadata": {
"id": "wEHkZ78Ks9Hj"
},
"source": [
"## [OPTIONAL] Auto-label newly imported documents\n",
"\n",
"When importing unlabeled documents for a processor with an existing deployed processor version, you can use [Auto-labeling](https://cloud.google.com/document-ai/docs/workbench/label-documents#auto-label) to save time on labeling."
]
},
{
"cell_type": "markdown",
"id": "f8a80228-d34a-426b-bf30-92885e68e4e2",
"metadata": {
"id": "f8a80228-d34a-426b-bf30-92885e68e4e2"
},
"source": [
"## Processor Uptraining"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4b22c1db-8e5a-4a07-b238-cd5d0957c3ab",
"metadata": {
"id": "4b22c1db-8e5a-4a07-b238-cd5d0957c3ab"
},
"outputs": [],
"source": [
"def train_processor_version(processor_name, display_name, base_version):\n",
" \"\"\"\n",
" Function to train the new processor version,\n",
" provided base version, the new version will be built from.\n",
" \"\"\"\n",
" # Create a client\n",
" client = documentai.DocumentProcessorServiceClient()\n",
"\n",
" # Initialize request argument(s)\n",
" processor_version = documentai.ProcessorVersion(\n",
" display_name=display_name,\n",
" )\n",
"\n",
" request = documentai.TrainProcessorVersionRequest(\n",
" parent=processor_name,\n",
" processor_version=processor_version,\n",
" base_processor_version=base_version,\n",
" )\n",
"\n",
" # Make the request\n",
" operation = client.train_processor_version(request=request)\n",
"\n",
" print(\"Training job is triggered\")\n",
" return operation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "sUnFdph3_xcm",
"metadata": {
"id": "sUnFdph3_xcm"
},
"outputs": [],
"source": [
"display_name = \"lab-uptraining-test-1\" # @param {type:\"string\"}\n",
"\n",
"operation = train_processor_version(processor_name, display_name, base_version)\n",
"uptrained_version = operation.metadata.common_metadata.resource\n",
"uptrained_version"
]
},
{
"cell_type": "markdown",
"id": "hKcKNKgmt8Yw",
"metadata": {
"id": "hKcKNKgmt8Yw"
},
"source": [
"#### Note: The training job will take around an hour, so come back later to proceed further."
]
},
{
"cell_type": "markdown",
"id": "23bf5b46-85c4-4390-b510-5c450ca34f52",
"metadata": {
"id": "23bf5b46-85c4-4390-b510-5c450ca34f52"
},
"source": [
"## Get Evaluation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d123a180-2963-4715-8f20-e593da2e3fb1",
"metadata": {
"id": "d123a180-2963-4715-8f20-e593da2e3fb1"
},
"outputs": [],
"source": [
"def get_processor_version_info(processor_name):\n",
" \"\"\"\n",
" Function to get the processor version info.\n",
" \"\"\"\n",
" # Create a client\n",
" client = documentai.DocumentProcessorServiceClient()\n",
"\n",
" # Initialize request argument(s)\n",
" request = documentai.GetProcessorVersionRequest(\n",
" name=processor_name,\n",
" )\n",
"\n",
" # Make the request\n",
" response = client.get_processor_version(request=request)\n",
"\n",
" # Handle the response\n",
" return response"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "s4CHscILtAnf",
"metadata": {
"id": "s4CHscILtAnf"
},
"outputs": [],
"source": [
"def get_f1score(operation, uptrained_version):\n",
" \"\"\"\n",
" Function to get the F1 score of the newly trained model.\n",
" \"\"\"\n",
" op_response = poll_operation(operation.name, location)\n",
" processor_version_info = get_processor_version_info(uptrained_version)\n",
" f1_score = processor_version_info.latest_evaluation.aggregate_metrics.f1_score\n",
" return f1_score"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ab650731-2482-4f6e-870b-d1fc7e158f73",
"metadata": {
"id": "ab650731-2482-4f6e-870b-d1fc7e158f73"
},
"outputs": [],
"source": [
"f1_score = get_f1score(uptrained_version)\n",
"f1_score"
]
},
{
"cell_type": "markdown",
"id": "49d0b14b-0013-4060-b66e-abd3abed00e4",
"metadata": {
"id": "49d0b14b-0013-4060-b66e-abd3abed00e4"
},
"source": [
"## Deploy trained processor\n",
"\n",
"Once the model is trained, you can deploy it to use in the document processing workflow based on the F1 score criteria. If the newly trained model meets our desired accuracy (F1 score), you can deploy the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "01d8d4d7-b1fd-4ac5-9c8d-7a76531c4d54",
"metadata": {
"id": "01d8d4d7-b1fd-4ac5-9c8d-7a76531c4d54"
},
"outputs": [],
"source": [
"def deploy_processor_version(processor_version):\n",
" \"\"\"\n",
" Function to deploy the processor version.\n",
" \"\"\"\n",
" # Create a client\n",
" client = documentai.DocumentProcessorServiceClient()\n",
"\n",
" # Initialize request argument(s)\n",
" request = documentai.DeployProcessorVersionRequest(\n",
" name=processor_version,\n",
" )\n",
"\n",
" # Make the request\n",
" operation = client.deploy_processor_version(request=request)\n",
"\n",
" print(\"Waiting for operation to complete...\")\n",
"\n",
" return operation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b52dc00e-d935-4f70-83dd-a24e991b1102",
"metadata": {
"id": "b52dc00e-d935-4f70-83dd-a24e991b1102"
},
"outputs": [],
"source": [
"# Set threhsold for F1 score\n",
"threshold = 0.8 # @param {type:\"number\"}\n",
"\n",
"# Deploy if F1 score of newly trained model is greater than threshold\n",
"if f1_score >= threshold:\n",
" operation = deploy_processor_version(uptrained_version)\n",
" op_name = operation.operation.name\n",
" op_response = poll_operation(op_name, location)\n",
" print(\"Processor is deployed\")\n",
"else:\n",
" print(\"The F1 score is below threshold\")"
]
},
{
"cell_type": "markdown",
"id": "K-G1QebGu08e",
"metadata": {
"id": "K-G1QebGu08e"
},
"source": [
"## Document Processing\n",
"\n",
"Once the model is deployed, you can use it in the document processing workflow by using the provided sample [code](https://cloud.google.com/document-ai/docs/process-documents-client-libraries#client-libraries-usage-python)."
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"environment": {
"kernel": "python3",
"name": "common-cpu.m108",
"type": "gcloud",
"uri": "gcr.io/deeplearning-platform-release/base-cpu:m108"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}