bench/kernels/benchmark_marlin_fp8.py (121 lines of code) (raw):
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from typing import Optional
import numpy as np
import torch
from optimum.quanto.tensor.weights.marlin.packed import pack_fp8_as_int32
M_SHAPES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
N_SHAPES = [4096]
K_SHAPES = [4096]
def run_benchmark(
m: Optional[int],
n: Optional[int],
k: Optional[int],
n_runs: int,
n_warmup: int,
dtype: torch.dtype = torch.float16,
):
print(f"\n----------- m={m}, n={n}, k={k}")
n_tokens = m
in_features = k
out_features = n
assert m is not None
device = torch.device("cuda")
inputs = torch.rand(n_tokens, in_features, dtype=dtype, device=device)
other_shape = (in_features, out_features)
other_data = torch.rand(other_shape, dtype=dtype, device=device).to(torch.float8_e4m3fn)
other_data_int32 = pack_fp8_as_int32(other_data)
perm = torch.empty(0, dtype=torch.int, device=device)
other_data_repack = torch.ops.quanto.gptq_marlin_repack(
b_q_weight=other_data_int32, perm=perm, size_k=in_features, size_n=out_features, num_bits=8
)
other_scale = torch.rand(1, dtype=dtype, device=device)
other_scale = other_scale.repeat(1, out_features)
workspace = torch.zeros(out_features // 64 * 16, dtype=torch.int, device=device)
latencies_marlin_fp8 = []
latencies_torch = []
with torch.no_grad():
for i in range(n_runs):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize(device)
start_event.record()
_ = torch.ops.quanto.fp8_marlin_gemm(
a=inputs,
b_q_weight=other_data_repack,
b_scales=other_scale,
workspace=workspace,
num_bits=8,
size_m=n_tokens,
size_n=out_features,
size_k=in_features,
)
end_event.record()
torch.cuda.synchronize(device)
latency_ms = start_event.elapsed_time(end_event)
if i >= n_warmup:
latencies_marlin_fp8.append(latency_ms)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize(device)
start_event.record()
other = other_data.to(dtype) * other_scale
_ = torch.matmul(inputs, other)
end_event.record()
torch.cuda.synchronize(device)
latency_ms = start_event.elapsed_time(end_event)
if i >= n_warmup:
latencies_torch.append(latency_ms)
mean_latency_torch = np.mean(latencies_torch)
mean_latency_marlin_fp8 = np.mean(latencies_marlin_fp8)
print("mean_latency_torch:", mean_latency_torch)
print("mean_latency_marlin_fp8:", mean_latency_marlin_fp8)
return mean_latency_torch, mean_latency_marlin_fp8
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Marlin FP8 kernel benchmark")
parser.add_argument("--nruns", type=int, default=20, help="The number of benchmark iterations")
parser.add_argument("--nwarmup", type=int, default=2, help="The number of warmup iterations (deducted from nruns)")
parser.add_argument(
"--m",
type=int,
help="m dimension of A=m*k",
default=None,
)
parser.add_argument(
"--n",
type=int,
help="n dimension of B=k*n (out_features)",
default=None,
)
parser.add_argument(
"--k",
type=int,
help="k dimension of A=m*k and B=k*n (in_features), hidden_size",
default=None,
)
args = parser.parse_args()
if args.m is not None:
def shape_generator():
yield (args.m, args.n, args.k)
else:
def shape_generator():
for m in M_SHAPES:
for n in N_SHAPES:
for k in K_SHAPES:
yield (m, n, k)
result = "m,n_out,k_in,torch_latency_ms,marlin_fp8_latency_ms\n"
for m, n, k in shape_generator():
mean_latency_torch, mean_latency_marlin_fp8 = run_benchmark(m, n, k, args.nruns, args.nwarmup)
result += (
",".join(
[
str(m),
str(n),
str(k),
f"{mean_latency_torch:.4f}",
f"{mean_latency_marlin_fp8:.4f}",
]
)
+ "\n"
)
print("\nResults:")
print(result)