fbgemm_gpu/bench/sparse_ops_benchmark.py (52 lines of code) (raw):
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import random
import click
import fbgemm_gpu
import torch
logging.basicConfig(level=logging.DEBUG)
open_source: bool = getattr(fbgemm_gpu, "open_source", False)
if open_source:
# pyre-ignore[21]
from bench_utils import benchmark_torch_function
else:
from fbgemm_gpu.bench.bench_utils import benchmark_torch_function
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
@click.group()
def cli() -> None:
pass
@cli.command()
@click.option("--world-size", default=128)
@click.option("--num-tables", default=10)
@click.option("--min-len", default=10000)
@click.option("--max-len", default=20000)
def device(
world_size: int,
num_tables: int,
min_len: int,
max_len: int,
) -> None:
lengths = torch.randint(min_len, max_len, size=(num_tables * world_size,))
offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
permute = list(range(num_tables * world_size))
random.shuffle(permute)
permute_tensor = torch.tensor(permute)
permuted_length = torch.index_select(lengths, 0, permute_tensor)
permuted_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(permuted_length)
jagged_size = offsets[-1]
if torch.cuda.is_available():
permute_tensor = permute_tensor.cuda()
offsets = offsets.cuda()
permuted_offsets = permuted_offsets.cuda()
time, output = benchmark_torch_function(
torch.ops.fbgemm.expand_into_jagged_permute,
(permute_tensor, offsets, permuted_offsets, jagged_size),
)
num_bytes = (
permute_tensor.numel() * permute_tensor.element_size()
+ offsets.numel() * offsets.element_size()
+ permuted_offsets.numel() * permuted_offsets.element_size()
+ output.numel() * output.element_size()
)
logging.info(f"expand_into_jagged_permute {time} sec {num_bytes / time / 1e9} GB/s")
if __name__ == "__main__":
cli()