bench/kernels/benchmark.py (86 lines of code) (raw):

# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import time from contextlib import nullcontext import numpy as np import torch from tqdm.auto import tqdm from optimum.quanto.library import disable_extensions def get_unpack_bench(bits, device): qmax = 2**bits a = torch.randint(0, qmax, [10240, 10240], dtype=torch.uint8).to(device) def bench_fn(): return torch.ops.quanto.unpack(a, bits) return bench_fn def timing(get_bench_func, device, iterations=10): def synchronize(device): if device.type == "cuda": torch.cuda.synchronize() elif device.type == "mps": torch.mps.synchronize() elif device.type == "xpu": torch.xpu.synchronize() else: torch.cpu.synchronize() def timing_event(device): if device.type == "cuda": return torch.cuda.Event(enable_timing=True) elif device.type == "mps": return torch.mps.Event(enable_timing=True) elif device.type == "xpu": return torch.xpu.Event(enable_timing=True) class CPUEvent: def __init__(self): self.time = None def record(self): self.time = time.time() def elapsed_time(self, other): assert self.time is not None assert other.time is not None return (other.time - self.time) * 1000 return CPUEvent() synchronize(device) bench_func = get_bench_func(device) # Warmup to load library bench_func() latencies = np.empty((iterations, 2)) for i in tqdm(range(iterations)): for j, context in enumerate([disable_extensions(), nullcontext()]): start_event = timing_event(device) end_event = timing_event(device) synchronize(device) start_event.record() with context: bench_func() end_event.record() synchronize(device) latencies[i, j] = start_event.elapsed_time(end_event) return np.mean(latencies[:, 0]), np.mean(latencies[:, 1]) GET_BENCH_FUNCTIONS = { "unpack_2bit": lambda device: get_unpack_bench(2, device), "unpack_4bit": lambda device: get_unpack_bench(4, device), } def main(): parser = argparse.ArgumentParser(description="Kernel benchmark") parser.add_argument("--kernel", type=str, default=None, help="The kernel to benchmark. None to test all of them") parser.add_argument("--device", type=str, default=None, help="The device to use for benchmark.") parser.add_argument("--it", type=int, default=10, help="The number of benchmark iterations") args = parser.parse_args() if args.device is None: if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") elif torch.xpu.is_available(): device = torch.device("xpu") else: device = torch.device("cpu") else: device = torch.device(args.device) all_kernels = GET_BENCH_FUNCTIONS.keys() kernels = all_kernels if args.kernel is None else [args.kernel] for kernel in kernels: get_bench_fn = GET_BENCH_FUNCTIONS[kernel] python_ms, ext_ms = timing(get_bench_fn, device, iterations=args.it) ratio = python_ms / ext_ms print(f"\n{kernel}[{device.type}]: python = {python_ms:.3f} ms, ext = {ext_ms:.3f} ms, ratio = {ratio:.1f}x") if __name__ == "__main__": main()