#!/usr/bin/env python3

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import itertools
import logging
import math
import random
import statistics
import time
from typing import Callable, List, Optional, Tuple

import click
import numpy as np
import torch

haveAIBench = False
try:
    from aibench_observer.utils.observer import emitMetric

    haveAIBench = True
except Exception:
    haveAIBench = False

from fbgemm_gpu.split_table_batched_embeddings_ops import (
    BoundsCheckMode,
    CacheAlgorithm,
    ComputeDevice,
    DenseTableBatchedEmbeddingBagsCodegen,
    EmbeddingLocation,
    OptimType,
    SparseType,
    SplitTableBatchedEmbeddingBagsCodegen,
    IntNBitTableBatchedEmbeddingBagsCodegen,
    PoolingMode,
)
from numpy.random import default_rng
from torch import Tensor

logging.basicConfig(level=logging.DEBUG)


def round_up(a: int, b: int) -> int:
    return int((a + b - 1) // b) * b


def get_device() -> torch.device:
    return (
        torch.cuda.current_device()
        if torch.cuda.is_available()
        else torch.device("cpu")
    )


# Merged indices with shape (T, B, L) -> (flattened indices with shape
# (T * B * L), offsets with shape (T * B + 1))
def get_table_batched_offsets_from_dense(
    merged_indices: Tensor,
) -> Tuple[Tensor, Tensor]:
    (T, B, L) = merged_indices.size()
    lengths = np.ones((T, B)) * L
    flat_lengths = lengths.flatten()
    return (
        merged_indices.long().contiguous().view(-1).to(get_device()),
        torch.tensor(([0] + np.cumsum(flat_lengths).tolist())).long().to(get_device()),
    )


def get_offsets_from_dense(indices: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    (B, L) = indices.size()
    return (
        indices.contiguous().view(-1),
        torch.tensor(
            np.cumsum(np.asarray([0] + [L for _ in range(B)])[:-1]).astype(np.int64)
        ),
    )


def b_indices(
    b: Callable[..., torch.Tensor],
    x: torch.Tensor,
    per_sample_weights: Optional[torch.Tensor] = None,
    use_cpu: bool = False,
    do_pooling: bool = True,
) -> torch.Tensor:
    (indices, offsets) = get_offsets_from_dense(x)
    if do_pooling:
        return b(
            indices.cuda(),
            offsets.cuda(),
            per_sample_weights=per_sample_weights,
        )
    else:
        return b(indices.cuda())


def generate_requests(
    iters: int,
    B: int,
    T: int,
    L: int,
    E: int,
    # inter-batch indices reuse rate
    reuse: float = 0.0,
    # alpha <= 1.0: use uniform distribution
    # alpha > 1.0: use zipf distribution
    alpha: float = 1.0,
    weights_precision: SparseType = SparseType.FP32,
    weighted: bool = False,
    requests_data_file: Optional[str] = None,
    # Comma-separated list of table numbers
    tables: Optional[str] = None,
) -> List[Tuple[torch.IntTensor, torch.IntTensor, Optional[Tensor]]]:
    if requests_data_file is not None:
        indices_tensor, offsets_tensor, lengths_tensor = torch.load(requests_data_file)

        average_L = 0
        if tables is not None:
            emb_tables = tuple(int(x) for x in tables.split(","))
            indices = torch.zeros(0, dtype=indices_tensor.dtype)
            offsets = torch.zeros(1, dtype=offsets_tensor.dtype)
            total_L = 0
            for t in emb_tables:
                t_offsets = offsets_tensor[B * t : B * (t + 1) + 1]
                total_L += t_offsets[-1] - t_offsets[0]
                indices = torch.cat(
                    (indices, indices_tensor[t_offsets[0] : t_offsets[-1]])
                )
                offsets = torch.cat(
                    (
                        offsets,
                        t_offsets[1:] - t_offsets[0] + offsets[-1],
                    )
                )
            indices_tensor = indices
            offsets_tensor = offsets
            average_L = int(total_L / B)

            assert np.prod(offsets_tensor.size()) - 1 == np.prod((T, B)), (
                f"Requested tables: {emb_tables} "
                f"does not conform to inputs (T, B) = ({T}, {B})."
            )
            logging.warning(
                f"Using (indices = {indices_tensor.size()}, offsets = {offsets_tensor.size()}) based "
                f"on tables: {emb_tables}"
            )
        else:
            average_L = int((offsets_tensor[-1] - offsets_tensor[0]) / B)
            assert (np.prod(offsets_tensor.size()) - 1) == np.prod((T, B)), (
                f"Data file (indices = {indices_tensor.size()}, "
                f"offsets = {offsets_tensor.size()}, lengths = {lengths_tensor.size()}) "
                f"does not conform to inputs (T, B) = ({T}, {B})."
            )

        assert (
            L == average_L
        ), f"Requested L does not align with provided data file ({L} vs. {average_L})"
        assert E > max(indices_tensor), (
            f"Number of embeddings is not enough to support maximum index "
            f"provided by data file {E} vs. {max(indices_tensor)}"
        )

        weights_tensor = (
            None
            if not weighted
            else torch.randn(indices_tensor.size(), device=get_device())
        )
        rs = []
        for _ in range(iters):
            rs.append(
                (
                    indices_tensor.to(get_device()),
                    offsets_tensor.to(get_device()),
                    weights_tensor,
                )
            )
        return rs

    if alpha <= 1.0:
        all_indices = torch.randint(
            low=0,
            high=E,
            size=(iters, T, B, L),
            device=get_device(),
            dtype=torch.int32,
        )
        # each bag is usually sorted
        (all_indices, _) = torch.sort(all_indices)
        all_indices = all_indices.reshape(iters, T, B * L)
    else:
        assert E >= L, "num-embeddings must be greater than equal to bag-size"
        # oversample and then remove duplicates to obtain sampling without
        # replacement
        all_indices = (np.random.zipf(a=alpha, size=(iters, T, B, 3 * L)) - 1) % E
        for index_tuple in itertools.product(range(iters), range(T), range(B)):
            # sample without replacement from
            # https://stats.stackexchange.com/questions/20590/how-do-i-sample-without-replacement-using-a-sampling-with-replacement-function
            r = set()
            for x in all_indices[index_tuple]:
                if x not in r:
                    r.add(x)
                    if len(r) == L:
                        break
            assert (len(r)) == L, "too skewed distribution (alpha too big)"
            all_indices[index_tuple][:L] = list(r)
        # shuffle indices so we don't have unintended spatial locality
        all_indices = torch.as_tensor(all_indices[:, :, :, :L])
        rng = default_rng()
        permutation = torch.as_tensor(
            rng.choice(E, size=all_indices.max().item() + 1, replace=False)
        )
        all_indices = permutation.gather(0, all_indices.flatten())
        all_indices = all_indices.to(get_device()).int().reshape(iters, T, B * L)
    for it in range(iters - 1):
        for t in range(T):
            reused_indices = torch.randperm(B * L, device=get_device())[
                : int(B * L * reuse)
            ]
            all_indices[it + 1, t, reused_indices] = all_indices[it, t, reused_indices]

    rs = []
    for it in range(iters):
        weights_tensor = (
            None if not weighted else torch.randn(T * B * L, device=get_device())
        )
        rs.append(
            get_table_batched_offsets_from_dense(all_indices[it].view(T, B, L))
            + (weights_tensor,)
        )
    return rs


def benchmark_requests(
    requests: List[Tuple[torch.IntTensor, torch.IntTensor, Optional[Tensor]]],
    func: Callable[[Tensor, Tensor, Optional[Tensor]], Tensor],
    flush_gpu_cache_size_mb: int = 0,
    check_median: bool = False,
) -> float:
    times = []
    if torch.cuda.is_available():
        torch.cuda.synchronize()
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
    for (indices, offsets, weights) in requests:
        start_time = time.time()
        if torch.cuda.is_available():
            if flush_gpu_cache_size_mb:
                _ = torch.rand(
                    flush_gpu_cache_size_mb * 1024 * 1024 // 4, dtype=torch.float
                )
                torch.cuda.synchronize()
            start_event.record()
        func(indices, offsets, weights)
        if torch.cuda.is_available():
            end_event.record()
            torch.cuda.synchronize()
            it_time = start_event.elapsed_time(end_event) * 1.0e-3
            times.append(it_time)
        else:
            it_time = time.time() - start_time
            times.append(it_time)
    avg_time = sum(times) / len(requests)
    median_time = statistics.median(times)
    return median_time if check_median else avg_time


def benchmark_requests_refer(
    requests: List[Tuple[torch.IntTensor, torch.IntTensor, Optional[Tensor]]],
    T: int,
    B: int,
    L: int,
    E: int,
    D: int,
    pooling_mode: str,
    weighted: bool,
    flush_gpu_cache_size_mb: int = 0,
    check_median: bool = False,
) -> float:
    do_pooling = pooling_mode in ["sum", "mean"]
    if do_pooling:
        nn_embedding_list = [
            torch.nn.EmbeddingBag(E, D, mode=pooling_mode, sparse=True).cuda()
        ] * T
    else:
        nn_embedding_list = [torch.nn.Embedding(E, D, sparse=True).cuda()] * T

    times = []
    if torch.cuda.is_available():
        torch.cuda.synchronize()
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
    for (indices, _, weights) in requests:
        indices_list = indices.view(T, B, L).split(1)

        if weighted:
            assert weights is not None
            weights_list = weights.view(T, B, L).split(1)

        start_time = time.time()
        if torch.cuda.is_available():
            if flush_gpu_cache_size_mb:
                _ = torch.rand(
                    flush_gpu_cache_size_mb * 1024 * 1024 // 4, dtype=torch.float
                )
                torch.cuda.synchronize()
            start_event.record()

        nn_embedding_output = (
            [
                b_indices(nn_embedding, x, use_cpu=False, do_pooling=do_pooling)
                for (nn_embedding, x) in zip(nn_embedding_list, indices_list)
            ]
            if not weighted
            else [
                b_indices(
                    nn_embedding,
                    x,
                    per_sample_weights=xw.view(-1),
                    use_cpu=False,
                    do_pooling=do_pooling,
                )
                for (nn_embedding, x, xw) in zip(
                    nn_embedding_list,
                    indices_list,
                    # pyre-fixme[61]: `weights_list` is undefined, or not always
                    #  defined.
                    weights_list,
                )
            ]
        )
        if do_pooling:
            final_output = torch.cat(
                [f.view(B, -1) for f in nn_embedding_output], dim=1
            )
        else:
            final_output = torch.cat(nn_embedding_output, dim=0).view(-1, D)

        if torch.cuda.is_available():
            end_event.record()
            torch.cuda.synchronize()
            it_time = start_event.elapsed_time(end_event) * 1.0e-3
            times.append(it_time)
        else:
            it_time = time.time() - start_time
            times.append(it_time)
    avg_time = sum(times) / len(requests)
    median_time = statistics.median(times)
    return median_time if check_median else avg_time


def benchmark_pipelined_requests(
    requests: List[Tuple[torch.IntTensor, torch.IntTensor, Optional[Tensor]]],
    func1: Callable[[Tensor, Tensor, Optional[Tensor]], None],
    func2: Callable[[Tensor, Tensor, Optional[Tensor]], None],
    flush_gpu_cache_size_mb: int = 0,
) -> Tuple[float, float]:
    torch.cuda.synchronize()
    start_events = [
        (torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True))
        for _ in requests
    ]
    end_events = [
        (torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True))
        for _ in requests
    ]
    for ((indices, offsets, indices_weights), start_event, end_event) in zip(
        requests, start_events, end_events
    ):
        if flush_gpu_cache_size_mb:
            _ = torch.rand(
                flush_gpu_cache_size_mb * 1024 * 1024 // 4, dtype=torch.float
            )
            torch.cuda.synchronize()
        start_event[0].record()
        func1(indices, offsets, indices_weights)
        end_event[0].record()
        start_event[1].record()
        func2(indices, offsets, indices_weights)
        end_event[1].record()
    torch.cuda.synchronize()
    return (
        sum(
            start_event[0].elapsed_time(end_event[0]) * 1.0e-3
            for start_event, end_event in zip(start_events, end_events)
        )
        / len(requests),
        sum(
            start_event[1].elapsed_time(end_event[1]) * 1.0e-3
            for start_event, end_event in zip(start_events, end_events)
        )
        / len(requests),
    )


