# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#    http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import copy
import math
import os
import time
from collections import OrderedDict
from multiprocessing import Process, Queue
from threading import Thread
from typing import Dict, Optional, Tuple

import numpy as np
import pyarrow as pa
import torch
from torch import distributed as dist
from torch.distributed import ReduceOp

from tzrec.constant import PREDICT_QUEUE_TIMEOUT, Mode
from tzrec.datasets.data_parser import DataParser
from tzrec.datasets.dataset import BaseWriter, create_writer
from tzrec.datasets.sampler import TDMPredictSampler
from tzrec.datasets.utils import Batch, RecordBatchTensor
from tzrec.main import _create_features, _get_dataloader, init_process_group
from tzrec.protos.data_pb2 import DatasetType
from tzrec.utils import config_util
from tzrec.utils.logging_util import ProgressLogger, logger


def update_data(
    input_data: pa.RecordBatch, sampled_data: Dict[str, pa.Array]
) -> Dict[str, pa.Array]:
    """Update input data based on sampled data.

    Args:
        input_data (pa.RecordBatch): raw input data.
        sampled_data (dict): sampled data.

    Returns:
        updated data.
    """
    item_fea_fields = sampled_data.keys()
    all_fea_fields = set(input_data.column_names)
    user_fea_fields = all_fea_fields - item_fea_fields

    updated_data = {}
    for item_fea in item_fea_fields:
        updated_data[item_fea] = sampled_data[item_fea]

    item_field_0 = list(item_fea_fields)[0]
    expand_num = len(sampled_data[item_field_0]) // len(input_data[item_field_0])
    for user_fea in user_fea_fields:
        _user_fea_array = input_data[user_fea]
        index = np.repeat(np.arange(len(_user_fea_array)), expand_num)

        expand_user_fea = _user_fea_array.take(index)
        updated_data[user_fea] = expand_user_fea

    return updated_data


def _tdm_predict_data_worker(
    sampler: TDMPredictSampler,
    data_parser: DataParser,
    first_recall_layer: int,
    n_cluster: int,
    in_queue: Queue,
    out_queue: Queue,
    is_first_layer: bool,
    worker_id: int,
) -> None:
    item_id_field = sampler._item_id_field
    sampler.init(worker_id)
    sampler.init_sampler(n_cluster)

    while True:
        record_batch_t, node_ids = in_queue.get(timeout=PREDICT_QUEUE_TIMEOUT)

        if record_batch_t is None:
            out_queue.put((None, None, None), timeout=PREDICT_QUEUE_TIMEOUT)
            time.sleep(10)
            break

        record_batch = record_batch_t.get()
        if is_first_layer:
            sampler.init_sampler(1)

            gt_node_ids = record_batch[item_id_field]
            cur_batch_size = len(gt_node_ids)
            node_ids = sampler.get({item_id_field: pa.array([-1] * cur_batch_size)})[
                item_id_field
            ]

            # skip layers before first_recall_layer
            sampler.init_sampler(n_cluster)
            for _ in range(1, first_recall_layer):
                sampled_result_dict = sampler.get({item_id_field: node_ids})
                node_ids = sampled_result_dict[item_id_field]

        sampled_result_dict = sampler.get({item_id_field: node_ids})
        updated_inputs = update_data(record_batch, sampled_result_dict)
        output_data = data_parser.parse(updated_inputs)
        batch = data_parser.to_batch(output_data, force_no_tile=True)

        out_queue.put(
            (batch, record_batch_t, updated_inputs[item_id_field]),
            timeout=PREDICT_QUEUE_TIMEOUT,
        )


