# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging

import click
import fbgemm_gpu
import torch

logging.basicConfig(level=logging.DEBUG)

open_source: bool = getattr(fbgemm_gpu, "open_source", False)

if open_source:
    # pyre-ignore[21]
    from bench_utils import benchmark_torch_function
else:
    from fbgemm_gpu.bench.bench_utils import benchmark_torch_function

    torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
    torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")


@click.group()
def cli() -> None:
    pass


@cli.command()
@click.option("--batch-size", default=128)
@click.option("--embedding-dim", default=128)
@click.option("--max-len", default=128)
def device(
    batch_size: int,
    embedding_dim: int,
    max_len: int,
) -> None:
    lengths = torch.randint(max_len, size=(batch_size,))
    total_lengths = lengths.sum().item()
    offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
    values_2d = torch.rand(total_lengths, embedding_dim)

    if torch.cuda.is_available():
        offsets = offsets.cuda()
        values_2d = values_2d.cuda()

    time, output = benchmark_torch_function(
        torch.ops.fbgemm.jagged_2d_to_dense,
        (values_2d, offsets, max_len),
    )

    num_bytes = (
        offsets.numel() * offsets.element_size()
        + values_2d.numel() * values_2d.element_size()
        + output.numel() * output.element_size()
    )
    logging.info(f"jagged_2d_to_dense {time} sec {num_bytes / time / 1e9} GB/s")

    values_1d = torch.rand(total_lengths)
    if torch.cuda.is_available():
        values_1d = values_1d.cuda()

    time, output = benchmark_torch_function(
        lambda: torch.ops.fbgemm.jagged_1d_to_dense(
            values_1d, offsets, max_len, padding_value=0
        ),
        (),
    )

    num_bytes = (
        offsets.numel() * offsets.element_size()
        + values_1d.numel() * values_1d.element_size()
        + output.numel() * output.element_size()
    )
    logging.info(f"jagged_1d_to_dense {time} sec {num_bytes / time / 1e9} GB/s")


if __name__ == "__main__":
    cli()