@click.group()
def cli() -> None:
    pass


@cli.command()
# recommended value: alpha=1.15 for training and alpha=1.09 for inference
@click.option("--alpha", default=1.0)
@click.option("--bag-size", default=20)
@click.option("--batch-size", default=512)
@click.option("--embedding-dim", default=128)
@click.option("--weights-precision", type=SparseType, default=SparseType.FP32)
@click.option("--stoc", is_flag=True, default=False)
@click.option("--iters", default=100)
@click.option("--managed", default="device")
@click.option("--mixed", is_flag=True, default=False)
@click.option("--num-embeddings", default=int(1e5))
@click.option("--num-tables", default=32)
@click.option("--reuse", default=0.0)
@click.option("--row-wise/--no-row-wise", default=True)
@click.option("--weighted", is_flag=True, default=False)
@click.option("--weighted-num-requires-grad", type=int, default=None)
@click.option("--flush-gpu-cache-size-mb", default=0)
@click.option("--dense", is_flag=True, default=False)
@click.option("--output-dtype", type=SparseType, default=SparseType.FP32)
@click.option("--requests_data_file", type=str, default=None)
@click.option("--tables", type=str, default=None)
def device(  # noqa C901
    alpha: float,
    bag_size: int,
    batch_size: int,
    embedding_dim: int,
    weights_precision: SparseType,
    stoc: bool,
    iters: int,
    managed: str,
    mixed: bool,
    num_embeddings: int,
    num_tables: int,
    reuse: float,
    row_wise: bool,
    weighted: bool,
    weighted_num_requires_grad: Optional[int],
    flush_gpu_cache_size_mb: int,
    dense: bool,
    output_dtype: SparseType,
    requests_data_file: Optional[str],
    tables: Optional[str],
) -> None:
    np.random.seed(42)
    torch.manual_seed(42)
    B = batch_size
    D = embedding_dim
    L = bag_size
    E = num_embeddings
    T = num_tables
    if weighted_num_requires_grad:
        assert weighted_num_requires_grad <= T
        weighted_requires_grad_tables = np.random.choice(
            T, replace=False, size=(weighted_num_requires_grad,)
        ).tolist()
        feature_requires_grad = (
            torch.tensor(
                [1 if t in weighted_requires_grad_tables else 0 for t in range(T)]
            )
            .to(get_device())
            .int()
        )
    else:
        feature_requires_grad = None
    if mixed:
        Ds = [
            round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
            for _ in range(T)
        ]
        D = np.average(Ds)
    else:
        Ds = [D] * T
    optimizer = OptimType.EXACT_ROWWISE_ADAGRAD if row_wise else OptimType.EXACT_ADAGRAD

    if managed == "device":
        managed_option = (
            EmbeddingLocation.DEVICE
            if torch.cuda.is_available()
            else EmbeddingLocation.HOST
        )
    else:
        managed_option = EmbeddingLocation.MANAGED

    if dense:
        emb = DenseTableBatchedEmbeddingBagsCodegen(
            [
                (
                    E,
                    d,
                )
                for d in Ds
            ],
            use_cpu=not torch.cuda.is_available(),
        )
    else:
        emb = SplitTableBatchedEmbeddingBagsCodegen(
            [
                (
                    E,
                    d,
                    managed_option,
                    ComputeDevice.CUDA
                    if torch.cuda.is_available()
                    else ComputeDevice.CPU,
                )
                for d in Ds
            ],
            optimizer=optimizer,
            learning_rate=0.1,
            eps=0.1,
            weights_precision=weights_precision,
            stochastic_rounding=stoc,
            output_dtype=output_dtype,
        )
    emb = emb.to(get_device())

    if weights_precision == SparseType.INT8:
        emb.init_embedding_weights_uniform(-0.0003, 0.0003)

    nparams = sum(w.numel() for w in emb.split_embedding_weights())
    param_size_multiplier = weights_precision.bit_rate() / 8.0
    logging.info(
        f"Embedding parameters: {nparams / 1.0e9: .2f} GParam, "
        f"{nparams * param_size_multiplier / 1.0e9: .2f} GB"
    )
    logging.info(
        f"Accessed weights per batch: {B * sum(Ds) * L * param_size_multiplier / 1.0e9: .2f} GB"
    )

    requests = generate_requests(
        iters,
        B,
        T,
        L,
        E,
        reuse=reuse,
        alpha=alpha,
        weights_precision=weights_precision,
        weighted=weighted,
        requests_data_file=requests_data_file,
        tables=tables,
    )

    # forward
    time_per_iter = benchmark_requests(
        requests,
        lambda indices, offsets, per_sample_weights: emb.forward(
            indices.long(),
            offsets.long(),
            per_sample_weights,
            feature_requires_grad=feature_requires_grad,
        ),
        flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
    )
    logging.info(
        f"Forward, B: {B}, "
        f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, "
        f"BW: {param_size_multiplier * B * sum(Ds) * L / time_per_iter / 1.0e9: .2f} GB/s, "  # noqa: B950
        f"T: {time_per_iter * 1.0e6:.0f}us"
    )

    if output_dtype == SparseType.INT8:
        # backward bench not representative
        return

    grad_output = torch.randn(B, sum(Ds)).to(get_device())
    # backward
    time_per_iter = benchmark_requests(
        requests,
        lambda indices, offsets, per_sample_weights: emb(
            indices.long(),
            offsets.long(),
            per_sample_weights,
            feature_requires_grad=feature_requires_grad,
        ).backward(grad_output),
        flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
    )
    logging.info(
        f"ForwardBackward, B: {B}, E: {E}, T: {T}, D: {D}, L: {L}, "
        f"BW: {3 * param_size_multiplier * B * sum(Ds) * L / time_per_iter / 1.0e9: .2f} GB/s, "
        f"T: {time_per_iter * 1.0e6:.0f}us"
    )


