import json
import os
import shlex
import subprocess
import datetime
import time
import csv
from Infra import tools
from prettytable import PrettyTable

class GEMMCublastLt:
    def __init__(self, path: str, machine: str, b: int = 1, i: int = 1000, w: int = 10000):
        self.name = "GEMMCublasLt"
        config = self.get_config(path)
        self.m, self.n, self.k, self.duration, self.datatype = self.config_conversion(config)
        self.b = b
        self.i = i
        self.w = w
        self.bindir = ''
        self.machine_name = machine
        self.buffer = []

        # A100 does not support fp8
        if "A100" in machine:
            self.datatype = "fp16"

    def get_config(self, path: str):
        file = open(path)
        data = json.load(file)
        file.close()
        try:
            return data[self.name]
        except KeyError:
            raise KeyError("no value found")

    def parse_json(self, config, var):
        if var == "duration":
            return config["inputs"]["duration"]
        if var == "datatype":
            return config["inputs"]["datatype"]
        start = config["inputs"][var]["start"]
        end = config["inputs"][var]["end"]
        interval = config["inputs"][var]["interval"]
        data = [a for a in range(start, end, interval)]
        if not data or data[-1] < end:
            data.append(end)
        return data

    def config_conversion(self, config):
        m = self.parse_json(config, "m")
        n = self.parse_json(config, "n")
        k = self.parse_json(config, "k")
        duration = self.parse_json(config, "duration")
        datatype = self.parse_json(config, "datatype")
        return m, n, k, duration, datatype

    def build(self):
        bindir = tools.create_dir("bin")
        self.bindir = bindir
        path = "superbenchmark"
        isdir = os.path.isdir(path)
        if not isdir:
            results = subprocess.run(
                [
                    "git",
                    "clone",
                    "https://github.com/gitaumark/superbenchmark",
                    path,
                ],
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
            )
            tools.write_log(tools.check_error(results))
        current = os.getcwd()
        build_path = os.path.join(
            current,
            "superbenchmark/superbench/benchmarks/micro_benchmarks/cublaslt_gemm",
        )
        os.chdir(build_path)

        results = subprocess.run(
            ["cmake", "-S", "./"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
        )
        tools.write_log(tools.check_error(results))

        results = subprocess.run(
            ["make"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
        )
        tools.write_log(tools.check_error(results))
        print(results.stderr.decode('utf-8'))
        results = subprocess.run(
            ["mv", "cublaslt_gemm", bindir],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
        )
        os.chdir(current)

    # run GEMM with predetermined matrix sizes that are commonly used in transformers
    def run_model_sizes(self):
        print("Running CublasLt...")
        current = os.getcwd()
        if self.datatype == "fp8e4m3":
            m_dims = [1024, 2048, 4096, 8192, 16384, 32768, 1024, 6144, 802816, 802816]
            n_dims = [1024, 2048, 4096, 8192, 16384, 32768, 2145, 12288, 192, 192]
            k_dims = [1024, 2048, 4096, 8192, 16384, 32768, 1024, 12288, 192, 768]
        else:
            m_dims = [1024, 2048, 4096, 8192, 16384, 1024, 6144, 802816, 802816]
            n_dims = [1024, 2048, 4096, 8192, 16384, 2145, 12288, 192, 192]
            k_dims = [1024, 2048, 4096, 8192, 16384, 1024, 12288, 192, 768]
        os.chdir(self.bindir)
        buffer = []
        for i in range(len(m_dims)):
            results = subprocess.run(
                [
                    "./cublaslt_gemm",
                    "-m",
                    str(m_dims[i]),
                    "-n",
                    str(n_dims[i]),
                    "-k",
                    str(k_dims[i]),
                    "-b",
                    str(self.b),
                    "-i",
                    str(self.i),
                    "-w",
                    str(self.w),
                    "-t",
                    self.datatype,
                ],
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
            )
            log = results.stdout.decode('utf-8').split()
            buffer.append(log)
            tools.write_log(tools.check_error(results))
        table1 = PrettyTable()

        with open('../Outputs/GEMMCublasLt_Performance_' + self.machine_name + '_' + self.datatype+'.csv', 'w') as csvFile:
            writer = csv.writer(csvFile, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
            writer.writerow(["M", "N", "K", "Batch", "Time(us)", "TFLOPS"])
            table1.field_names = ["M", "N", "K", "Batch Size", "Time(us)", "TFLOPS"]
            for item in buffer:
                writer.writerow(item)
                table1.add_row(item)
        print(table1)
        os.chdir(current)
