# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#    http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import Dict, Optional, Tuple

import numpy as np
import numpy.typing as npt
import pyarrow as pa
import pyarrow.compute as pc
import torch
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
from torchrec.streamable import Pipelineable

from tzrec.protos.data_pb2 import FieldType

BASE_DATA_GROUP = "__BASE__"
NEG_DATA_GROUP = "__NEG__"
CROSS_NEG_DATA_GROUP = "__CNEG__"

C_SAMPLE_MASK = "__SAMPLE_MASK__"
C_NEG_SAMPLE_MASK = "__NEG_SAMPLE_MASK__"

FIELD_TYPE_TO_PA = {
    FieldType.INT32: pa.int32(),
    FieldType.INT64: pa.int64(),
    FieldType.FLOAT: pa.float32(),
    FieldType.DOUBLE: pa.float64(),
    FieldType.STRING: pa.string(),
}


@dataclass
class ParsedData:
    """Internal parsed data structure."""

    name: str


@dataclass
class SparseData(ParsedData):
    """Internal data structure for sparse feature."""

    values: npt.NDArray
    lengths: npt.NDArray
    weights: Optional[npt.NDArray] = None


@dataclass
class DenseData(ParsedData):
    """Internal data structure for dense feature."""

    values: npt.NDArray


@dataclass
class SequenceSparseData(ParsedData):
    """Internal data structure for sequence sparse feature."""

    values: npt.NDArray
    key_lengths: npt.NDArray
    seq_lengths: npt.NDArray


@dataclass
class SequenceDenseData(ParsedData):
    """Internal data structure for sequence dense feature."""

    values: npt.NDArray
    seq_lengths: npt.NDArray


class RecordBatchTensor:
    """PyArrow RecordBatch use Tensor as buffer.

    For efficient transfer data between processes, e.g., mp.Queue.
    """

    def __init__(self, record_batch: Optional[pa.RecordBatch] = None) -> None:
        self._schema = None
        self._buff = None
        if record_batch:
            self._schema = record_batch.schema
            self._buff = torch.UntypedStorage.from_buffer(
                record_batch.serialize(), dtype=torch.uint8
            )

    def get(self) -> Optional[pa.RecordBatch]:
        """Get RecordBatch."""
        if self._buff is not None:
            # pyre-ignore[16]
            return pa.ipc.read_record_batch(
                pa.foreign_buffer(self._buff.data_ptr(), self._buff.size()),
                self._schema,
            )
        else:
            return None