@cli.command()
@click.option("--alpha", default=1.0)
@click.option("--bag-size", default=20)
@click.option("--batch-size", default=512)
@click.option("--embedding-dim", default=128)
@click.option("--weights-precision", type=SparseType, default=SparseType.FP32)
@click.option("--stoc", is_flag=True, default=False)
@click.option("--iters", default=100)
@click.option("--mixed", is_flag=True, default=False)
@click.option("--num-embeddings", default=int(1e5))
@click.option("--num-tables", default=32)
@click.option("--reuse", default=0.1)
@click.option("--uvm-tables", default=1)
@click.option("--uvm-bag-size", default=1)
@click.option("--weighted", is_flag=True, default=False)
@click.option("--flush-gpu-cache-size-mb", default=0)
@click.option("--requests_data_file", type=str, default=None)
@click.option("--tables", type=str, default=None)
@click.option("--output-dtype", type=SparseType, default=SparseType.FP32)
@click.option("--use-cache", is_flag=True, default=False)
@click.option("--cache-algorithm", default="lru")
@click.option("--cache-load-factor", default=0.2)
@click.option("--enforce-hbm", is_flag=True, default=False)
def uvm(
    alpha: bool,
    bag_size: int,
    batch_size: int,
    embedding_dim: int,
    weights_precision: SparseType,
    stoc: bool,
    iters: int,
    mixed: bool,
    num_embeddings: int,
    num_tables: int,
    reuse: float,
    uvm_tables: int,
    uvm_bag_size: int,
    weighted: bool,
    flush_gpu_cache_size_mb: int,
    requests_data_file: Optional[str],
    tables: Optional[str],
    output_dtype: SparseType,
    use_cache: bool,
    cache_algorithm: str,
    cache_load_factor: float,
    enforce_hbm: bool,
) -> None:
    np.random.seed(42)
    torch.manual_seed(42)
    B = batch_size
    D = embedding_dim
    L = bag_size
    E = num_embeddings
    T = num_tables
    T_uvm = uvm_tables
    assert T_uvm <= T
    assert (
        T_uvm > 0
    ), f"T_uvm specified {T_uvm} <= 0. If not testing UVM, please use device benchmark."
    T_gpu = T - T_uvm
    L_uvm = uvm_bag_size

    cache_alg = CacheAlgorithm.LRU if cache_algorithm == "lru" else CacheAlgorithm.LFU
    managed_type = (
        EmbeddingLocation.MANAGED_CACHING if use_cache else EmbeddingLocation.MANAGED
    )

    if mixed:
        Ds = [
            round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
            for _ in range(T)
        ]
        D = np.average(Ds)
    else:
        Ds = [D] * T
    emb_uvm = SplitTableBatchedEmbeddingBagsCodegen(
        [
            (
                E,
                d,
                managed_type,
                ComputeDevice.CUDA,
            )
            for d in Ds[:T_uvm]
        ],
        weights_precision=weights_precision,
        stochastic_rounding=stoc,
        output_dtype=output_dtype,
        cache_load_factor=cache_load_factor,
        cache_algorithm=cache_alg,
        enforce_hbm=enforce_hbm,
    ).cuda()

    if weights_precision == SparseType.INT8:
        emb_uvm.init_embedding_weights_uniform(-0.0003, 0.0003)

    if T_gpu > 0:
        emb_gpu = SplitTableBatchedEmbeddingBagsCodegen(
            [
                (
                    E,
                    d,
                    EmbeddingLocation.DEVICE,
                    ComputeDevice.CUDA,
                )
                for d in Ds[T_uvm:]
            ],
            weights_precision=weights_precision,
            stochastic_rounding=stoc,
        ).cuda()

        if weights_precision == SparseType.INT8:
            emb_gpu.init_embedding_weights_uniform(-0.0003, 0.0003)

        emb_mixed = SplitTableBatchedEmbeddingBagsCodegen(
            [
                (
                    E,
                    d,
                    managed_option,
                    ComputeDevice.CUDA,
                )
                for (d, managed_option) in zip(
                    Ds,
                    [managed_type] * T_uvm + [EmbeddingLocation.DEVICE] * T_gpu,
                )
            ],
            weights_precision=weights_precision,
            stochastic_rounding=stoc,
            output_dtype=output_dtype,
            cache_load_factor=cache_load_factor,
            cache_algorithm=cache_alg,
            enforce_hbm=enforce_hbm,
        ).cuda()

        if weights_precision == SparseType.INT8:
            emb_mixed.init_embedding_weights_uniform(-0.0003, 0.0003)

    requests_uvm = generate_requests(
        iters,
        B,
        T_uvm,
        L_uvm,
        E,
        reuse=reuse,
        alpha=alpha,
        weights_precision=weights_precision,
        weighted=weighted,
        requests_data_file=requests_data_file,
        tables=tables,
    )

    requests_gpu = None
    if T_gpu > 0:
        requests_gpu = generate_requests(
            iters,
            B,
            T_gpu,
            L,
            E,
            reuse=reuse,
            alpha=alpha,
            weights_precision=weights_precision,
            weighted=False,
            requests_data_file=requests_data_file,
            tables=tables,
        )

    param_size_multiplier = weights_precision.bit_rate() / 8.0
    output_size_multiplier = output_dtype.bit_rate() / 8.0
    read_write_bytes_uvm = (
        output_size_multiplier * B * sum(Ds[:T_uvm])
        + param_size_multiplier * B * sum(Ds[:T_uvm]) * L_uvm
    )

    time_per_iter = benchmark_requests(
        requests_uvm,
        lambda indices, offsets, per_sample_weights: emb_uvm.forward(
            indices.long(),
            offsets.long(),
            per_sample_weights,
        ),
        flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
    )
    logging.info(
        f"UVM Forward, B: {B}, "
        f"E: {E}, T: {T_uvm}, D: {D}, L: {L_uvm}, W: {weighted}, "
        f"BW: {read_write_bytes_uvm / time_per_iter / 1.0e9: .2f} GB/s, "  # noqa: B950
        f"T: {time_per_iter * 1.0e6:.0f}us"
    )

    if T_gpu > 0:
        requests = []
        assert requests_gpu is not None
        for rs_uvm, rs_gpu in zip(requests_uvm, requests_gpu):
            indices = torch.cat([rs_uvm[0], rs_gpu[0]])
            lengths = [L_uvm] * (T_uvm * B) + [L] * (T_gpu * B)
            offsets = torch.tensor(([0] + np.cumsum(lengths).tolist())).int().cuda()
            per_sample_weights = None
            if weighted:
                assert (this_rs_uvm_weights := rs_uvm[2]) is not None
                assert (this_rs_gpu_weights := rs_gpu[2]) is not None
                per_sample_weights = torch.cat(
                    [this_rs_uvm_weights, this_rs_gpu_weights]
                )
            requests.append((indices, offsets, per_sample_weights))

        # forward
        time_per_iter = benchmark_requests(
            requests_gpu,
            lambda indices, offsets, per_sample_weights: emb_gpu.forward(
                indices.long(),
                offsets.long(),
                per_sample_weights,
            ),
            flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
        )
        read_write_bytes_hbm = (
            output_size_multiplier * B * sum(Ds[T_uvm:])
            + param_size_multiplier * B * sum(Ds[T_uvm:]) * L
        )
        logging.info(
            f"GPU Forward, B: {B}, "
            f"E: {E}, T: {T_gpu}, D: {D}, L: {L}, W: {weighted}, "
            f"BW: {read_write_bytes_hbm / time_per_iter / 1.0e9: .2f} GB/s, "  # noqa: B950
            f"T: {time_per_iter * 1.0e6:.0f}us"
        )

        time_per_iter = benchmark_requests(
            requests,
            lambda indices, offsets, per_sample_weights: emb_mixed.forward(
                indices.long(),
                offsets.long(),
                per_sample_weights,
            ),
            flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
        )
        read_write_bytes_total = read_write_bytes_uvm + read_write_bytes_hbm
        logging.info(
            f"Mixed Forward, B: {B}, "
            f"E: {E}, T_GPU: {T_gpu}, T_UVM: {T_uvm}, D: {D}, L_GPU: {L}, L_UVM: {L_uvm}, W: {weighted}, "
            f"BW: {read_write_bytes_total / time_per_iter / 1.0e9: .2f} GB/s, "  # noqa: B950
            f"T: {time_per_iter * 1.0e6:.0f}us"
        )


