classify-split-extract-workflow/classify-job/config.py (210 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module handles the configuration for the workflow.
It includes functionality for loading configurations from Google Cloud Storage (GCS),
retrieving specific configuration elements.
"""
import datetime
import json
import os
from typing import Any, cast, Dict, Optional, Tuple
import google.auth
from google.cloud import run_v2
from google.cloud import storage
from logging_handler import Logger
# pylint: disable=logging-fstring-interpolation,import-error,global-statement
logger = Logger.get_logger(__file__)
# Environment variables and default settings
PROJECT_ID = os.environ.get("PROJECT_ID") or google.auth.default()[1]
CLASSIFY_INPUT_BUCKET = os.environ.get(
"CLASSIFY_INPUT_BUCKET", f"{PROJECT_ID}-documents"
)
CLASSIFY_OUTPUT_BUCKET = os.environ.get(
"CLASSIFY_OUTPUT_BUCKET", f"{PROJECT_ID}-workflow"
)
INPUT_FILE = os.environ.get("INPUT_FILE")
GOOGLE_APPLICATION_CREDENTIALS = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
OUTPUT_FILE_JSON = os.environ.get("OUTPUT_FILE_JSON", "classify_output.json")
OUTPUT_FILE_CSV = os.environ.get("OUTPUT_FILE_CSV", "classify_output.csv")
CALL_BACK_URL = os.environ.get("CALL_BACK_URL")
BQ_DATASET_ID_PROCESSED_DOCS = os.environ.get(
"BQ_DATASET_ID_PROCESSED_DOCS", "processed_documents"
)
BQ_OBJECT_TABLE_RETENTION_DAYS = int(
os.environ.get("BQ_OBJECT_TABLE_RETENTION_DAYS", 7)
)
BQ_DATASET_ID_MLOPS = os.environ.get("BQ_DATASET_ID_MLOPS", "mlops")
BQ_PROJECT_ID = os.environ.get("BQ_PROJECT_ID", PROJECT_ID)
BQ_REGION = os.environ.get("BQ_REGION", "us")
BQ_GCS_CONNECTION_NAME = os.environ.get("BQ_GCS_CONNECTION_NAME", "bq-connection-gcs")
START_PIPELINE_FILENAME = "START_PIPELINE"
CLASSIFIER = "classifier"
DOCAI_OUTPUT_BUCKET = os.environ.get(
"DOCAI_OUTPUT_BUCKET", f"{PROJECT_ID}-docai-output"
)
CONFIG_BUCKET = os.environ.get("CONFIG_BUCKET", f"{PROJECT_ID}-config")
CONFIG_FILE_NAME = "config.json"
CLASSIFICATION_UNDETECTABLE = "unclassified"
CLOUD_RUN_EXECUTION = os.environ.get("CLOUD_RUN_EXECUTION")
REGION = os.environ.get("REGION")
SPLITTER_OUTPUT_DIR = os.environ.get("SPLITTER_OUTPUT_DIR", "splitter_output")
PDF_EXTENSION = ".pdf"
PDF_MIME_TYPE = "application/pdf"
MIME_TYPES = [
PDF_MIME_TYPE,
]
OTHER_MIME_TYPES_TO_SUPPORT = [
"image/gif",
"image/tiff",
"image/jpeg",
"image/png",
"image/bmp",
"image/webp",
]
NO_CLASSIFIER_LABEL = "No Classifier"
METADATA_CONFIDENCE = "confidence"
METADATA_DOCUMENT_TYPE = "type"
CONFIG_JSON_DOCUMENT_TYPES_CONFIG = "document_types_config"
FULL_JOB_NAME = run_v2.ExecutionsClient.job_path(PROJECT_ID, REGION, "classify-job")
# Global variables
BUCKET: Optional[storage.Bucket] = None
LAST_MODIFIED_TIME_OF_CONFIG = datetime.datetime.now()
CONFIG_DATA: Optional[Dict[Any, Any]] = None
logger.info(
f"Settings used: CLASSIFY_INPUT_BUCKET=gs://{CLASSIFY_INPUT_BUCKET}, INPUT_FILE={INPUT_FILE}, "
f"CLASSIFY_OUTPUT_BUCKET=gs://{CLASSIFY_OUTPUT_BUCKET}, OUTPUT_FILE_JSON={OUTPUT_FILE_JSON}, "
f"OUTPUT_FILE_CSV={OUTPUT_FILE_CSV}, CALL_BACK_URL={CALL_BACK_URL}, "
f"BQ_DATASET_ID_PROCESSED_DOCS={BQ_DATASET_ID_PROCESSED_DOCS}, "
f"BQ_DATASET_ID_MLOPS={BQ_DATASET_ID_MLOPS}, "
f"BQ_PROJECT_ID={BQ_PROJECT_ID}, BQ_GCS_CONNECTION_NAME={BQ_GCS_CONNECTION_NAME}, "
f"DOCAI_OUTPUT_BUCKET={DOCAI_OUTPUT_BUCKET}"
)
def init_bucket(bucket_name: str) -> Optional[storage.Bucket]:
"""
Initializes the GCS bucket.
Args:
bucket_name (str): The name of the bucket.
"""
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
if not bucket.exists():
logger.error(f"Bucket does not exist: gs://{bucket_name}")
return None # Return None to indicate failure
return bucket
def get_config(
config_name: Optional[str] = None, element_path: Optional[str] = None
) -> Optional[Dict[Any, Any]]:
"""
Retrieves the configuration data.
Args:
config_name (Optional[str]): The configuration name.
element_path (Optional[str]): The element path.
Returns:
Optional[Dict[Any, Any]]: The configuration data.
"""
global CONFIG_DATA
if not CONFIG_DATA:
CONFIG_DATA = load_config(CONFIG_BUCKET, CONFIG_FILE_NAME)
assert CONFIG_DATA, "Unable to load configuration data"
config_data_loaded = (
CONFIG_DATA.get(config_name, {}) if config_name else CONFIG_DATA
)
if element_path:
keys = element_path.split(".")
for key in keys:
if isinstance(config_data_loaded, dict):
config_data_loaded_new = config_data_loaded.get(key)
if config_data_loaded_new is None:
logger.error(
f"Key '{key}' not present in the "
f"configuration {json.dumps(config_data_loaded, indent=4)}"
)
return None
config_data_loaded = config_data_loaded_new
else:
logger.error(
f"Expected a dictionary at '{key}' but found a "
f"{type(config_data_loaded).__name__}"
)
return None
return config_data_loaded
def get_parser_name_by_doc_type(doc_type: str) -> Optional[str]:
"""Retrieves the parser name based on the document type.
Args:
doc_type (str): The document type.
Returns:
Optional[str]: The parser name, or None if not found.
"""
return cast(
Optional[str],
get_config(CONFIG_JSON_DOCUMENT_TYPES_CONFIG, f"{doc_type}.parser"),
)
def get_document_types_config() -> Optional[Dict[Any, Any]]:
"""
Retrieves the document types configuration.
Returns:
Optional[Dict[Any, Any]]: The document types configuration.
"""
return get_config(CONFIG_JSON_DOCUMENT_TYPES_CONFIG)
def get_parser_by_doc_type(doc_type: str) -> Optional[Dict[Any, Any]]:
"""
Retrieves the parser by document type.
Args:
doc_type (str): The document type.
Returns:
Optional[Dict[Any, Any]]: The parser configuration.
"""
parser_name = get_parser_name_by_doc_type(doc_type)
if parser_name:
return get_config("parser_config", parser_name)
return None
def load_config(bucket_name: str, filename: str) -> Optional[Dict[Any, Any]]:
"""
Loads the configuration data from a GCS bucket or local file.
Args:
bucket_name (str): The GCS bucket name.
filename (str): The configuration file name.
Returns:
Optional[Dict[Any, Any]]: The configuration data.
"""
global BUCKET, CONFIG_DATA, LAST_MODIFIED_TIME_OF_CONFIG
if not BUCKET:
BUCKET = init_bucket(bucket_name)
if not BUCKET:
return None
blob = BUCKET.get_blob(filename)
if not blob:
logger.error(f"Error: file does not exist gs://{bucket_name}/{filename}")
return None
last_modified_time = blob.updated
if LAST_MODIFIED_TIME_OF_CONFIG == last_modified_time:
return CONFIG_DATA
logger.info(f"Reloading config from: {filename}")
try:
CONFIG_DATA = json.loads(blob.download_as_text(encoding="utf-8"))
LAST_MODIFIED_TIME_OF_CONFIG = last_modified_time
except (json.JSONDecodeError, OSError) as e:
logger.error(
f"Error while obtaining file from GCS gs://{bucket_name}/{filename}: {e}"
)
logger.warning(f"Using local {filename}")
try:
with open(
os.path.join(os.path.dirname(__file__), "config", filename),
encoding="utf-8",
) as json_file:
CONFIG_DATA = json.load(json_file)
except (FileNotFoundError, json.JSONDecodeError) as exc:
logger.error(f"Error loading local config file {filename}: {exc}")
return None
return CONFIG_DATA
def get_docai_settings() -> Optional[Dict[Any, Any]]:
"""
Retrieves the Document AI settings configuration.
Returns:
Optional[Dict[Any, Any]]: The Document AI settings configuration.
"""
return get_config("settings_config")
def get_classification_confidence_threshold() -> float:
"""
Retrieves the classification confidence threshold.
Returns:
float: The classification confidence threshold.
"""
settings = get_docai_settings()
return (
float(settings.get("classification_confidence_threshold", 0)) if settings else 0
)
def get_classification_default_class() -> str:
"""
Retrieves the default classification class.
Returns:
str: The default classification class.
"""
settings = get_docai_settings()
classification_default_class = (
settings.get("classification_default_class", CLASSIFICATION_UNDETECTABLE)
if settings
else CLASSIFICATION_UNDETECTABLE
)
parser = get_parser_by_doc_type(classification_default_class)
if parser:
return classification_default_class
logger.warning(
f"Classification default label {classification_default_class}"
f" is not a valid Label or missing a corresponding "
f"parser in parser_config"
)
return CLASSIFICATION_UNDETECTABLE
def get_document_class_by_classifier_label(label_name: str) -> Optional[str]:
"""
Retrieves the document class by classifier label.
Args:
label_name (str): The classifier label name.
Returns:
Optional[str]: The document class.
"""
doc_types_config = get_document_types_config()
if doc_types_config:
for k, v in doc_types_config.items():
if v.get("classifier_label") == label_name:
return k
logger.error(
f"classifier_label={label_name} is not assigned to any document in the config"
)
return None
def get_parser_by_name(parser_name: str) -> Optional[Dict[Any, Any]]:
"""
Retrieves the parser configuration by parser name.
Args:
parser_name (str): The parser name.
Returns:
Optional[Dict[Any, Any]]: The parser configuration.
"""
return get_config("parser_config", parser_name)
def get_model_name_table_name(
document_type: str,
) -> Tuple[Optional[str], Optional[str]]:
"""
Retrieves the output table name and model name by document type.
Args:
document_type (str): The document type.
Returns:
Tuple[Optional[str], Optional[str]]: The output table name and model name.
"""
parser_name = get_parser_name_by_doc_type(document_type)
if parser_name:
parser = get_parser_by_doc_type(document_type)
model_name = (
f"{BQ_PROJECT_ID}.{BQ_DATASET_ID_MLOPS}."
f"{parser.get('model_name', parser_name.upper() + '_MODEL')}"
if parser
else None
)
out_table_name = (
f"{BQ_PROJECT_ID}.{BQ_DATASET_ID_PROCESSED_DOCS}."
f"{parser.get('out_table_name', parser_name.upper() + '_DOCUMENTS')}"
if parser
else None
)
else:
logger.warning(f"No parser found for document type {document_type}")
return None, None
logger.info(f"model_name={model_name}, out_table_name={out_table_name}")
return model_name, out_table_name