#!/usr/bin/env python3

import ctypes
import logging
import pandas as pd
import numpy as np
import os
import pathlib
from enum import Enum
import functools


class NCCLDebugLogLevel(int, Enum):
    NONE = 0
    ERROR = 1
    WARN = 2
    INFO = 3
    DEBUG = 4
    TRACE = 5


class NCCLFunc(int, Enum):
    Broadcast = 0
    Reduce = 1
    AllGather = 2
    ReduceScatter = 3
    AllReduce = 4
    SendRecv = 5
    Send = 6
    # FIXME: tuner does not handle this right.
    # Recv = 7


class NCCLAlgo(int, Enum):
    TREE = 0
    RING = 1
    COLLNET_DIRECT = 2
    COLLNET_CHAIN = 3
    NVLS = 4
    NVLS_TREE = 5
    PAT = 6


class NCCLProto(int, Enum):
    LL = 0
    LL128 = 1
    SIMPLE = 2

class TunerPlatform(str, Enum):
    P5en = "p5en.48xlarge"
    P5 = "p5.48xlarge"


class Tuner:
    def _debug_logger_callback(self, level, flags, file, line, fmt, args):
        level_map = {
            NCCLDebugLogLevel.NONE: logging.NOTSET,
            NCCLDebugLogLevel.ERROR: logging.ERROR,
            NCCLDebugLogLevel.WARN: logging.WARNING,
            NCCLDebugLogLevel.INFO: logging.INFO,
            NCCLDebugLogLevel.DEBUG: logging.DEBUG,
            NCCLDebugLogLevel.TRACE: logging.DEBUG,  # Python doesn't have a TRACE level
        }
        py_level = level_map.get(level, logging.DEBUG)
        message = f"NCCL [{file.decode('utf-8')}:{line}]: {fmt.decode('utf-8')} {args}"
        self.logger.log(py_level, message)

    def __init__(
            self, tuner_dso: pathlib.Path, nranks: int, nnodes: int, platform: TunerPlatform, log_level=logging.DEBUG
    ):
        self.nranks = nranks
        self.nnodes = nnodes
        self.platform = platform
        os.environ["OFI_NCCL_FORCE_PRODUCT_NAME"] = str(platform.value)

        self.logger = logging.getLogger("NCCLTuner")
        self.logger.setLevel(log_level)
        self.ncclDebugLogger_t = ctypes.CFUNCTYPE(
            None,
            ctypes.c_int,
            ctypes.c_ulong,
            ctypes.c_char_p,
            ctypes.c_int,
            ctypes.c_char_p,
            ctypes.c_void_p,
        )
        num_entries = len(NCCLAlgo) * len(NCCLProto)
        self.ncclTuner_v3_t = type(
            "ncclTuner_v3_t",
            (ctypes.Structure,),
            {
                "_fields_": [
                    ("name", ctypes.c_char_p),
                    (
                        "init",
                        ctypes.CFUNCTYPE(
                            ctypes.c_int,
                            ctypes.c_size_t,
                            ctypes.c_size_t,
                            ctypes.c_void_p,
                            ctypes.POINTER(ctypes.c_void_p),
                        ),
                    ),
                    (
                        "getCollInfo",
                        ctypes.CFUNCTYPE(
                            ctypes.c_int,
                            ctypes.c_void_p,
                            ctypes.c_int,
                            ctypes.c_size_t,
                            ctypes.c_int,
                            ctypes.POINTER(ctypes.c_float * num_entries),
                            ctypes.c_int,
                            ctypes.c_int,
                            ctypes.POINTER(ctypes.c_int),
                        ),
                    ),
                    ("destroy", ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p)),
                ]
            },
        )

        callback_func = self.ncclDebugLogger_t(self._debug_logger_callback)
        self._callback_func = callback_func
        if not self.logger.handlers:
            handler = logging.StreamHandler()
            formatter = logging.Formatter(
                "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
            )
            handler.setFormatter(formatter)
            self.logger.addHandler(handler)

        self.context = ctypes.c_void_p()
        self.lib = ctypes.CDLL(tuner_dso)
        self.libc = ctypes.CDLL("libc.so.6")
        self.tuner = self.ncclTuner_v3_t.in_dll(self.lib, "ncclTunerPlugin_v3")
        result = self.tuner.init(
            nranks, nnodes, callback_func, ctypes.byref(self.context)
        )
        if result != 0:
            raise RuntimeError(f"Failed to initialize NCCL tuner. Error code: {result}")

    @functools.lru_cache(maxsize=None)
    def get_coll_info(self, coll_type: NCCLFunc, msg_size: int, num_pipe_ops: int = 1):
        if self.context is None:
            raise RuntimeError("NCCL tuner not initialized. Call initialize() first.")

        nBytes = ctypes.c_size_t(msg_size)
        numPipeOps = ctypes.c_int(num_pipe_ops)
        num_entries = len(NCCLAlgo) * len(NCCLProto)
        cost_table_array = (ctypes.c_float * num_entries)()
        for i in range(len(cost_table_array)):
            cost_table_array[i] = float(1337)

        nChannels = ctypes.c_int(0)
        result = self.tuner.getCollInfo(
            self.context,
            coll_type.value,
            nBytes,
            numPipeOps,
            ctypes.byref(cost_table_array),
            len(NCCLAlgo),
            len(NCCLProto),
            ctypes.byref(nChannels),
        )

        if result != 0:
            raise RuntimeError(f"Failed to get collective info. Error code: {result}")

        decision = (np.nan, np.nan)
        for algo in NCCLAlgo:
            for proto in NCCLProto:
                cost = cost_table_array[algo.value * len(NCCLProto) + proto.value]
                if cost != float(1337):
                    decision = (algo, proto)

        return {
            "algo": decision[0],
            "proto": decision[1],
        }

    def analyze_message_range(self, coll_type: NCCLFunc, min_size: int, max_size: int):
        def bisect(min_size, max_size):
            min_decision = self.get_coll_info(coll_type, min_size)
            max_decision = self.get_coll_info(coll_type, max_size)
            if min_decision == max_decision or max_size - min_size <= 1:
                return [(min_size, min_decision), (max_size, max_decision)]

            mid_size = (min_size + max_size) // 2
            left_results = bisect(min_size, mid_size)
            right_results = bisect(mid_size, max_size)

            combined = left_results + right_results[1:]
            return [combined[0]] + [
                combined[i]
                for i in range(1, len(combined))
                if combined[i][1] != combined[i - 1][1]
            ]

        edges = bisect(min_size, max_size)
        return pd.DataFrame(
            [
                {
                    "collective": coll_type.name,
                    "message_size": size,
                    "algo": decision["algo"],
                    "proto": decision["proto"],
                    "ranks": self.nranks,
                    "nodes": self.nnodes,
                    "platform": self.platform,
                }
                for size, decision in edges
            ]
        )

    def analyze_all(self, min_size: int = 32, max_size: int = 32 * 1024 * 1024 * 1024):
        return pd.concat(
            [self.analyze_message_range(c, min_size, max_size) for c in NCCLFunc]
        )

    def __del__(self):
        self.tuner.destroy(self.context)
        del self.context
        del self.tuner