@cli.command()
@click.option("--alpha", default=1.0)
@click.option("--bag-size", default=20)
@click.option("--batch-size", default=512)
@click.option("--cache-algorithm", default="lru")
@click.option("--cache-load-factor", default=0.2)
@click.option("--embedding-dim", default=128)
@click.option("--weights-precision", type=SparseType, default=SparseType.FP32)
@click.option("--stoc", is_flag=True, default=False)
@click.option("--long-index", is_flag=True, default=False)
@click.option("--iters", default=100)
@click.option("--mixed", is_flag=True, default=False)
@click.option("--num-embeddings", default=int(1e5))
@click.option("--num-tables", default=32)
@click.option("--reuse", default=0.1)
@click.option("--weighted", is_flag=True, default=False)
@click.option("--flush-gpu-cache-size-mb", default=0)
@click.option("--requests_data_file", type=str, default=None)
@click.option("--tables", type=str, default=None)
def cache(  # noqa C901
    alpha: float,
    bag_size: int,
    batch_size: int,
    cache_algorithm: str,
    cache_load_factor: float,
    embedding_dim: int,
    weights_precision: SparseType,
    stoc: bool,
    iters: int,
    long_index: bool,
    mixed: bool,
    num_embeddings: int,
    num_tables: int,
    reuse: float,
    weighted: bool,
    flush_gpu_cache_size_mb: int,
    requests_data_file: Optional[str],
    tables: Optional[str],
) -> None:
    np.random.seed(42)
    torch.manual_seed(42)
    optimizer = OptimType.EXACT_ROWWISE_ADAGRAD
    B = batch_size
    D = embedding_dim
    L = bag_size
    E = num_embeddings
    T = num_tables
    cache_alg = CacheAlgorithm.LRU if cache_algorithm == "lru" else CacheAlgorithm.LFU
    if mixed:
        Ds = [
            round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
            for _ in range(T)
        ]
        D = np.average(Ds)
    else:
        Ds = [D] * T

    emb_nc = SplitTableBatchedEmbeddingBagsCodegen(
        [
            (
                E,
                d,
                EmbeddingLocation.MANAGED,
                ComputeDevice.CUDA,
            )
            for d in Ds
        ],
        optimizer=optimizer,
        weights_precision=weights_precision,
        stochastic_rounding=stoc,
    ).cuda()

    if weights_precision == SparseType.INT8:
        emb_nc.init_embedding_weights_uniform(-0.0003, 0.0003)

    emb = SplitTableBatchedEmbeddingBagsCodegen(
        [
            (
                E,
                d,
                EmbeddingLocation.MANAGED_CACHING,
                ComputeDevice.CUDA,
            )
            for d in Ds
        ],
        optimizer=optimizer,
        weights_precision=weights_precision,
        stochastic_rounding=stoc,
        cache_load_factor=cache_load_factor,
        cache_algorithm=cache_alg,
    ).cuda()

    if weights_precision == SparseType.INT8:
        emb.init_embedding_weights_uniform(-0.0003, 0.0003)

    nparams = sum(w.numel() for w in emb.split_embedding_weights())
    param_size_multiplier = weights_precision.bit_rate() / 8.0
    logging.info(
        f"Embedding tables: {E * T} rows, {nparams / 1.0e9: .2f} GParam, "
        f"{nparams * param_size_multiplier  / 1.0e9: .2f} GB"
    )
    logging.info(
        f"Accessed weights per batch: {B * T * L} rows, "
        f"{B * T * L * D * param_size_multiplier / 1.0e9: .2f} GB"
    )

    requests = generate_requests(
        2 * iters,
        B,
        T,
        L,
        E,
        reuse=reuse,
        alpha=alpha,
        weighted=weighted,
        requests_data_file=requests_data_file,
        tables=tables,
    )
    warmup_requests, requests = requests[:iters], requests[iters:]
    grad_output = torch.randn(B, sum(Ds)).cuda()

    time_per_iter = benchmark_requests(
        requests,
        lambda indices, offsets, per_sample_weights: emb_nc(
            indices.long(), offsets.long(), per_sample_weights
        ).backward(grad_output),
        flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
    )
    logging.info(
        f"ForwardBackward (UVM), B: {B}, E: {E}, T: {T}, D: {D}, L: {L}, "
        f"BW: {3 * param_size_multiplier * B * sum(Ds) * L / time_per_iter / 1.0e9: .2f} GB/s, "
        f"T: {time_per_iter * 1.0e6:.0f}us"
    )

    # warm up
    for indices, offsets, _ in warmup_requests:
        emb.forward(indices.long(), offsets.long())
    # get cache miss rate (forward and backward) and exchanged cache lines (prefetch)
    cache_misses = []
    exchanged_cache_lines = []
    NOT_FOUND = -1
    for indices, offsets, _ in requests:
        # pyre-fixme[29]:
        #  `Union[BoundMethod[typing.Callable(Tensor.clone)[[Named(self,
        #  Variable[torch._TTensor (bound to Tensor)])], Variable[torch._TTensor (bound
        #  to Tensor)]], Tensor], Tensor, torch.nn.Module]` is not a function.
        old_lxu_cache_state = emb.lxu_cache_state.clone()
        emb.prefetch(indices.long(), offsets.long())
        exchanged_cache_lines.append(
            # pyre-fixme[16]: `bool` has no attribute `sum`.
            (emb.lxu_cache_state != old_lxu_cache_state)
            .sum()
            .item()
        )
        cache_misses.append((emb.lxu_cache_locations_list[0] == NOT_FOUND).sum().item())
        emb.forward(indices.long(), offsets.long())
    logging.info(
        f"Exchanged cache lines -- mean: {sum(exchanged_cache_lines)/len(requests): .2f}, "
        f"max: {max(exchanged_cache_lines)}, min: {min(exchanged_cache_lines)}"
    )
    logging.info(
        f"Cache miss -- mean: {sum(cache_misses)/len(requests)}, "
        f"max: {max(cache_misses)}, min: {min(cache_misses)}"
    )

    # benchmark prefetch
    emb.reset_cache_states()
    for indices, offsets, _ in warmup_requests:
        emb.forward(indices, offsets)
    prefetch_time, forward_backward_time = benchmark_pipelined_requests(
        requests,
        lambda indices, offsets, indices_weights: emb.prefetch(indices, offsets),
        lambda indices, offsets, indices_weights: emb.forward(
            indices, offsets, indices_weights
        ).backward(grad_output),
        flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
    )
    e2e_time = prefetch_time + forward_backward_time

    logging.info(
        f"ForwardBackward (LXU), reuse: {reuse}, alpha: {alpha}, B: {B}, "
        f"E: {E}, T: {T}, D: {D}, L: {L}, "
        f"BW: {3 * param_size_multiplier * B * sum(Ds) * L / e2e_time / 1.0e9: .2f} GB/s, "
        f"Tprefetch: {prefetch_time * 1.0e6:.0f}us, "
        f"{2 * sum(exchanged_cache_lines) * param_size_multiplier * D / prefetch_time / len(requests) / 1.0e9: .2f} GB/s, "
        f"Tfwdbwd: {forward_backward_time * 1.0e6:.0f}us, "
        f"{3 * param_size_multiplier * B * sum(Ds) * L / forward_backward_time / 1.0e9: .2f} GB/s, "
        f"Te2e: {e2e_time * 1.0e6:.0f}us, "
    )


def benchmark_cpu_requests(
    requests: List[Tuple[torch.IntTensor, torch.IntTensor, Optional[torch.Tensor]]],
    func: Callable[[Tensor, Tensor, Optional[Tensor]], Tensor],
) -> float:
    import time

    start_time = time.perf_counter()
    for (indices, offsets, weights) in requests:
        func(indices, offsets, weights)
    end_time = time.perf_counter()
    return (end_time - start_time) / len(requests)


@cli.command()
@click.option("--alpha", default=1.0)
@click.option("--bag-size", default=20)
@click.option("--batch-size", default=512)
@click.option("--embedding-dim", default=128)
@click.option("--weights-precision", type=SparseType, default=SparseType.INT4)
@click.option("--stoc", is_flag=True, default=False)
@click.option("--iters", default=100)
@click.option("--managed", default="device")
@click.option("--mixed", is_flag=True, default=False)
@click.option("--num-embeddings", default=int(1e5))
@click.option("--num-tables", default=32)
@click.option("--reuse", default=0.0)
@click.option("--row-wise/--no-row-wise", default=True)
@click.option("--weighted", is_flag=True, default=False)
@click.option("--index-remapping", is_flag=True, default=False)
@click.option("--requests_data_file", type=str, default=None)
@click.option("--tables", type=str, default=None)
@click.option("--output-dtype", type=SparseType, default=SparseType.FP16)
def nbit_cpu(  # noqa C901
    alpha: float,
    bag_size: int,
    batch_size: int,
    embedding_dim: int,
    weights_precision: SparseType,
    stoc: bool,
    iters: int,
    managed: str,
    mixed: bool,
    num_embeddings: int,
    num_tables: int,
    reuse: float,
    row_wise: bool,
    weighted: bool,
    index_remapping: bool,
    requests_data_file: Optional[str],
    tables: Optional[str],
    output_dtype: SparseType,
) -> None:
    np.random.seed(42)
    torch.manual_seed(42)
    B = batch_size
    D = embedding_dim
    L = bag_size
    E = num_embeddings
    T = num_tables
    if mixed:
        Ds = [
            # int4 table batched emb op can only handle mixed D where D is multiple of 8
            round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 8)
            for _ in range(T)
        ]
        D = np.average(Ds)
    else:
        Ds = [D] * T

    emb = IntNBitTableBatchedEmbeddingBagsCodegen(
        [("", E, d, weights_precision, EmbeddingLocation.HOST) for d in Ds],
        device="cpu",
        index_remapping=[torch.arange(E) for _ in Ds] if index_remapping else None,
        output_dtype=output_dtype,
    ).cpu()
    emb.fill_random_weights()

    nparams_byte = sum(w.numel() for (w, _) in emb.split_embedding_weights())
    param_size_multiplier = weights_precision.bit_rate() / 8.0
    output_size_multiplier = output_dtype.bit_rate() / 8.0
    read_write_bytes = (
        output_size_multiplier * B * T * D + param_size_multiplier * B * T * L * D
    )
    logging.info(
        f"{weights_precision} Embedding tables: {E * T} rows, {nparams_byte / param_size_multiplier / 1.0e9: .2f} GParam, "
        f"{nparams_byte / 1.0e9: .2f} GB"  # IntN TBE use byte for storage
    )
    logging.info(
        f"Accessed weights per batch: {B * T * L} rows, "
        f"{B * T * L * D * param_size_multiplier / 1.0e9: .2f} GB"
    )

    requests = generate_requests(
        iters,
        B,
        T,
        L,
        E,
        reuse=reuse,
        alpha=alpha,
        weights_precision=weights_precision,
        weighted=weighted,
        requests_data_file=requests_data_file,
        tables=tables,
    )
    requests = [
        (a.cpu().int(), b.cpu().int(), c.cpu() if c else None) for (a, b, c) in requests
    ]

    time_per_iter = benchmark_cpu_requests(
        requests,
        lambda indices, offsets, per_sample_weights: emb.forward(
            indices,
            offsets,
            per_sample_weights,
        ),
    )

    logging.info(
        f"{weights_precision} Forward, B: {B}, "
        f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, "
        f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, "  # noqa: B950
        f"T: {time_per_iter * 1.0e6:.0f}us"
    )


