JobManager.py (160 lines of code) (raw):

# Copyright 2022-2023 Google, LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import uuid, datetime, json, configparser import constants from google.cloud import firestore from google.cloud.firestore_v1.base_query import FieldFilter from google.cloud import tasks_v2 from google.api_core.client_info import ClientInfo USER_AGENT = 'cloud-solutions/datacatalog-tag-engine-v2' class JobManager: """Class for managing jobs for async task create and update requests cloud_run_sa = Cloud Run service account queue_project = Project where the queue is based (e.g. tag-engine-project) queue_region = Region where the queue is based (e.g. us-central1) queue_name = Name of the queue (e.g. tag-engine-queue) task_handler_uri = task handler uri in the Flask app hosted by Cloud Run """ def __init__(self, tag_engine_sa, tag_engine_project, tag_engine_region, tag_engine_queue, task_handler_uri, db_project, db_name): self.tag_engine_sa = tag_engine_sa self.tag_engine_project = tag_engine_project self.tag_engine_region = tag_engine_region self.tag_engine_queue = tag_engine_queue self.task_handler_uri = task_handler_uri self.db = firestore.Client(project=db_project, database=db_name, client_info=ClientInfo(user_agent=USER_AGENT)) ##################### API METHODS ################# def create_job(self, tag_creator_account, tag_invoker_account, config_uuid, config_type, metadata=None): job_uuid = self._create_job_record(config_uuid, config_type) if metadata != None: self._create_job_metadata_record(job_uuid, config_uuid, config_type, metadata) resp = self._create_job_task(tag_creator_account, tag_invoker_account, job_uuid, config_uuid, config_type) return job_uuid def update_job_running(self, job_uuid): #print('*** update_job_running ***') job_ref = self.db.collection('jobs').document(job_uuid) job_ref.update({'job_status': 'RUNNING'}) print('Set job running.') def record_num_tasks(self, job_uuid, num_tasks): job_ref = self.db.collection('jobs').document(job_uuid) job_ref.update({'task_count': num_tasks}) print('record_num_tasks') def calculate_job_completion(self, job_uuid): tasks_success = self._get_tasks_success(job_uuid) tasks_failed = self._get_tasks_failed(job_uuid) tasks_ran = tasks_success + tasks_failed print('tasks_success:', tasks_success) print('tasks_failed:', tasks_failed) print('tasks_ran:', tasks_ran) job_ref = self.db.collection('jobs').document(job_uuid) job = job_ref.get() if job.exists: job_dict = job.to_dict() task_count = job_dict['task_count'] # job running if job_dict['task_count'] > tasks_ran: job_ref.update({ 'tasks_ran': tasks_success + tasks_failed, 'job_status': 'RUNNING', 'tasks_success': tasks_success, 'tasks_failed': tasks_failed, }) pct_complete = round(tasks_ran / task_count * 100, 2) # job completed if job_dict['task_count'] <= tasks_ran: if tasks_failed > 0: job_ref.update({ 'tasks_ran': tasks_success + tasks_failed, 'tasks_success': tasks_success, 'tasks_failed': tasks_failed, 'job_status': 'ERROR', 'completion_time': datetime.datetime.utcnow() }) else: job_ref.update({ 'tasks_ran': tasks_success + tasks_failed, 'tasks_success': tasks_success, 'tasks_failed': tasks_failed, 'job_status': 'SUCCESS', 'completion_time': datetime.datetime.utcnow() }) pct_complete = 100 return tasks_success, tasks_failed, pct_complete def get_job_status(self, job_uuid): job = self.db.collection('jobs').document(job_uuid).get() if job.exists: job_dict = job.to_dict() return job_dict def set_job_status(self, job_uuid, status): self.db.collection('jobs').document(job_uuid).update({ 'job_status': status }) ################ INTERNAL PROCESSING METHODS ################# def _create_job_record(self, config_uuid, config_type): print('*** _create_job_record ***') job_uuid = uuid.uuid1().hex job_ref = self.db.collection('jobs').document(job_uuid) job_ref.set({ 'job_uuid': job_uuid, 'config_uuid': config_uuid, 'config_type': config_type, 'job_status': 'PENDING', 'task_count': 0, 'tasks_ran': 0, 'tasks_success': 0, 'tasks_failed': 0, 'creation_time': datetime.datetime.utcnow() }) print('Created job record.') return job_uuid def _create_job_task(self, tag_creator_account, tag_invoker_account, job_uuid, config_uuid, config_type): payload = {'job_uuid': job_uuid, 'config_uuid': config_uuid, 'config_type': config_type, \ 'tag_creator_account': tag_creator_account, 'tag_invoker_account': tag_invoker_account} task = { 'http_request': { 'http_method': 'POST', 'url': self.task_handler_uri, 'headers': {'content-type': 'application/json'}, 'body': json.dumps(payload).encode(), 'oidc_token': {'service_account_email': self.tag_engine_sa, 'audience': self.task_handler_uri} } } print('task create:', task) client = tasks_v2.CloudTasksClient() parent = client.queue_path(self.tag_engine_project, self.tag_engine_region, self.tag_engine_queue) resp = client.create_task(parent=parent, task=task) print('task resp: ', resp) return resp def _get_task_count(job_uuid): job = self.db.collection('jobs').document(job_uuid).get() if job.exists: job_dict = job.to_dict() return job_dict['task_count'] def _get_tasks_success(self, job_uuid): tasks_success = 0 shards = self.db.collection('shards').where(filter=FieldFilter('job_uuid', '==', job_uuid)).stream() for shard in shards: tasks_success += shard.to_dict().get('tasks_success', 0) return tasks_success def _get_tasks_failed(self, job_uuid): tasks_failed = 0 shards = self.db.collection('shards').where(filter=FieldFilter('job_uuid', '==', job_uuid)).stream() for shard in shards: tasks_failed += shard.to_dict().get('tasks_failed', 0) return tasks_failed def _create_job_metadata_record(self, job_uuid, config_uuid, config_type, metadata): print('*** _create_job_metadata_record ***') job_ref = self.db.collection('job_metadata').document(job_uuid) job_ref.set({ 'job_uuid': job_uuid, 'config_uuid': config_uuid, 'config_type': config_type, 'metadata': metadata, 'creation_time': datetime.datetime.utcnow() }) print('Created job_metadata record.') if __name__ == '__main__': config = configparser.ConfigParser() config.read("tagengine.ini") queue_project = config['DEFAULT']['QUEUE_PROJECT'] queue_region = config['DEFAULT']['QUEUE_REGION'] queue_name = config['DEFAULT']['INJECTOR_QUEUE'] task_handler_uri = '/_split_work' db_project = config['DEFAULT']['FIRESTORE_PROJECT'] db_name = config['DEFAULT']['FIRESTORE_DB'] jm = JobManager(queue_project, queue_region, queue_name, task_handler_uri, db_project, db_name) config_uuid = '1f1b4720839c11eca541e1ad551502cb' jm.create_async_job(config_uuid) print('done')