def timing()

in bench/kernels/benchmark.py [0:0]


def timing(get_bench_func, device, iterations=10):
    def synchronize(device):
        if device.type == "cuda":
            torch.cuda.synchronize()
        elif device.type == "mps":
            torch.mps.synchronize()
        elif device.type == "xpu":
            torch.xpu.synchronize()
        else:
            torch.cpu.synchronize()

    def timing_event(device):
        if device.type == "cuda":
            return torch.cuda.Event(enable_timing=True)
        elif device.type == "mps":
            return torch.mps.Event(enable_timing=True)
        elif device.type == "xpu":
            return torch.xpu.Event(enable_timing=True)

        class CPUEvent:
            def __init__(self):
                self.time = None

            def record(self):
                self.time = time.time()

            def elapsed_time(self, other):
                assert self.time is not None
                assert other.time is not None
                return (other.time - self.time) * 1000

        return CPUEvent()

    synchronize(device)

    bench_func = get_bench_func(device)
    # Warmup to load library
    bench_func()
    latencies = np.empty((iterations, 2))
    for i in tqdm(range(iterations)):
        for j, context in enumerate([disable_extensions(), nullcontext()]):
            start_event = timing_event(device)
            end_event = timing_event(device)
            synchronize(device)
            start_event.record()
            with context:
                bench_func()
            end_event.record()
            synchronize(device)
            latencies[i, j] = start_event.elapsed_time(end_event)
    return np.mean(latencies[:, 0]), np.mean(latencies[:, 1])