@cli.command()
@click.option("--alpha", default=1.0)
@click.option("--bag-size", default=20)
@click.option("--batch-size", default=512)
@click.option("--embedding-dim", default=128)
@click.option("--weights-precision", type=SparseType, default=SparseType.INT4)
@click.option("--stoc", is_flag=True, default=False)
@click.option("--managed", default="device")
@click.option("--mixed", is_flag=True, default=False)
@click.option("--num-embeddings", default=int(1e5))
@click.option("--num-tables", default=32)
@click.option("--reuse", default=0.0)
@click.option("--row-wise/--no-row-wise", default=True)
@click.option("--weighted", is_flag=True, default=False)
@click.option("--pooling", type=str, default="sum")
@click.option("--weighted-num-requires-grad", type=int, default=None)
@click.option("--bounds-check-mode", type=int, default=BoundsCheckMode.WARNING.value)
@click.option("--pruning-ratio", type=float, default=None)
@click.option("--load-factor", default=0.75)
@click.option("--use-array-for-index-remapping", is_flag=True, default=True)
@click.option("--check-median", is_flag=True, default=True)
@click.option("--iters", default=100)
@click.option("--runs-of-iters", default=5)
@click.option("--warmup-runs", default=2)
@click.option("--output-dtype", type=SparseType, default=SparseType.FP16)
@click.option("--report-aibench", is_flag=True)
@click.option("--run-reference", is_flag=True, default=False)
@click.option("--requests_data_file", type=str, default=None)
@click.option("--tables", type=str, default=None)
def nbit_device(  # noqa C901
    alpha: float,
    bag_size: int,
    batch_size: int,
    embedding_dim: int,
    weights_precision: SparseType,
    stoc: bool,
    managed: str,
    mixed: bool,
    num_embeddings: int,
    num_tables: int,
    reuse: float,
    row_wise: bool,
    weighted: bool,
    pooling: str,
    weighted_num_requires_grad: Optional[int],
    bounds_check_mode: int,
    pruning_ratio: Optional[float],
    load_factor: float,
    use_array_for_index_remapping: bool,
    check_median: bool,
    iters: int,
    runs_of_iters: int,
    warmup_runs: int,
    output_dtype: SparseType,
    report_aibench: bool,
    run_reference: bool,
    requests_data_file: Optional[str],
    tables: Optional[str],
) -> None:
    np.random.seed(42)
    torch.manual_seed(42)
    B = batch_size
    D = embedding_dim
    L = bag_size
    E = num_embeddings
    original_E = E
    T = num_tables
    index_remapping = None
    if mixed:
        # int4 table batched emb op can only handle mixed D where D is multiple of 8
        Ds = [
            round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 8)
            for _ in range(T)
        ]
        D = np.average(Ds)
    else:
        Ds = [D] * T

    mem_for_pruning = 0
    if pruning_ratio:
        assert pruning_ratio < 1 and pruning_ratio >= 0
        E = math.ceil(E * (1.0 - pruning_ratio))
        index_remapping = []
        for _ in range(T):
            mapping = torch.tensor([-1] * original_E, dtype=torch.int32)
            selected_indices = random.sample(range(original_E), E)
            for i, idx in enumerate(selected_indices):
                mapping[idx] = i
            index_remapping.append(mapping)
            if use_array_for_index_remapping:
                mem_for_pruning += mapping.numel() * 4
            else:
                mem_for_pruning += E / load_factor * 2 * 4

    if managed == "device":
        managed_option = EmbeddingLocation.DEVICE
    else:
        managed_option = EmbeddingLocation.MANAGED

    if pooling is None or pooling == "sum":
        pooling = "sum"
        pooling_mode = PoolingMode.SUM
        do_pooling = True
    elif pooling == "mean":
        pooling_mode = PoolingMode.MEAN
        do_pooling = True
    else:  # "none"
        pooling_mode = PoolingMode.NONE
        do_pooling = False

    emb = IntNBitTableBatchedEmbeddingBagsCodegen(
        [("", E, d, weights_precision, managed_option) for d in Ds],
        bounds_check_mode=BoundsCheckMode(bounds_check_mode),
        index_remapping=index_remapping,
        load_factor=load_factor,
        use_array_for_index_remapping=use_array_for_index_remapping,
        output_dtype=output_dtype,
        pooling_mode=pooling_mode,
    ).cuda()
    emb.fill_random_weights()

    nparams_byte = sum(w.numel() for (w, _) in emb.split_embedding_weights())
    param_size_multiplier = weights_precision.bit_rate() / 8.0
    output_size_multiplier = output_dtype.bit_rate() / 8.0
    if do_pooling:
        read_write_bytes = (
            output_size_multiplier * B * T * D + param_size_multiplier * B * T * L * D
        )
    else:
        read_write_bytes = (
            output_size_multiplier * B * T * L * D
            + param_size_multiplier * B * T * L * D
        )
    logging.info(
        f"{weights_precision} Embedding tables: {E * T} rows, {nparams_byte / param_size_multiplier / 1.0e9: .2f} GParam, "
        f"{nparams_byte / 1.0e9: .2f} GB"  # IntN TBE use byte for storage
    )
    logging.info(
        f"Accessed weights per batch: {B * T * L} rows, "
        f"{B * T * L * D * param_size_multiplier / 1.0e9: .2f} GB"
    )

    times = []
    for i in range(runs_of_iters):
        requests = generate_requests(
            iters,
            B,
            T,
            L,
            E,
            reuse=reuse,
            alpha=alpha,
            weights_precision=weights_precision,
            weighted=weighted,
            requests_data_file=requests_data_file,
            tables=tables,
        )
        requests = [(a.int(), b.int(), c if c else None) for (a, b, c) in requests]

        # forward
        time_per_iter = benchmark_requests(
            requests,
            lambda indices, offsets, per_sample_weights: emb.forward(
                indices.int(),
                offsets.int(),
                per_sample_weights,
            ),
            check_median=check_median,
        )

        # free up GPU memory
        del requests

        logging.info(
            f"Iteration {i}: "
            f"{weights_precision} Forward, B: {B}, "
            f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, "
            f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, "  # noqa: B950
            f"Time: {time_per_iter * 1.0e6:.0f}us, "
            f"Memory Usage For Pruning: {mem_for_pruning / 1.0e9:.0f} GB"
        )

        if i >= warmup_runs:
            times.append(time_per_iter)

    time_per_iter = statistics.mean(times)
    bandwidth = read_write_bytes / time_per_iter / 1.0e9

    logging.info(
        f"Average of all iterations: "
        f"{weights_precision} Forward, B: {B}, "
        f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, "
        f"BW: {bandwidth: .2f} GB/s, "  # noqa: B950
        f"Time: {time_per_iter * 1.0e6:.0f}us, "
        f"Memory Usage For Pruning: {mem_for_pruning / 1.0e9:.0f} GB"
    )

    if report_aibench and haveAIBench:
        print(
            emitMetric(
                type="NET",
                metric=f"bandwidth_{weights_precision}",
                unit="scalar",
                value=str(bandwidth),
            )
        )
        print(
            emitMetric(
                type="NET",
                metric=f"time_per_iter_{weights_precision}",
                unit="scalar",
                value=str(time_per_iter * 1.0e6),
            )
        )

    if run_reference:
        times = []
        for i in range(runs_of_iters):
            requests = generate_requests(
                iters,
                B,
                T,
                L,
                E,
                reuse=reuse,
                alpha=alpha,
                weights_precision=weights_precision,
                weighted=weighted,
                requests_data_file=requests_data_file,
                tables=tables,
            )
            requests = [(a.int(), b.int(), c if c else None) for (a, b, c) in requests]

            # forward
            time_per_iter_refer = benchmark_requests_refer(
                requests,
                T,
                B,
                L,
                E,
                D,
                pooling,
                weighted,
                check_median=check_median,
            )

            # free up GPU memory
            del requests

            logging.info(
                f"Reference (nn.Embedding(Bag)) Iteration {i}: "
                f"Forward, B: {B}, "
                f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, "
                f"BW: {read_write_bytes / time_per_iter_refer / 1.0e9: .2f} GB/s, "  # noqa: B950
                f"Time: {time_per_iter_refer * 1.0e6:.0f}us "
            )

            if i >= warmup_runs:
                times.append(time_per_iter_refer)

        time_per_iter_refer = statistics.mean(times)
        bandwidth = read_write_bytes / time_per_iter_refer / 1.0e9

        logging.info(
            f"Average of all iterations: "
            f"Forward, B: {B}, "
            f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, "
            f"Effective BW: {bandwidth: .2f} GB/s, "  # noqa: B950
            f"Time: {time_per_iter_refer * 1.0e6:.0f}us "
        )