@dataclass
class Batch(Pipelineable):
    """Input Batch."""

    # key of dense_features is data group name
    dense_features: Dict[str, KeyedTensor] = field(default_factory=dict)
    # key of sparse_features is data group name
    sparse_features: Dict[str, KeyedJaggedTensor] = field(default_factory=dict)
    # key of sequence_mulval_lengths is data group name
    #
    # for multi-value sequence, we flatten it, then store values & accumate lengths
    # into sparse_features, store key_lengths & seq_lengths into sequence_mulval_lengths
    #
    # e.g.
    # for the sequence `click_seq`: [[[3, 4], [5]], [6, [7, 8]]]
    # we can denote it in jagged formular with:
    #   values: [3, 4, 5, 6, 7, 8]
    #   key_lengths: [2, 1, 1, 2]
    #   seq_lengths: [2, 2]
    # then:
    #   sparse_features[dg]['click_seq'].values() = [3, 4, 5, 6, 7, 8]  # values
    #   sparse_features[dg]['click_seq'].lengths() = [3, 3]  # accumate lengths
    #   sequence_mulval_lengths[dg]['click_seq'].values() = [2, 1, 1, 2]  # key_lengths
    #   sequence_mulval_lengths[dg]['click_seq'].lengths() = [2, 2]  # seq_lengths
    sequence_mulval_lengths: Dict[str, KeyedJaggedTensor] = field(default_factory=dict)
    # key of sequence_dense_features is feature name
    sequence_dense_features: Dict[str, JaggedTensor] = field(default_factory=dict)
    # key of labels is label name
    labels: Dict[str, torch.Tensor] = field(default_factory=dict)
    # reserved inputs [for predict]
    reserves: RecordBatchTensor = field(default_factory=RecordBatchTensor)
    # size for user side input tile when do inference and INPUT_TILE=2 or 3
    tile_size: int = field(default=-1)
    # sample_weight
    sample_weights: Dict[str, torch.Tensor] = field(default_factory=dict)

    def to(self, device: torch.device, non_blocking: bool = False) -> "Batch":
        """Copy to specified device."""
        return Batch(
            dense_features={
                k: v.to(device=device, non_blocking=non_blocking)
                for k, v in self.dense_features.items()
            },
            sparse_features={
                k: v.to(device=device, non_blocking=non_blocking)
                for k, v in self.sparse_features.items()
            },
            sequence_mulval_lengths={
                k: v.to(device=device, non_blocking=non_blocking)
                for k, v in self.sequence_mulval_lengths.items()
            },
            sequence_dense_features={
                k: v.to(device=device, non_blocking=non_blocking)
                for k, v in self.sequence_dense_features.items()
            },
            labels={
                k: v.to(device=device, non_blocking=non_blocking)
                for k, v in self.labels.items()
            },
            reserves=self.reserves,
            tile_size=self.tile_size,
            sample_weights={
                k: v.to(device=device, non_blocking=non_blocking)
                for k, v in self.sample_weights.items()
            },
        )

    def record_stream(self, stream: torch.Stream) -> None:
        """Record which streams have used the tensor."""
        for v in self.dense_features.values():
            v.record_stream(stream)
        for v in self.sparse_features.values():
            v.record_stream(stream)
        for v in self.sequence_mulval_lengths.values():
            v.record_stream(stream)
        for v in self.sequence_dense_features.values():
            v.record_stream(stream)
        for v in self.labels.values():
            v.record_stream(stream)
        for v in self.sample_weights.values():
            v.record_stream(stream)

    def pin_memory(self) -> "Batch":
        """Copy to pinned memory."""
        # TODO(hongsheng.jhs): KeyedTensor do not have pin_memory()
        dense_features = {}
        for k, v in self.dense_features.items():
            dense_features[k] = KeyedTensor(
                keys=v.keys(),
                length_per_key=v.length_per_key(),
                values=v.values().pin_memory(),
                key_dim=v.key_dim(),
            )
        sequence_dense_features = {}
        for k, v in self.sequence_dense_features.items():
            weights = v._weights
            lengths = v._lengths
            offsets = v._offsets
            sequence_dense_features[k] = JaggedTensor(
                values=v.values().pin_memory(),
                weights=weights.pin_memory() if weights is not None else None,
                lengths=lengths.pin_memory() if lengths is not None else None,
                offsets=offsets.pin_memory() if offsets is not None else None,
            )
        return Batch(
            dense_features=dense_features,
            sparse_features={
                k: v.pin_memory() for k, v in self.sparse_features.items()
            },
            sequence_mulval_lengths={
                k: v.pin_memory() for k, v in self.sequence_mulval_lengths.items()
            },
            sequence_dense_features=sequence_dense_features,
            labels={k: v.pin_memory() for k, v in self.labels.items()},
            reserves=self.reserves,
            tile_size=self.tile_size,
            sample_weights={k: v.pin_memory() for k, v in self.sample_weights.items()},
        )

    def to_dict(
        self, sparse_dtype: Optional[torch.dtype] = None
    ) -> Dict[str, torch.Tensor]:
        """Convert to feature tensor dict."""
        tensor_dict = {}
        for x in self.dense_features.values():
            for k, v in x.to_dict().items():
                tensor_dict[f"{k}.values"] = v
        for x in self.sparse_features.values():
            if sparse_dtype:
                x = KeyedJaggedTensor(
                    keys=x.keys(),
                    values=x.values().to(sparse_dtype),
                    lengths=x.lengths().to(sparse_dtype),
                    weights=x.weights_or_none(),
                )
            for k, v in x.to_dict().items():
                tensor_dict[f"{k}.values"] = v.values()
                tensor_dict[f"{k}.lengths"] = v.lengths()
                if v.weights_or_none() is not None:
                    tensor_dict[f"{k}.weights"] = v.weights()
        for x in self.sequence_mulval_lengths.values():
            if sparse_dtype:
                x = KeyedJaggedTensor(
                    keys=x.keys(),
                    values=x.values().to(sparse_dtype),
                    lengths=x.lengths().to(sparse_dtype),
                )
            for k, v in x.to_dict().items():
                tensor_dict[f"{k}.key_lengths"] = v.values()
                tensor_dict[f"{k}.lengths"] = v.lengths()
        for k, v in self.sequence_dense_features.items():
            tensor_dict[f"{k}.values"] = v.values()
            tensor_dict[f"{k}.lengths"] = v.lengths()
        for k, v in self.labels.items():
            tensor_dict[f"{k}"] = v
        for k, v in self.sample_weights.items():
            tensor_dict[f"{k}"] = v
        if self.tile_size > 0:
            tensor_dict["batch_size"] = torch.tensor(self.tile_size, dtype=torch.int64)
        return tensor_dict


