#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#
import typer
from airavata_mft_sdk import mft_client
from airavata_mft_sdk.common import StorageCommon_pb2
from airavata_mft_sdk import MFTTransferApi_pb2
from rich.console import Console
from rich.table import Table
import time
import sys
sys.path.append('.')
from . import config as configcli
import pandas as pd

def fetch_storage_and_secret_ids(storage_name):
  client = mft_client.MFTClient(transfer_api_port = configcli.transfer_api_port,
                                transfer_api_secured = configcli.transfer_api_secured,
                                resource_service_host = configcli.resource_service_host,
                                resource_service_port = configcli.resource_service_port,
                                resource_service_secured = configcli.resource_service_secured,
                                secret_service_host = configcli.secret_service_host,
                                secret_service_port = configcli.secret_service_port)
  search_req = StorageCommon_pb2.StorageSearchRequest(storageName=storage_name)
  storages = client.common_api.searchStorages(search_req)

  if len(storages.storageList) == 0:
    search_req = StorageCommon_pb2.StorageSearchRequest(storageId=storage_name)
    storages = client.common_api.searchStorages(search_req)

  if len(storages.storageList) == 0:
    print("No storage with name or id " + storage_name + " was found. Please register the storage with command mft-cli storage add")
    raise typer.Abort()

  if len(storages.storageList) > 1:
    print("More than one storage with nam " + storage_name + " was found. Please use the storage id. You can fetch it from mft-cli storage list")
    raise typer.Abort()

  storage = storages.storageList[0]
  if storage.storageType == StorageCommon_pb2.StorageType.LOCAL : #checking if storageID is Local
      return storage.storageId, ''
  sec_req = StorageCommon_pb2.SecretForStorageGetRequest(storageId = storage.storageId)
  sec_resp = client.common_api.getSecretForStorage(sec_req)
  if sec_resp.error != 0:
    print("Could not fetch the secret for storage " + storage.storageId)

  return sec_resp.storageId, sec_resp.secretId

def get_resource_metadata(storage_path, recursive_search = False):
  storage_name = storage_path.split("/")[0]
  resource_path = storage_path[len(storage_name) +1 :]

  storage_id, secret_id = fetch_storage_and_secret_ids(storage_name)

  id_req = MFTTransferApi_pb2.GetResourceMetadataFromIDsRequest(storageId = storage_id,
                                                                secretId = secret_id,
                                                                resourcePath = resource_path)
  resource_medata_req = MFTTransferApi_pb2.FetchResourceMetadataRequest(idRequest = id_req)

  client = mft_client.MFTClient(transfer_api_port = configcli.transfer_api_port,
                                transfer_api_secured = configcli.transfer_api_secured,
                                resource_service_host = configcli.resource_service_host,
                                resource_service_port = configcli.resource_service_port,
                                resource_service_secured = configcli.resource_service_secured,
                                secret_service_host = configcli.secret_service_host,
                                secret_service_port = configcli.secret_service_port)

  metadata_resp = client.transfer_api.resourceMetadata(resource_medata_req)
  return metadata_resp

def list(storage_path):

  metadata_resp = get_resource_metadata(storage_path)

  console = Console()
  table = Table()

  table.add_column('Name', justify='left')
  table.add_column('Type', justify='center')
  table.add_column('Size', justify='center')

  if (metadata_resp.WhichOneof('metadata') == 'directory') :
    for dir in metadata_resp.directory.directories:
      table.add_row('[bold]' + dir.friendlyName + '[/bold]', 'DIR', '')

    for file in metadata_resp.directory.files:
      table.add_row('[bold]' + file.friendlyName + '[/bold]', 'FILE', str(file.resourceSize))

  elif (metadata_resp.WhichOneof('metadata') == 'file'):
    table.add_row('[bold]' + metadata_resp.file.friendlyName + '[/bold]', 'FILE', str(metadata_resp.file.resourceSize))

  elif (metadata_resp.WhichOneof('metadata') == 'error'):
    print(metadata_resp.error)

  console.print(table)

def flatten_directories(directory, parent_path, file_list):
  for dir in directory.directories:
    flatten_directories(dir, parent_path + dir.friendlyName + "/", file_list)

  for file in directory.files:
    file_list.append((file, parent_path + file.friendlyName))

