fbgemm_gpu/bench/jagged_tensor_benchmark.py (58 lines of code) (raw):

# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import logging import click import fbgemm_gpu import torch logging.basicConfig(level=logging.DEBUG) open_source: bool = getattr(fbgemm_gpu, "open_source", False) if open_source: # pyre-ignore[21] from bench_utils import benchmark_torch_function else: from fbgemm_gpu.bench.bench_utils import benchmark_torch_function torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") @click.group() def cli() -> None: pass @cli.command() @click.option("--batch-size", default=128) @click.option("--embedding-dim", default=128) @click.option("--max-len", default=128) def device( batch_size: int, embedding_dim: int, max_len: int, ) -> None: lengths = torch.randint(max_len, size=(batch_size,)) total_lengths = lengths.sum().item() offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) values_2d = torch.rand(total_lengths, embedding_dim) if torch.cuda.is_available(): offsets = offsets.cuda() values_2d = values_2d.cuda() time, output = benchmark_torch_function( torch.ops.fbgemm.jagged_2d_to_dense, (values_2d, offsets, max_len), ) num_bytes = ( offsets.numel() * offsets.element_size() + values_2d.numel() * values_2d.element_size() + output.numel() * output.element_size() ) logging.info(f"jagged_2d_to_dense {time} sec {num_bytes / time / 1e9} GB/s") values_1d = torch.rand(total_lengths) if torch.cuda.is_available(): values_1d = values_1d.cuda() time, output = benchmark_torch_function( lambda: torch.ops.fbgemm.jagged_1d_to_dense( values_1d, offsets, max_len, padding_value=0 ), (), ) num_bytes = ( offsets.numel() * offsets.element_size() + values_1d.numel() * values_1d.element_size() + output.numel() * output.element_size() ) logging.info(f"jagged_1d_to_dense {time} sec {num_bytes / time / 1e9} GB/s") if __name__ == "__main__": cli()