fbgemm_gpu/bench/split_embeddings_cache_benchmark.py (448 lines of code) (raw):
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
import random
from typing import List, Tuple
import click
import numpy as np
import torch
from fbgemm_gpu.split_embedding_configs import SparseType
from fbgemm_gpu.split_table_batched_embeddings_ops import (
CacheAlgorithm,
EmbeddingLocation,
IntNBitTableBatchedEmbeddingBagsCodegen,
)
from torch import Tensor, nn
logging.basicConfig(level=logging.DEBUG)
try:
# pyre-ignore[21]
from fbgemm_gpu import open_source # noqa: F401
except Exception:
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings"
)
ASSOC: int = 32
# pyre-ignore
def benchmark_same_input(iters: int, f, *args) -> float:
"""
Returns average execution time in milliseconds across "iters".
"""
# Warm-up
f(*args)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(iters):
f(*args)
end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / iters
# pyre-ignore
def benchmark_different_inputs(f, args) -> float:
"""
Returns average execution time in milliseconds across "iters".
"""
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for arg in args:
f(arg)
end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / len(args)
def get_num_cached_tables(num_tables: int, cached_tables_ratio: float) -> int:
"""
Controls how # of cached tables are determined based on parameters.
"""
return round(num_tables * cached_tables_ratio)
def create_table_offsets(
num_tables: int, cached_tables_ratio: float, num_embeddings: int
) -> Tensor:
"""
Returns "table size cumsum", which is information of UVM caching for tables.
"""
num_cached_tables = get_num_cached_tables(num_tables, cached_tables_ratio)
np_list = np.arange(0, num_embeddings * num_cached_tables, num_embeddings)
num_uncached_tables = num_tables - num_cached_tables
while num_uncached_tables > 0:
added = random.randint(1, num_uncached_tables)
pos = random.randint(0, len(np_list) - 1)
np_list = np.insert(np_list, pos, [np_list[pos]] * added)
num_uncached_tables -= added
cache_hash_size_cumsum: Tensor = torch.tensor(np_list).cuda()
return cache_hash_size_cumsum
def create_embedding_specs(
num_tables: int,
cached_tables_ratio: float,
num_embeddings: int,
embedding_dims: int,
) -> List[Tuple[str, int, int, SparseType, EmbeddingLocation]]:
"""
Returns embedding specs to be used with IntNBitTableBatchedEmbeddingBagsCodegen.
"""
num_cached_tables = get_num_cached_tables(num_tables, cached_tables_ratio)
num_uncached_tables = num_tables - num_cached_tables
embedding_specs = []
for _ in range(min(num_cached_tables, num_uncached_tables)):
embedding_specs.append(
(
"",
num_embeddings,
embedding_dims,
SparseType.INT8,
EmbeddingLocation.DEVICE,
)
)
embedding_specs.append(
(
"",
num_embeddings,
embedding_dims,
SparseType.INT8,
EmbeddingLocation.MANAGED_CACHING,
)
)
if num_cached_tables > num_uncached_tables:
for _ in range(num_cached_tables - num_uncached_tables):
embedding_specs.append(
(
"",
num_embeddings,
embedding_dims,
SparseType.INT8,
EmbeddingLocation.MANAGED_CACHING,
)
)
else:
for _ in range(num_uncached_tables - num_cached_tables):
embedding_specs.append(
(
"",
num_embeddings,
embedding_dims,
SparseType.INT8,
EmbeddingLocation.DEVICE,
)
)
return embedding_specs
def create_request(
num_tables: int, num_embeddings: int, batch: int, avg_pooling_factor: int
) -> Tuple[Tensor, Tensor]:
"""
Returns [indices, offsets], which are inputs of embedding bags.
"""
indices: Tensor = torch.randint(
0, num_embeddings, (num_tables * batch * avg_pooling_factor,), dtype=torch.int32
).cuda()
# Pooling factors are intentionally diversified between [1, pf / 2, pf, pf* 2, pf * 4, pf * 8].
# where pf == avg_pooling_factor.
pooling_factors = []
for _ in range(num_tables - 1):
half_avg_pooling_factor = avg_pooling_factor // 2
if half_avg_pooling_factor > 0:
pooling_factors.append(
random.choices(
[
1,
half_avg_pooling_factor,
avg_pooling_factor,
2 * avg_pooling_factor,
4 * avg_pooling_factor,
8 * avg_pooling_factor,
],
weights=[5, 10, 15, 1, 1, 3],
)[0]
)
else:
pooling_factors.append(
random.choices(
[1, avg_pooling_factor, 2 * avg_pooling_factor], weights=[2, 20, 1]
)[0]
)
# Last one is whatever is the remainder.
curr_total_pooling_factors = sum(pooling_factors)
pooling_factors.append(num_tables * avg_pooling_factor - curr_total_pooling_factors)
offsets_list = [0]
for pooling_factor in pooling_factors:
if pooling_factor == 1:
for _ in range(batch):
offsets_list.append(pooling_factor)
else:
finish_offset = offsets_list[-1] + pooling_factor * batch
for _ in range(batch - 1):
selected = max(
int(random.gauss(pooling_factor, 0.1 * pooling_factor)), 1
)
last_offset = offsets_list[-1]
offsets_list.append(last_offset + selected)
offsets_list.append(finish_offset)
offsets: Tensor = torch.tensor(offsets_list, dtype=torch.int32).cuda()
return (indices, offsets)
@click.group()
def cli() -> None:
pass
@cli.command()
@click.option("--iters", default=100)
@click.option("--num-tables", default=50)
@click.option("--cached-tables-ratio", default=1.0)
@click.option("--batch", default=100)
@click.option("--avg-pooling-factor", default=100)
def linearize_cache_indices(
iters: int,
num_tables: int,
cached_tables_ratio: float,
batch: int,
avg_pooling_factor: int,
) -> None:
num_embeddings: int = 1000000
cache_hash_size_cumsum = create_table_offsets(
num_tables, cached_tables_ratio, num_embeddings
)
indices, offsets = create_request(
num_tables, num_embeddings, batch, avg_pooling_factor
)
t_ms = benchmark_same_input(
iters,
lambda indices, offsets: torch.ops.fbgemm.linearize_cache_indices(
cache_hash_size_cumsum, indices, offsets
),
indices,
offsets,
)
logging.info(
f"Across {iters} runs, T: {num_tables}, Cached T: {get_num_cached_tables(num_tables, cached_tables_ratio)}, BS: {batch}, {t_ms * 1.0e3:.0f}us"
)
@cli.command()
@click.option("--iters", default=100)
@click.option("--num-tables", default=50)
@click.option("--cached-tables-ratio", default=1.0)
@click.option("--batch", default=100)
@click.option("--avg-pooling-factor", default=100)
@click.option("--cache-load-factor", default=0.2)
def lxu_cache_lookup(
iters: int,
num_tables: int,
cached_tables_ratio: float,
batch: int,
avg_pooling_factor: int,
cache_load_factor: float,
) -> None:
num_embeddings: int = 1000000
embedding_dims: int = 128
embedding_specs = create_embedding_specs(
num_tables, cached_tables_ratio, num_embeddings, embedding_dims
)
tbe: nn.Module = IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_specs, cache_load_factor=cache_load_factor
)
tbe.fill_random_weights()
# Imitate execution flow by performing prefetching once.
indices, offsets = create_request(
num_tables, num_embeddings, batch, avg_pooling_factor
)
tbe.prefetch(indices, offsets)
linearized_indices = torch.ops.fbgemm.linearize_cache_indices(
tbe.cache_hash_size_cumsum, indices, offsets
)
t_ms = benchmark_same_input(
iters,
lambda linearized_indices, lxu_cache_state: torch.ops.fbgemm.lxu_cache_lookup(
linearized_indices, lxu_cache_state, tbe.total_cache_hash_size
),
linearized_indices,
tbe.lxu_cache_state,
)
# Run once again to obtain cache miss ratio.
locations = torch.ops.fbgemm.lxu_cache_lookup(
linearized_indices, tbe.lxu_cache_state, tbe.total_cache_hash_size
)
num_invalid_accesses = torch.sum(linearized_indices == tbe.total_cache_hash_size)
num_valid_accesses = linearized_indices.numel() - num_invalid_accesses
num_misses = torch.sum(locations == -1) - num_invalid_accesses
logging.info(
f"Across {iters} runs, T: {num_tables}, Cached T: {get_num_cached_tables(num_tables, cached_tables_ratio)}, "
f"BS: {batch}, cache_load_factor: {cache_load_factor}, {t_ms * 1.0e3:.0f}us, "
f"cache miss: {num_misses.item() / num_valid_accesses * 100}%"
)
@cli.command()
@click.option("--iters", default=100)
@click.option("--num-tables", default=50)
@click.option("--cached-tables-ratio", default=1.0)
@click.option("--batch", default=100)
@click.option("--avg-pooling-factor", default=100)
@click.option("--cache-load-factor", default=0.2)
def lru_cache_populate_byte(
iters: int,
num_tables: int,
cached_tables_ratio: float,
batch: int,
avg_pooling_factor: int,
cache_load_factor: float,
) -> None:
num_warm_ups: int = 5
num_embeddings: int = 1000000
embedding_dims: int = 128
embedding_specs = create_embedding_specs(
num_tables, cached_tables_ratio, num_embeddings, embedding_dims
)
cc: nn.Module = IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_specs, cache_load_factor=cache_load_factor
)
cc.fill_random_weights()
warm_up_requests = []
for _ in range(num_warm_ups):
indices, offsets = create_request(
num_tables, num_embeddings, batch, avg_pooling_factor
)
warm_up_requests.append(
torch.ops.fbgemm.linearize_cache_indices(
cc.cache_hash_size_cumsum, indices, offsets
)
)
requests = []
for _ in range(iters):
indices, offsets = create_request(
num_tables, num_embeddings, batch, avg_pooling_factor
)
requests.append(
torch.ops.fbgemm.linearize_cache_indices(
cc.cache_hash_size_cumsum, indices, offsets
)
)
timestep: int = 1
def populate(linear_indices: Tensor) -> None:
nonlocal timestep
torch.ops.fbgemm.lru_cache_populate_byte(
cc.weights_uvm,
cc.cache_hash_size_cumsum,
cc.total_cache_hash_size,
cc.cache_index_table_map,
cc.weights_offsets,
cc.weights_tys,
cc.D_offsets,
linear_indices,
cc.lxu_cache_state,
cc.lxu_cache_weights,
timestep,
cc.lxu_state,
)
timestep += 1
for warm_up_request in warm_up_requests:
populate(warm_up_request)
t_ms = benchmark_different_inputs(
populate,
requests,
)
# Replay to figure out UVM access BW, which would be PCIe bound.
replay_cc: nn.Module = IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_specs, cache_load_factor=cache_load_factor
)
replay_cc.fill_random_weights()
replay_timestep: int = 1
def replay_populate(linear_indices: Tensor) -> None:
nonlocal replay_timestep
torch.ops.fbgemm.lru_cache_populate_byte(
replay_cc.weights_uvm,
replay_cc.cache_hash_size_cumsum,
replay_cc.total_cache_hash_size,
replay_cc.cache_index_table_map,
replay_cc.weights_offsets,
replay_cc.weights_tys,
replay_cc.D_offsets,
linear_indices,
replay_cc.lxu_cache_state,
replay_cc.lxu_cache_weights,
replay_timestep,
replay_cc.lxu_state,
)
replay_timestep += 1
for warm_up_request in warm_up_requests:
replay_populate(warm_up_request)
total_rows = 0
for request in requests:
# pyre-ignore
prev = replay_cc.lxu_cache_state.clone().detach()
replay_populate(request)
# pyre-ignore
after = replay_cc.lxu_cache_state.clone().detach()
diff = after - prev
total_rows += diff.count_nonzero().item()
logging.info(
f"Across {iters} runs, T: {num_tables}, Cached T: {get_num_cached_tables(num_tables, cached_tables_ratio)}, "
f"BS: {batch}, cache_load_factor: {cache_load_factor}, {t_ms * 1.0e3:.0f}us, "
f"BW (just UVM accesses): {total_rows * embedding_dims / iters / t_ms * 1000 / 1024 / 1024} MB/s"
)
@cli.command()
@click.option("--iters", default=100)
@click.option("--num-tables", default=50)
@click.option("--cached-tables-ratio", default=1.0)
@click.option("--batch", default=100)
@click.option("--avg-pooling-factor", default=100)
@click.option("--cache-load-factor", default=0.2)
def lfu_cache_populate_byte(
iters: int,
num_tables: int,
cached_tables_ratio: float,
batch: int,
avg_pooling_factor: int,
cache_load_factor: float,
) -> None:
num_warm_ups: int = 5
num_embeddings: int = 1000000
embedding_dims: int = 128
embedding_specs = create_embedding_specs(
num_tables, cached_tables_ratio, num_embeddings, embedding_dims
)
cc: nn.Module = IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_specs,
cache_load_factor=cache_load_factor,
cache_algorithm=CacheAlgorithm.LFU,
)
cc.fill_random_weights()
warm_up_requests = []
for _ in range(num_warm_ups):
indices, offsets = create_request(
num_tables, num_embeddings, batch, avg_pooling_factor
)
warm_up_requests.append(
torch.ops.fbgemm.linearize_cache_indices(
cc.cache_hash_size_cumsum, indices, offsets
)
)
requests = []
for _ in range(iters):
indices, offsets = create_request(
num_tables, num_embeddings, batch, avg_pooling_factor
)
requests.append(
torch.ops.fbgemm.linearize_cache_indices(
cc.cache_hash_size_cumsum, indices, offsets
)
)
def populate(linear_indices: Tensor) -> None:
torch.ops.fbgemm.lfu_cache_populate_byte(
cc.weights_uvm,
cc.cache_hash_size_cumsum,
cc.total_cache_hash_size,
cc.cache_index_table_map,
cc.weights_offsets,
cc.weights_tys,
cc.D_offsets,
linear_indices,
cc.lxu_cache_state,
cc.lxu_cache_weights,
cc.lxu_state,
)
for warm_up_request in warm_up_requests:
populate(warm_up_request)
t_ms = benchmark_different_inputs(
populate,
requests,
)
# Replay to figure out UVM access BW, which would be PCIe bound.
replay_cc: nn.Module = IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_specs,
cache_load_factor=cache_load_factor,
cache_algorithm=CacheAlgorithm.LFU,
)
replay_cc.fill_random_weights()
def replay_populate(linear_indices: Tensor) -> None:
torch.ops.fbgemm.lfu_cache_populate_byte(
replay_cc.weights_uvm,
replay_cc.cache_hash_size_cumsum,
replay_cc.total_cache_hash_size,
replay_cc.cache_index_table_map,
replay_cc.weights_offsets,
replay_cc.weights_tys,
replay_cc.D_offsets,
linear_indices,
replay_cc.lxu_cache_state,
replay_cc.lxu_cache_weights,
replay_cc.lxu_state,
)
for warm_up_request in warm_up_requests:
replay_populate(warm_up_request)
total_rows = 0
for request in requests:
# pyre-ignore
prev = replay_cc.lxu_cache_state.clone().detach()
replay_populate(request)
# pyre-ignore
after = replay_cc.lxu_cache_state.clone().detach()
diff = after - prev
total_rows += diff.count_nonzero().item()
logging.info(
f"Across {iters} runs, T: {num_tables}, Cached T: {get_num_cached_tables(num_tables, cached_tables_ratio)}, "
f"BS: {batch}, cache_load_factor: {cache_load_factor}, {t_ms * 1.0e3:.0f}us, "
f"BW (just UVM accesses): {total_rows * embedding_dims / iters / t_ms * 1000 / 1024 / 1024} MB/s"
)
if __name__ == "__main__":
cli()