#!/usr/bin/env python3

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import logging
import signal
from typing import Tuple, List

import click
import fbgemm_gpu
import numpy as np
import tabulate
import torch

open_source: bool = getattr(fbgemm_gpu, "open_source", False)

if open_source:
    # pyre-ignore[21]
    from bench_utils import benchmark_torch_function
else:
    from fbgemm_gpu.bench.bench_utils import benchmark_torch_function


from fbgemm_gpu.split_table_batched_embeddings_ops import (
    SparseType,
    BoundsCheckMode,
    IntNBitTableBatchedEmbeddingBagsCodegen,
    EmbeddingLocation,
)
from torch.profiler import ProfilerActivity, profile


def get_gpu_device(gpu_num) -> torch.device:
    return torch.device(f"cuda:{gpu_num}")


# Merged indices with shape (T, B, L) -> (flattened indices with shape
# (T * B * L), offsets with shape (T * B + 1)).
# Reference: https://fburl.com/code/5ueyfv5j
def get_table_batched_offsets_from_dense(
    merged_indices: torch.Tensor,
    gpu_num,
) -> Tuple[torch.Tensor, torch.Tensor]:
    (T, B, L) = merged_indices.size()
    lengths = np.ones((T, B)) * L
    flat_lengths = lengths.flatten()
    return (
        merged_indices.int().contiguous().view(-1).to(device=get_gpu_device(gpu_num)),
        torch.tensor(
            ([0] + np.cumsum(flat_lengths).tolist()), device=get_gpu_device(gpu_num)
        ).int(),
    )


# Reference: https://fburl.com/code/o5600si0
def generate_requests(
    num_gpus: int,
    B: int,
    T: int,
    L: int,
    E: int,
    # inter-batch indices reuse rate
    reuse: float = 0.0,
) -> List[Tuple[torch.IntTensor, torch.IntTensor,]]:
    rs = []
    for gpu_num in range(num_gpus):
        all_indices = torch.randint(
            low=0,
            high=E,
            size=(T, B, L),
            device=get_gpu_device(gpu_num),
            dtype=torch.int32,
        )
        # each bag is usually sorted
        (all_indices, _) = torch.sort(all_indices)
        all_indices = all_indices.reshape(T, B * L)

        rs.append(
            get_table_batched_offsets_from_dense(all_indices.view(T, B, L), gpu_num)
        )
    return rs


