def run_model_sizes()

in Benchmarks/NVIDIA/GEMMCublasLt.py [0:0]


    def run_model_sizes(self):
        print("Running CublasLt...")
        current = os.getcwd()
        if self.datatype == "fp8e4m3":
            m_dims = [1024, 2048, 4096, 8192, 16384, 32768, 1024, 6144, 802816, 802816]
            n_dims = [1024, 2048, 4096, 8192, 16384, 32768, 2145, 12288, 192, 192]
            k_dims = [1024, 2048, 4096, 8192, 16384, 32768, 1024, 12288, 192, 768]
        else:
            m_dims = [1024, 2048, 4096, 8192, 16384, 1024, 6144, 802816, 802816]
            n_dims = [1024, 2048, 4096, 8192, 16384, 2145, 12288, 192, 192]
            k_dims = [1024, 2048, 4096, 8192, 16384, 1024, 12288, 192, 768]
        os.chdir(self.bindir)
        buffer = []
        for i in range(len(m_dims)):
            results = subprocess.run(
                [
                    "./cublaslt_gemm",
                    "-m",
                    str(m_dims[i]),
                    "-n",
                    str(n_dims[i]),
                    "-k",
                    str(k_dims[i]),
                    "-b",
                    str(self.b),
                    "-i",
                    str(self.i),
                    "-w",
                    str(self.w),
                    "-t",
                    self.datatype,
                ],
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
            )
            log = results.stdout.decode('utf-8').split()
            buffer.append(log)
            tools.write_log(tools.check_error(results))
        table1 = PrettyTable()

        with open('../Outputs/GEMMCublasLt_Performance_' + self.machine_name + '_' + self.datatype+'.csv', 'w') as csvFile:
            writer = csv.writer(csvFile, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
            writer.writerow(["M", "N", "K", "Batch", "Time(us)", "TFLOPS"])
            table1.field_names = ["M", "N", "K", "Batch Size", "Time(us)", "TFLOPS"]
            for item in buffer:
                writer.writerow(item)
                table1.add_row(item)
        print(table1)
        os.chdir(current)