@cli.command()
@click.option("--alpha", default=1.0)
@click.option("--bag-size", default=20)
@click.option("--batch-size", default=512)
@click.option("--embedding-dim", default=128)
@click.option("--weights-precision", type=SparseType, default=SparseType.INT4)
@click.option("--iters", default=100)
@click.option("--mixed", is_flag=True, default=False)
@click.option("--num-embeddings", default=int(1e5))
@click.option("--num-tables", default=32)
@click.option("--reuse", default=0.1)
@click.option("--uvm-num-embeddings", default=int(1e5))
@click.option("--uvm-tables", default=1)
@click.option("--uvm-bag-size", default=1)
@click.option("--weighted", is_flag=True, default=False)
@click.option("--flush-gpu-cache-size-mb", default=0)
@click.option("--output-dtype", type=SparseType, default=SparseType.FP16)
@click.option("--use-cache", is_flag=True, default=False)
@click.option("--cache-algorithm", default="lru")
@click.option("--cache-load-factor", default=0.2)
@click.option("--enforce-hbm", is_flag=True, default=False)
def nbit_uvm(
    alpha: bool,
    bag_size: int,
    batch_size: int,
    embedding_dim: int,
    weights_precision: SparseType,
    iters: int,
    mixed: bool,
    num_embeddings: int,
    num_tables: int,
    reuse: float,
    uvm_num_embeddings: int,
    uvm_tables: int,
    uvm_bag_size: int,
    weighted: bool,
    flush_gpu_cache_size_mb: int,
    output_dtype: SparseType,
    use_cache: bool,
    cache_algorithm: str,
    cache_load_factor: float,
    enforce_hbm: bool,
) -> None:
    np.random.seed(42)
    torch.manual_seed(42)
    B = batch_size
    D = embedding_dim
    L = bag_size
    E = num_embeddings
    E_uvm = uvm_num_embeddings
    T = num_tables
    T_uvm = uvm_tables
    assert T_uvm <= T
    assert (
        T_uvm > 0
    ), f"T_uvm specified {T_uvm} <= 0. If not testing UVM, please use device benchmark."
    T_gpu = T - T_uvm
    L_uvm = uvm_bag_size
    cache_alg = CacheAlgorithm.LRU if cache_algorithm == "lru" else CacheAlgorithm.LFU
    managed_type = (
        EmbeddingLocation.MANAGED_CACHING if use_cache else EmbeddingLocation.MANAGED
    )

    logging.info(f"T: {T}, T_uvm: {T_uvm}, T_gpu: {T_gpu}")

    if mixed:
        Ds = [
            round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
            for _ in range(T)
        ]
        D = np.average(Ds)
    else:
        Ds = [D] * T
    emb_uvm = IntNBitTableBatchedEmbeddingBagsCodegen(
        [
            (
                "",
                E_uvm,
                d,
                weights_precision,
                managed_type,
            )
            for d in Ds[:T_uvm]
        ],
        output_dtype=output_dtype,
        cache_load_factor=cache_load_factor,
        cache_algorithm=cache_alg,
        enforce_hbm=enforce_hbm,
    ).cuda()
    emb_uvm.fill_random_weights()

    if T_gpu > 0:
        emb_gpu = IntNBitTableBatchedEmbeddingBagsCodegen(
            [
                (
                    "",
                    E,
                    d,
                    weights_precision,
                    EmbeddingLocation.DEVICE,
                )
                for d in Ds[T_uvm:]
            ],
            output_dtype=output_dtype,
        ).cuda()
        emb_gpu.fill_random_weights()

        emb_mixed = IntNBitTableBatchedEmbeddingBagsCodegen(
            [
                (
                    "",
                    e,
                    d,
                    weights_precision,
                    managed_option,
                )
                for (e, d, managed_option) in zip(
                    [E_uvm] * T_uvm + [E] * T_gpu,
                    Ds,
                    [managed_type] * T_uvm + [EmbeddingLocation.DEVICE] * T_gpu,
                )
            ],
            output_dtype=output_dtype,
            cache_load_factor=cache_load_factor,
            cache_algorithm=cache_alg,
            enforce_hbm=enforce_hbm,
        ).cuda()
        emb_mixed.fill_random_weights()

    requests_uvm = generate_requests(
        iters,
        B,
        T_uvm,
        L_uvm,
        E_uvm,
        reuse=reuse,
        alpha=alpha,
        weights_precision=weights_precision,
        weighted=weighted,
    )
    requests_uvm = [(a.int(), b.int(), c if c else None) for (a, b, c) in requests_uvm]

    requests_gpu = None
    if T_gpu > 0:
        requests_gpu = generate_requests(
            iters,
            B,
            T_gpu,
            L,
            E,
            reuse=reuse,
            alpha=alpha,
            weights_precision=weights_precision,
            weighted=False,
        )
        requests_gpu = [
            (a.int(), b.int(), c if c else None) for (a, b, c) in requests_gpu
        ]

    param_size_multiplier = weights_precision.bit_rate() / 8.0
    output_size_multiplier = output_dtype.bit_rate() / 8.0
    read_write_bytes_uvm = (
        output_size_multiplier * B * sum(Ds[:T_uvm])
        + param_size_multiplier * B * sum(Ds[:T_uvm]) * L_uvm
    )

    if T_gpu > 0:
        nparams_byte = sum(w.numel() for (w, _) in emb_mixed.split_embedding_weights())
        logging.info(
            f"{weights_precision} Embedding tables: {E * T + E_uvm * T_uvm} rows, {nparams_byte / param_size_multiplier / 1.0e9: .2f} GParam, "
            f"{nparams_byte / 1.0e9: .2f} GB"  # IntN TBE use byte for storage
        )
        logging.info(
            f"Accessed weights per batch: {B * (T * L + T_uvm * L_uvm)} rows, "
            f"{B * (T * L * sum(Ds[T_uvm:]) + T_uvm * L_uvm * sum(Ds[:T_uvm])) * param_size_multiplier / 1.0e9: .2f} GB"
        )

    time_per_iter = benchmark_requests(
        requests_uvm,
        lambda indices, offsets, per_sample_weights: emb_uvm.forward(
            indices.int(),
            offsets.int(),
            per_sample_weights,
        ),
        flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
    )
    logging.info(
        f"UVM NBit Forward, {weights_precision}, B: {B}, "
        f"E_uvm: {E_uvm}, T: {T_uvm}, D: {D}, L: {L_uvm}, W: {weighted}, "
        f"BW: {read_write_bytes_uvm / time_per_iter / 1.0e9: .2f} GB/s, "  # noqa: B950
        f"Time: {time_per_iter * 1.0e6:.0f}us"
    )

    if T_gpu > 0:
        requests = []
        assert requests_gpu is not None
        for rs_uvm, rs_gpu in zip(requests_uvm, requests_gpu):
            indices = torch.cat([rs_uvm[0], rs_gpu[0]])
            lengths = [L_uvm] * (T_uvm * B) + [L] * (T_gpu * B)
            offsets = torch.tensor(([0] + np.cumsum(lengths).tolist())).int().cuda()
            per_sample_weights = None
            if weighted:
                assert (this_rs_uvm_weights := rs_uvm[2]) is not None
                assert (this_rs_gpu_weights := rs_gpu[2]) is not None
                per_sample_weights = torch.cat(
                    [this_rs_uvm_weights, this_rs_gpu_weights]
                )
            requests.append((indices, offsets, per_sample_weights))

        # forward
        time_per_iter = benchmark_requests(
            requests_gpu,
            lambda indices, offsets, per_sample_weights: emb_gpu.forward(
                indices.int(),
                offsets.int(),
                per_sample_weights,
            ),
            flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
        )

        read_write_bytes_hbm = (
            output_size_multiplier * B * sum(Ds[T_uvm:])
            + param_size_multiplier * B * sum(Ds[T_uvm:]) * L
        )
        logging.info(
            f"GPU NBit Forward, {weights_precision}, B: {B}, "
            f"E: {E}, T: {T_gpu}, D: {D}, L: {L}, W: {weighted}, "
            f"BW: {read_write_bytes_hbm / time_per_iter / 1.0e9: .2f} GB/s, "  # noqa: B950
            f"Time: {time_per_iter * 1.0e6:.0f}us"
        )

        time_per_iter = benchmark_requests(
            requests,
            lambda indices, offsets, per_sample_weights: emb_mixed.forward(
                indices.int(),
                offsets.int(),
                per_sample_weights,
            ),
            flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
        )
        read_write_bytes_total = read_write_bytes_uvm + read_write_bytes_hbm
        logging.info(
            f"Mixed NBit Forward, {weights_precision}, B: {B}, "
            f"E_GPU: {E}, E_UVM: {E_uvm}, T_GPU: {T_gpu}, T_UVM: {T_uvm}, D: {D}, L_GPU: {L}, L_UVM: {L_uvm}, W: {weighted}, "
            f"BW: {read_write_bytes_total / time_per_iter / 1.0e9: .2f} GB/s, "  # noqa: B950
            f"Time: {time_per_iter * 1.0e6:.0f}us"
        )

        # benchmark prefetch
        emb_mixed.reset_cache_states()
        for indices, offsets, _ in requests:
            emb_mixed.forward(indices, offsets)
        prefetch_time, forward_time = benchmark_pipelined_requests(
            requests,
            lambda indices, offsets, indices_weights: emb_mixed.prefetch(
                indices,
                offsets,
            ),
            # pyre-fixme[6]: Expected `(Tensor, Tensor, Optional[Tensor]) -> None` for
            #  3rd param but got `(indices: Any, offsets: Any, indices_weights: Any) ->
            #  Tensor`.
            lambda indices, offsets, indices_weights: emb_mixed.forward(
                indices,
                offsets,
                indices_weights,
            ),
            flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
        )
        e2e_time = prefetch_time + forward_time

        logging.info(
            f"Forward(LXU) {weights_precision}, reuse: {reuse}, alpha: {alpha}, B: {B}, "
            f"E: {E}, T: {T}, D: {D}, L: {L}, "
            f"Te2e: {e2e_time * 1.0e6:.0f}us, "
            f"e2e BW: {read_write_bytes_total / e2e_time / 1.0e9: .2f} GB/s, "
            f"Tprefetch: {prefetch_time * 1.0e6:.0f}us, "
            f"TfwdTime: {forward_time * 1.0e6:.0f}us, "
            f"{read_write_bytes_total / forward_time / 1.0e9: .2f} GB/s"
        )


