fbgemm_gpu/bench/merge_embeddings_benchmark.py (440 lines of code) (raw):
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
import logging
import signal
from typing import Tuple, List
import click
import fbgemm_gpu
import numpy as np
import tabulate
import torch
open_source: bool = getattr(fbgemm_gpu, "open_source", False)
if open_source:
# pyre-ignore[21]
from bench_utils import benchmark_torch_function
else:
from fbgemm_gpu.bench.bench_utils import benchmark_torch_function
from fbgemm_gpu.split_table_batched_embeddings_ops import (
SparseType,
BoundsCheckMode,
IntNBitTableBatchedEmbeddingBagsCodegen,
EmbeddingLocation,
)
from torch.profiler import ProfilerActivity, profile
def get_gpu_device(gpu_num) -> torch.device:
return torch.device(f"cuda:{gpu_num}")
# Merged indices with shape (T, B, L) -> (flattened indices with shape
# (T * B * L), offsets with shape (T * B + 1)).
# Reference: https://fburl.com/code/5ueyfv5j
def get_table_batched_offsets_from_dense(
merged_indices: torch.Tensor,
gpu_num,
) -> Tuple[torch.Tensor, torch.Tensor]:
(T, B, L) = merged_indices.size()
lengths = np.ones((T, B)) * L
flat_lengths = lengths.flatten()
return (
merged_indices.int().contiguous().view(-1).to(device=get_gpu_device(gpu_num)),
torch.tensor(
([0] + np.cumsum(flat_lengths).tolist()), device=get_gpu_device(gpu_num)
).int(),
)
# Reference: https://fburl.com/code/o5600si0
def generate_requests(
num_gpus: int,
B: int,
T: int,
L: int,
E: int,
# inter-batch indices reuse rate
reuse: float = 0.0,
) -> List[Tuple[torch.IntTensor, torch.IntTensor,]]:
rs = []
for gpu_num in range(num_gpus):
all_indices = torch.randint(
low=0,
high=E,
size=(T, B, L),
device=get_gpu_device(gpu_num),
dtype=torch.int32,
)
# each bag is usually sorted
(all_indices, _) = torch.sort(all_indices)
all_indices = all_indices.reshape(T, B * L)
rs.append(
get_table_batched_offsets_from_dense(all_indices.view(T, B, L), gpu_num)
)
return rs
def _get_random_tensor(
num_ads: int,
embedding_dimension: int,
ads_tables: int,
data_type: str,
gpu_idx: int,
include_quantization: bool,
):
if data_type == "FP16" or include_quantization:
result_tensor = torch.randn(
num_ads,
embedding_dimension * ads_tables,
dtype=torch.float16,
device=torch.device(f"cuda:{gpu_idx}"),
)
elif data_type == "INT8":
assert (
embedding_dimension % 2
) == 0, "needs to align to 2 bytes (half type size) for INT8"
result_tensor = torch.randint(
0,
255,
# 2 FP16 numbers for scale and bias, total of 4 bytes overhead
size=(num_ads, (embedding_dimension + 4) * ads_tables),
dtype=torch.uint8,
device=torch.device(f"cuda:{gpu_idx}"),
)
elif data_type == "INT4":
assert (
embedding_dimension % 4
) == 0, "needs to align to 2 bytes (half type size) for INT4"
result_tensor = torch.randint(
0,
255,
# Using torch.uint8 for int4 storage
size=(num_ads, (embedding_dimension // 2 + 4) * ads_tables),
dtype=torch.uint8,
device=torch.device(f"cuda:{gpu_idx}"),
)
else:
raise ValueError
return result_tensor
def generate_tbe(
batch_indices,
num_ads: int,
embedding_dimension: int,
num_of_embeddings: int,
pooling_factor: int,
ads_tables: int,
fused_tbe: bool,
data_type: str,
num_gpus: int,
):
B = num_ads
D = embedding_dimension
E = num_of_embeddings
L = pooling_factor
T = ads_tables
Ds = [D] * T
managed_option = EmbeddingLocation.DEVICE
output_dtype = SparseType.FP16
if fused_tbe:
assert data_type == "INT8" # INT4 not implemented yet
output_dtype = SparseType.INT8
emb = [
IntNBitTableBatchedEmbeddingBagsCodegen(
[
(
str(idx),
E,
d,
SparseType.INT4,
managed_option,
)
for d in Ds
],
output_dtype=output_dtype,
device=get_gpu_device(idx),
bounds_check_mode=BoundsCheckMode.NONE,
)
for idx in range(num_gpus)
]
for e in emb:
e.fill_random_weights()
requests = generate_requests(num_gpus, B, T, L, E)
# https://fburl.com/code/doxxjc8c
SIZE_OF_FLOAT = 4
num_elem_per_byte = 1 if data_type == "INT8" else 2
assert embedding_dimension % (2 * num_elem_per_byte) == 0
col_sizes = (
[
(embedding_dimension + num_elem_per_byte - 1) // num_elem_per_byte
+ 2 * SIZE_OF_FLOAT
]
* ads_tables
* num_gpus
)
offset = torch.tensor([0] + col_sizes, device=batch_indices.device)
tbe_offset = torch.cumsum(offset, dim=0).to(torch.int).cuda()
return emb, requests, tbe_offset
def print_p2p_bandwidth(
num_gpus, iters, pooled_ad_embeddings, bytes_per_element
) -> None:
print("Pairwise GPU Copy Bandwidth (GB/s)")
p2p_copy_bw = np.zeros((num_gpus, num_gpus))
for i in range(num_gpus):
for j in range(num_gpus):
with torch.cuda.device(i):
t, _ = benchmark_torch_function(
lambda: pooled_ad_embeddings[i].copy_(pooled_ad_embeddings[j])
if i != j
else pooled_ad_embeddings[i].clone(),
(),
flush_gpu_cache_size_mb=0,
iters=iters,
)
p2p_copy_bw[i, j] = (
pooled_ad_embeddings[i].numel() * bytes_per_element / t / 1.0e9
)
table = tabulate.tabulate(
p2p_copy_bw,
headers=[f"GPU {i}" for i in range(num_gpus)],
tablefmt="fancy_grid",
floatfmt=".0f",
)
print(table)
def benchmark(
all_to_one_only: bool,
num_ads: int,
embedding_dimension: int,
ads_tables: int,
iters: int = 10,
p2p_bw: bool = False,
dst_device: int = 0,
data_type: str = "FP16",
mode: str = "P2P",
skip_dequantization: bool = False,
num_of_embeddings: int = 10000,
pooling_factor: int = 25,
) -> str:
assert torch.cuda.is_available()
torch.cuda.set_device(dst_device)
num_gpus = torch.cuda.device_count()
batch_indices = torch.zeros(num_ads).long().cuda()
include_quantization = not mode == "P2P"
# Using torch.int8 for int4 storage
bytes_per_element = 2 if (data_type == "FP16" or include_quantization) else 1
total_elements = num_ads * embedding_dimension * ads_tables * num_gpus
logging.debug(
f"B: {num_ads}, D: {embedding_dimension}, T: {ads_tables}, Data Type: {data_type}, Num GPUs: {num_gpus}, Destination GPU: {dst_device}"
)
fused_tbe = mode == "P2P_FUSED_TBE"
include_tbe = fused_tbe or mode == "P2P_TBE"
if include_tbe:
emb, requests, tbe_offset = generate_tbe(
batch_indices,
num_ads,
embedding_dimension,
num_of_embeddings,
pooling_factor,
ads_tables,
fused_tbe,
data_type,
num_gpus,
)
pooled_ad_embeddings = [
_get_random_tensor(
num_ads,
embedding_dimension,
ads_tables,
data_type,
gpu_idx,
include_quantization,
)
for gpu_idx in range(num_gpus)
]
if p2p_bw:
print_p2p_bandwidth(num_gpus, iters, pooled_ad_embeddings, bytes_per_element)
def pool_func_with_quantization(
batch_indices,
include_quantization,
include_tbe,
fused_tbe,
skip_dequantization,
data_type,
):
if include_tbe:
embedding_results = []
for idx, (indices, offsets) in enumerate(requests):
with torch.cuda.device(idx):
embedding_results.append(emb[idx].forward(indices, offsets))
else:
embedding_results = pooled_ad_embeddings
if data_type == "FP16" or (not fused_tbe and not include_quantization):
if all_to_one_only:
return torch.ops.fbgemm.all_to_one_device(
pooled_ad_embeddings, batch_indices.device
)
else:
return torch.ops.fbgemm.merge_pooled_embeddings(
embedding_results, batch_indices.size(0), batch_indices.device
)
assert data_type == "INT8" or data_type == "INT4"
assert not all_to_one_only # not supported
if fused_tbe:
pooled_quantized_result = torch.ops.fbgemm.merge_pooled_embeddings(
embedding_results, batch_indices.size(0), batch_indices.device
)
else:
quantized = []
for t in embedding_results:
t_split_by_table = torch.split(t, embedding_dimension, dim=1)
quantized_split_by_table = [
torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(t.float())
if data_type == "INT8"
else torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf(
t.float(), 4
)
for t in t_split_by_table
]
result = torch.cat(quantized_split_by_table, dim=1)
quantized.append(result)
pooled_quantized_result = torch.ops.fbgemm.merge_pooled_embeddings(
quantized, batch_indices.size(0), batch_indices.device
)
if skip_dequantization:
return pooled_quantized_result
PooledEmbeddingDequantizeDataTypeFP16 = 1
if data_type == "INT8":
return torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloatMixedDim(
pooled_quantized_result,
tbe_offset,
PooledEmbeddingDequantizeDataTypeFP16,
)
else:
# TODO: the result here is wrong. Once MixedDim version for FusedNBit quantization is done, switch to that.
# Since their performance is similar, keep using Fused8BitRowwiseQuantizedToHalf for now.
return torch.ops.fbgemm.Fused8BitRowwiseQuantizedToHalf(
pooled_quantized_result
).half()
streams = [torch.cuda.Stream(device=i) for i in range(num_gpus)]
import contextlib
with contextlib.ExitStack() as stack:
for stream in streams:
stack.enter_context(torch.cuda.stream(stream))
# warm up
merged = pool_func_with_quantization(
batch_indices,
include_quantization,
include_tbe,
fused_tbe,
skip_dequantization,
data_type,
)
t, _ = benchmark_torch_function(
pool_func_with_quantization,
(
batch_indices,
include_quantization,
include_tbe,
fused_tbe,
skip_dequantization,
data_type,
),
flush_gpu_cache_size_mb=0,
iters=iters,
)
with profile(activities=[ProfilerActivity.CUDA]) as prof:
pool_func_with_quantization(
batch_indices,
include_quantization,
include_tbe,
fused_tbe,
skip_dequantization,
data_type,
)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
logging.debug(
f"Mode: {mode}, Data Type: {data_type}, B: {num_ads}, D: {embedding_dimension}, T: {ads_tables}, Num GPUs: {num_gpus}, Destination GPU: {dst_device}, "
f"Number of elements: {total_elements / 1.0e6:.2f} Million, Number of elements per GPU: {total_elements / 1.0e6 / num_gpus:.2f}, Billion elements per sec: {total_elements / t / 1.0e9:.1f}, "
f"Output Size: {merged.numel() * bytes_per_element / 1.0e6:.0f}MB, BW: {merged.numel() * bytes_per_element / t / 1.0e9:.1f}GB/s, "
f"t: {t * 1.0e3:.2f}ms"
)
# return result in CSV format
return (
f"{mode}, {data_type}, {num_ads}, {embedding_dimension}, {ads_tables}, {num_gpus}, {dst_device}, "
f"{total_elements / 1.0e6:.2f}, {total_elements / 1.0e6 / num_gpus:.2f}, {total_elements / 1.0e9 / t:.1f}, "
f"{merged.numel() * bytes_per_element / 1.0e6:.0f}, {merged.numel() * bytes_per_element / 1.0e9 / t:.1f}, "
f"{t * 1.0e3:.2f}"
)
@click.command()
@click.option("--all-to-one-only", is_flag=True, default=False)
@click.option("--num_ads", default=1024, type=int)
@click.option("--embedding_dimension", default=300, type=int)
@click.option("--ads_tables", default=100, type=int)
@click.option("--iters", default=10, type=int)
@click.option("--p2p_bw", is_flag=True, default=False)
@click.option("--dst_device", default=0, type=int)
@click.option(
"--data_type",
type=click.Choice(["FP16", "INT8", "INT4"]),
default="FP16",
)
# P2P: merge_pooled_embeddings() or all_to_one_device() for tensor with "--data_type"
# P2P_QUANT: for INT8/INT4 data type, start with FP16, then quantize -> P2P -> dequantize to FP16
# P2P_TBE: add TBE in front of P2P_QUANT. When "--data_type" is FP16, the flow is TBE -> P2P; for INT8/INT4, the flow is TBE -> quantize -> P2P -> dequantize
# P2P_FUSED_TBE: similar to P2P_TBE except fuse the quantization into TBE
@click.option(
"--mode",
type=click.Choice(["P2P", "P2P_QUANT", "P2P_TBE", "P2P_FUSED_TBE"]),
default="P2P",
)
# For quantized communication, do we dequantize back to FP16 in the end.
@click.option("--skip_dequantization", is_flag=True, default=False)
@click.option("--num_of_embeddings", default=100000, type=int)
@click.option("--pooling_factor", default=25, type=int)
@click.option("--sweep", is_flag=True, default=False)
def main(
all_to_one_only: bool,
num_ads: int,
embedding_dimension: int,
ads_tables: int,
iters: int,
p2p_bw: bool,
dst_device: int,
data_type: str,
mode: str,
skip_dequantization: bool,
num_of_embeddings: int,
pooling_factor: int,
sweep: bool,
) -> None:
csv_header = (
"mode, data_type, num_ads, embedding_dimension, ads_tables, num_gpus, "
"dst_device, number of elements (Million), number of elements per GPU (Million), throughput (billion elements per sec), "
"output size (MB), BW (GB/s), t (ms)"
)
if sweep:
def handler(signum, frame):
logging.error("timeout")
raise TimeoutError()
results = []
num_gpu = torch.cuda.device_count()
for num_ads in [128, 256, 512, 1024, 2048]:
# Scale num_ads so all GPUs have sweep through the same number of total elements
num_ads *= 8 // num_gpu
for embedding_dimension in [16, 64, 112, 304]:
for ads_tables in [25, 50, 100, 400, 800]:
if num_ads * embedding_dimension * ads_tables > 983040000:
continue # Skip tests that are too large
signal.signal(signal.SIGTERM, handler)
signal.alarm(600)
logging.info(
f"config: num_ads: {num_ads}, embedding_dimension: {embedding_dimension}, ads_tables: {ads_tables}"
)
try:
result = benchmark(
all_to_one_only,
num_ads,
embedding_dimension,
ads_tables,
iters,
p2p_bw,
dst_device,
data_type,
mode,
skip_dequantization,
num_of_embeddings,
pooling_factor,
)
results.append(result)
except (TimeoutError, RuntimeError) as err:
logging.error(
f"B: {num_ads}, D: {embedding_dimension}, T: {ads_tables}, Data Type: {data_type}, Num GPU: {num_gpu}, time out or failed: {err}"
)
print(csv_header)
print(*results, sep="\n")
return
result = benchmark(
all_to_one_only,
num_ads,
embedding_dimension,
ads_tables,
iters,
p2p_bw,
dst_device,
data_type,
mode,
skip_dequantization,
num_of_embeddings,
pooling_factor,
)
print(csv_header)
print(result)
if __name__ == "__main__":
main()