def copy(source, destination):

  source_storage_id, source_secret_id = fetch_storage_and_secret_ids(source.split("/")[0])
  dest_storage_id, dest_secret_id = fetch_storage_and_secret_ids(destination.split("/")[0])

  ## TODO : Check agent availability and deploy cloud agents if required

  file_list = []
  source_metadata = get_resource_metadata(source)
  endpoint_paths = []
  total_volume = 0

  transfer_request = MFTTransferApi_pb2.TransferApiRequest(sourceStorageId = source_storage_id,
                                                           sourceSecretId = source_secret_id,
                                                           destinationStorageId = dest_storage_id,
                                                           destinationSecretId = dest_secret_id,
                                                           optimizeTransferPath = False)

  if (source_metadata.WhichOneof('metadata') == 'directory') :
    if (destination[-1] != "/"):
      print("Source is a directory path so destination path should end with /")
      raise typer.Abort()

    flatten_directories(source_metadata.directory, "", file_list)
    for file_entry in file_list:
      file = file_entry[0]
      relative_path = file_entry[1]
      endpoint_paths.append(MFTTransferApi_pb2.EndpointPaths(
          sourcePath = file.resourcePath,
          destinationPath = destination[len(destination.split("/")[0]) +1 :] + relative_path))
      total_volume += file.resourceSize

  elif (source_metadata.WhichOneof('metadata') == 'file'):
    file_list.append((source_metadata.file, source_metadata.file.friendlyName))

    if destination[-1] == "/":
      destination = destination + source_metadata.file.friendlyName

    endpoint_paths.append(MFTTransferApi_pb2.EndpointPaths(
        sourcePath = source_metadata.file.resourcePath,
        destinationPath = destination[len(destination.split("/")[0]) +1 :]))

    total_volume += source_metadata.file.resourceSize

  elif (source_metadata.WhichOneof('metadata') == 'error'):
    print("Failed while fetching source details")
    print(metadata_resp.error)
    raise typer.Abort()

  transfer_request.endpointPaths.extend(endpoint_paths)

  confirm = typer.confirm("Total number of " + str(len(endpoint_paths)) +
                          " files to be transferred. Total volume is " + str(total_volume)
                          + " bytes. Do you want to start the transfer? ", True)

  if not confirm:
      raise typer.Abort()

  client = mft_client.MFTClient(transfer_api_port = configcli.transfer_api_port,
                                transfer_api_secured = configcli.transfer_api_secured,
                                resource_service_host = configcli.resource_service_host,
                                resource_service_port = configcli.resource_service_port,
                                resource_service_secured = configcli.resource_service_secured,
                                secret_service_host = configcli.secret_service_host,
                                secret_service_port = configcli.secret_service_port)

  transfer_resp = client.transfer_api.submitTransfer(transfer_request)

  transfer_id = transfer_resp.transferId

  state_request = MFTTransferApi_pb2.TransferStateApiRequest(transferId=transfer_id)

  ## TODO: This has to be optimized and avoid frequent polling of all transfer ids in each iteration
  ## Possible fix is to introduce a parent batch transfer id at the API level and fetch child trnasfer id
  # summaries in a single API call

  completed = 0
  failed = 0
  progress_percentage = 0

  with typer.progressbar(length=100) as progress:

    while 1:
      state_resp = client.transfer_api.getTransferStateSummary(state_request)

      progress_percentage = int(state_resp.percentage * 100)
      progress.update(progress_percentage - prev_percentage)
      prev_percentage = progress_percentage

      if (state_resp.percentage == 1.0):
        completed = len(state_resp.completed)
        failed = len(state_resp.failed)
        break

      if (state_resp.state == "FAILED"):
        print("Transfer failed. Reason: " + state_resp.description)
        raise typer.Abort()
      time.sleep(1)

  print(f"Processed {completed + failed} files. Completed {completed}, Failed {failed}.")

def copy_list(source_storage_id, dest_storage_id, list_file):

  source_storage_id, source_secret_id = fetch_storage_and_secret_ids(source_storage_id)
  dest_storage_id, dest_secret_id = fetch_storage_and_secret_ids(dest_storage_id)

  ## TODO : Check agent availability and deploy cloud agents if required

  file_list = []
  endpoint_paths = []
  total_volume = 0

  transfer_request = MFTTransferApi_pb2.TransferApiRequest(sourceStorageId = source_storage_id,
                                                           sourceSecretId = source_secret_id,
                                                           destinationStorageId = dest_storage_id,
                                                           destinationSecretId = dest_secret_id,
                                                           optimizeTransferPath = False)
  columns=['source', 'destination']
  df = pd.read_csv(list_file, header=None, dtype=str, names=columns)
  for i in range(len(df)):
    source_path = df['source'][i]
    destination_path = df['destination'][i]

    endpoint_paths.append(MFTTransferApi_pb2.EndpointPaths(sourcePath = source_path,
            destinationPath = destination_path))

  transfer_request.endpointPaths.extend(endpoint_paths)

  confirm = typer.confirm("Total number of " + str(len(endpoint_paths)) +
                          " files to be transferred. Do you want to start the transfer? ", True)

  if not confirm:
      raise typer.Abort()

  client = mft_client.MFTClient(transfer_api_port = configcli.transfer_api_port,
                                transfer_api_secured = configcli.transfer_api_secured,
                                resource_service_host = configcli.resource_service_host,
                                resource_service_port = configcli.resource_service_port,
                                resource_service_secured = configcli.resource_service_secured,
                                secret_service_host = configcli.secret_service_host,
                                secret_service_port = configcli.secret_service_port)

  transfer_resp = client.transfer_api.submitTransfer(transfer_request)

  transfer_id = transfer_resp.transferId

  state_request = MFTTransferApi_pb2.TransferStateApiRequest(transferId=transfer_id)

  ## TODO: This has to be optimized and avoid frequent polling of all transfer ids in each iteration
  ## Possible fix is to introduce a parent batch transfer id at the API level and fetch child trnasfer id
  # summaries in a single API call

  completed = 0
  failed = 0

  prev_percentage = 0
  with typer.progressbar(length=100) as progress:

    while 1:
      state_resp = client.transfer_api.getTransferStateSummary(state_request)

      progress_percentage = int(state_resp.percentage * 100)
      progress.update(progress_percentage - prev_percentage)
      prev_percentage = progress_percentage

      if (state_resp.percentage == 1.0):
        completed = len(state_resp.completed)
        failed = len(state_resp.failed)
        break

      if (state_resp.state == "FAILED"):
        print("Transfer failed. Reason: " + state_resp.description)
        raise typer.Abort()
      time.sleep(1)

  print(f"Processed {completed + failed} files. Completed {completed}, Failed {failed}.")
