classify-split-extract-workflow/classify-job/main.py (42 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=logging-fstring-interpolation,broad-exception-caught
""" Main function to run Cloud Run job Classification/Splitting tasks """
import os
import config
from config import CLASSIFIER
from config import FULL_JOB_NAME
from docai_helper import get_processor_and_client
from gcs_helper import get_list_of_uris
from logging_handler import Logger
from split_and_classify import batch_classification
from split_and_classify import handle_no_classifier
from split_and_classify import save_classification_results
from split_and_classify import stream_classification_results
logger = Logger.get_logger(__file__)
def process():
"""Main function for the Classifier/Splitter Cloud Run Job"""
input_bucket = config.CLASSIFY_INPUT_BUCKET
input_file = config.INPUT_FILE
out_bucket_name = None
out_file_name = None
logger.info(f"Processing documents on event gs://{input_bucket}/{input_file} ")
logger.info(f"FULL_JOB_NAME={FULL_JOB_NAME}")
logger.info(f"CLOUD_RUN_EXECUTION={os.getenv('CLOUD_RUN_EXECUTION')}")
try:
processor, dai_client = get_processor_and_client(CLASSIFIER)
f_uris = get_list_of_uris(input_bucket, input_file)
# When classifier/splitter is not setup
if not processor:
logger.info(f"{CLASSIFIER} processor not found in the config.json")
classified_items = handle_no_classifier(f_uris)
else:
# Run Classification job
logger.info(
f"Using {processor.name} processor for Classification/Splitting"
)
classified_items = batch_classification(processor, dai_client, f_uris)
logger.info(f"Classified items: {classified_items}")
out_bucket_name, out_file_name = save_classification_results(classified_items)
except Exception as e:
logger.error(f"Error during batch classification: {e}")
stream_classification_results(
call_back_url=config.CALL_BACK_URL,
bucket_name=out_bucket_name,
file_name=out_file_name,
)
if __name__ == "__main__":
process()