def process_hstu_seq_data(
    input_data: Dict[str, pa.Array],
    seq_attr: str,
    seq_str_delim: str,
) -> Tuple[pa.Array, pa.Array, pa.Array]:
    """Process sequence data for HSTU match model.

    Args:
        input_data: Dictionary containing input arrays
        seq_attr: Name of the sequence attribute field
        seq_str_delim: Delimiter used to separate sequence items

    Returns:
        Tuple containing:
        - input_data_k_split: pa.Array, Original sequence items
        - input_data_k_split_slice: pa.Array, Target items for autoregressive training
        - pre_seq_filter_reshaped_joined: pa.Array,
        Training sequence for autoregressive training
    """
    # default sequence data is string
    if pa.types.is_string(input_data[seq_attr].type):
        input_data_k_split = pc.split_pattern(input_data[seq_attr], seq_str_delim)
        # Get target items for training for autoregressive training
        # Example: [1,2,3,4,5] -> [2,3,4,5]
        input_data_k_split_slice = pc.list_flatten(
            pc.list_slice(input_data_k_split, start=1)
        )

        # Directly extract the training sequence for autoregressive training
        # Operation target example: [1,2,3,4,5] -> [1,2,3,4]
        # (corresponding target: [2,3,4,5])
        # As this can not be achieved by pyarrow.compute, and for loop is costly
        # we need to do this using pa.ListArray.from_arrays using offsets
        # 1. transfer to numpy and filter out the last item
        pre_seq = pc.list_flatten(input_data_k_split).to_numpy(zero_copy_only=False)
        # Mark last items of each seq with '-1'
        pre_seq[input_data_k_split.offsets.to_numpy()[1:] - 1] = "-1"
        # Filter out -1 marker elements
        mask = pre_seq != "-1"
        pre_seq_filter = pre_seq[mask]
        # 2. create offsets for reshaping filtered sequence
        # The offsets should be created extract the training sequence
        # Example: if the original offsets are [0,2,5,9], after filter,
        # for the offsets should be [0, 1, 3, 6]
        # that is, [0] + [2-1, 5-2, 9-3]
        pre_seq_filter_offsets = pa.array(
            np.concatenate(
                [
                    np.array([0]),
                    input_data_k_split.offsets[1:].to_numpy(zero_copy_only=False)
                    - np.arange(
                        1,
                        len(
                            input_data_k_split.offsets[1:].to_numpy(
                                zero_copy_only=False
                            )
                        )
                        + 1,
                    ),
                ]
            )
        )
        pre_seq_filter_reshaped = pa.ListArray.from_arrays(
            pre_seq_filter_offsets, pre_seq_filter
        )
        # Join filtered sequence with delimiter
        pre_seq_filter_reshaped_joined = pc.binary_join(
            pre_seq_filter_reshaped, seq_str_delim
        )

        return (
            input_data_k_split,
            input_data_k_split_slice,
            pre_seq_filter_reshaped_joined,
        )


