fbgemm_gpu/bench/quantize_ops_benchmark.py (186 lines of code) (raw):
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import functools
import logging
import random
import click
import fbgemm_gpu
import hypothesis.strategies as st
import torch
from hypothesis import given, settings
logging.basicConfig(level=logging.DEBUG)
open_source: bool = getattr(fbgemm_gpu, "open_source", False)
if open_source:
# pyre-ignore[21]
from bench_utils import benchmark_torch_function
else:
from fbgemm_gpu.bench.bench_utils import benchmark_torch_function
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
@click.group()
def cli() -> None:
pass
@cli.command()
@click.option("--flush-gpu-cache-size-mb", default=0)
@click.option("--iters", default=100)
@click.option("--warmup-runs", default=2)
@settings(max_examples=10, deadline=None)
# pyre-ignore
@given(
num_columns=st.sampled_from([2 ** n for n in range(4, 10)]),
num_rows=st.sampled_from([2 ** n for n in range(4, 10)]),
)
def bench(
flush_gpu_cache_size_mb: int,
iters: int,
num_columns: int,
num_rows: int,
warmup_runs: int,
) -> None:
average_time = {
"int8_quant": 0.0,
"int4_quant": 0.0,
"int2_quant": 0.0,
"fp8_143_quant": 0.0,
"fp8_152_quant": 0.0,
"int8_dequant": 0.0,
"int4_dequant": 0.0,
"int2_dequant": 0.0,
"fp8_143_dequant": 0.0,
"fp8_152_dequant": 0.0,
}
benchmark = functools.partial(
benchmark_torch_function,
flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
iters=iters,
num_warmups=warmup_runs,
)
input_data = torch.rand(num_rows, num_columns).float()
quant_data_8bit = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(input_data)
quant_data_4bit = torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf(
input_data, 4
)
quant_data_2bit = torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf(
input_data, 2
)
quant_data_fp8_143 = torch.ops.fbgemm.FloatToHFP8Quantized(
input_data.contiguous(), 4, 14, (2 - 2 ** (-3))
)
quant_data_fp8_152 = torch.ops.fbgemm.FloatToHFP8Quantized(
input_data, 5, 30, (2 - 2 ** (-2))
)
if torch.cuda.is_available():
input_data = input_data.cuda()
average_time["int8_quant"], _ = benchmark(
torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized,
(input_data,),
)
average_time["int4_quant"], _ = benchmark(
torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf,
(input_data, 4),
)
average_time["int2_quant"], _ = benchmark(
torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf,
(input_data, 2),
)
average_time["fp8_143_quant"], _ = benchmark(
torch.ops.fbgemm.FloatToHFP8Quantized,
(input_data, 4, 14, (2 - 2 ** (-3))),
)
average_time["fp8_152_quant"], _ = benchmark(
torch.ops.fbgemm.FloatToHFP8Quantized,
(input_data, 5, 30, (2 - 2 ** (-2))),
)
average_time["int8_dequant"], _ = benchmark(
torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloat,
(quant_data_8bit,),
)
average_time["int4_dequant"], _ = benchmark(
torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfToFloat,
(quant_data_4bit, 4),
)
average_time["int2_dequant"], _ = benchmark(
torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfToFloat,
(quant_data_2bit, 2),
)
average_time["fp8_143_dequant"], _ = benchmark(
torch.ops.fbgemm.HFP8QuantizedToFloat,
(quant_data_fp8_143, 4, 14),
)
average_time["fp8_152_dequant"], _ = benchmark(
torch.ops.fbgemm.HFP8QuantizedToFloat,
(quant_data_fp8_152, 5, 30),
)
logging.info(f"-------------- ncols={num_columns}, nrows={num_rows}-------------")
for k, t_time in average_time.items():
logging.info(f"{k} time per iter: {t_time * 1.0e6:.0f}us")
@cli.command()
@click.option("--flush-gpu-cache-size-mb", default=0)
@click.option("--iters", default=100)
@click.option("--batch_size", default=512)
@click.option("--num_tables", default=256)
@click.option("--min_dim", default=1)
@click.option("--max_dim", default=128)
@click.option("--warmup-runs", default=2)
def mixdim(
flush_gpu_cache_size_mb: int,
iters: int,
batch_size: int,
num_tables: int,
min_dim: int,
max_dim: int,
warmup_runs: int,
) -> None:
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available.")
random.seed(0)
table_dims = [
random.randint(min_dim, max_dim) * 8 for _ in range(num_tables)
] # assume table dimensions are multiples of 8
table_dims_with_qparams = [d + 8 for d in table_dims]
D_offsets = (
torch.cumsum(torch.tensor([0] + table_dims_with_qparams), dim=0)
.to(torch.int)
.cuda()
)
input_refs = [torch.randn((batch_size, d)).cuda() for d in table_dims]
input_refs_int8 = [
torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(t) for t in input_refs
]
input_data = torch.concat(input_refs_int8, dim=1).contiguous()
benchmark = functools.partial(
benchmark_torch_function,
flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
iters=iters,
num_warmups=warmup_runs,
)
average_time_mixed_dim_fp32, _ = benchmark(
torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloatMixedDim,
(
input_data,
D_offsets,
0,
),
) # output is FP32
average_time_mixed_dim_fp16, _ = benchmark_torch_function(
torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloatMixedDim,
(
input_data,
D_offsets,
1,
),
) # output is FP16
average_time_single_dim, _ = benchmark(
torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloat,
(input_data,),
) # output is FP32
print(
f"Input tensor batch_size: {batch_size}, num_tables: {num_tables}, tensor_size: {input_data.numel() / (1 << 30)} GB, average table dimension: {sum(table_dims) * 1.0/num_tables}."
)
print(
f"Mixed dim dequantize average time per iter FP32: {average_time_mixed_dim_fp32} s, bandwidth : {input_data.numel() / (1 << 30) / average_time_mixed_dim_fp32} GB/s."
)
print(
f"Mixed dim dequantize average time per iter FP16: {average_time_mixed_dim_fp16} s, bandwidth : {input_data.numel() / (1 << 30) / average_time_mixed_dim_fp16} GB/s."
)
print(
f"Single dim dequantize average time per iter FP32: {average_time_single_dim} s, bandwidth: {input_data.numel() / (1 << 30) / average_time_single_dim} GB/s."
)
if __name__ == "__main__":
cli()