# Copyright 2022-2023 Google, LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import uuid, hashlib, datetime, json, configparser, math
import constants
from google.cloud import firestore
from google.cloud import tasks_v2
from google.protobuf import duration_pb2
from google.api_core.client_info import ClientInfo

USER_AGENT = 'cloud-solutions/datacatalog-tag-engine-v2'

class TaskManager:
    """Class for creating and managing work requests in the form of cloud tasks 
    
    cloud_run_sa = Cloud Run service account
    project = Cloud Run project id (e.g. tag-engine-project)
    region = Cloud Run region (e.g. us-central1)
    queue_name = Cloud Task queue (e.g. tag-engine-queue)
    task_handler_uri = task handler uri in the Flask app hosted by Cloud Run 
    """
    def __init__(self,
                tag_engine_sa,
                tag_engine_project,
                tag_engine_region,
                tag_engine_queue, 
                task_handler_uri,
                db_project,
                db_name):

        self.tag_engine_sa = tag_engine_sa
        self.tag_engine_project = tag_engine_project
        self.tag_engine_region = tag_engine_region
        self.tag_engine_queue = tag_engine_queue
        self.task_handler_uri = task_handler_uri
        
        self.db = firestore.Client(project=db_project, database=db_name, client_info=ClientInfo(user_agent=USER_AGENT))
        self.tasks_per_shard = 1000
        self.task_deadline = duration_pb2.Duration().FromSeconds(1800) # 30 minutes per task (max supported duration)

##################### API METHODS #################
        
    def create_config_uuid_tasks(self, tag_creator_account, tag_invoker_account, job_uuid, config_uuid, config_type, uris):
        
        # create shards of 1000 tasks
        if len(uris) > self.tasks_per_shard:
            shards = math.ceil(len(uris) / self.tasks_per_shard)
        else:
            shards = 1
            
        task_running_total = 0
        task_counter = 0
        
        for shard_index in range(0, shards):
            
            shard_id_raw = job_uuid + str(shard_index)
            shard_uuid = hashlib.md5(shard_id_raw.encode()).hexdigest()
            self._create_shard(job_uuid, shard_uuid)

            for uri_index, uri_val in enumerate(uris[task_running_total:], task_running_total):

                # create the task
                if isinstance(uri_val, str):
                    task_id_raw = job_uuid + uri_val + str(datetime.datetime.utcnow())
                    
                if isinstance(uri_val, tuple):
                    task_id_raw = job_uuid + ''.join(uri_val) + str(datetime.datetime.utcnow()) # uri_val is a tuple when it contains a gcs path
                
                task_id = hashlib.md5(task_id_raw.encode()).hexdigest()
            
                task_uuid = self._record_config_uuid_task(job_uuid, shard_uuid, task_id, config_uuid, config_type, uri_val)
                self._create_config_uuid_task(tag_creator_account, tag_invoker_account, job_uuid, shard_uuid, task_uuid, task_id, config_uuid, config_type, uri_val)
                
                task_counter += 1
                task_running_total += 1
                
                if task_counter == self.tasks_per_shard:
                    self._update_shard_tasks(job_uuid, shard_uuid, task_counter)
                    task_counter = 0
                    break
        
        # update shard with last task_counter
        if task_counter > 0:    
            self._update_shard_tasks(job_uuid, shard_uuid, task_counter)

    
    def create_tag_extract_tasks(self, tag_creator_account, tag_invoker_account, job_uuid, config_uuid, config_type, tag_extract_list):
        
        # create shards of 5000 records
        if len(tag_extract_list) > self.tasks_per_shard:
            shards = math.ceil(len(tag_extract_list) / self.tasks_per_shard)
        else:
            shards = 1
            
        task_running_total = 0
        task_counter = 0
        
        for shard_index in range(0, shards):
            
            shard_id_raw = job_uuid + str(shard_index)
            shard_uuid = hashlib.md5(shard_id_raw.encode()).hexdigest()
            self._create_shard(job_uuid, shard_uuid)

            for extract_index, extract_val in enumerate(tag_extract_list[task_running_total:], task_running_total):
                
                print('task_running_total: ', task_running_total)
                print('extract_index: ', extract_index)
                print('extract_val: ', extract_val)
                
                # create the task
                task_id_raw = job_uuid + ''.join(str(extract_val)) + str(datetime.datetime.utcnow())
                task_id = hashlib.md5(task_id_raw.encode()).hexdigest()
                
                #print('task_id: ', task_id)

                task_uuid = self._record_tag_extract_task(job_uuid, shard_uuid, task_id, config_uuid, config_type, extract_val)
                self._create_tag_extract_task(tag_creator_account, tag_invoker_account, job_uuid, shard_uuid, task_uuid, task_id, config_uuid, config_type, extract_val)
                
                task_counter += 1
                task_running_total += 1
                
                if task_counter == self.tasks_per_shard:
                    self._update_shard_tasks(job_uuid, shard_uuid, task_counter)
                    task_counter = 0
                    break
        
        # update shard with last task_counter
        if task_counter > 0:    
            self._update_shard_tasks(job_uuid, shard_uuid, task_counter)

         
    def update_task_status(self, shard_uuid, task_uuid, status):     

        if status == 'RUNNING':
            self._set_task_running(shard_uuid, task_uuid)
            self._set_rollup_tasks_running(shard_uuid)
        
        if status == 'SUCCESS':
            self._set_task_success(shard_uuid, task_uuid)
            self._set_rollup_tasks_success(shard_uuid)
            
        if status == 'ERROR':
            self._set_task_failed(shard_uuid, task_uuid)
            self._set_rollup_tasks_failed(shard_uuid)