def process_hstu_neg_sample(
    input_data: Dict[str, pa.Array],
    v: pa.Array,
    neg_sample_num: int,
    seq_str_delim: str,
    seq_attr: str,
) -> pa.Array:
    """Process negative samples for HSTU match model.

    Args:
        input_data: Dict[str, pa.Array], Dictionary containing input arrays
        v: pa.Array, negative samples.
        neg_sample_num: int, number of negative samples.
        seq_str_delim: str, delimiter for sequence string.
        seq_attr: str, attribute name of sequence.

    Returns:
        pa.Array: Processed negative samples
    """
    # The goal is to make neg samples concat to the training sequence
    # Example:
    # input_data[seq_attr] = ["1;2;3"]
    # neg_sample_num = 2
    # v = [4,5,6,7,8,9]
    # then the output should be [[1,4,5], [2,6,7], [3,8,9]]
    v_str = v.cast(pa.string())
    filtered_v_offsets = pa.array(
        np.concatenate(
            [
                np.array([0]),
                np.arange(neg_sample_num, len(v_str) + 1, neg_sample_num),
            ]
        )
    )
    # Reshape v for each input_data[seq_attr]
    # Example:[4,5,6,7,8,9] -> [[4,5], [6,7], [8,9]]
    filtered_v_palist = pa.ListArray.from_arrays(filtered_v_offsets, v_str)
    # Using string for join, as not found operation for ListArray achieving this
    # Example: [[4,5], [6,7], [8,9]] -> ["4;5", "6;7", "8;9"]
    sampled_joined = pc.binary_join(filtered_v_palist, seq_str_delim)
    # Combine training sequence and target items
    # Example: ["1;2;3"] + ["4;5", "6;7", "8;9"]
    # -> ["1;4;5", "2;6;7", "3;8;9"]
    return pc.binary_join_element_wise(
        input_data[seq_attr], sampled_joined, seq_str_delim
    )


def calc_slice_position(
    row_count: int,
    slice_id: int,
    slice_count: int,
    batch_size: int,
    drop_redundant_bs_eq_one: bool,
    pre_total_remain: int = 0,
) -> Tuple[int, int, int]:
    """Calc table read position according to the slice information.

    Args:
        row_count (int): table total row count.
        slice_id (int): worker id.
        slice_count (int): total worker number.
        batch_size (int): batch_size.
        drop_redundant_bs_eq_one (bool): drop last redundant batch with batch_size
            equal one to prevent train_eval hung.
        pre_total_remain (int): remaining total count in pre-table is
            insufficient to meet the batch_size requirement for each worker.

    Return:
        start (int): start row position in table.
        end (int): start row position in table.
        total_remain (int): remaining total count in curr-table is
            insufficient to meet the batch_size requirement for each worker.
    """
    pre_remain_size = int(pre_total_remain / slice_count)
    pre_remain_split_point = pre_total_remain % slice_count

    size = int((row_count + pre_total_remain) / slice_count)
    split_point = (row_count + pre_total_remain) % slice_count
    if slice_id < split_point:
        start = slice_id * (size + 1)
        end = start + (size + 1)
    else:
        start = split_point * (size + 1) + (slice_id - split_point) * size
        end = start + size

    real_start = (
        start - pre_remain_size * slice_id - min(pre_remain_split_point, slice_id)
    )
    real_end = (
        end
        - pre_remain_size * (slice_id + 1)
        - min(pre_remain_split_point, slice_id + 1)
    )
    # when (end - start) % bz = 1 on some workers and
    # (end - start) % bz = 0 on other workers, train_eval will hang
    if (
        drop_redundant_bs_eq_one
        and split_point != 0
        and (end - start) % batch_size == 1
        and size % batch_size == 0
    ):
        real_end = real_end - 1
        split_point = 0
    return real_start, real_end, (size % batch_size) * slice_count + split_point


def remove_nullable(field_type: pa.DataType) -> pa.DataType:
    """Recursive removal of the null=False property from lists and nested lists."""
    if pa.types.is_list(field_type):
        # Get element fields
        value_field = field_type.value_field
        # Change the nullable to True
        normalized_value_field = value_field.with_nullable(True)
        # Recursive processing of element types
        normalized_value_type = remove_nullable(normalized_value_field.type)
        # Construct a new list type
        return pa.list_(normalized_value_type)

    else:
        return field_type
