"""Module providing slurm topology.conf for IB SHARP-enabled Cluster"""
import os
import sys
import logging
from pathlib import Path
import subprocess as subprocesslib
import datetime
from . import util as slutil

log=logging.getLogger('topology')

class Topology:
    """
    A class to represent and manage the topology of a Slurm cluster.
    Attributes:
    """

    def __init__(self,partition, output,directory):
        self.timestamp= datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        self.output_dir = f"{directory}/.topology/topology_ouput_{self.timestamp}"
        Path(self.output_dir).mkdir(parents=True, exist_ok=True)
        self.guid_to_host_map = {}
        self.device_guids_per_switch = []
        self.host_to_torset_map = {}
        self.torsets = {}
        self.partition=partition
        self.hosts=[]
        self.sharp_cmd_path = None
        self.guids_file = f"{self.output_dir}/guids.txt"
        self.topo_file = f"{self.output_dir}/topology.txt"
        self.slurm_top_file= output

    def get_hostnames(self) -> None:
        """
        Validates partition and retrieves a list of hostnames from the SLURM scheduler based on the provided parition.
        It also checks for hosts that are not idle and powered on and filters them out.
            None
        Raises:
            SystemExit: If the number of valid and powered-on hosts is less than 2.
        Logs:
            Warnings for invalid or powered-down nodes.
            Debug information for the list of hosts and the filtered valid hosts.
            Error if the number of valid and powered-on hosts is less than 2.
        """
        def validate_partition(partition) -> None:
            try:
                output=slutil.run("sinfo -o %P | tr -d '*'", shell=True)
            except subprocesslib.CalledProcessError:
                sys.exit(1)
            except subprocesslib.TimeoutExpired:
                sys.exit(1)
            partitions=set(output.stdout.strip('*').split('\n')[1:-1])
            log.debug("Valid Partitions: %s", partitions)
            if partition not in partitions:
                log.error("Partition %s does not exist", partition)
                sys.exit(1)
            else:
                log.debug("Partition %s exists", partition)
        def get_hostlist(cmd) -> list:
            try:
                output=slutil.run(cmd, shell=True)
            except subprocesslib.CalledProcessError:
                sys.exit(1)
            except subprocesslib.TimeoutExpired:
                sys.exit(1)
            return set(output.stdout.split('\n')[:-1])
        validate_partition(self.partition)
        partition_cmd = f'-p {self.partition} '
        host_cmd = f'scontrol show hostnames $(sinfo -p {self.partition} -o "%N" -h)'
        partition_states = "powered_down,powering_up,powering_down,power_down,drain,drained,draining,unknown,down,no_respond,fail,reboot"
        sinfo_cmd = f'sinfo {partition_cmd}-t {partition_states} -o "%N" -h'
        down_cmd = f'scontrol show hostnames $({sinfo_cmd})'
        hosts=get_hostlist(host_cmd)
        down_hosts=get_hostlist(down_cmd)
        self.hosts = list(hosts-down_hosts)
        if len(self.hosts)<len(hosts):
            log.warning(
                "Some nodes were not fully powered up and idle, "
                "running on a subset of nodes that are powered on and idle"
            )
            log.warning("Excluded Nodes: %s",
                                down_hosts)
        log.debug("Original hosts: %s", hosts)
        log.debug("Powered On and Idle Hosts: %s", self.hosts)
        if len(self.hosts)<2:
            log.error(
                "Need more than 2 nodes to create slurm topology, "
                "less than 2 nodes were powered up and idle. "
            )
            sys.exit(1)

    def get_os_name(self):
        """
        Retrieves the operating system name from the first host in self.hosts.

        This method runs a command on the first host to extract the OS ID from the
        /etc/os-release file. It uses the `grep` command to find the line starting
        with 'ID=' and then cuts the value after the '=' character.

        Returns:
            str: The operating system ID if the command is successful.

        Raises:
            SystemExit: If the command fails, logs the error and exits the program.
        """
        cmd = "grep '^ID=' /etc/os-release | cut -d'=' -f2"
        try:
            output = slutil.srun([self.hosts[0]],cmd,shell=True, partition=self.partition)
            exit_code=output.returncode
            stdout=output.stdout
        except slutil.SrunExitCodeException as e:
            log.error("Error running get_os_id command on host %s",self.hosts[0])
            if e.stderr_content:
                log.error(e.stderr_content)
            log.error(e.stderr)
            sys.exit(e.returncode)
        except subprocesslib.TimeoutExpired:
            sys.exit(1)
        if exit_code==0:
            os_id = stdout.strip().strip('"')
            log.debug(f"OS ID for host {self.hosts[0]}: {os_id}")
            return os_id

    def get_sharp_cmd(self):
        """
        Determines the appropriate SHARP command based on the operating system.

        Returns:
            str: The path to the SHARP command for the detected operating system.

        Raises:
            SystemExit: If the operating system is not supported.
        """
        os_id=self.get_os_name()
        if os_id == "ubuntu":
            log.debug("sharp_cmd_path: /opt/hpcx-v2.18-gcc-mlnx_ofed-ubuntu22.04-cuda12-x86_64/")
            return "/opt/hpcx-v2.18-gcc-mlnx_ofed-ubuntu22.04-cuda12-x86_64/"
        if os_id=="almalinux":
            log.debug("sharp_cmd_path: /opt/hpcx-v2.18-gcc-mlnx_ofed-redhat8-cuda12-x86_64/")
            return "/opt/hpcx-v2.18-gcc-mlnx_ofed-redhat8-cuda12-x86_64/"
        log.error("OS Not supported, exiting")
        sys.exit(1)


    def check_sharp_hello(self):
        """
        Executes the sharp_hello command on the first host in self.hosts and logs the output.

        This method constructs a command to run the `sharp_hello` executable located in the 
        `sharp_cmd_path` directory on the first host in self.hosts. It then executes 
        this command in parallel using `slutil.srun`.

        The standard output of the command is logged at the debug level. If the command 
        fails (i.e., the exit code is not 0), the standard error output is logged at the 
        error level, and the program exits with the same exit code. If the command succeeds, 
        a debug message is logged indicating success, and the method returns 0.

        Returns:
            int: 0 if the sharp_hello command passes successfully.

        Raises:
            SystemExit: If the sharp_hello command fails, the program exits with the 
                        corresponding exit code.
        """
        cmd = f"{self.sharp_cmd_path}sharp/bin/sharp_hello"
        try:
            output = slutil.srun([self.hosts[0]],cmd, partition=self.partition)
            log.debug(output.stdout)
        except slutil.SrunExitCodeException as e:
            log.error("SHARP is disabled on cluster")
            if e.stderr_content:
                log.error(e.stderr_content)
            log.error(e.stderr)
            sys.exit(e.returncode)
        except subprocesslib.TimeoutExpired:
            sys.exit(1)
        if output.returncode==0:
            log.debug("sharp_hello command passed")
            return 0


    def check_ibstatus(self) -> None:
        """
        Checks the availability of the 'ibstatus' command on the first host in self.hosts.

        This method runs a Python command to check if 'ibstatus' is available on the first host.
        If 'ibstatus' is not found, it logs an error message and exits the program.
        If 'ibstatus' is found, it logs a debug message indicating its availability.

        Returns:
            None

        Raises:
            SystemExit: If 'ibstatus' is not available on the first host.
        """
        cmd ="python3 -c \"import shutil; print(shutil.which('ibstatus'))\""
        try:
            output = slutil.srun([self.hosts[0]],cmd, partition=self.partition)
            path=output.stdout.strip()
            log.debug(path)
        except slutil.SrunExitCodeException as e:
            log.error("Error running check_ibstatus command on host %s",self.hosts[0])
            if e.stderr_content:
                log.error(e.stderr_content)
            log.error(e.stderr)
            sys.exit(e.returncode)
        except subprocesslib.TimeoutExpired:
            sys.exit(1)
        if path=="None":
            log.error("The 'ibstatus' command is not available")
            sys.exit(1)
        else:
            log.debug("The 'ibstatus' command is available.")
            return 0

    def retrieve_guids(self) -> None:
        """
        Retrieve GUIDs (Globally Unique Identifiers) from the hosts.
        This method runs a command on self.hosts to retrieve the Port GUIDs
        from the InfiniBand status. The command extracts the GUIDs using a series
        of shell commands and processes the output to map each GUID to its
        corresponding host.
        The GUIDs are stored in the `guid_to_host_map` attribute.
        """
        cmd = (
            'ibstatus | grep mlx5_ib | cut -d" " -f3 | '
            'xargs -I% ibstat "%" | grep "Port GUID" | cut -d: -f2 | '
            'while IFS= read -r line; do echo \"$(hostname): $line\"; done'
        )
        try:
            output = slutil.srun(self.hosts, cmd, shell=True, partition=self.partition)
        except slutil.SrunExitCodeException as e:
            log.error("Error running retrieve_guids command on hosts")
            if e.stderr_content:
                log.error(e.stderr_content)
            log.error(e.stderr)
            sys.exit(e.returncode)
        except subprocesslib.TimeoutExpired:
            sys.exit(1)
        lines=output.stdout.split('\n')[:-1]
        for line in lines:
                # Querying GUIDs from ibstat will have pattern 0x0099999999999999,
                # but Sharp will return 0x99999999999999
                # - So we need to remove the leading 00 after 0x
            node,guid = line.split(':')
            self.guid_to_host_map[guid.replace('0x00', '0x').strip()]=node.strip()

    def write_guids_to_file(self) -> None:
        """
        Writes the GUIDs from the guid_to_host_map to a file.

        This method opens the file specified by self.guids_file in write mode
        with UTF-8 encoding and writes each GUID from the guid_to_host_map
        to the file, each on a new line.
        """
        with open(self.guids_file, 'w', encoding="utf-8") as f:
            for guid in self.guid_to_host_map:
                f.write(f"{guid}\n")

    def generate_topo_file(self) -> None:
        """
        Generates the topology file for SHARP (Scalable Hierarchical Aggregation and Reduction Protocol).

        This method sets up the environment variables and constructs the command to generate the topology file
        using the SHARP command-line tool. The command is executed, and the output is logged to a file.

        Environment Variables:
            SHARP_SMX_UC_INTERFACE: Set to "mlx5_ib0:1".
            SHARP_CMD: Optional. If set, it is used as the base path for the SHARP command.

        Attributes:
            sharp_cmd_path (str): The base path for the SHARP command if SHARP_CMD is not set in the environment.
            guids_file (str): The path to the GUIDs file.
            topo_file (str): The path to the output topology file.
            output_dir (str): The directory where the log file will be saved.

        Raises:
            Any exceptions raised by `slutil.run_command` will propagate.
        """
        env=os.environ.copy()
        if 'SHARP_CMD' not in env:
            command = (
                f"SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 "
                f"{self.sharp_cmd_path}sharp/bin/sharp_cmd topology "
                f"--ib-dev mlx5_ib0:1 "
                f"--guids_file {self.guids_file} "
                f"--topology_file {self.topo_file}"
            )
        else:
            command = (
                f"SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 "
                f"{env['SHARP_CMD']}sharp/bin/sharp_cmd topology "
                f"--ib-dev mlx5_ib0:1 "
                f"--guids_file {self.guids_file} "
                f"--topology_file {self.topo_file}"
            )

        try:
            output = slutil.srun([self.hosts[0]], command, shell = True, partition=self.partition)
            log.debug(output.stdout)
        except slutil.SrunExitCodeException as e:
            log.error("Error running sharp_command on host %s",self.hosts[0])
            if e.stderr_content:
                log.error(e.stderr_content)
            log.error(e.stderr)
            sys.exit(e.returncode)
        except subprocesslib.TimeoutExpired:
            sys.exit(1)

    def group_guids_per_switch(self) -> list:
        """
        Parses the topology file and groups GUIDs per switch.

        This method reads the topology file specified by `self.topo_file`, 
        extracts lines containing 'Nodes=', and retrieves the GUIDs associated 
        with each switch. The GUIDs are then grouped and returned as a list.

        Returns:
            list: A list of GUIDs grouped per switch.
        """
        guids_per_switch = []
        with open(self.topo_file, 'r', encoding="utf-8") as f:
            for line in f:
                if 'Nodes=' not in line:
                    continue
                # 'SwitchName=ibsw2 Nodes=0x155dfffd341acb,0x155dfffd341b0b'
                guids_per_switch.append(line.strip().split(' ')[1].split('=')[1])
        return guids_per_switch

    def identify_torsets(self) -> dict:
        """
        Identify and map hosts to torsets based on device GUIDs per switch.

        This method processes the device GUIDs for each switch, assigns a torset index
        to each unique set of hosts, and maps each host to a torset identifier in the
        format "torset-XX", where XX is a zero-padded index.

        Returns:
            dict: A dictionary mapping each host to its corresponding torset identifier.
        """
        host_to_torset_map = {}
        for device_guids_one_switch in self.device_guids_per_switch:
            device_guids = device_guids_one_switch.strip().split(",")
            # increment torset index for each new torset
            torset_index = len(set(host_to_torset_map.values()))
            for guid in device_guids:
                host = self.guid_to_host_map[guid]
                if host in host_to_torset_map:
                    continue
                host_to_torset_map[host] = f"torset-{torset_index:02}"
        return host_to_torset_map

    def group_hosts_by_torset(self) -> dict:
        """
        Groups hosts by their torset.

        This method iterates over the `host_to_torset_map` dictionary and groups
        hosts based on their associated torset. It returns a dictionary where the
        keys are torset identifiers and the values are lists of hosts that belong
        to each torset.

        Returns:
            dict: A dictionary with torset identifiers as keys and lists of hosts
                  as values.
        """
        torsets = {}
        for host, torset in self.host_to_torset_map.items():
            if torset not in torsets:
                torsets[torset] = [host]
            else:
                torsets[torset].append(host)
        return torsets
    def write_slurm_topology(self)-> None:
        """
        Writes the SLURM topology configuration to a file or prints it to the console.

        This method generates the SLURM topology configuration based on the `torsets` attribute
        and either writes it to a file specified by `self.slurm_top_file` or prints it to the console,
        depending on the value of the `output` parameter.

        Returns:
            None
        """
        switches=[]
        if self.slurm_top_file:
            with open(self.slurm_top_file, 'w', encoding="utf-8") as file:
                for torset, hosts in self.torsets.items():
                    torset_index=torset[-2:]
                    num_nodes = len(hosts)
                    file.write(f"# Number of Nodes in sw{torset_index}: {num_nodes}\n")
                    print(f"# Number of Nodes in sw{torset_index}: {num_nodes}\n")
                    file.write(f"SwitchName=sw{torset_index} Nodes={','.join(hosts)}\n")
                    print(f"SwitchName=sw{torset_index} Nodes={','.join(hosts)}\n")
                    switches.append(f"sw{torset_index}")
                if len(self.torsets)>1:
                    switch_name=int(torset_index)+1
                    file.write(f"SwitchName=sw{switch_name:02} Switches={','.join(switches)}\n")
                    print(f"SwitchName=sw{switch_name:02} Switches={','.join(switches)}\n")
        else:
            for torset, hosts in self.torsets.items():
                torset_index=torset[-2:]
                num_nodes = len(hosts)
                print(f"# Number of Nodes in sw{torset_index}: {num_nodes}\n")
                print(f"SwitchName=sw{torset_index} Nodes={','.join(hosts)}\n")
                switches.append(f"sw{torset_index}")
            if len(self.torsets)>1:
                switch_name=int(torset_index)+1
                print(f"SwitchName=sw{switch_name:02} Switches={','.join(switches)}\n")
    def run(self):
        """
        Executes the sequence of steps to generate and write the SLURM topology.
        Returns:
            None
        """
        log.debug("Retrieving hostnames")
        self.get_hostnames()
        log.debug("Retrieving sharp_cmd directory")
        self.sharp_cmd_path=self.get_sharp_cmd()
        log.debug("checking that sharp_hello_works")
        self.check_sharp_hello()
        log.debug("Checking ibstat can be run on all hosts")
        self.check_ibstatus()
        log.debug("Running ibstat on hosts to collect InfiniBand device GUIDs")
        self.retrieve_guids()
        log.debug("Finished collecting InfiniBand device GUIDs from hosts")
        self.write_guids_to_file()
        log.debug("Finished writing guids to %s", self.guids_file)
        self.generate_topo_file()
        log.debug("Topology file generated at %s", self.topo_file)
        self.device_guids_per_switch =  self.group_guids_per_switch()
        log.debug("Finished grouping device guids per switch")
        self.host_to_torset_map = self.identify_torsets()
        log.debug("Identified torsets for hosts")
        self.torsets = self.group_hosts_by_torset()
        log.debug("Finished grouping hosts by torsets")
        self.write_slurm_topology()
        if self.topo_file:
            log.info("Finished writing slurm topology from torsets to %s",
                          self.slurm_top_file)
        else:
            log.info("Printed slurm topology")