@cli.command()
@click.option("--alpha", default=1.0)
@click.option("--bag-size", default=20)
@click.option("--batch-size", default=512)
@click.option("--cache-algorithm", default="lru")
@click.option("--cache-load-factor", default=0.2)
@click.option("--embedding-dim", default=128)
@click.option("--weights-precision", type=SparseType, default=SparseType.INT4)
@click.option("--iters", default=100)
@click.option("--mixed", is_flag=True, default=False)
@click.option("--num-embeddings", default=int(1e5))
@click.option("--num-tables", default=32)
@click.option("--reuse", default=0.1)
@click.option("--weighted", is_flag=True, default=False)
@click.option("--flush-gpu-cache-size-mb", default=0)
@click.option("--output-dtype", type=SparseType, default=SparseType.FP16)
@click.option("--enforce-hbm", is_flag=True, default=False)
def nbit_cache(  # noqa C901
    alpha: float,
    bag_size: int,
    batch_size: int,
    cache_algorithm: str,
    cache_load_factor: float,
    embedding_dim: int,
    weights_precision: SparseType,
    iters: int,
    mixed: bool,
    num_embeddings: int,
    num_tables: int,
    reuse: float,
    weighted: bool,
    flush_gpu_cache_size_mb: int,
    output_dtype: SparseType,
    enforce_hbm: bool,
) -> None:
    np.random.seed(42)
    torch.manual_seed(42)
    B = batch_size
    D = embedding_dim
    L = bag_size
    E = num_embeddings
    T = num_tables
    cache_alg = CacheAlgorithm.LRU if cache_algorithm == "lru" else CacheAlgorithm.LFU
    if mixed:
        Ds = [
            round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
            for _ in range(T)
        ]
        D = np.average(Ds)
    else:
        Ds = [D] * T

    emb_nc = IntNBitTableBatchedEmbeddingBagsCodegen(
        [
            (
                "",
                E,
                d,
                weights_precision,
                EmbeddingLocation.MANAGED,
            )
            for d in Ds
        ],
        output_dtype=output_dtype,
        enforce_hbm=enforce_hbm,
    ).cuda()
    emb_nc.fill_random_weights()

    emb = IntNBitTableBatchedEmbeddingBagsCodegen(
        [
            (
                "",
                E,
                d,
                weights_precision,
                EmbeddingLocation.MANAGED_CACHING,
            )
            for d in Ds
        ],
        cache_load_factor=cache_load_factor,
        cache_algorithm=cache_alg,
        output_dtype=output_dtype,
        enforce_hbm=enforce_hbm,
    ).cuda()
    emb.fill_random_weights()

    nparams_byte = sum(w.numel() for (w, _) in emb.split_embedding_weights())
    param_size_multiplier = weights_precision.bit_rate() / 8.0
    output_size_multiplier = output_dtype.bit_rate() / 8.0
    read_write_bytes = (
        output_size_multiplier * B * sum(Ds) + param_size_multiplier * B * sum(Ds) * L
    )
    logging.info(
        f"{weights_precision} Embedding tables: {E * T} rows, {nparams_byte / param_size_multiplier / 1.0e9: .2f} GParam, "
        f"{nparams_byte / 1.0e9: .2f} GB"  # IntN TBE use byte for storage
    )
    logging.info(
        f"Accessed weights per batch: {B * T * L} rows, "
        f"{B * T * L * D * param_size_multiplier / 1.0e9: .2f} GB"
    )

    requests = generate_requests(
        2 * iters, B, T, L, E, reuse=reuse, alpha=alpha, weighted=weighted
    )
    requests = [(a.int(), b.int(), c if c else None) for (a, b, c) in requests]
    warmup_requests, requests = requests[:iters], requests[iters:]

    time_per_iter = benchmark_requests(
        requests,
        lambda indices, offsets, per_sample_weights: emb_nc(
            indices.int(), offsets.int(), per_sample_weights
        ),
        flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
    )
    logging.info(
        f"Forward (UVM) {weights_precision}, B: {B}, E: {E}, T: {T}, D: {D}, L: {L}, "
        f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, "  # noqa: B950
        f"T: {time_per_iter * 1.0e6:.0f}us"
    )

    # exchanged_cache_lines = [100]
    # warm up
    for indices, offsets, _ in warmup_requests:
        emb.forward(indices.int(), offsets.int())
    # get cache miss rate (forward only) and exchanged cache lines (prefetch)
    cache_misses = []
    exchanged_cache_lines = []
    NOT_FOUND = -1
    for indices, offsets, _ in requests:
        # pyre-fixme[29]:
        #  `Union[BoundMethod[typing.Callable(Tensor.clone)[[Named(self,
        #  Variable[torch._TTensor (bound to Tensor)])], Variable[torch._TTensor (bound
        #  to Tensor)]], Tensor], Tensor, torch.nn.Module]` is not a function.
        old_lxu_cache_state = emb.lxu_cache_state.clone()
        emb.prefetch(indices, offsets)
        exchanged_cache_lines.append(
            # pyre-fixme[16]: `bool` has no attribute `sum`.
            (emb.lxu_cache_state != old_lxu_cache_state)
            .sum()
            .item()
        )
        cache_misses.append(
            (emb.lxu_cache_locations_list.top() == NOT_FOUND).sum().item()
        )
        emb.forward(indices, offsets)
    logging.info(
        f"Exchanged cache lines -- mean: {sum(exchanged_cache_lines)/len(requests): .2f}, "
        f"max: {max(exchanged_cache_lines)}, min: {min(exchanged_cache_lines)}"
    )
    logging.info(
        f"Cache miss -- mean: {sum(cache_misses)/len(requests)}, "
        f"max: {max(cache_misses)}, min: {min(cache_misses)}"
    )

    # benchmark prefetch
    emb.reset_cache_states()
    for indices, offsets, _ in warmup_requests:
        emb.forward(indices, offsets)
    prefetch_time, forward_time = benchmark_pipelined_requests(
        requests,
        lambda indices, offsets, indices_weights: emb.prefetch(
            indices,
            offsets,
        ),
        # pyre-fixme[6]: Expected `(Tensor, Tensor, Optional[Tensor]) -> None` for
        #  3rd param but got `(indices: Any, offsets: Any, indices_weights: Any) ->
        #  Tensor`.
        lambda indices, offsets, indices_weights: emb.forward(
            indices,
            offsets,
            indices_weights,
        ),
        flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
    )
    e2e_time = prefetch_time + forward_time

    logging.info(
        f"Forward(LXU) {weights_precision}, reuse: {reuse}, alpha: {alpha}, B: {B}, "
        f"E: {E}, T: {T}, D: {D}, L: {L}, "
        f"Te2e: {e2e_time * 1.0e6:.0f}us, "
        f"e2e BW: {read_write_bytes / e2e_time / 1.0e9: .2f} GB/s, "
        f"Tprefetch: {prefetch_time * 1.0e6:.0f}us, "
        f"{2 * sum(exchanged_cache_lines) * param_size_multiplier * D / prefetch_time / len(requests) / 1.0e9: .2f} GB/s, "
        f"TfwdTime: {forward_time * 1.0e6:.0f}us, "
        f"{read_write_bytes / forward_time / 1.0e9: .2f} GB/s"
    )