def _get_random_tensor(
    num_ads: int,
    embedding_dimension: int,
    ads_tables: int,
    data_type: str,
    gpu_idx: int,
    include_quantization: bool,
):
    if data_type == "FP16" or include_quantization:
        result_tensor = torch.randn(
            num_ads,
            embedding_dimension * ads_tables,
            dtype=torch.float16,
            device=torch.device(f"cuda:{gpu_idx}"),
        )
    elif data_type == "INT8":
        assert (
            embedding_dimension % 2
        ) == 0, "needs to align to 2 bytes (half type size) for INT8"
        result_tensor = torch.randint(
            0,
            255,
            # 2 FP16 numbers for scale and bias, total of 4 bytes overhead
            size=(num_ads, (embedding_dimension + 4) * ads_tables),
            dtype=torch.uint8,
            device=torch.device(f"cuda:{gpu_idx}"),
        )
    elif data_type == "INT4":
        assert (
            embedding_dimension % 4
        ) == 0, "needs to align to 2 bytes (half type size) for INT4"
        result_tensor = torch.randint(
            0,
            255,
            # Using torch.uint8 for int4 storage
            size=(num_ads, (embedding_dimension // 2 + 4) * ads_tables),
            dtype=torch.uint8,
            device=torch.device(f"cuda:{gpu_idx}"),
        )
    else:
        raise ValueError

    return result_tensor


def generate_tbe(
    batch_indices,
    num_ads: int,
    embedding_dimension: int,
    num_of_embeddings: int,
    pooling_factor: int,
    ads_tables: int,
    fused_tbe: bool,
    data_type: str,
    num_gpus: int,
):
    B = num_ads
    D = embedding_dimension
    E = num_of_embeddings
    L = pooling_factor
    T = ads_tables
    Ds = [D] * T
    managed_option = EmbeddingLocation.DEVICE

    output_dtype = SparseType.FP16
    if fused_tbe:
        assert data_type == "INT8"  # INT4 not implemented yet
        output_dtype = SparseType.INT8

    emb = [
        IntNBitTableBatchedEmbeddingBagsCodegen(
            [
                (
                    str(idx),
                    E,
                    d,
                    SparseType.INT4,
                    managed_option,
                )
                for d in Ds
            ],
            output_dtype=output_dtype,
            device=get_gpu_device(idx),
            bounds_check_mode=BoundsCheckMode.NONE,
        )
        for idx in range(num_gpus)
    ]
    for e in emb:
        e.fill_random_weights()
    requests = generate_requests(num_gpus, B, T, L, E)
    # https://fburl.com/code/doxxjc8c
    SIZE_OF_FLOAT = 4
    num_elem_per_byte = 1 if data_type == "INT8" else 2
    assert embedding_dimension % (2 * num_elem_per_byte) == 0
    col_sizes = (
        [
            (embedding_dimension + num_elem_per_byte - 1) // num_elem_per_byte
            + 2 * SIZE_OF_FLOAT
        ]
        * ads_tables
        * num_gpus
    )
    offset = torch.tensor([0] + col_sizes, device=batch_indices.device)
    tbe_offset = torch.cumsum(offset, dim=0).to(torch.int).cuda()

    return emb, requests, tbe_offset


def print_p2p_bandwidth(
    num_gpus, iters, pooled_ad_embeddings, bytes_per_element
) -> None:
    print("Pairwise GPU Copy Bandwidth (GB/s)")
    p2p_copy_bw = np.zeros((num_gpus, num_gpus))
    for i in range(num_gpus):
        for j in range(num_gpus):
            with torch.cuda.device(i):
                t, _ = benchmark_torch_function(
                    lambda: pooled_ad_embeddings[i].copy_(pooled_ad_embeddings[j])
                    if i != j
                    else pooled_ad_embeddings[i].clone(),
                    (),
                    flush_gpu_cache_size_mb=0,
                    iters=iters,
                )
                p2p_copy_bw[i, j] = (
                    pooled_ad_embeddings[i].numel() * bytes_per_element / t / 1.0e9
                )
    table = tabulate.tabulate(
        p2p_copy_bw,
        headers=[f"GPU {i}" for i in range(num_gpus)],
        tablefmt="fancy_grid",
        floatfmt=".0f",
    )
    print(table)


def benchmark(
    all_to_one_only: bool,
    num_ads: int,
    embedding_dimension: int,
    ads_tables: int,
    iters: int = 10,
    p2p_bw: bool = False,
    dst_device: int = 0,
    data_type: str = "FP16",
    mode: str = "P2P",
    skip_dequantization: bool = False,
    num_of_embeddings: int = 10000,
    pooling_factor: int = 25,
) -> str:
    assert torch.cuda.is_available()
    torch.cuda.set_device(dst_device)
    num_gpus = torch.cuda.device_count()
    batch_indices = torch.zeros(num_ads).long().cuda()
    include_quantization = not mode == "P2P"
    # Using torch.int8 for int4 storage
    bytes_per_element = 2 if (data_type == "FP16" or include_quantization) else 1
    total_elements = num_ads * embedding_dimension * ads_tables * num_gpus

    logging.debug(
        f"B: {num_ads}, D: {embedding_dimension}, T: {ads_tables}, Data Type: {data_type}, Num GPUs: {num_gpus}, Destination GPU: {dst_device}"
    )

    fused_tbe = mode == "P2P_FUSED_TBE"
    include_tbe = fused_tbe or mode == "P2P_TBE"
    if include_tbe:
        emb, requests, tbe_offset = generate_tbe(
            batch_indices,
            num_ads,
            embedding_dimension,
            num_of_embeddings,
            pooling_factor,
            ads_tables,
            fused_tbe,
            data_type,
            num_gpus,
        )

    pooled_ad_embeddings = [
        _get_random_tensor(
            num_ads,
            embedding_dimension,
            ads_tables,
            data_type,
            gpu_idx,
            include_quantization,
        )
        for gpu_idx in range(num_gpus)
    ]

    if p2p_bw:
        print_p2p_bandwidth(num_gpus, iters, pooled_ad_embeddings, bytes_per_element)

    def pool_func_with_quantization(
        batch_indices,
        include_quantization,
        include_tbe,
        fused_tbe,
        skip_dequantization,
        data_type,
    ):
        if include_tbe:
            embedding_results = []
            for idx, (indices, offsets) in enumerate(requests):
                with torch.cuda.device(idx):
                    embedding_results.append(emb[idx].forward(indices, offsets))
        else:
            embedding_results = pooled_ad_embeddings

        if data_type == "FP16" or (not fused_tbe and not include_quantization):
            if all_to_one_only:
                return torch.ops.fbgemm.all_to_one_device(
                    pooled_ad_embeddings, batch_indices.device
                )
            else:
                return torch.ops.fbgemm.merge_pooled_embeddings(
                    embedding_results, batch_indices.size(0), batch_indices.device
                )

        assert data_type == "INT8" or data_type == "INT4"
        assert not all_to_one_only  # not supported
        if fused_tbe:
            pooled_quantized_result = torch.ops.fbgemm.merge_pooled_embeddings(
                embedding_results, batch_indices.size(0), batch_indices.device
            )
        else:
            quantized = []
            for t in embedding_results:
                t_split_by_table = torch.split(t, embedding_dimension, dim=1)
                quantized_split_by_table = [
                    torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(t.float())
                    if data_type == "INT8"
                    else torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf(
                        t.float(), 4
                    )
                    for t in t_split_by_table
                ]
                result = torch.cat(quantized_split_by_table, dim=1)
                quantized.append(result)
            pooled_quantized_result = torch.ops.fbgemm.merge_pooled_embeddings(
                quantized, batch_indices.size(0), batch_indices.device
            )

        if skip_dequantization:
            return pooled_quantized_result

        PooledEmbeddingDequantizeDataTypeFP16 = 1
        if data_type == "INT8":
            return torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloatMixedDim(
                pooled_quantized_result,
                tbe_offset,
                PooledEmbeddingDequantizeDataTypeFP16,
            )
        else:
            # TODO: the result here is wrong. Once MixedDim version for FusedNBit quantization is done, switch to that.
            # Since their performance is similar, keep using Fused8BitRowwiseQuantizedToHalf for now.
            return torch.ops.fbgemm.Fused8BitRowwiseQuantizedToHalf(
                pooled_quantized_result
            ).half()

    streams = [torch.cuda.Stream(device=i) for i in range(num_gpus)]
    import contextlib

    with contextlib.ExitStack() as stack:
        for stream in streams:
            stack.enter_context(torch.cuda.stream(stream))

        # warm up
        merged = pool_func_with_quantization(
            batch_indices,
            include_quantization,
            include_tbe,
            fused_tbe,
            skip_dequantization,
            data_type,
        )
        t, _ = benchmark_torch_function(
            pool_func_with_quantization,
            (
                batch_indices,
                include_quantization,
                include_tbe,
                fused_tbe,
                skip_dequantization,
                data_type,
            ),
            flush_gpu_cache_size_mb=0,
            iters=iters,
        )
        with profile(activities=[ProfilerActivity.CUDA]) as prof:
            pool_func_with_quantization(
                batch_indices,
                include_quantization,
                include_tbe,
                fused_tbe,
                skip_dequantization,
                data_type,
            )
        print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

    logging.debug(
        f"Mode: {mode}, Data Type: {data_type}, B: {num_ads}, D: {embedding_dimension}, T: {ads_tables}, Num GPUs: {num_gpus}, Destination GPU: {dst_device}, "
        f"Number of elements: {total_elements / 1.0e6:.2f} Million, Number of elements per GPU: {total_elements / 1.0e6 / num_gpus:.2f}, Billion elements per sec: {total_elements / t / 1.0e9:.1f}, "
        f"Output Size: {merged.numel() * bytes_per_element / 1.0e6:.0f}MB, BW: {merged.numel() * bytes_per_element / t / 1.0e9:.1f}GB/s, "
        f"t: {t * 1.0e3:.2f}ms"
    )
    # return result in CSV format
    return (
        f"{mode}, {data_type}, {num_ads}, {embedding_dimension}, {ads_tables}, {num_gpus}, {dst_device}, "
        f"{total_elements / 1.0e6:.2f}, {total_elements / 1.0e6 / num_gpus:.2f}, {total_elements / 1.0e9 / t:.1f}, "
        f"{merged.numel() * bytes_per_element / 1.0e6:.0f}, {merged.numel() * bytes_per_element / 1.0e9 / t:.1f}, "
        f"{t * 1.0e3:.2f}"
    )


@click.command()
@click.option("--all-to-one-only", is_flag=True, default=False)
@click.option("--num_ads", default=1024, type=int)
@click.option("--embedding_dimension", default=300, type=int)
@click.option("--ads_tables", default=100, type=int)
@click.option("--iters", default=10, type=int)
@click.option("--p2p_bw", is_flag=True, default=False)
@click.option("--dst_device", default=0, type=int)
@click.option(
    "--data_type",
    type=click.Choice(["FP16", "INT8", "INT4"]),
    default="FP16",
)
# P2P: merge_pooled_embeddings() or all_to_one_device() for tensor with "--data_type"
# P2P_QUANT: for INT8/INT4 data type, start with FP16, then quantize -> P2P -> dequantize to FP16
# P2P_TBE: add TBE in front of P2P_QUANT.  When "--data_type" is FP16, the flow is TBE -> P2P; for INT8/INT4, the flow is TBE -> quantize -> P2P -> dequantize
# P2P_FUSED_TBE: similar to P2P_TBE except fuse the quantization into TBE
@click.option(
    "--mode",
    type=click.Choice(["P2P", "P2P_QUANT", "P2P_TBE", "P2P_FUSED_TBE"]),
    default="P2P",
)
# For quantized communication, do we dequantize back to FP16 in the end.
@click.option("--skip_dequantization", is_flag=True, default=False)
@click.option("--num_of_embeddings", default=100000, type=int)
@click.option("--pooling_factor", default=25, type=int)
@click.option("--sweep", is_flag=True, default=False)
def main(
    all_to_one_only: bool,
    num_ads: int,
    embedding_dimension: int,
    ads_tables: int,
    iters: int,
    p2p_bw: bool,
    dst_device: int,
    data_type: str,
    mode: str,
    skip_dequantization: bool,
    num_of_embeddings: int,
    pooling_factor: int,
    sweep: bool,
) -> None:
    csv_header = (
        "mode, data_type, num_ads, embedding_dimension, ads_tables, num_gpus, "
        "dst_device, number of elements (Million), number of elements per GPU (Million), throughput (billion elements per sec), "
        "output size (MB), BW (GB/s), t (ms)"
    )
    if sweep:

        def handler(signum, frame):
            logging.error("timeout")
            raise TimeoutError()

        results = []
        num_gpu = torch.cuda.device_count()
        for num_ads in [128, 256, 512, 1024, 2048]:
            # Scale num_ads so all GPUs have sweep through the same number of total elements
            num_ads *= 8 // num_gpu
            for embedding_dimension in [16, 64, 112, 304]:
                for ads_tables in [25, 50, 100, 400, 800]:
                    if num_ads * embedding_dimension * ads_tables > 983040000:
                        continue  # Skip tests that are too large
                    signal.signal(signal.SIGTERM, handler)
                    signal.alarm(600)
                    logging.info(
                        f"config: num_ads: {num_ads}, embedding_dimension: {embedding_dimension}, ads_tables: {ads_tables}"
                    )
                    try:
                        result = benchmark(
                            all_to_one_only,
                            num_ads,
                            embedding_dimension,
                            ads_tables,
                            iters,
                            p2p_bw,
                            dst_device,
                            data_type,
                            mode,
                            skip_dequantization,
                            num_of_embeddings,
                            pooling_factor,
                        )
                        results.append(result)
                    except (TimeoutError, RuntimeError) as err:
                        logging.error(
                            f"B: {num_ads}, D: {embedding_dimension}, T: {ads_tables}, Data Type: {data_type}, Num GPU: {num_gpu}, time out or failed: {err}"
                        )
        print(csv_header)
        print(*results, sep="\n")
        return

    result = benchmark(
        all_to_one_only,
        num_ads,
        embedding_dimension,
        ads_tables,
        iters,
        p2p_bw,
        dst_device,
        data_type,
        mode,
        skip_dequantization,
        num_of_embeddings,
        pooling_factor,
    )
    print(csv_header)
    print(result)


if __name__ == "__main__":
    main()
