fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py (1,916 lines of code) (raw):

#!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import itertools import logging import math import random import statistics import time from typing import Callable, List, Optional, Tuple import click import numpy as np import torch haveAIBench = False try: from aibench_observer.utils.observer import emitMetric haveAIBench = True except Exception: haveAIBench = False from fbgemm_gpu.split_table_batched_embeddings_ops import ( BoundsCheckMode, CacheAlgorithm, ComputeDevice, DenseTableBatchedEmbeddingBagsCodegen, EmbeddingLocation, OptimType, SparseType, SplitTableBatchedEmbeddingBagsCodegen, IntNBitTableBatchedEmbeddingBagsCodegen, PoolingMode, ) from numpy.random import default_rng from torch import Tensor logging.basicConfig(level=logging.DEBUG) def round_up(a: int, b: int) -> int: return int((a + b - 1) // b) * b def get_device() -> torch.device: return ( torch.cuda.current_device() if torch.cuda.is_available() else torch.device("cpu") ) # Merged indices with shape (T, B, L) -> (flattened indices with shape # (T * B * L), offsets with shape (T * B + 1)) def get_table_batched_offsets_from_dense( merged_indices: Tensor, ) -> Tuple[Tensor, Tensor]: (T, B, L) = merged_indices.size() lengths = np.ones((T, B)) * L flat_lengths = lengths.flatten() return ( merged_indices.long().contiguous().view(-1).to(get_device()), torch.tensor(([0] + np.cumsum(flat_lengths).tolist())).long().to(get_device()), ) def get_offsets_from_dense(indices: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: (B, L) = indices.size() return ( indices.contiguous().view(-1), torch.tensor( np.cumsum(np.asarray([0] + [L for _ in range(B)])[:-1]).astype(np.int64) ), ) def b_indices( b: Callable[..., torch.Tensor], x: torch.Tensor, per_sample_weights: Optional[torch.Tensor] = None, use_cpu: bool = False, do_pooling: bool = True, ) -> torch.Tensor: (indices, offsets) = get_offsets_from_dense(x) if do_pooling: return b( indices.cuda(), offsets.cuda(), per_sample_weights=per_sample_weights, ) else: return b(indices.cuda()) def generate_requests( iters: int, B: int, T: int, L: int, E: int, # inter-batch indices reuse rate reuse: float = 0.0, # alpha <= 1.0: use uniform distribution # alpha > 1.0: use zipf distribution alpha: float = 1.0, weights_precision: SparseType = SparseType.FP32, weighted: bool = False, requests_data_file: Optional[str] = None, # Comma-separated list of table numbers tables: Optional[str] = None, ) -> List[Tuple[torch.IntTensor, torch.IntTensor, Optional[Tensor]]]: if requests_data_file is not None: indices_tensor, offsets_tensor, lengths_tensor = torch.load(requests_data_file) average_L = 0 if tables is not None: emb_tables = tuple(int(x) for x in tables.split(",")) indices = torch.zeros(0, dtype=indices_tensor.dtype) offsets = torch.zeros(1, dtype=offsets_tensor.dtype) total_L = 0 for t in emb_tables: t_offsets = offsets_tensor[B * t : B * (t + 1) + 1] total_L += t_offsets[-1] - t_offsets[0] indices = torch.cat( (indices, indices_tensor[t_offsets[0] : t_offsets[-1]]) ) offsets = torch.cat( ( offsets, t_offsets[1:] - t_offsets[0] + offsets[-1], ) ) indices_tensor = indices offsets_tensor = offsets average_L = int(total_L / B) assert np.prod(offsets_tensor.size()) - 1 == np.prod((T, B)), ( f"Requested tables: {emb_tables} " f"does not conform to inputs (T, B) = ({T}, {B})." ) logging.warning( f"Using (indices = {indices_tensor.size()}, offsets = {offsets_tensor.size()}) based " f"on tables: {emb_tables}" ) else: average_L = int((offsets_tensor[-1] - offsets_tensor[0]) / B) assert (np.prod(offsets_tensor.size()) - 1) == np.prod((T, B)), ( f"Data file (indices = {indices_tensor.size()}, " f"offsets = {offsets_tensor.size()}, lengths = {lengths_tensor.size()}) " f"does not conform to inputs (T, B) = ({T}, {B})." ) assert ( L == average_L ), f"Requested L does not align with provided data file ({L} vs. {average_L})" assert E > max(indices_tensor), ( f"Number of embeddings is not enough to support maximum index " f"provided by data file {E} vs. {max(indices_tensor)}" ) weights_tensor = ( None if not weighted else torch.randn(indices_tensor.size(), device=get_device()) ) rs = [] for _ in range(iters): rs.append( ( indices_tensor.to(get_device()), offsets_tensor.to(get_device()), weights_tensor, ) ) return rs if alpha <= 1.0: all_indices = torch.randint( low=0, high=E, size=(iters, T, B, L), device=get_device(), dtype=torch.int32, ) # each bag is usually sorted (all_indices, _) = torch.sort(all_indices) all_indices = all_indices.reshape(iters, T, B * L) else: assert E >= L, "num-embeddings must be greater than equal to bag-size" # oversample and then remove duplicates to obtain sampling without # replacement all_indices = (np.random.zipf(a=alpha, size=(iters, T, B, 3 * L)) - 1) % E for index_tuple in itertools.product(range(iters), range(T), range(B)): # sample without replacement from # https://stats.stackexchange.com/questions/20590/how-do-i-sample-without-replacement-using-a-sampling-with-replacement-function r = set() for x in all_indices[index_tuple]: if x not in r: r.add(x) if len(r) == L: break assert (len(r)) == L, "too skewed distribution (alpha too big)" all_indices[index_tuple][:L] = list(r) # shuffle indices so we don't have unintended spatial locality all_indices = torch.as_tensor(all_indices[:, :, :, :L]) rng = default_rng() permutation = torch.as_tensor( rng.choice(E, size=all_indices.max().item() + 1, replace=False) ) all_indices = permutation.gather(0, all_indices.flatten()) all_indices = all_indices.to(get_device()).int().reshape(iters, T, B * L) for it in range(iters - 1): for t in range(T): reused_indices = torch.randperm(B * L, device=get_device())[ : int(B * L * reuse) ] all_indices[it + 1, t, reused_indices] = all_indices[it, t, reused_indices] rs = [] for it in range(iters): weights_tensor = ( None if not weighted else torch.randn(T * B * L, device=get_device()) ) rs.append( get_table_batched_offsets_from_dense(all_indices[it].view(T, B, L)) + (weights_tensor,) ) return rs def benchmark_requests( requests: List[Tuple[torch.IntTensor, torch.IntTensor, Optional[Tensor]]], func: Callable[[Tensor, Tensor, Optional[Tensor]], Tensor], flush_gpu_cache_size_mb: int = 0, check_median: bool = False, ) -> float: times = [] if torch.cuda.is_available(): torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) for (indices, offsets, weights) in requests: start_time = time.time() if torch.cuda.is_available(): if flush_gpu_cache_size_mb: _ = torch.rand( flush_gpu_cache_size_mb * 1024 * 1024 // 4, dtype=torch.float ) torch.cuda.synchronize() start_event.record() func(indices, offsets, weights) if torch.cuda.is_available(): end_event.record() torch.cuda.synchronize() it_time = start_event.elapsed_time(end_event) * 1.0e-3 times.append(it_time) else: it_time = time.time() - start_time times.append(it_time) avg_time = sum(times) / len(requests) median_time = statistics.median(times) return median_time if check_median else avg_time def benchmark_requests_refer( requests: List[Tuple[torch.IntTensor, torch.IntTensor, Optional[Tensor]]], T: int, B: int, L: int, E: int, D: int, pooling_mode: str, weighted: bool, flush_gpu_cache_size_mb: int = 0, check_median: bool = False, ) -> float: do_pooling = pooling_mode in ["sum", "mean"] if do_pooling: nn_embedding_list = [ torch.nn.EmbeddingBag(E, D, mode=pooling_mode, sparse=True).cuda() ] * T else: nn_embedding_list = [torch.nn.Embedding(E, D, sparse=True).cuda()] * T times = [] if torch.cuda.is_available(): torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) for (indices, _, weights) in requests: indices_list = indices.view(T, B, L).split(1) if weighted: assert weights is not None weights_list = weights.view(T, B, L).split(1) start_time = time.time() if torch.cuda.is_available(): if flush_gpu_cache_size_mb: _ = torch.rand( flush_gpu_cache_size_mb * 1024 * 1024 // 4, dtype=torch.float ) torch.cuda.synchronize() start_event.record() nn_embedding_output = ( [ b_indices(nn_embedding, x, use_cpu=False, do_pooling=do_pooling) for (nn_embedding, x) in zip(nn_embedding_list, indices_list) ] if not weighted else [ b_indices( nn_embedding, x, per_sample_weights=xw.view(-1), use_cpu=False, do_pooling=do_pooling, ) for (nn_embedding, x, xw) in zip( nn_embedding_list, indices_list, # pyre-fixme[61]: `weights_list` is undefined, or not always # defined. weights_list, ) ] ) if do_pooling: final_output = torch.cat( [f.view(B, -1) for f in nn_embedding_output], dim=1 ) else: final_output = torch.cat(nn_embedding_output, dim=0).view(-1, D) if torch.cuda.is_available(): end_event.record() torch.cuda.synchronize() it_time = start_event.elapsed_time(end_event) * 1.0e-3 times.append(it_time) else: it_time = time.time() - start_time times.append(it_time) avg_time = sum(times) / len(requests) median_time = statistics.median(times) return median_time if check_median else avg_time def benchmark_pipelined_requests( requests: List[Tuple[torch.IntTensor, torch.IntTensor, Optional[Tensor]]], func1: Callable[[Tensor, Tensor, Optional[Tensor]], None], func2: Callable[[Tensor, Tensor, Optional[Tensor]], None], flush_gpu_cache_size_mb: int = 0, ) -> Tuple[float, float]: torch.cuda.synchronize() start_events = [ (torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)) for _ in requests ] end_events = [ (torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)) for _ in requests ] for ((indices, offsets, indices_weights), start_event, end_event) in zip( requests, start_events, end_events ): if flush_gpu_cache_size_mb: _ = torch.rand( flush_gpu_cache_size_mb * 1024 * 1024 // 4, dtype=torch.float ) torch.cuda.synchronize() start_event[0].record() func1(indices, offsets, indices_weights) end_event[0].record() start_event[1].record() func2(indices, offsets, indices_weights) end_event[1].record() torch.cuda.synchronize() return ( sum( start_event[0].elapsed_time(end_event[0]) * 1.0e-3 for start_event, end_event in zip(start_events, end_events) ) / len(requests), sum( start_event[1].elapsed_time(end_event[1]) * 1.0e-3 for start_event, end_event in zip(start_events, end_events) ) / len(requests), ) @click.group() def cli() -> None: pass @cli.command() # recommended value: alpha=1.15 for training and alpha=1.09 for inference @click.option("--alpha", default=1.0) @click.option("--bag-size", default=20) @click.option("--batch-size", default=512) @click.option("--embedding-dim", default=128) @click.option("--weights-precision", type=SparseType, default=SparseType.FP32) @click.option("--stoc", is_flag=True, default=False) @click.option("--iters", default=100) @click.option("--managed", default="device") @click.option("--mixed", is_flag=True, default=False) @click.option("--num-embeddings", default=int(1e5)) @click.option("--num-tables", default=32) @click.option("--reuse", default=0.0) @click.option("--row-wise/--no-row-wise", default=True) @click.option("--weighted", is_flag=True, default=False) @click.option("--weighted-num-requires-grad", type=int, default=None) @click.option("--flush-gpu-cache-size-mb", default=0) @click.option("--dense", is_flag=True, default=False) @click.option("--output-dtype", type=SparseType, default=SparseType.FP32) @click.option("--requests_data_file", type=str, default=None) @click.option("--tables", type=str, default=None) def device( # noqa C901 alpha: float, bag_size: int, batch_size: int, embedding_dim: int, weights_precision: SparseType, stoc: bool, iters: int, managed: str, mixed: bool, num_embeddings: int, num_tables: int, reuse: float, row_wise: bool, weighted: bool, weighted_num_requires_grad: Optional[int], flush_gpu_cache_size_mb: int, dense: bool, output_dtype: SparseType, requests_data_file: Optional[str], tables: Optional[str], ) -> None: np.random.seed(42) torch.manual_seed(42) B = batch_size D = embedding_dim L = bag_size E = num_embeddings T = num_tables if weighted_num_requires_grad: assert weighted_num_requires_grad <= T weighted_requires_grad_tables = np.random.choice( T, replace=False, size=(weighted_num_requires_grad,) ).tolist() feature_requires_grad = ( torch.tensor( [1 if t in weighted_requires_grad_tables else 0 for t in range(T)] ) .to(get_device()) .int() ) else: feature_requires_grad = None if mixed: Ds = [ round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4) for _ in range(T) ] D = np.average(Ds) else: Ds = [D] * T optimizer = OptimType.EXACT_ROWWISE_ADAGRAD if row_wise else OptimType.EXACT_ADAGRAD if managed == "device": managed_option = ( EmbeddingLocation.DEVICE if torch.cuda.is_available() else EmbeddingLocation.HOST ) else: managed_option = EmbeddingLocation.MANAGED if dense: emb = DenseTableBatchedEmbeddingBagsCodegen( [ ( E, d, ) for d in Ds ], use_cpu=not torch.cuda.is_available(), ) else: emb = SplitTableBatchedEmbeddingBagsCodegen( [ ( E, d, managed_option, ComputeDevice.CUDA if torch.cuda.is_available() else ComputeDevice.CPU, ) for d in Ds ], optimizer=optimizer, learning_rate=0.1, eps=0.1, weights_precision=weights_precision, stochastic_rounding=stoc, output_dtype=output_dtype, ) emb = emb.to(get_device()) if weights_precision == SparseType.INT8: emb.init_embedding_weights_uniform(-0.0003, 0.0003) nparams = sum(w.numel() for w in emb.split_embedding_weights()) param_size_multiplier = weights_precision.bit_rate() / 8.0 logging.info( f"Embedding parameters: {nparams / 1.0e9: .2f} GParam, " f"{nparams * param_size_multiplier / 1.0e9: .2f} GB" ) logging.info( f"Accessed weights per batch: {B * sum(Ds) * L * param_size_multiplier / 1.0e9: .2f} GB" ) requests = generate_requests( iters, B, T, L, E, reuse=reuse, alpha=alpha, weights_precision=weights_precision, weighted=weighted, requests_data_file=requests_data_file, tables=tables, ) # forward time_per_iter = benchmark_requests( requests, lambda indices, offsets, per_sample_weights: emb.forward( indices.long(), offsets.long(), per_sample_weights, feature_requires_grad=feature_requires_grad, ), flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, ) logging.info( f"Forward, B: {B}, " f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, " f"BW: {param_size_multiplier * B * sum(Ds) * L / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 f"T: {time_per_iter * 1.0e6:.0f}us" ) if output_dtype == SparseType.INT8: # backward bench not representative return grad_output = torch.randn(B, sum(Ds)).to(get_device()) # backward time_per_iter = benchmark_requests( requests, lambda indices, offsets, per_sample_weights: emb( indices.long(), offsets.long(), per_sample_weights, feature_requires_grad=feature_requires_grad, ).backward(grad_output), flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, ) logging.info( f"ForwardBackward, B: {B}, E: {E}, T: {T}, D: {D}, L: {L}, " f"BW: {3 * param_size_multiplier * B * sum(Ds) * L / time_per_iter / 1.0e9: .2f} GB/s, " f"T: {time_per_iter * 1.0e6:.0f}us" ) @cli.command() @click.option("--alpha", default=1.0) @click.option("--bag-size", default=20) @click.option("--batch-size", default=512) @click.option("--embedding-dim", default=128) @click.option("--weights-precision", type=SparseType, default=SparseType.FP32) @click.option("--stoc", is_flag=True, default=False) @click.option("--iters", default=100) @click.option("--mixed", is_flag=True, default=False) @click.option("--num-embeddings", default=int(1e5)) @click.option("--num-tables", default=32) @click.option("--reuse", default=0.1) @click.option("--uvm-tables", default=1) @click.option("--uvm-bag-size", default=1) @click.option("--weighted", is_flag=True, default=False) @click.option("--flush-gpu-cache-size-mb", default=0) @click.option("--requests_data_file", type=str, default=None) @click.option("--tables", type=str, default=None) @click.option("--output-dtype", type=SparseType, default=SparseType.FP32) @click.option("--use-cache", is_flag=True, default=False) @click.option("--cache-algorithm", default="lru") @click.option("--cache-load-factor", default=0.2) @click.option("--enforce-hbm", is_flag=True, default=False) def uvm( alpha: bool, bag_size: int, batch_size: int, embedding_dim: int, weights_precision: SparseType, stoc: bool, iters: int, mixed: bool, num_embeddings: int, num_tables: int, reuse: float, uvm_tables: int, uvm_bag_size: int, weighted: bool, flush_gpu_cache_size_mb: int, requests_data_file: Optional[str], tables: Optional[str], output_dtype: SparseType, use_cache: bool, cache_algorithm: str, cache_load_factor: float, enforce_hbm: bool, ) -> None: np.random.seed(42) torch.manual_seed(42) B = batch_size D = embedding_dim L = bag_size E = num_embeddings T = num_tables T_uvm = uvm_tables assert T_uvm <= T assert ( T_uvm > 0 ), f"T_uvm specified {T_uvm} <= 0. If not testing UVM, please use device benchmark." T_gpu = T - T_uvm L_uvm = uvm_bag_size cache_alg = CacheAlgorithm.LRU if cache_algorithm == "lru" else CacheAlgorithm.LFU managed_type = ( EmbeddingLocation.MANAGED_CACHING if use_cache else EmbeddingLocation.MANAGED ) if mixed: Ds = [ round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4) for _ in range(T) ] D = np.average(Ds) else: Ds = [D] * T emb_uvm = SplitTableBatchedEmbeddingBagsCodegen( [ ( E, d, managed_type, ComputeDevice.CUDA, ) for d in Ds[:T_uvm] ], weights_precision=weights_precision, stochastic_rounding=stoc, output_dtype=output_dtype, cache_load_factor=cache_load_factor, cache_algorithm=cache_alg, enforce_hbm=enforce_hbm, ).cuda() if weights_precision == SparseType.INT8: emb_uvm.init_embedding_weights_uniform(-0.0003, 0.0003) if T_gpu > 0: emb_gpu = SplitTableBatchedEmbeddingBagsCodegen( [ ( E, d, EmbeddingLocation.DEVICE, ComputeDevice.CUDA, ) for d in Ds[T_uvm:] ], weights_precision=weights_precision, stochastic_rounding=stoc, ).cuda() if weights_precision == SparseType.INT8: emb_gpu.init_embedding_weights_uniform(-0.0003, 0.0003) emb_mixed = SplitTableBatchedEmbeddingBagsCodegen( [ ( E, d, managed_option, ComputeDevice.CUDA, ) for (d, managed_option) in zip( Ds, [managed_type] * T_uvm + [EmbeddingLocation.DEVICE] * T_gpu, ) ], weights_precision=weights_precision, stochastic_rounding=stoc, output_dtype=output_dtype, cache_load_factor=cache_load_factor, cache_algorithm=cache_alg, enforce_hbm=enforce_hbm, ).cuda() if weights_precision == SparseType.INT8: emb_mixed.init_embedding_weights_uniform(-0.0003, 0.0003) requests_uvm = generate_requests( iters, B, T_uvm, L_uvm, E, reuse=reuse, alpha=alpha, weights_precision=weights_precision, weighted=weighted, requests_data_file=requests_data_file, tables=tables, ) requests_gpu = None if T_gpu > 0: requests_gpu = generate_requests( iters, B, T_gpu, L, E, reuse=reuse, alpha=alpha, weights_precision=weights_precision, weighted=False, requests_data_file=requests_data_file, tables=tables, ) param_size_multiplier = weights_precision.bit_rate() / 8.0 output_size_multiplier = output_dtype.bit_rate() / 8.0 read_write_bytes_uvm = ( output_size_multiplier * B * sum(Ds[:T_uvm]) + param_size_multiplier * B * sum(Ds[:T_uvm]) * L_uvm ) time_per_iter = benchmark_requests( requests_uvm, lambda indices, offsets, per_sample_weights: emb_uvm.forward( indices.long(), offsets.long(), per_sample_weights, ), flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, ) logging.info( f"UVM Forward, B: {B}, " f"E: {E}, T: {T_uvm}, D: {D}, L: {L_uvm}, W: {weighted}, " f"BW: {read_write_bytes_uvm / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 f"T: {time_per_iter * 1.0e6:.0f}us" ) if T_gpu > 0: requests = [] assert requests_gpu is not None for rs_uvm, rs_gpu in zip(requests_uvm, requests_gpu): indices = torch.cat([rs_uvm[0], rs_gpu[0]]) lengths = [L_uvm] * (T_uvm * B) + [L] * (T_gpu * B) offsets = torch.tensor(([0] + np.cumsum(lengths).tolist())).int().cuda() per_sample_weights = None if weighted: assert (this_rs_uvm_weights := rs_uvm[2]) is not None assert (this_rs_gpu_weights := rs_gpu[2]) is not None per_sample_weights = torch.cat( [this_rs_uvm_weights, this_rs_gpu_weights] ) requests.append((indices, offsets, per_sample_weights)) # forward time_per_iter = benchmark_requests( requests_gpu, lambda indices, offsets, per_sample_weights: emb_gpu.forward( indices.long(), offsets.long(), per_sample_weights, ), flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, ) read_write_bytes_hbm = ( output_size_multiplier * B * sum(Ds[T_uvm:]) + param_size_multiplier * B * sum(Ds[T_uvm:]) * L ) logging.info( f"GPU Forward, B: {B}, " f"E: {E}, T: {T_gpu}, D: {D}, L: {L}, W: {weighted}, " f"BW: {read_write_bytes_hbm / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 f"T: {time_per_iter * 1.0e6:.0f}us" ) time_per_iter = benchmark_requests( requests, lambda indices, offsets, per_sample_weights: emb_mixed.forward( indices.long(), offsets.long(), per_sample_weights, ), flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, ) read_write_bytes_total = read_write_bytes_uvm + read_write_bytes_hbm logging.info( f"Mixed Forward, B: {B}, " f"E: {E}, T_GPU: {T_gpu}, T_UVM: {T_uvm}, D: {D}, L_GPU: {L}, L_UVM: {L_uvm}, W: {weighted}, " f"BW: {read_write_bytes_total / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 f"T: {time_per_iter * 1.0e6:.0f}us" ) @cli.command() @click.option("--alpha", default=1.0) @click.option("--bag-size", default=20) @click.option("--batch-size", default=512) @click.option("--cache-algorithm", default="lru") @click.option("--cache-load-factor", default=0.2) @click.option("--embedding-dim", default=128) @click.option("--weights-precision", type=SparseType, default=SparseType.FP32) @click.option("--stoc", is_flag=True, default=False) @click.option("--long-index", is_flag=True, default=False) @click.option("--iters", default=100) @click.option("--mixed", is_flag=True, default=False) @click.option("--num-embeddings", default=int(1e5)) @click.option("--num-tables", default=32) @click.option("--reuse", default=0.1) @click.option("--weighted", is_flag=True, default=False) @click.option("--flush-gpu-cache-size-mb", default=0) @click.option("--requests_data_file", type=str, default=None) @click.option("--tables", type=str, default=None) def cache( # noqa C901 alpha: float, bag_size: int, batch_size: int, cache_algorithm: str, cache_load_factor: float, embedding_dim: int, weights_precision: SparseType, stoc: bool, iters: int, long_index: bool, mixed: bool, num_embeddings: int, num_tables: int, reuse: float, weighted: bool, flush_gpu_cache_size_mb: int, requests_data_file: Optional[str], tables: Optional[str], ) -> None: np.random.seed(42) torch.manual_seed(42) optimizer = OptimType.EXACT_ROWWISE_ADAGRAD B = batch_size D = embedding_dim L = bag_size E = num_embeddings T = num_tables cache_alg = CacheAlgorithm.LRU if cache_algorithm == "lru" else CacheAlgorithm.LFU if mixed: Ds = [ round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4) for _ in range(T) ] D = np.average(Ds) else: Ds = [D] * T emb_nc = SplitTableBatchedEmbeddingBagsCodegen( [ ( E, d, EmbeddingLocation.MANAGED, ComputeDevice.CUDA, ) for d in Ds ], optimizer=optimizer, weights_precision=weights_precision, stochastic_rounding=stoc, ).cuda() if weights_precision == SparseType.INT8: emb_nc.init_embedding_weights_uniform(-0.0003, 0.0003) emb = SplitTableBatchedEmbeddingBagsCodegen( [ ( E, d, EmbeddingLocation.MANAGED_CACHING, ComputeDevice.CUDA, ) for d in Ds ], optimizer=optimizer, weights_precision=weights_precision, stochastic_rounding=stoc, cache_load_factor=cache_load_factor, cache_algorithm=cache_alg, ).cuda() if weights_precision == SparseType.INT8: emb.init_embedding_weights_uniform(-0.0003, 0.0003) nparams = sum(w.numel() for w in emb.split_embedding_weights()) param_size_multiplier = weights_precision.bit_rate() / 8.0 logging.info( f"Embedding tables: {E * T} rows, {nparams / 1.0e9: .2f} GParam, " f"{nparams * param_size_multiplier / 1.0e9: .2f} GB" ) logging.info( f"Accessed weights per batch: {B * T * L} rows, " f"{B * T * L * D * param_size_multiplier / 1.0e9: .2f} GB" ) requests = generate_requests( 2 * iters, B, T, L, E, reuse=reuse, alpha=alpha, weighted=weighted, requests_data_file=requests_data_file, tables=tables, ) warmup_requests, requests = requests[:iters], requests[iters:] grad_output = torch.randn(B, sum(Ds)).cuda() time_per_iter = benchmark_requests( requests, lambda indices, offsets, per_sample_weights: emb_nc( indices.long(), offsets.long(), per_sample_weights ).backward(grad_output), flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, ) logging.info( f"ForwardBackward (UVM), B: {B}, E: {E}, T: {T}, D: {D}, L: {L}, " f"BW: {3 * param_size_multiplier * B * sum(Ds) * L / time_per_iter / 1.0e9: .2f} GB/s, " f"T: {time_per_iter * 1.0e6:.0f}us" ) # warm up for indices, offsets, _ in warmup_requests: emb.forward(indices.long(), offsets.long()) # get cache miss rate (forward and backward) and exchanged cache lines (prefetch) cache_misses = [] exchanged_cache_lines = [] NOT_FOUND = -1 for indices, offsets, _ in requests: # pyre-fixme[29]: # `Union[BoundMethod[typing.Callable(Tensor.clone)[[Named(self, # Variable[torch._TTensor (bound to Tensor)])], Variable[torch._TTensor (bound # to Tensor)]], Tensor], Tensor, torch.nn.Module]` is not a function. old_lxu_cache_state = emb.lxu_cache_state.clone() emb.prefetch(indices.long(), offsets.long()) exchanged_cache_lines.append( # pyre-fixme[16]: `bool` has no attribute `sum`. (emb.lxu_cache_state != old_lxu_cache_state) .sum() .item() ) cache_misses.append((emb.lxu_cache_locations_list[0] == NOT_FOUND).sum().item()) emb.forward(indices.long(), offsets.long()) logging.info( f"Exchanged cache lines -- mean: {sum(exchanged_cache_lines)/len(requests): .2f}, " f"max: {max(exchanged_cache_lines)}, min: {min(exchanged_cache_lines)}" ) logging.info( f"Cache miss -- mean: {sum(cache_misses)/len(requests)}, " f"max: {max(cache_misses)}, min: {min(cache_misses)}" ) # benchmark prefetch emb.reset_cache_states() for indices, offsets, _ in warmup_requests: emb.forward(indices, offsets) prefetch_time, forward_backward_time = benchmark_pipelined_requests( requests, lambda indices, offsets, indices_weights: emb.prefetch(indices, offsets), lambda indices, offsets, indices_weights: emb.forward( indices, offsets, indices_weights ).backward(grad_output), flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, ) e2e_time = prefetch_time + forward_backward_time logging.info( f"ForwardBackward (LXU), reuse: {reuse}, alpha: {alpha}, B: {B}, " f"E: {E}, T: {T}, D: {D}, L: {L}, " f"BW: {3 * param_size_multiplier * B * sum(Ds) * L / e2e_time / 1.0e9: .2f} GB/s, " f"Tprefetch: {prefetch_time * 1.0e6:.0f}us, " f"{2 * sum(exchanged_cache_lines) * param_size_multiplier * D / prefetch_time / len(requests) / 1.0e9: .2f} GB/s, " f"Tfwdbwd: {forward_backward_time * 1.0e6:.0f}us, " f"{3 * param_size_multiplier * B * sum(Ds) * L / forward_backward_time / 1.0e9: .2f} GB/s, " f"Te2e: {e2e_time * 1.0e6:.0f}us, " ) def benchmark_cpu_requests( requests: List[Tuple[torch.IntTensor, torch.IntTensor, Optional[torch.Tensor]]], func: Callable[[Tensor, Tensor, Optional[Tensor]], Tensor], ) -> float: import time start_time = time.perf_counter() for (indices, offsets, weights) in requests: func(indices, offsets, weights) end_time = time.perf_counter() return (end_time - start_time) / len(requests) @cli.command() @click.option("--alpha", default=1.0) @click.option("--bag-size", default=20) @click.option("--batch-size", default=512) @click.option("--embedding-dim", default=128) @click.option("--weights-precision", type=SparseType, default=SparseType.INT4) @click.option("--stoc", is_flag=True, default=False) @click.option("--iters", default=100) @click.option("--managed", default="device") @click.option("--mixed", is_flag=True, default=False) @click.option("--num-embeddings", default=int(1e5)) @click.option("--num-tables", default=32) @click.option("--reuse", default=0.0) @click.option("--row-wise/--no-row-wise", default=True) @click.option("--weighted", is_flag=True, default=False) @click.option("--index-remapping", is_flag=True, default=False) @click.option("--requests_data_file", type=str, default=None) @click.option("--tables", type=str, default=None) @click.option("--output-dtype", type=SparseType, default=SparseType.FP16) def nbit_cpu( # noqa C901 alpha: float, bag_size: int, batch_size: int, embedding_dim: int, weights_precision: SparseType, stoc: bool, iters: int, managed: str, mixed: bool, num_embeddings: int, num_tables: int, reuse: float, row_wise: bool, weighted: bool, index_remapping: bool, requests_data_file: Optional[str], tables: Optional[str], output_dtype: SparseType, ) -> None: np.random.seed(42) torch.manual_seed(42) B = batch_size D = embedding_dim L = bag_size E = num_embeddings T = num_tables if mixed: Ds = [ # int4 table batched emb op can only handle mixed D where D is multiple of 8 round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 8) for _ in range(T) ] D = np.average(Ds) else: Ds = [D] * T emb = IntNBitTableBatchedEmbeddingBagsCodegen( [("", E, d, weights_precision, EmbeddingLocation.HOST) for d in Ds], device="cpu", index_remapping=[torch.arange(E) for _ in Ds] if index_remapping else None, output_dtype=output_dtype, ).cpu() emb.fill_random_weights() nparams_byte = sum(w.numel() for (w, _) in emb.split_embedding_weights()) param_size_multiplier = weights_precision.bit_rate() / 8.0 output_size_multiplier = output_dtype.bit_rate() / 8.0 read_write_bytes = ( output_size_multiplier * B * T * D + param_size_multiplier * B * T * L * D ) logging.info( f"{weights_precision} Embedding tables: {E * T} rows, {nparams_byte / param_size_multiplier / 1.0e9: .2f} GParam, " f"{nparams_byte / 1.0e9: .2f} GB" # IntN TBE use byte for storage ) logging.info( f"Accessed weights per batch: {B * T * L} rows, " f"{B * T * L * D * param_size_multiplier / 1.0e9: .2f} GB" ) requests = generate_requests( iters, B, T, L, E, reuse=reuse, alpha=alpha, weights_precision=weights_precision, weighted=weighted, requests_data_file=requests_data_file, tables=tables, ) requests = [ (a.cpu().int(), b.cpu().int(), c.cpu() if c else None) for (a, b, c) in requests ] time_per_iter = benchmark_cpu_requests( requests, lambda indices, offsets, per_sample_weights: emb.forward( indices, offsets, per_sample_weights, ), ) logging.info( f"{weights_precision} Forward, B: {B}, " f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, " f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 f"T: {time_per_iter * 1.0e6:.0f}us" ) @cli.command() @click.option("--alpha", default=1.0) @click.option("--bag-size", default=20) @click.option("--batch-size", default=512) @click.option("--embedding-dim", default=128) @click.option("--weights-precision", type=SparseType, default=SparseType.INT4) @click.option("--stoc", is_flag=True, default=False) @click.option("--managed", default="device") @click.option("--mixed", is_flag=True, default=False) @click.option("--num-embeddings", default=int(1e5)) @click.option("--num-tables", default=32) @click.option("--reuse", default=0.0) @click.option("--row-wise/--no-row-wise", default=True) @click.option("--weighted", is_flag=True, default=False) @click.option("--pooling", type=str, default="sum") @click.option("--weighted-num-requires-grad", type=int, default=None) @click.option("--bounds-check-mode", type=int, default=BoundsCheckMode.WARNING.value) @click.option("--pruning-ratio", type=float, default=None) @click.option("--load-factor", default=0.75) @click.option("--use-array-for-index-remapping", is_flag=True, default=True) @click.option("--check-median", is_flag=True, default=True) @click.option("--iters", default=100) @click.option("--runs-of-iters", default=5) @click.option("--warmup-runs", default=2) @click.option("--output-dtype", type=SparseType, default=SparseType.FP16) @click.option("--report-aibench", is_flag=True) @click.option("--run-reference", is_flag=True, default=False) @click.option("--requests_data_file", type=str, default=None) @click.option("--tables", type=str, default=None) def nbit_device( # noqa C901 alpha: float, bag_size: int, batch_size: int, embedding_dim: int, weights_precision: SparseType, stoc: bool, managed: str, mixed: bool, num_embeddings: int, num_tables: int, reuse: float, row_wise: bool, weighted: bool, pooling: str, weighted_num_requires_grad: Optional[int], bounds_check_mode: int, pruning_ratio: Optional[float], load_factor: float, use_array_for_index_remapping: bool, check_median: bool, iters: int, runs_of_iters: int, warmup_runs: int, output_dtype: SparseType, report_aibench: bool, run_reference: bool, requests_data_file: Optional[str], tables: Optional[str], ) -> None: np.random.seed(42) torch.manual_seed(42) B = batch_size D = embedding_dim L = bag_size E = num_embeddings original_E = E T = num_tables index_remapping = None if mixed: # int4 table batched emb op can only handle mixed D where D is multiple of 8 Ds = [ round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 8) for _ in range(T) ] D = np.average(Ds) else: Ds = [D] * T mem_for_pruning = 0 if pruning_ratio: assert pruning_ratio < 1 and pruning_ratio >= 0 E = math.ceil(E * (1.0 - pruning_ratio)) index_remapping = [] for _ in range(T): mapping = torch.tensor([-1] * original_E, dtype=torch.int32) selected_indices = random.sample(range(original_E), E) for i, idx in enumerate(selected_indices): mapping[idx] = i index_remapping.append(mapping) if use_array_for_index_remapping: mem_for_pruning += mapping.numel() * 4 else: mem_for_pruning += E / load_factor * 2 * 4 if managed == "device": managed_option = EmbeddingLocation.DEVICE else: managed_option = EmbeddingLocation.MANAGED if pooling is None or pooling == "sum": pooling = "sum" pooling_mode = PoolingMode.SUM do_pooling = True elif pooling == "mean": pooling_mode = PoolingMode.MEAN do_pooling = True else: # "none" pooling_mode = PoolingMode.NONE do_pooling = False emb = IntNBitTableBatchedEmbeddingBagsCodegen( [("", E, d, weights_precision, managed_option) for d in Ds], bounds_check_mode=BoundsCheckMode(bounds_check_mode), index_remapping=index_remapping, load_factor=load_factor, use_array_for_index_remapping=use_array_for_index_remapping, output_dtype=output_dtype, pooling_mode=pooling_mode, ).cuda() emb.fill_random_weights() nparams_byte = sum(w.numel() for (w, _) in emb.split_embedding_weights()) param_size_multiplier = weights_precision.bit_rate() / 8.0 output_size_multiplier = output_dtype.bit_rate() / 8.0 if do_pooling: read_write_bytes = ( output_size_multiplier * B * T * D + param_size_multiplier * B * T * L * D ) else: read_write_bytes = ( output_size_multiplier * B * T * L * D + param_size_multiplier * B * T * L * D ) logging.info( f"{weights_precision} Embedding tables: {E * T} rows, {nparams_byte / param_size_multiplier / 1.0e9: .2f} GParam, " f"{nparams_byte / 1.0e9: .2f} GB" # IntN TBE use byte for storage ) logging.info( f"Accessed weights per batch: {B * T * L} rows, " f"{B * T * L * D * param_size_multiplier / 1.0e9: .2f} GB" ) times = [] for i in range(runs_of_iters): requests = generate_requests( iters, B, T, L, E, reuse=reuse, alpha=alpha, weights_precision=weights_precision, weighted=weighted, requests_data_file=requests_data_file, tables=tables, ) requests = [(a.int(), b.int(), c if c else None) for (a, b, c) in requests] # forward time_per_iter = benchmark_requests( requests, lambda indices, offsets, per_sample_weights: emb.forward( indices.int(), offsets.int(), per_sample_weights, ), check_median=check_median, ) # free up GPU memory del requests logging.info( f"Iteration {i}: " f"{weights_precision} Forward, B: {B}, " f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, " f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 f"Time: {time_per_iter * 1.0e6:.0f}us, " f"Memory Usage For Pruning: {mem_for_pruning / 1.0e9:.0f} GB" ) if i >= warmup_runs: times.append(time_per_iter) time_per_iter = statistics.mean(times) bandwidth = read_write_bytes / time_per_iter / 1.0e9 logging.info( f"Average of all iterations: " f"{weights_precision} Forward, B: {B}, " f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, " f"BW: {bandwidth: .2f} GB/s, " # noqa: B950 f"Time: {time_per_iter * 1.0e6:.0f}us, " f"Memory Usage For Pruning: {mem_for_pruning / 1.0e9:.0f} GB" ) if report_aibench and haveAIBench: print( emitMetric( type="NET", metric=f"bandwidth_{weights_precision}", unit="scalar", value=str(bandwidth), ) ) print( emitMetric( type="NET", metric=f"time_per_iter_{weights_precision}", unit="scalar", value=str(time_per_iter * 1.0e6), ) ) if run_reference: times = [] for i in range(runs_of_iters): requests = generate_requests( iters, B, T, L, E, reuse=reuse, alpha=alpha, weights_precision=weights_precision, weighted=weighted, requests_data_file=requests_data_file, tables=tables, ) requests = [(a.int(), b.int(), c if c else None) for (a, b, c) in requests] # forward time_per_iter_refer = benchmark_requests_refer( requests, T, B, L, E, D, pooling, weighted, check_median=check_median, ) # free up GPU memory del requests logging.info( f"Reference (nn.Embedding(Bag)) Iteration {i}: " f"Forward, B: {B}, " f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, " f"BW: {read_write_bytes / time_per_iter_refer / 1.0e9: .2f} GB/s, " # noqa: B950 f"Time: {time_per_iter_refer * 1.0e6:.0f}us " ) if i >= warmup_runs: times.append(time_per_iter_refer) time_per_iter_refer = statistics.mean(times) bandwidth = read_write_bytes / time_per_iter_refer / 1.0e9 logging.info( f"Average of all iterations: " f"Forward, B: {B}, " f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, " f"Effective BW: {bandwidth: .2f} GB/s, " # noqa: B950 f"Time: {time_per_iter_refer * 1.0e6:.0f}us " ) @cli.command() @click.option("--alpha", default=1.0) @click.option("--bag-size", default=20) @click.option("--batch-size", default=512) @click.option("--embedding-dim", default=128) @click.option("--weights-precision", type=SparseType, default=SparseType.INT4) @click.option("--iters", default=100) @click.option("--mixed", is_flag=True, default=False) @click.option("--num-embeddings", default=int(1e5)) @click.option("--num-tables", default=32) @click.option("--reuse", default=0.1) @click.option("--uvm-num-embeddings", default=int(1e5)) @click.option("--uvm-tables", default=1) @click.option("--uvm-bag-size", default=1) @click.option("--weighted", is_flag=True, default=False) @click.option("--flush-gpu-cache-size-mb", default=0) @click.option("--output-dtype", type=SparseType, default=SparseType.FP16) @click.option("--use-cache", is_flag=True, default=False) @click.option("--cache-algorithm", default="lru") @click.option("--cache-load-factor", default=0.2) @click.option("--enforce-hbm", is_flag=True, default=False) def nbit_uvm( alpha: bool, bag_size: int, batch_size: int, embedding_dim: int, weights_precision: SparseType, iters: int, mixed: bool, num_embeddings: int, num_tables: int, reuse: float, uvm_num_embeddings: int, uvm_tables: int, uvm_bag_size: int, weighted: bool, flush_gpu_cache_size_mb: int, output_dtype: SparseType, use_cache: bool, cache_algorithm: str, cache_load_factor: float, enforce_hbm: bool, ) -> None: np.random.seed(42) torch.manual_seed(42) B = batch_size D = embedding_dim L = bag_size E = num_embeddings E_uvm = uvm_num_embeddings T = num_tables T_uvm = uvm_tables assert T_uvm <= T assert ( T_uvm > 0 ), f"T_uvm specified {T_uvm} <= 0. If not testing UVM, please use device benchmark." T_gpu = T - T_uvm L_uvm = uvm_bag_size cache_alg = CacheAlgorithm.LRU if cache_algorithm == "lru" else CacheAlgorithm.LFU managed_type = ( EmbeddingLocation.MANAGED_CACHING if use_cache else EmbeddingLocation.MANAGED ) logging.info(f"T: {T}, T_uvm: {T_uvm}, T_gpu: {T_gpu}") if mixed: Ds = [ round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4) for _ in range(T) ] D = np.average(Ds) else: Ds = [D] * T emb_uvm = IntNBitTableBatchedEmbeddingBagsCodegen( [ ( "", E_uvm, d, weights_precision, managed_type, ) for d in Ds[:T_uvm] ], output_dtype=output_dtype, cache_load_factor=cache_load_factor, cache_algorithm=cache_alg, enforce_hbm=enforce_hbm, ).cuda() emb_uvm.fill_random_weights() if T_gpu > 0: emb_gpu = IntNBitTableBatchedEmbeddingBagsCodegen( [ ( "", E, d, weights_precision, EmbeddingLocation.DEVICE, ) for d in Ds[T_uvm:] ], output_dtype=output_dtype, ).cuda() emb_gpu.fill_random_weights() emb_mixed = IntNBitTableBatchedEmbeddingBagsCodegen( [ ( "", e, d, weights_precision, managed_option, ) for (e, d, managed_option) in zip( [E_uvm] * T_uvm + [E] * T_gpu, Ds, [managed_type] * T_uvm + [EmbeddingLocation.DEVICE] * T_gpu, ) ], output_dtype=output_dtype, cache_load_factor=cache_load_factor, cache_algorithm=cache_alg, enforce_hbm=enforce_hbm, ).cuda() emb_mixed.fill_random_weights() requests_uvm = generate_requests( iters, B, T_uvm, L_uvm, E_uvm, reuse=reuse, alpha=alpha, weights_precision=weights_precision, weighted=weighted, ) requests_uvm = [(a.int(), b.int(), c if c else None) for (a, b, c) in requests_uvm] requests_gpu = None if T_gpu > 0: requests_gpu = generate_requests( iters, B, T_gpu, L, E, reuse=reuse, alpha=alpha, weights_precision=weights_precision, weighted=False, ) requests_gpu = [ (a.int(), b.int(), c if c else None) for (a, b, c) in requests_gpu ] param_size_multiplier = weights_precision.bit_rate() / 8.0 output_size_multiplier = output_dtype.bit_rate() / 8.0 read_write_bytes_uvm = ( output_size_multiplier * B * sum(Ds[:T_uvm]) + param_size_multiplier * B * sum(Ds[:T_uvm]) * L_uvm ) if T_gpu > 0: nparams_byte = sum(w.numel() for (w, _) in emb_mixed.split_embedding_weights()) logging.info( f"{weights_precision} Embedding tables: {E * T + E_uvm * T_uvm} rows, {nparams_byte / param_size_multiplier / 1.0e9: .2f} GParam, " f"{nparams_byte / 1.0e9: .2f} GB" # IntN TBE use byte for storage ) logging.info( f"Accessed weights per batch: {B * (T * L + T_uvm * L_uvm)} rows, " f"{B * (T * L * sum(Ds[T_uvm:]) + T_uvm * L_uvm * sum(Ds[:T_uvm])) * param_size_multiplier / 1.0e9: .2f} GB" ) time_per_iter = benchmark_requests( requests_uvm, lambda indices, offsets, per_sample_weights: emb_uvm.forward( indices.int(), offsets.int(), per_sample_weights, ), flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, ) logging.info( f"UVM NBit Forward, {weights_precision}, B: {B}, " f"E_uvm: {E_uvm}, T: {T_uvm}, D: {D}, L: {L_uvm}, W: {weighted}, " f"BW: {read_write_bytes_uvm / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 f"Time: {time_per_iter * 1.0e6:.0f}us" ) if T_gpu > 0: requests = [] assert requests_gpu is not None for rs_uvm, rs_gpu in zip(requests_uvm, requests_gpu): indices = torch.cat([rs_uvm[0], rs_gpu[0]]) lengths = [L_uvm] * (T_uvm * B) + [L] * (T_gpu * B) offsets = torch.tensor(([0] + np.cumsum(lengths).tolist())).int().cuda() per_sample_weights = None if weighted: assert (this_rs_uvm_weights := rs_uvm[2]) is not None assert (this_rs_gpu_weights := rs_gpu[2]) is not None per_sample_weights = torch.cat( [this_rs_uvm_weights, this_rs_gpu_weights] ) requests.append((indices, offsets, per_sample_weights)) # forward time_per_iter = benchmark_requests( requests_gpu, lambda indices, offsets, per_sample_weights: emb_gpu.forward( indices.int(), offsets.int(), per_sample_weights, ), flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, ) read_write_bytes_hbm = ( output_size_multiplier * B * sum(Ds[T_uvm:]) + param_size_multiplier * B * sum(Ds[T_uvm:]) * L ) logging.info( f"GPU NBit Forward, {weights_precision}, B: {B}, " f"E: {E}, T: {T_gpu}, D: {D}, L: {L}, W: {weighted}, " f"BW: {read_write_bytes_hbm / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 f"Time: {time_per_iter * 1.0e6:.0f}us" ) time_per_iter = benchmark_requests( requests, lambda indices, offsets, per_sample_weights: emb_mixed.forward( indices.int(), offsets.int(), per_sample_weights, ), flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, ) read_write_bytes_total = read_write_bytes_uvm + read_write_bytes_hbm logging.info( f"Mixed NBit Forward, {weights_precision}, B: {B}, " f"E_GPU: {E}, E_UVM: {E_uvm}, T_GPU: {T_gpu}, T_UVM: {T_uvm}, D: {D}, L_GPU: {L}, L_UVM: {L_uvm}, W: {weighted}, " f"BW: {read_write_bytes_total / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 f"Time: {time_per_iter * 1.0e6:.0f}us" ) # benchmark prefetch emb_mixed.reset_cache_states() for indices, offsets, _ in requests: emb_mixed.forward(indices, offsets) prefetch_time, forward_time = benchmark_pipelined_requests( requests, lambda indices, offsets, indices_weights: emb_mixed.prefetch( indices, offsets, ), # pyre-fixme[6]: Expected `(Tensor, Tensor, Optional[Tensor]) -> None` for # 3rd param but got `(indices: Any, offsets: Any, indices_weights: Any) -> # Tensor`. lambda indices, offsets, indices_weights: emb_mixed.forward( indices, offsets, indices_weights, ), flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, ) e2e_time = prefetch_time + forward_time logging.info( f"Forward(LXU) {weights_precision}, reuse: {reuse}, alpha: {alpha}, B: {B}, " f"E: {E}, T: {T}, D: {D}, L: {L}, " f"Te2e: {e2e_time * 1.0e6:.0f}us, " f"e2e BW: {read_write_bytes_total / e2e_time / 1.0e9: .2f} GB/s, " f"Tprefetch: {prefetch_time * 1.0e6:.0f}us, " f"TfwdTime: {forward_time * 1.0e6:.0f}us, " f"{read_write_bytes_total / forward_time / 1.0e9: .2f} GB/s" ) @cli.command() @click.option("--alpha", default=1.0) @click.option("--bag-size", default=20) @click.option("--batch-size", default=512) @click.option("--cache-algorithm", default="lru") @click.option("--cache-load-factor", default=0.2) @click.option("--embedding-dim", default=128) @click.option("--weights-precision", type=SparseType, default=SparseType.INT4) @click.option("--iters", default=100) @click.option("--mixed", is_flag=True, default=False) @click.option("--num-embeddings", default=int(1e5)) @click.option("--num-tables", default=32) @click.option("--reuse", default=0.1) @click.option("--weighted", is_flag=True, default=False) @click.option("--flush-gpu-cache-size-mb", default=0) @click.option("--output-dtype", type=SparseType, default=SparseType.FP16) @click.option("--enforce-hbm", is_flag=True, default=False) def nbit_cache( # noqa C901 alpha: float, bag_size: int, batch_size: int, cache_algorithm: str, cache_load_factor: float, embedding_dim: int, weights_precision: SparseType, iters: int, mixed: bool, num_embeddings: int, num_tables: int, reuse: float, weighted: bool, flush_gpu_cache_size_mb: int, output_dtype: SparseType, enforce_hbm: bool, ) -> None: np.random.seed(42) torch.manual_seed(42) B = batch_size D = embedding_dim L = bag_size E = num_embeddings T = num_tables cache_alg = CacheAlgorithm.LRU if cache_algorithm == "lru" else CacheAlgorithm.LFU if mixed: Ds = [ round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4) for _ in range(T) ] D = np.average(Ds) else: Ds = [D] * T emb_nc = IntNBitTableBatchedEmbeddingBagsCodegen( [ ( "", E, d, weights_precision, EmbeddingLocation.MANAGED, ) for d in Ds ], output_dtype=output_dtype, enforce_hbm=enforce_hbm, ).cuda() emb_nc.fill_random_weights() emb = IntNBitTableBatchedEmbeddingBagsCodegen( [ ( "", E, d, weights_precision, EmbeddingLocation.MANAGED_CACHING, ) for d in Ds ], cache_load_factor=cache_load_factor, cache_algorithm=cache_alg, output_dtype=output_dtype, enforce_hbm=enforce_hbm, ).cuda() emb.fill_random_weights() nparams_byte = sum(w.numel() for (w, _) in emb.split_embedding_weights()) param_size_multiplier = weights_precision.bit_rate() / 8.0 output_size_multiplier = output_dtype.bit_rate() / 8.0 read_write_bytes = ( output_size_multiplier * B * sum(Ds) + param_size_multiplier * B * sum(Ds) * L ) logging.info( f"{weights_precision} Embedding tables: {E * T} rows, {nparams_byte / param_size_multiplier / 1.0e9: .2f} GParam, " f"{nparams_byte / 1.0e9: .2f} GB" # IntN TBE use byte for storage ) logging.info( f"Accessed weights per batch: {B * T * L} rows, " f"{B * T * L * D * param_size_multiplier / 1.0e9: .2f} GB" ) requests = generate_requests( 2 * iters, B, T, L, E, reuse=reuse, alpha=alpha, weighted=weighted ) requests = [(a.int(), b.int(), c if c else None) for (a, b, c) in requests] warmup_requests, requests = requests[:iters], requests[iters:] time_per_iter = benchmark_requests( requests, lambda indices, offsets, per_sample_weights: emb_nc( indices.int(), offsets.int(), per_sample_weights ), flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, ) logging.info( f"Forward (UVM) {weights_precision}, B: {B}, E: {E}, T: {T}, D: {D}, L: {L}, " f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 f"T: {time_per_iter * 1.0e6:.0f}us" ) # exchanged_cache_lines = [100] # warm up for indices, offsets, _ in warmup_requests: emb.forward(indices.int(), offsets.int()) # get cache miss rate (forward only) and exchanged cache lines (prefetch) cache_misses = [] exchanged_cache_lines = [] NOT_FOUND = -1 for indices, offsets, _ in requests: # pyre-fixme[29]: # `Union[BoundMethod[typing.Callable(Tensor.clone)[[Named(self, # Variable[torch._TTensor (bound to Tensor)])], Variable[torch._TTensor (bound # to Tensor)]], Tensor], Tensor, torch.nn.Module]` is not a function. old_lxu_cache_state = emb.lxu_cache_state.clone() emb.prefetch(indices, offsets) exchanged_cache_lines.append( # pyre-fixme[16]: `bool` has no attribute `sum`. (emb.lxu_cache_state != old_lxu_cache_state) .sum() .item() ) cache_misses.append( (emb.lxu_cache_locations_list.top() == NOT_FOUND).sum().item() ) emb.forward(indices, offsets) logging.info( f"Exchanged cache lines -- mean: {sum(exchanged_cache_lines)/len(requests): .2f}, " f"max: {max(exchanged_cache_lines)}, min: {min(exchanged_cache_lines)}" ) logging.info( f"Cache miss -- mean: {sum(cache_misses)/len(requests)}, " f"max: {max(cache_misses)}, min: {min(cache_misses)}" ) # benchmark prefetch emb.reset_cache_states() for indices, offsets, _ in warmup_requests: emb.forward(indices, offsets) prefetch_time, forward_time = benchmark_pipelined_requests( requests, lambda indices, offsets, indices_weights: emb.prefetch( indices, offsets, ), # pyre-fixme[6]: Expected `(Tensor, Tensor, Optional[Tensor]) -> None` for # 3rd param but got `(indices: Any, offsets: Any, indices_weights: Any) -> # Tensor`. lambda indices, offsets, indices_weights: emb.forward( indices, offsets, indices_weights, ), flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, ) e2e_time = prefetch_time + forward_time logging.info( f"Forward(LXU) {weights_precision}, reuse: {reuse}, alpha: {alpha}, B: {B}, " f"E: {E}, T: {T}, D: {D}, L: {L}, " f"Te2e: {e2e_time * 1.0e6:.0f}us, " f"e2e BW: {read_write_bytes / e2e_time / 1.0e9: .2f} GB/s, " f"Tprefetch: {prefetch_time * 1.0e6:.0f}us, " f"{2 * sum(exchanged_cache_lines) * param_size_multiplier * D / prefetch_time / len(requests) / 1.0e9: .2f} GB/s, " f"TfwdTime: {forward_time * 1.0e6:.0f}us, " f"{read_write_bytes / forward_time / 1.0e9: .2f} GB/s" ) @cli.command() @click.option("--bag-size", default=20) @click.option("--batch-size", default=2048) @click.option("--iters", default=10) @click.option("--num-embeddings", default=int(1e5)) @click.option("--num-tables", default=100) @click.option("--load-factor", default=0.75) @click.option("--hit-rate", default=0.9) @click.option("--use-cpu", is_flag=True, default=False) @click.option("--requests_data_file", type=str, default=None) @click.option("--tables", type=str, default=None) def hashtable( # noqa C901 bag_size: int, batch_size: int, iters: int, num_embeddings: int, num_tables: int, load_factor: float, hit_rate: float, use_cpu: bool, requests_data_file: Optional[str], tables: Optional[str], ) -> None: B = batch_size T = num_tables L = bag_size E = num_embeddings np.random.seed(42) torch.manual_seed(42) if hit_rate == 1.0: chosen_indices = torch.cat([torch.arange(E) for _ in range(T)], dim=0).int() else: chosen_indices = ( torch.randint(low=0, high=int(E * 1.0 / hit_rate), size=(E * T,)) .view(-1) .int() ) dense_indices = torch.cat([torch.arange(E) for _ in range(T)], dim=0).int() offsets = torch.tensor([E * t for t in range(T + 1)]).int() assert offsets[-1] == chosen_indices.numel() assert offsets.numel() == T + 1 assert (offsets.numel() - 1) // T == 1 capacities = [round_up(int(E / load_factor), 32) for _ in range(T)] hash_table = torch.zeros( (sum(capacities), 2), dtype=torch.int32, ) hash_table_offsets = torch.tensor([0] + np.cumsum(capacities).tolist()).long() assert hash_table.numel() * 4 < 2 ** 32 # initialize hash_table[:, :] = -1 torch.ops.fbgemm.pruned_hashmap_insert( chosen_indices, dense_indices, offsets, hash_table, hash_table_offsets ) requests = generate_requests( iters, B, T, L, E, requests_data_file=requests_data_file, tables=tables, ) if not use_cpu: hash_table = hash_table.cuda() hash_table_offsets = hash_table_offsets.cuda() requests = [(a.cuda().int(), b.cuda().int(), c) for (a, b, c) in requests] else: requests = [(a.int().cpu(), b.int().cpu(), c) for (a, b, c) in requests] empirical_hit_rate = np.mean( [ torch.ops.fbgemm.pruned_hashmap_lookup( indices, offsets, hash_table, hash_table_offsets ) .ne(-1) .sum() .item() / indices.numel() for indices, offsets, _ in requests ] ) time_per_iter = benchmark_requests( requests, lambda indices, offsets, _: torch.ops.fbgemm.pruned_hashmap_lookup( indices, offsets, hash_table, hash_table_offsets ), ) logging.info( f"LinearTable: B: {B}, T: {T}, L: {L}, E: {E}, QPS: {B * T * L / time_per_iter / 1.0e9:.2f}B QPS/s, " f"T: {time_per_iter * 1.0e6:.0f}us, load factor: {E * T / hash_table.shape[0] * 100:.1f}%, hit rate: {empirical_hit_rate * 100:.2f}%, Table size: {hash_table.numel() * 4 / 1.0e9:.0f} GB" ) if use_cpu: ht = torch.classes.fb.PrunedMapCPU() ht.insert(chosen_indices, dense_indices, offsets, T) time_per_iter = benchmark_requests( requests, lambda indices, offsets, _: ht.lookup(indices, offsets), ) logging.info( f"HashTable: B: {B}, T: {T}, L: {L}, E: {E}, QPS: {B * T * L / time_per_iter / 1.0e9:.2f}B QPS/s, " f"T: {time_per_iter * 1.0e6:.0f}us, load factor: {E * T / hash_table.shape[0] * 100:.1f}%, hit rate: {empirical_hit_rate * 100:.2f}%, Table size: {hash_table.numel() * 4 / 1.0e9:.0f} GB" ) @cli.command() @click.option("--bag-size", default=20) @click.option("--batch-size", default=2048) @click.option("--iters", default=100) @click.option("--num-embeddings", default=int(1e5)) @click.option("--num-tables", default=100) @click.option("--pruning-ratio", default=0.9) @click.option("--requests_data_file", type=str, default=None) @click.option("--tables", type=str, default=None) def pruned_array( # noqa C901 bag_size: int, batch_size: int, iters: int, num_embeddings: int, num_tables: int, pruning_ratio: float, requests_data_file: Optional[str], tables: Optional[str], ) -> None: B = batch_size T = num_tables L = bag_size E = num_embeddings np.random.seed(42) torch.manual_seed(42) assert pruning_ratio > 0 and pruning_ratio <= 1 original_E = int(E / (1.0 - pruning_ratio)) index_remappings = torch.tensor( [-1] * original_E * T, dtype=torch.int32, device="cuda" ) index_remappings_offsets = torch.empty(T + 1, dtype=torch.int32, device="cuda") index_remappings_offsets[0] = 0 dense_indices = torch.tensor(range(E), dtype=torch.int32, device="cuda") for t in range(T): selected_indices = torch.add( torch.randperm(original_E, device="cuda"), t * original_E )[:E] index_remappings[selected_indices] = dense_indices index_remappings_offsets[t + 1] = index_remappings_offsets[t] + original_E requests = generate_requests( iters, B, T, L, E, requests_data_file=requests_data_file, tables=tables, ) requests = [(a.cuda().int(), b.cuda().int(), c) for (a, b, c) in requests] time_per_iter = benchmark_requests( requests, lambda indices, offsets, _: torch.ops.fbgemm.pruned_array_lookup( indices, offsets, index_remappings, index_remappings_offsets, ), ) logging.info( f"LinearTable: B: {B}, T: {T}, L: {L}, E: {E}, QPS: {B * T * L / time_per_iter / 1.0e9:.2f}B QPS/s, " f"T: {time_per_iter * 1.0e6:.0f}us, Pruning Ratio: {pruning_ratio * 100:.2f}%, Table size: {original_E * T * 4 / 1.0e9:.0f} GB" ) @cli.command() @click.option("--bag-size", default=20) @click.option("--batch-size", default=512) @click.option("--iters", default=100) @click.option("--num-embeddings", default=int(1e5)) @click.option("--num-tables", default=32) @click.option("--bounds-check-mode", type=int, default=BoundsCheckMode.WARNING.value) @click.option("--requests_data_file", type=str, default=None) @click.option("--tables", type=str, default=None) def bounds_check_indices( # noqa C901 bag_size: int, batch_size: int, iters: int, num_embeddings: int, num_tables: int, bounds_check_mode: int, requests_data_file: Optional[str], tables: Optional[str], ) -> None: np.random.seed(42) torch.manual_seed(42) B = batch_size L = bag_size E = num_embeddings T = num_tables requests = generate_requests( iters, B, T, L, E, requests_data_file=requests_data_file, tables=tables, ) # requests = [(a.int(), b.int(), c if c else None) for (a, b, c) in requests] warning = torch.tensor([0]).long().to(get_device()) rows_per_table = torch.tensor([E for _ in range(T)]).long().to(get_device()) # forward time_per_iter = benchmark_requests( requests, lambda indices, offsets, _: torch.ops.fbgemm.bounds_check_indices( rows_per_table, indices, offsets, BoundsCheckMode(bounds_check_mode), warning, ), ) logging.info( f"Bounds Check Indices: B: {B}, " f"E: {E}, T: {T}, L: {L}, " f"BW: {(8 * B * T * L + 8 * (B * T + 1)) / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 f"T: {time_per_iter * 1.0e6:.0f}us" ) if __name__ == "__main__": cli()