@cli.command()
@click.option("--bag-size", default=20)
@click.option("--batch-size", default=2048)
@click.option("--iters", default=10)
@click.option("--num-embeddings", default=int(1e5))
@click.option("--num-tables", default=100)
@click.option("--load-factor", default=0.75)
@click.option("--hit-rate", default=0.9)
@click.option("--use-cpu", is_flag=True, default=False)
@click.option("--requests_data_file", type=str, default=None)
@click.option("--tables", type=str, default=None)
def hashtable(  # noqa C901
    bag_size: int,
    batch_size: int,
    iters: int,
    num_embeddings: int,
    num_tables: int,
    load_factor: float,
    hit_rate: float,
    use_cpu: bool,
    requests_data_file: Optional[str],
    tables: Optional[str],
) -> None:
    B = batch_size
    T = num_tables
    L = bag_size
    E = num_embeddings
    np.random.seed(42)
    torch.manual_seed(42)
    if hit_rate == 1.0:
        chosen_indices = torch.cat([torch.arange(E) for _ in range(T)], dim=0).int()
    else:
        chosen_indices = (
            torch.randint(low=0, high=int(E * 1.0 / hit_rate), size=(E * T,))
            .view(-1)
            .int()
        )
    dense_indices = torch.cat([torch.arange(E) for _ in range(T)], dim=0).int()
    offsets = torch.tensor([E * t for t in range(T + 1)]).int()
    assert offsets[-1] == chosen_indices.numel()
    assert offsets.numel() == T + 1
    assert (offsets.numel() - 1) // T == 1

    capacities = [round_up(int(E / load_factor), 32) for _ in range(T)]

    hash_table = torch.zeros(
        (sum(capacities), 2),
        dtype=torch.int32,
    )
    hash_table_offsets = torch.tensor([0] + np.cumsum(capacities).tolist()).long()

    assert hash_table.numel() * 4 < 2 ** 32
    # initialize
    hash_table[:, :] = -1
    torch.ops.fbgemm.pruned_hashmap_insert(
        chosen_indices, dense_indices, offsets, hash_table, hash_table_offsets
    )

    requests = generate_requests(
        iters,
        B,
        T,
        L,
        E,
        requests_data_file=requests_data_file,
        tables=tables,
    )

    if not use_cpu:
        hash_table = hash_table.cuda()
        hash_table_offsets = hash_table_offsets.cuda()
        requests = [(a.cuda().int(), b.cuda().int(), c) for (a, b, c) in requests]
    else:
        requests = [(a.int().cpu(), b.int().cpu(), c) for (a, b, c) in requests]

    empirical_hit_rate = np.mean(
        [
            torch.ops.fbgemm.pruned_hashmap_lookup(
                indices, offsets, hash_table, hash_table_offsets
            )
            .ne(-1)
            .sum()
            .item()
            / indices.numel()
            for indices, offsets, _ in requests
        ]
    )

    time_per_iter = benchmark_requests(
        requests,
        lambda indices, offsets, _: torch.ops.fbgemm.pruned_hashmap_lookup(
            indices, offsets, hash_table, hash_table_offsets
        ),
    )

    logging.info(
        f"LinearTable: B: {B}, T: {T}, L: {L}, E: {E}, QPS: {B * T * L / time_per_iter / 1.0e9:.2f}B QPS/s, "
        f"T: {time_per_iter * 1.0e6:.0f}us, load factor: {E * T / hash_table.shape[0] * 100:.1f}%, hit rate: {empirical_hit_rate * 100:.2f}%, Table size: {hash_table.numel() * 4 / 1.0e9:.0f} GB"
    )

    if use_cpu:
        ht = torch.classes.fb.PrunedMapCPU()
        ht.insert(chosen_indices, dense_indices, offsets, T)

        time_per_iter = benchmark_requests(
            requests,
            lambda indices, offsets, _: ht.lookup(indices, offsets),
        )

        logging.info(
            f"HashTable: B: {B}, T: {T}, L: {L}, E: {E}, QPS: {B * T * L / time_per_iter / 1.0e9:.2f}B QPS/s, "
            f"T: {time_per_iter * 1.0e6:.0f}us, load factor: {E * T / hash_table.shape[0] * 100:.1f}%, hit rate: {empirical_hit_rate * 100:.2f}%, Table size: {hash_table.numel() * 4 / 1.0e9:.0f} GB"
        )


@cli.command()
@click.option("--bag-size", default=20)
@click.option("--batch-size", default=2048)
@click.option("--iters", default=100)
@click.option("--num-embeddings", default=int(1e5))
@click.option("--num-tables", default=100)
@click.option("--pruning-ratio", default=0.9)
@click.option("--requests_data_file", type=str, default=None)
@click.option("--tables", type=str, default=None)
def pruned_array(  # noqa C901
    bag_size: int,
    batch_size: int,
    iters: int,
    num_embeddings: int,
    num_tables: int,
    pruning_ratio: float,
    requests_data_file: Optional[str],
    tables: Optional[str],
) -> None:
    B = batch_size
    T = num_tables
    L = bag_size
    E = num_embeddings
    np.random.seed(42)
    torch.manual_seed(42)
    assert pruning_ratio > 0 and pruning_ratio <= 1
    original_E = int(E / (1.0 - pruning_ratio))
    index_remappings = torch.tensor(
        [-1] * original_E * T, dtype=torch.int32, device="cuda"
    )
    index_remappings_offsets = torch.empty(T + 1, dtype=torch.int32, device="cuda")
    index_remappings_offsets[0] = 0
    dense_indices = torch.tensor(range(E), dtype=torch.int32, device="cuda")
    for t in range(T):
        selected_indices = torch.add(
            torch.randperm(original_E, device="cuda"), t * original_E
        )[:E]
        index_remappings[selected_indices] = dense_indices
        index_remappings_offsets[t + 1] = index_remappings_offsets[t] + original_E

    requests = generate_requests(
        iters,
        B,
        T,
        L,
        E,
        requests_data_file=requests_data_file,
        tables=tables,
    )
    requests = [(a.cuda().int(), b.cuda().int(), c) for (a, b, c) in requests]

    time_per_iter = benchmark_requests(
        requests,
        lambda indices, offsets, _: torch.ops.fbgemm.pruned_array_lookup(
            indices,
            offsets,
            index_remappings,
            index_remappings_offsets,
        ),
    )

    logging.info(
        f"LinearTable: B: {B}, T: {T}, L: {L}, E: {E}, QPS: {B * T * L / time_per_iter / 1.0e9:.2f}B QPS/s, "
        f"T: {time_per_iter * 1.0e6:.0f}us, Pruning Ratio: {pruning_ratio * 100:.2f}%, Table size: {original_E * T * 4 / 1.0e9:.0f} GB"
    )


@cli.command()
@click.option("--bag-size", default=20)
@click.option("--batch-size", default=512)
@click.option("--iters", default=100)
@click.option("--num-embeddings", default=int(1e5))
@click.option("--num-tables", default=32)
@click.option("--bounds-check-mode", type=int, default=BoundsCheckMode.WARNING.value)
@click.option("--requests_data_file", type=str, default=None)
@click.option("--tables", type=str, default=None)
def bounds_check_indices(  # noqa C901
    bag_size: int,
    batch_size: int,
    iters: int,
    num_embeddings: int,
    num_tables: int,
    bounds_check_mode: int,
    requests_data_file: Optional[str],
    tables: Optional[str],
) -> None:
    np.random.seed(42)
    torch.manual_seed(42)
    B = batch_size
    L = bag_size
    E = num_embeddings
    T = num_tables

    requests = generate_requests(
        iters,
        B,
        T,
        L,
        E,
        requests_data_file=requests_data_file,
        tables=tables,
    )
    # requests = [(a.int(), b.int(), c if c else None) for (a, b, c) in requests]

    warning = torch.tensor([0]).long().to(get_device())
    rows_per_table = torch.tensor([E for _ in range(T)]).long().to(get_device())
    # forward
    time_per_iter = benchmark_requests(
        requests,
        lambda indices, offsets, _: torch.ops.fbgemm.bounds_check_indices(
            rows_per_table,
            indices,
            offsets,
            BoundsCheckMode(bounds_check_mode),
            warning,
        ),
    )

    logging.info(
        f"Bounds Check Indices:  B: {B}, "
        f"E: {E}, T: {T}, L: {L}, "
        f"BW: {(8 * B * T * L + 8 * (B * T + 1)) / time_per_iter / 1.0e9: .2f} GB/s, "  # noqa: B950
        f"T: {time_per_iter * 1.0e6:.0f}us"
    )


if __name__ == "__main__":
    cli()