def tdm_retrieval(
    predict_input_path: str,
    predict_output_path: str,
    scripted_model_path: str,
    recall_num: int,
    n_cluster: int = 2,
    reserved_columns: Optional[str] = None,
    batch_size: Optional[int] = None,
    is_profiling: bool = False,
    debug_level: int = 0,
    dataset_type: Optional[str] = None,
    writer_type: Optional[str] = None,
    num_worker_per_level: int = 1,
) -> None:
    """Evaluate EasyRec TDM model.

    Args:
        predict_input_path (str): inference input data path.
        predict_output_path (str): inference output data path.
        scripted_model_path (str): path to scripted model.
        recall_num (int): recall item num per user.
        n_cluster (int): tree cluster num.
        reserved_columns (str, optional): columns to reserved in output.
        batch_size (int, optional): predict batch_size.
        is_profiling (bool): profiling predict process or not.
        debug_level (int, optional): debug level for debug parsed inputs etc.
        dataset_type (str, optional): dataset type, default use the type in pipeline.
        writer_type (int, optional): data writer type, default will be same as
            dataset_type in data_config.
        num_worker_per_level (int): num data generate worker per tree level.
    """
    reserved_cols: Optional[list[str]] = None
    if reserved_columns is not None:
        reserved_cols = [x.strip() for x in reserved_columns.split(",")]

    pipeline_config = config_util.load_pipeline_config(
        os.path.join(scripted_model_path, "pipeline.config")
    )
    if batch_size:
        pipeline_config.data_config.batch_size = batch_size
    if dataset_type:
        pipeline_config.data_config.dataset_type = getattr(DatasetType, dataset_type)

    device_and_backend = init_process_group()
    device: torch.device = device_and_backend[0]
    sparse_dtype: torch.dtype = torch.int32 if device.type == "cuda" else torch.int64

    is_rank_zero = int(os.environ.get("RANK", 0)) == 0
    is_local_rank_zero = int(os.environ.get("LOCAL_RANK", 0)) == 0

    data_config = pipeline_config.data_config
    data_config.ClearField("label_fields")
    data_config.drop_remainder = False
    # Build feature
    features = _create_features(list(pipeline_config.feature_configs), data_config)

    infer_data_config = copy.copy(data_config)
    infer_data_config.num_workers = 1
    infer_dataloader = _get_dataloader(
        infer_data_config,
        features,
        predict_input_path,
        reserved_columns=["ALL_COLUMNS"],
        mode=Mode.PREDICT,
        debug_level=debug_level,
    )
    infer_iterator = iter(infer_dataloader)

    if writer_type is None:
        writer_type = DatasetType.Name(data_config.dataset_type).replace(
            "Dataset", "Writer"
        )
    writer: BaseWriter = create_writer(
        predict_output_path,
        writer_type,
        quota_name=data_config.odps_data_quota_name,
    )

    # disable jit compile， as it compile too slow now.
    if "PYTORCH_TENSOREXPR_FALLBACK" not in os.environ:
        os.environ["PYTORCH_TENSOREXPR_FALLBACK"] = "2"
    model: torch.jit.ScriptModule = torch.jit.load(
        os.path.join(scripted_model_path, "scripted_model.pt"), map_location=device
    )
    model.eval()

    if is_local_rank_zero:
        plogger = ProgressLogger(desc="Predicting", miniters=10)

    if is_profiling:
        if is_rank_zero:
            logger.info(str(model))
        prof = torch.profiler.profile(
            schedule=torch.profiler.schedule(wait=1, warmup=1, active=10, repeat=1),
            on_trace_ready=torch.profiler.tensorboard_trace_handler(
                os.path.join(scripted_model_path, "predict_trace")
            ),
            record_shapes=True,
            profile_memory=True,
            with_stack=True,
        )
        prof.start()

    parser = DataParser(features)

    sampler_config = pipeline_config.data_config.tdm_sampler
    item_id_field: str = sampler_config.item_id_field
    max_level: int = len(sampler_config.layer_num_sample)
    first_recall_layer = int(math.ceil(math.log(2 * n_cluster * recall_num, n_cluster)))

    dataset = infer_dataloader.dataset
    # pyre-ignore [16]
    fields = dataset.input_fields
    # pyre-ignore [29]
    predict_sampler = TDMPredictSampler(
        sampler_config, fields, batch_size, is_training=False
    )
    predict_sampler.init_cluster(
        num_client_per_rank=(max_level - first_recall_layer) * num_worker_per_level
    )
    predict_sampler.launch_server()

    num_class = pipeline_config.model_config.num_class
    pos_prob_name: str = "probs1" if num_class == 2 else "probs"

    def _forward(
        batch: Batch,
        record_batch_t: RecordBatchTensor,
        node_ids: pa.Array,
        layer_id: int,
    ) -> Tuple[RecordBatchTensor, pa.Array]:
        with torch.no_grad():
            parsed_inputs = batch.to_dict(sparse_dtype=sparse_dtype)
            # when predicting with a model exported using INPUT_TILE,
            #  we set the batch size tensor to 1 to disable tiling.
            parsed_inputs["batch_size"] = torch.tensor(1, dtype=torch.int64)
            predictions = model(parsed_inputs, device)

            gt_node_ids = record_batch_t.get()[item_id_field]
            cur_batch_size = len(gt_node_ids)
            probs = predictions[pos_prob_name].reshape(cur_batch_size, -1)
            if layer_id == max_level - 1:
                k = recall_num
                candidate_ids = node_ids.to_numpy(zero_copy_only=False).reshape(
                    cur_batch_size, -1
                )
                sort_prob_index = torch.argsort(-probs, dim=1).cpu().numpy()
                sort_cand_ids = np.take_along_axis(
                    candidate_ids, sort_prob_index, axis=1
                )
                node_ids = []
                for i in range(cur_batch_size):
                    _, unique_indices = np.unique(sort_cand_ids[i], return_index=True)
                    node_ids.append(
                        np.take(sort_cand_ids[i], np.sort(unique_indices)[:k])
                    )
                node_ids = pa.array(node_ids)
            else:
                k = 2 * recall_num
                _, topk_indices_in_group = torch.topk(probs, k, dim=1)
                topk_indices = topk_indices_in_group + torch.arange(
                    cur_batch_size, device=device
                ).unsqueeze(1) * probs.size(1)
                topk_indices = topk_indices.reshape(-1).cpu().numpy()
                node_ids = node_ids.take(topk_indices)

            return record_batch_t, node_ids

    def _forward_loop(data_queue: Queue, pred_queue: Queue, layer_id: int) -> None:
        stop_cnt = 0
        while True:
            batch, record_batch_t, node_ids = data_queue.get(
                timeout=PREDICT_QUEUE_TIMEOUT
            )
            if batch is None:
                stop_cnt += 1
                if stop_cnt == num_worker_per_level:
                    for _ in range(num_worker_per_level):
                        pred_queue.put((None, None), timeout=PREDICT_QUEUE_TIMEOUT)
                    break
                else:
                    continue
            assert batch is not None
            pred = _forward(batch, record_batch_t, node_ids, layer_id)
            pred_queue.put(pred, timeout=PREDICT_QUEUE_TIMEOUT)

    def _write_loop(pred_queue: Queue, metric_queue: Queue) -> None:
        total = 0
        recall = 0
        while True:
            record_batch_t, node_ids = pred_queue.get(timeout=PREDICT_QUEUE_TIMEOUT)
            if record_batch_t is None:
                break

            output_dict = OrderedDict()
            reserve_batch_record = record_batch_t.get()
            gt_node_ids = reserve_batch_record[item_id_field]
            cur_batch_size = len(gt_node_ids)
            if reserved_cols is not None:
                for c in reserved_cols:
                    output_dict[c] = reserve_batch_record[c]
            output_dict["recall_ids"] = node_ids
            writer.write(output_dict)

            # calculate precision and recall
            retrieval_result = np.any(
                np.equal(
                    gt_node_ids.to_numpy(zero_copy_only=False)[:, None],
                    np.array(node_ids.to_pylist()),
                ),
                axis=1,
            )
            total += cur_batch_size
            recall += np.sum(retrieval_result)
        metric_queue.put((total, recall), timeout=PREDICT_QUEUE_TIMEOUT)

    in_queues = [Queue(maxsize=2) for _ in range(max_level - first_recall_layer + 1)]
    out_queues = [Queue(maxsize=2) for _ in range(max_level - first_recall_layer)]
    metric_queue = Queue(maxsize=1)

    data_p_list = []
    for i in range(max_level - first_recall_layer):
        for j in range(num_worker_per_level):
            p = Process(
                target=_tdm_predict_data_worker,
                args=(
                    predict_sampler,
                    parser,
                    first_recall_layer,
                    n_cluster,
                    in_queues[i],
                    out_queues[i],
                    i == 0,
                    i * num_worker_per_level + j,
                ),
            )
            p.start()
            data_p_list.append(p)

    forward_t_list = []
    for i in range(max_level - first_recall_layer):
        t = Thread(
            target=_forward_loop,
            args=(out_queues[i], in_queues[i + 1], i + first_recall_layer),
        )
        t.start()
        forward_t_list.append(t)

    write_t = Thread(
        target=_write_loop, args=(in_queues[len(in_queues) - 1], metric_queue)
    )
    write_t.start()

    i_step = 0
    while True:
        try:
            batch = next(infer_iterator)
            in_queues[0].put((batch.reserves, None), timeout=PREDICT_QUEUE_TIMEOUT)
            if is_local_rank_zero:
                plogger.log(i_step)
            if is_profiling:
                prof.step()
            i_step += 1
        except StopIteration:
            break

    for _ in range(num_worker_per_level):
        in_queues[0].put((None, None), timeout=PREDICT_QUEUE_TIMEOUT)
    for p in data_p_list:
        p.join()
    for t in forward_t_list:
        t.join()
    write_t.join()
    writer.close()

    total, recall = metric_queue.get(timeout=PREDICT_QUEUE_TIMEOUT)
    total_t = torch.tensor(total, device=device)
    recall_t = torch.tensor(recall, device=device)
    dist.all_reduce(total_t, op=ReduceOp.SUM)
    dist.all_reduce(recall_t, op=ReduceOp.SUM)
    # pyre-ignore [6]
    recall_ratio = recall_t.cpu().item() / total_t.cpu().item()

    if is_profiling:
        prof.stop()
    if is_rank_zero:
        logger.info(f"Retrieval Finished. Recall:{recall_ratio}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--scripted_model_path",
        type=str,
        default=None,
        help="scripted model to be evaled, if not specified, use the checkpoint",
    )
    parser.add_argument(
        "--predict_input_path",
        type=str,
        default=None,
        help="inference data input path",
    )
    parser.add_argument(
        "--predict_output_path",
        type=str,
        default=None,
        help="inference data output path",
    )
    parser.add_argument(
        "--reserved_columns",
        type=str,
        default=None,
        help="column names to reserved in output",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=None,
        help="predict batch size, default will use batch size in config.",
    )
    parser.add_argument(
        "--is_profiling",
        action="store_true",
        default=False,
        help="profiling predict progress.",
    )
    parser.add_argument(
        "--debug_level",
        type=int,
        default=0,
        help="debug level for debug parsed inputs etc.",
    )
    parser.add_argument(
        "--dataset_type",
        type=str,
        default=None,
        help="dataset type, default will use dataset type in config.",
    )
    parser.add_argument(
        "--recall_num", type=int, default=200, help="recall item num per user."
    )
    parser.add_argument("--n_cluster", type=int, default=2, help="tree cluster num.")
    parser.add_argument(
        "--num_worker_per_level",
        type=int,
        default=1,
        help="num data generate worker per tree level.",
    )
    args, extra_args = parser.parse_known_args()

    tdm_retrieval(
        predict_input_path=args.predict_input_path,
        predict_output_path=args.predict_output_path,
        scripted_model_path=args.scripted_model_path,
        recall_num=args.recall_num,
        n_cluster=args.n_cluster,
        reserved_columns=args.reserved_columns,
        batch_size=args.batch_size,
        is_profiling=args.is_profiling,
        debug_level=args.debug_level,
        dataset_type=args.dataset_type,
        num_worker_per_level=args.num_worker_per_level,
    )
