# Copyright 2025 PerfKitBenchmarker Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Class to represent a Cluster object."""
import typing
from typing import Callable, List, Tuple

from absl import flags
from absl import logging
from perfkitbenchmarker import disk
from perfkitbenchmarker import errors
from perfkitbenchmarker import linux_virtual_machine
from perfkitbenchmarker import resource
from perfkitbenchmarker import static_virtual_machine
from perfkitbenchmarker import virtual_machine
from perfkitbenchmarker import vm_util
from perfkitbenchmarker.configs import option_decoders
from perfkitbenchmarker.configs import spec
from perfkitbenchmarker.configs import vm_group_decoders


FLAGS = flags.FLAGS
TEMPLATE_FILE = flags.DEFINE_string(
    'cluster_template_file',
    None,
    'The template file to be used to create the cluster. None by default, '
    'each provider has a default template file.',
)
UNMANAGED = flags.DEFINE_boolean(
    'cluster_unmanaged_provision',
    False,
    'Instead of creating with cluster toolset, relying on cloud provider CLI.'
    ' e.g. gcloud for gcp, awscli for aws.'
)
TYPE = flags.DEFINE_string(
    'cluster_type',
    'default',
    'Type of cluster to use. Chances are clusters vary quite differently and '
    'may as well use its own template.'
)


class BaseClusterSpec(spec.BaseSpec):
  """Storing various data about HPC/ML cluster.

  Attributes:
    zone: The region / zone the in which to launch the cluster.
    machine_type: The provider-specific instance type (e.g. n1-standard-8).
    image: The disk image to boot from.
  """

  SPEC_TYPE = 'BaseClusterSpec'
  SPEC_ATTRS = ['CLOUD']
  CLOUD = None

  @classmethod
  def _ApplyFlags(cls, config_values, flag_values):
    """Overrides config values with flag values.

    Can be overridden by derived classes to add support for specific flags.

    Args:
      config_values: dict mapping config option names to provided values. Is
        modified by this function.
      flag_values: flags.FlagValues. Runtime flags that may override the
        provided config values.

    Returns:
      dict mapping config option names to values derived from the config
      values or flag values.
    """
    super()._ApplyFlags(config_values, flag_values)
    if flag_values['cloud'].present:
      config_values['cloud'] = flag_values.cloud
    if flag_values['cluster_template_file'].present:
      config_values['template'] = flag_values.cluster_template_file
    if flag_values['cluster_unmanaged_provision'].present:
      config_values['unmanaged'] = flag_values.cluster_unmanaged_provision
    cloud = config_values['cloud']
    # only apply to workers
    if flag_values['num_vms'].present:
      config_values['workers']['vm_count'] = flag_values['num_vms'].value
    # flags should be applied to workers and headnode
    if flag_values['zone'].present:
      config_values['workers']['vm_spec'][cloud]['zone'] = flag_values[
          'zone'
      ].value[0]
      config_values['headnode']['vm_spec'][cloud]['zone'] = flag_values[
          'zone'
      ].value[0]
    for flag_name in ('os_type', 'cloud'):
      if flag_values[flag_name].present:
        config_values['workers'][flag_name] = flag_values[flag_name].value
        config_values['headnode'][flag_name] = flag_values[flag_name].value

  @classmethod
  def _GetOptionDecoderConstructions(cls):
    """Gets decoder classes and constructor args for each configurable option.

    Can be overridden by derived classes to add options or impose additional
    requirements on existing options.

    Returns:
      dict. Maps option name string to a (ConfigOptionDecoder class, dict) pair.
          The pair specifies a decoder class and its __init__() keyword
          arguments to construct in order to decode the named option.
    """
    result = super()._GetOptionDecoderConstructions()
    result.update({
        'workers': (vm_group_decoders.VmGroupSpecDecoder, {}),
        'headnode': (vm_group_decoders.VmGroupSpecDecoder, {}),
        'cloud': (option_decoders.StringDecoder, {'default': None}),
        'template': (option_decoders.StringDecoder, {'default': None}),
        'unmanaged': (option_decoders.BooleanDecoder, {'default': False}),
    })
    return result


