def main()

in bench/kernels/benchmark.py [0:0]


def main():
    parser = argparse.ArgumentParser(description="Kernel benchmark")
    parser.add_argument("--kernel", type=str, default=None, help="The kernel to benchmark. None to test all of them")
    parser.add_argument("--device", type=str, default=None, help="The device to use for benchmark.")
    parser.add_argument("--it", type=int, default=10, help="The number of benchmark iterations")
    args = parser.parse_args()
    if args.device is None:
        if torch.cuda.is_available():
            device = torch.device("cuda")
        elif torch.backends.mps.is_available():
            device = torch.device("mps")
        elif torch.xpu.is_available():
            device = torch.device("xpu")
        else:
            device = torch.device("cpu")
    else:
        device = torch.device(args.device)
    all_kernels = GET_BENCH_FUNCTIONS.keys()
    kernels = all_kernels if args.kernel is None else [args.kernel]
    for kernel in kernels:
        get_bench_fn = GET_BENCH_FUNCTIONS[kernel]
        python_ms, ext_ms = timing(get_bench_fn, device, iterations=args.it)
        ratio = python_ms / ext_ms
        print(f"\n{kernel}[{device.type}]: python = {python_ms:.3f} ms, ext = {ext_ms:.3f} ms, ratio = {ratio:.1f}x")