################ INTERNAL PROCESSING METHODS #################

    def _create_shard(self, job_uuid, shard_uuid):
        
        print('*** _create_shard ***')
        print('job_uuid: ' + job_uuid + ', shard_uuid: ' + shard_uuid)
        
        shard_ref = self.db.collection('shards').document(shard_uuid)
        
        shard_ref.set({
            'shard_uuid': shard_uuid,   
            'job_uuid': job_uuid,
            'tasks_ran': 0,
            'tasks_success': 0,
            'tasks_failed': 0,
            'creation_time': datetime.datetime.utcnow()
        })
        
    
    def _update_shard_tasks(self, job_uuid, shard_uuid, task_counter):
        
        #print('*** _update_shard ***')

        self.db.collection('shards').document(shard_uuid).update({'task_count': task_counter});
        

    def _record_config_uuid_task(self, job_uuid, shard_uuid, task_id, config_uuid, config_type, uri):
        
        #print('*** _record_config_uuid_task ***')
        
        task_uuid = uuid.uuid1().hex
        
        task_ref = self.db.collection('shards').document(shard_uuid).collection('tasks').document(task_uuid)

        task_ref.set({
            'task_uuid': task_uuid,     # task identifier in Firestore
            'task_id': task_id,         # cloud task identifier, based on uri
            'shard_uuid': shard_uuid,   # shard which this task belongs to
            'job_uuid': job_uuid,
            'config_uuid': config_uuid,
            'config_type': config_type,
            'uri': uri,
            'status':  'PENDING',
            'creation_time': datetime.datetime.utcnow()
        })
        
        #print('created task record ' + task_uuid + ' in shard ' + shard_uuid)
         
        return task_uuid    
    
    
    def _record_tag_extract_task(self, job_uuid, shard_uuid, task_id, config_uuid, config_type, extract):
        
        print('*** _record_task ***')
        
        task_uuid = uuid.uuid1().hex
        
        task_ref = self.db.collection('shards').document(shard_uuid).collection('tasks').document(task_uuid)

        task_ref.set({
            'task_uuid': task_uuid,     # task identifier in Firestore
            'task_id': task_id,         # cloud task identifier, based on uri
            'shard_uuid': shard_uuid,   # shard which this task belongs to
            'job_uuid': job_uuid,
            'config_uuid': config_uuid,
            'config_type': config_type,
            'tag_extract': extract,
            'status':  'PENDING',
            'creation_time': datetime.datetime.utcnow()
        })
        
        #print('created task record ' + task_uuid + ' in shard ' + shard_uuid)
         
        return task_uuid
    
    
    def _create_config_uuid_task(self, tag_creator_account, tag_invoker_account, job_uuid, shard_uuid, task_uuid, task_id, \
                                 config_uuid, config_type, uri):
        
        success = True
        
        payload = {'job_uuid': job_uuid, 'shard_uuid': shard_uuid, 'task_uuid': task_uuid, 'config_uuid': config_uuid, \
                   'config_type': config_type, 'uri': uri, 'tag_creator_account': tag_creator_account, \
                   'tag_invoker_account': tag_invoker_account}
        
        client = tasks_v2.CloudTasksClient()
        parent = client.queue_path(self.tag_engine_project, self.tag_engine_region, self.tag_engine_queue)
                
        task = {
            'name': parent + '/tasks/' + task_id,
            'dispatch_deadline': self.task_deadline,
            'http_request': {  
                'http_method': 'POST',
                'url': self.task_handler_uri,
                'headers': {'content-type': 'application/json'},
                'body': json.dumps(payload).encode(),
                'oidc_token': {'service_account_email': self.tag_engine_sa, 'audience': self.task_handler_uri}
            }
        }
        print('task request:', task)

        try:
            task = client.create_task(parent=parent, task=task)
            #print('task response:', task)
        
        except Exception as e:
            print('Error: could not create task for uri', self.task_handler_uri, '. Error: ', e)
            self._set_task_failed(shard_uuid, task_uuid)
            success = False
            
        return success
                  
        
    def _create_tag_extract_task(self, tag_creator_account, tag_invoker_account, job_uuid, shard_uuid, task_uuid, task_id, config_uuid, config_type, extract):
        
        success = True
    
        payload = {'job_uuid': job_uuid, 'shard_uuid': shard_uuid, 'task_uuid': task_uuid, 'config_uuid': config_uuid, \
                   'config_type': config_type, 'tag_extract': extract, 'tag_creator_account': tag_creator_account, \
                   'tag_invoker_account': tag_invoker_account}
        
        client = tasks_v2.CloudTasksClient()
        parent = client.queue_path(self.tag_engine_project, self.tag_engine_region, self.tag_engine_queue)
        
        task = {
            'name': parent + '/tasks/' + task_id,
            'dispatch_deadline': self.task_deadline,
            'http_request': {  
                'http_method': 'POST',
                'url': self.task_handler_uri,
                'headers': {'content-type': 'application/json'},
                'body': json.dumps(payload).encode(),
                'oidc_token': {'service_account_email': self.tag_engine_sa, 'audience': self.task_handler_uri}
            }
        }
        
        print('task request:', task)

        try:
            task = client.create_task(parent=parent, task=task)
            print('task response:', task)
        
        except Exception as e:
            print('Error: could not create task for uri', self.task_handler_uri, '. Error: ', e)
            self._set_task_failed(shard_uuid, task_uuid)
            success = False
            
        return success
            
    
    def _set_task_running(self, shard_uuid, task_uuid):
        
        #print('*** _set_task_running ***')
        
        task_ref = self.db.collection('shards').document(shard_uuid).collection('tasks').document(task_uuid)
        
        task_ref.set({
            'status':  'RUNNING',
            'start_time': datetime.datetime.utcnow()
        }, merge=True)
        
        print('set task running.')
    
    
    def _set_rollup_tasks_running(self, shard_uuid):
        
        shard_ref = self.db.collection('shards').document(shard_uuid)
        shard_ref.update({'tasks_running': firestore.Increment(1)})
    
    
    def _set_task_success(self, shard_uuid, task_uuid):
        
        print('*** _set_task_success ***')
        
        task_ref = self.db.collection('shards').document(shard_uuid).collection('tasks').document(task_uuid)

        task_ref.set({
            'status':  'SUCCESS',
            'end_time': datetime.datetime.utcnow()
        }, merge=True)
        
        print('set task success.')


    def _set_rollup_tasks_success(self, shard_uuid):
        
        shard_ref = self.db.collection('shards').document(shard_uuid)
        shard_ref.update({'tasks_success': firestore.Increment(1), 'tasks_running': firestore.Increment(-1)})
        
        
    def _set_task_failed(self, shard_uuid, task_uuid):
        
        print('*** _set_task_failed ***')
        
        task_ref = self.db.collection('shards').document(shard_uuid).collection('tasks').document(task_uuid)

        task_ref.set({
            'status':  'ERROR',
            'end_time': datetime.datetime.utcnow()
        }, merge=True)
        
        print('set task failed.')
        
    
    def _set_rollup_tasks_failed(self, shard_uuid):
        
        shard_ref = self.db.collection('shards').document(shard_uuid)
        shard_ref.update({'tasks_failed': firestore.Increment(1), 'tasks_running': firestore.Increment(-1)})
    




    

    

            
        