def GetClusterSpecClass(cloud: str):
  """Returns the cluster spec class corresponding to the given service."""
  return spec.GetSpecClass(BaseClusterSpec, CLOUD=cloud)


def GetClusterClass(cloud: str):
  """Returns the cluster spec class corresponding to the given service."""
  if UNMANAGED.value:
    return BaseCluster
  return resource.GetResourceClass(BaseCluster, CLOUD=cloud, TYPE=TYPE.value)


class BaseCluster(resource.BaseResource):
  """Base class for cluster resources.

  This class holds cluster-level methods and attributes.

  Attributes:
    image: The disk image used to boot.
    machine_type: The provider-specific instance type for worker VMs.
    zone: The region / zone the VM was launched in.
    headnode_vm: The headnode VM.
    worker_vms: Internal IP address.
  """

  RESOURCE_TYPE = 'BaseCluster'
  TYPE = 'default'
  REQUIRED_ATTRS = ['CLOUD', 'TYPE']
  DEFAULT_TEMPLATE = ''

  def __init__(self, cluster_spec: BaseClusterSpec):
    """Initialize BaseCluster class.

    Args:
      cluster_spec: cluster.BaseBaseClusterSpec object.
    """
    super().__init__()
    self.zone: str = cluster_spec.workers.vm_spec.zone
    self.machine_type: str = cluster_spec.workers.vm_spec.machine_type
    self.template: str = cluster_spec.template or self.DEFAULT_TEMPLATE
    self.unmanaged: bool = cluster_spec.unmanaged
    self.spec: BaseClusterSpec = cluster_spec
    self.worker_machine_type: str = self.machine_type
    self.headnode_machine_type: str = cluster_spec.headnode.vm_spec.machine_type
    self.headnode_spec: virtual_machine.BaseVmSpec = (
        cluster_spec.headnode.vm_spec
    )
    self.image: str = cluster_spec.workers.vm_spec.image
    self.workers_spec: virtual_machine.BaseVmSpec = cluster_spec.workers.vm_spec
    self.workers_static_disk_spec: disk.BaseDiskSpec = (
        cluster_spec.workers.disk_spec
    )
    self.workers_static_disk: static_virtual_machine.StaticDisk | None = (
        static_virtual_machine.StaticDisk(self.workers_static_disk_spec)
        if self.workers_static_disk_spec
        else None
    )
    self.os_type: str = cluster_spec.workers.os_type
    self.num_workers: int = cluster_spec.workers.vm_count
    self.vms: List[linux_virtual_machine.BaseLinuxVirtualMachine] = []
    self.headnode_vm: linux_virtual_machine.BaseLinuxVirtualMachine | None = (
        None
    )
    self.worker_vms: List[linux_virtual_machine.BaseLinuxVirtualMachine] = []
    self.name: str = f'pkb{FLAGS.run_uri}'[:10]
    self.nfs_path: str = None

  def GetResourceMetadata(self):
    return {
        'zone': self.zone,
        'machine_type': self.machine_type,
        'worker_machine_type': self.worker_machine_type,
        'headnode_machine_type': self.headnode_machine_type,
        'image': self.image,
        'os_type': self.os_type,
        'num_workers': self.num_workers,
        'template': self.template,
        'unmanaged': self.unmanaged
    }

  def __repr__(self):
    return f'<BaseCluster [name={self.name}]>'

  # TODO(yuyanting) Move common logic here after having concrete implementation.
  def _RenderClusterConfig(self):
    """Render the config file that will be used to create the cluster."""
    pass

  def RemoteCommand(
      self,
      command: str,
      ignore_failure: bool = False,
      timeout: float | None = None,
      **kwargs,
  ) -> Tuple[str, str]:
    """Runs a command on the VM.

    Derived classes may add additional kwargs if necessary, but they should not
    be used outside of the class itself since they are non standard.

    Args:
      command: A valid bash command.
      ignore_failure: Ignore any failure if set to true.
      timeout: The time to wait in seconds for the command before exiting. None
        means no timeout.
      **kwargs: Additional command arguments.

    Returns:
      A tuple of stdout and stderr from running the command.

    Raises:
      RemoteCommandError: If there was a problem issuing the command.
    """
    return self.headnode_vm.RemoteCommand(
        f'srun -N {self.num_workers} {command}',
        ignore_failure=ignore_failure,
        timeout=timeout,
        **kwargs,
    )

  def RobustRemoteCommand(
      self,
      command: str,
      timeout: float | None = None,
      ignore_failure: bool = False,
  ) -> Tuple[str, str]:
    """Runs a command on the VM in a more robust way than RemoteCommand.

    The default should be to call RemoteCommand and log that it is not yet
    implemented. This function should be overwritten it is decendents.

    Args:
      command: The command to run.
      timeout: The timeout for the command in seconds.
      ignore_failure: Ignore any failure if set to true.

    Returns:
      A tuple of stdout, stderr from running the command.

    Raises:
      RemoteCommandError: If there was a problem establishing the connection, or
          the command fails.
    """
    return self.headnode_vm.RobustRemoteCommand(
        command, ignore_failure=ignore_failure, timeout=timeout
    )

  def TryRemoteCommand(self, command: str, **kwargs):
    """Runs a remote command and returns True iff it succeeded."""
    try:
      self.RemoteCommand(command, **kwargs)
      return True
    except errors.VirtualMachine.RemoteCommandError:
      return False

  def InstantiateVm(self, vm_spec):
    """Creates VM object."""
    vm_class = virtual_machine.GetVmClass(vm_spec.CLOUD, self.os_type)
    return vm_class(vm_spec)

  def BackfillVm(
      self,
      vm_spec: virtual_machine.BaseVmSpec,
      fn: Callable[[virtual_machine.BaseVirtualMachine], None],
  ):
    """Create and backfill a VM object created using cluster resource.

    Args:
      vm_spec: VM spec to be used to find corresponding VM class.
      fn: The function to be called on the newly created VM.

    Returns:
      The newly created VM object.
    """
    vm = self.InstantiateVm(vm_spec)
    fn(vm)
    vm.disks = []
    vm._PostCreate()  # pylint: disable=protected-access
    vm.created = True
    return vm

  def AuthenticateVM(self):
    """Authenticate a remote machine to access all vms."""
    for vm in self.vms:
      vm.AuthenticateVm()

  def ExportVmGroupsForUnmanagedProvision(self):
    """Export VmGroups for unmanaged provisioning.

    Returns:
      Dictionary of VmGroupSpec for provisioning in BenchmarkSpec object.
    """
    if not self.unmanaged:
      return {}
    logging.info('Provisioning cluster resources with unmanaged VM creation.')
    return {
        'headnode': self.spec.headnode,
        'workers': self.spec.workers,
    }

  def ImportVmGroups(self, headnode, workers):
    """Imports VMGroups from unmanaged provision.

    After VMs being created from unmanaged codepath. Add corresponding vm_groups
    back to cluster object. So the benchmark hopefully do not care about
    how underlying resources being created.

    Args:
      headnode: VirtualMachine object representing a headnode.
      workers: List of VirtualMachine objects representing workers.
    """
    self.headnode_vm = headnode
    self.worker_vms = workers
    self.vms = [self.headnode_vm] + self.worker_vms

  def _Create(self):
    pass

  def _Delete(self):
    pass

  def Create(self):
    if self.unmanaged:
      return
    super().Create()

  def Delete(self):
    if self.unmanaged:
      return
    super().Delete()

  @vm_util.Retry(
      fuzz=0,
      timeout=1800,
      max_retries=5,
      retryable_exceptions=(errors.Resource.RetryableCreationError,),
  )
  def _WaitUntilReady(self):
    if self.unmanaged:
      return
    if not self.headnode_vm.TryRemoteCommand(
        f'srun -N {self.num_workers} hostname'
    ):
      raise errors.Resource.RetryableCreationError('Cluster not ready.')


Cluster = typing.TypeVar('Cluster', bound=BaseCluster)
