TaskManager.py (219 lines of code) (raw):
# Copyright 2022-2023 Google, LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import uuid, hashlib, datetime, json, configparser, math
import constants
from google.cloud import firestore
from google.cloud import tasks_v2
from google.protobuf import duration_pb2
from google.api_core.client_info import ClientInfo
USER_AGENT = 'cloud-solutions/datacatalog-tag-engine-v2'
class TaskManager:
"""Class for creating and managing work requests in the form of cloud tasks
cloud_run_sa = Cloud Run service account
project = Cloud Run project id (e.g. tag-engine-project)
region = Cloud Run region (e.g. us-central1)
queue_name = Cloud Task queue (e.g. tag-engine-queue)
task_handler_uri = task handler uri in the Flask app hosted by Cloud Run
"""
def __init__(self,
tag_engine_sa,
tag_engine_project,
tag_engine_region,
tag_engine_queue,
task_handler_uri,
db_project,
db_name):
self.tag_engine_sa = tag_engine_sa
self.tag_engine_project = tag_engine_project
self.tag_engine_region = tag_engine_region
self.tag_engine_queue = tag_engine_queue
self.task_handler_uri = task_handler_uri
self.db = firestore.Client(project=db_project, database=db_name, client_info=ClientInfo(user_agent=USER_AGENT))
self.tasks_per_shard = 1000
self.task_deadline = duration_pb2.Duration().FromSeconds(1800) # 30 minutes per task (max supported duration)
##################### API METHODS #################
def create_config_uuid_tasks(self, tag_creator_account, tag_invoker_account, job_uuid, config_uuid, config_type, uris):
# create shards of 1000 tasks
if len(uris) > self.tasks_per_shard:
shards = math.ceil(len(uris) / self.tasks_per_shard)
else:
shards = 1
task_running_total = 0
task_counter = 0
for shard_index in range(0, shards):
shard_id_raw = job_uuid + str(shard_index)
shard_uuid = hashlib.md5(shard_id_raw.encode()).hexdigest()
self._create_shard(job_uuid, shard_uuid)
for uri_index, uri_val in enumerate(uris[task_running_total:], task_running_total):
# create the task
if isinstance(uri_val, str):
task_id_raw = job_uuid + uri_val + str(datetime.datetime.utcnow())
if isinstance(uri_val, tuple):
task_id_raw = job_uuid + ''.join(uri_val) + str(datetime.datetime.utcnow()) # uri_val is a tuple when it contains a gcs path
task_id = hashlib.md5(task_id_raw.encode()).hexdigest()
task_uuid = self._record_config_uuid_task(job_uuid, shard_uuid, task_id, config_uuid, config_type, uri_val)
self._create_config_uuid_task(tag_creator_account, tag_invoker_account, job_uuid, shard_uuid, task_uuid, task_id, config_uuid, config_type, uri_val)
task_counter += 1
task_running_total += 1
if task_counter == self.tasks_per_shard:
self._update_shard_tasks(job_uuid, shard_uuid, task_counter)
task_counter = 0
break
# update shard with last task_counter
if task_counter > 0:
self._update_shard_tasks(job_uuid, shard_uuid, task_counter)
def create_tag_extract_tasks(self, tag_creator_account, tag_invoker_account, job_uuid, config_uuid, config_type, tag_extract_list):
# create shards of 5000 records
if len(tag_extract_list) > self.tasks_per_shard:
shards = math.ceil(len(tag_extract_list) / self.tasks_per_shard)
else:
shards = 1
task_running_total = 0
task_counter = 0
for shard_index in range(0, shards):
shard_id_raw = job_uuid + str(shard_index)
shard_uuid = hashlib.md5(shard_id_raw.encode()).hexdigest()
self._create_shard(job_uuid, shard_uuid)
for extract_index, extract_val in enumerate(tag_extract_list[task_running_total:], task_running_total):
print('task_running_total: ', task_running_total)
print('extract_index: ', extract_index)
print('extract_val: ', extract_val)
# create the task
task_id_raw = job_uuid + ''.join(str(extract_val)) + str(datetime.datetime.utcnow())
task_id = hashlib.md5(task_id_raw.encode()).hexdigest()
#print('task_id: ', task_id)
task_uuid = self._record_tag_extract_task(job_uuid, shard_uuid, task_id, config_uuid, config_type, extract_val)
self._create_tag_extract_task(tag_creator_account, tag_invoker_account, job_uuid, shard_uuid, task_uuid, task_id, config_uuid, config_type, extract_val)
task_counter += 1
task_running_total += 1
if task_counter == self.tasks_per_shard:
self._update_shard_tasks(job_uuid, shard_uuid, task_counter)
task_counter = 0
break
# update shard with last task_counter
if task_counter > 0:
self._update_shard_tasks(job_uuid, shard_uuid, task_counter)
def update_task_status(self, shard_uuid, task_uuid, status):
if status == 'RUNNING':
self._set_task_running(shard_uuid, task_uuid)
self._set_rollup_tasks_running(shard_uuid)
if status == 'SUCCESS':
self._set_task_success(shard_uuid, task_uuid)
self._set_rollup_tasks_success(shard_uuid)
if status == 'ERROR':
self._set_task_failed(shard_uuid, task_uuid)
self._set_rollup_tasks_failed(shard_uuid)
################ INTERNAL PROCESSING METHODS #################
def _create_shard(self, job_uuid, shard_uuid):
print('*** _create_shard ***')
print('job_uuid: ' + job_uuid + ', shard_uuid: ' + shard_uuid)
shard_ref = self.db.collection('shards').document(shard_uuid)
shard_ref.set({
'shard_uuid': shard_uuid,
'job_uuid': job_uuid,
'tasks_ran': 0,
'tasks_success': 0,
'tasks_failed': 0,
'creation_time': datetime.datetime.utcnow()
})
def _update_shard_tasks(self, job_uuid, shard_uuid, task_counter):
#print('*** _update_shard ***')
self.db.collection('shards').document(shard_uuid).update({'task_count': task_counter});
def _record_config_uuid_task(self, job_uuid, shard_uuid, task_id, config_uuid, config_type, uri):
#print('*** _record_config_uuid_task ***')
task_uuid = uuid.uuid1().hex
task_ref = self.db.collection('shards').document(shard_uuid).collection('tasks').document(task_uuid)
task_ref.set({
'task_uuid': task_uuid, # task identifier in Firestore
'task_id': task_id, # cloud task identifier, based on uri
'shard_uuid': shard_uuid, # shard which this task belongs to
'job_uuid': job_uuid,
'config_uuid': config_uuid,
'config_type': config_type,
'uri': uri,
'status': 'PENDING',
'creation_time': datetime.datetime.utcnow()
})
#print('created task record ' + task_uuid + ' in shard ' + shard_uuid)
return task_uuid
def _record_tag_extract_task(self, job_uuid, shard_uuid, task_id, config_uuid, config_type, extract):
print('*** _record_task ***')
task_uuid = uuid.uuid1().hex
task_ref = self.db.collection('shards').document(shard_uuid).collection('tasks').document(task_uuid)
task_ref.set({
'task_uuid': task_uuid, # task identifier in Firestore
'task_id': task_id, # cloud task identifier, based on uri
'shard_uuid': shard_uuid, # shard which this task belongs to
'job_uuid': job_uuid,
'config_uuid': config_uuid,
'config_type': config_type,
'tag_extract': extract,
'status': 'PENDING',
'creation_time': datetime.datetime.utcnow()
})
#print('created task record ' + task_uuid + ' in shard ' + shard_uuid)
return task_uuid
def _create_config_uuid_task(self, tag_creator_account, tag_invoker_account, job_uuid, shard_uuid, task_uuid, task_id, \
config_uuid, config_type, uri):
success = True
payload = {'job_uuid': job_uuid, 'shard_uuid': shard_uuid, 'task_uuid': task_uuid, 'config_uuid': config_uuid, \
'config_type': config_type, 'uri': uri, 'tag_creator_account': tag_creator_account, \
'tag_invoker_account': tag_invoker_account}
client = tasks_v2.CloudTasksClient()
parent = client.queue_path(self.tag_engine_project, self.tag_engine_region, self.tag_engine_queue)
task = {
'name': parent + '/tasks/' + task_id,
'dispatch_deadline': self.task_deadline,
'http_request': {
'http_method': 'POST',
'url': self.task_handler_uri,
'headers': {'content-type': 'application/json'},
'body': json.dumps(payload).encode(),
'oidc_token': {'service_account_email': self.tag_engine_sa, 'audience': self.task_handler_uri}
}
}
print('task request:', task)
try:
task = client.create_task(parent=parent, task=task)
#print('task response:', task)
except Exception as e:
print('Error: could not create task for uri', self.task_handler_uri, '. Error: ', e)
self._set_task_failed(shard_uuid, task_uuid)
success = False
return success
def _create_tag_extract_task(self, tag_creator_account, tag_invoker_account, job_uuid, shard_uuid, task_uuid, task_id, config_uuid, config_type, extract):
success = True
payload = {'job_uuid': job_uuid, 'shard_uuid': shard_uuid, 'task_uuid': task_uuid, 'config_uuid': config_uuid, \
'config_type': config_type, 'tag_extract': extract, 'tag_creator_account': tag_creator_account, \
'tag_invoker_account': tag_invoker_account}
client = tasks_v2.CloudTasksClient()
parent = client.queue_path(self.tag_engine_project, self.tag_engine_region, self.tag_engine_queue)
task = {
'name': parent + '/tasks/' + task_id,
'dispatch_deadline': self.task_deadline,
'http_request': {
'http_method': 'POST',
'url': self.task_handler_uri,
'headers': {'content-type': 'application/json'},
'body': json.dumps(payload).encode(),
'oidc_token': {'service_account_email': self.tag_engine_sa, 'audience': self.task_handler_uri}
}
}
print('task request:', task)
try:
task = client.create_task(parent=parent, task=task)
print('task response:', task)
except Exception as e:
print('Error: could not create task for uri', self.task_handler_uri, '. Error: ', e)
self._set_task_failed(shard_uuid, task_uuid)
success = False
return success
def _set_task_running(self, shard_uuid, task_uuid):
#print('*** _set_task_running ***')
task_ref = self.db.collection('shards').document(shard_uuid).collection('tasks').document(task_uuid)
task_ref.set({
'status': 'RUNNING',
'start_time': datetime.datetime.utcnow()
}, merge=True)
print('set task running.')
def _set_rollup_tasks_running(self, shard_uuid):
shard_ref = self.db.collection('shards').document(shard_uuid)
shard_ref.update({'tasks_running': firestore.Increment(1)})
def _set_task_success(self, shard_uuid, task_uuid):
print('*** _set_task_success ***')
task_ref = self.db.collection('shards').document(shard_uuid).collection('tasks').document(task_uuid)
task_ref.set({
'status': 'SUCCESS',
'end_time': datetime.datetime.utcnow()
}, merge=True)
print('set task success.')
def _set_rollup_tasks_success(self, shard_uuid):
shard_ref = self.db.collection('shards').document(shard_uuid)
shard_ref.update({'tasks_success': firestore.Increment(1), 'tasks_running': firestore.Increment(-1)})
def _set_task_failed(self, shard_uuid, task_uuid):
print('*** _set_task_failed ***')
task_ref = self.db.collection('shards').document(shard_uuid).collection('tasks').document(task_uuid)
task_ref.set({
'status': 'ERROR',
'end_time': datetime.datetime.utcnow()
}, merge=True)
print('set task failed.')
def _set_rollup_tasks_failed(self, shard_uuid):
shard_ref = self.db.collection('shards').document(shard_uuid)
shard_ref.update({'tasks_failed': firestore.Increment(1), 'tasks_running': firestore.Increment(-1)})