# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#    http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Optional

import torch
from torch import nn

from tzrec.datasets.utils import Batch
from tzrec.features.feature import BaseFeature
from tzrec.models.rank_model import RankModel
from tzrec.modules.interaction import InteractionArch
from tzrec.modules.mlp import MLP
from tzrec.protos.model_pb2 import ModelConfig
from tzrec.utils.config_util import config_to_kwargs


class DLRM(RankModel):
    """DLRM model.

    Args:
        model_config (ModelConfig): an instance of ModelConfig.
        features (list): list of features.
        labels (list): list of label names.
        sample_weights (list): sample weight names.
    """

    def __init__(
        self,
        model_config: ModelConfig,
        features: List[BaseFeature],
        labels: List[str],
        sample_weights: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> None:
        super().__init__(model_config, features, labels, sample_weights, **kwargs)
        self.init_input()

        assert self.embedding_group.has_group("dense"), "dense group is not specified"
        dense_dim = self.embedding_group.group_total_dim("dense")
        dense_feature_dims = self.embedding_group.group_feature_dims("dense")
        for feature_name in dense_feature_dims.keys():
            if "seq_encoder" in feature_name:
                raise Exception("dense group not have sequence features.")
        self.dense_mlp = MLP(
            dense_dim, **config_to_kwargs(self._model_config.dense_mlp)
        )

        assert self.embedding_group.has_group("sparse"), "sparse group is not specified"
        sparse_feature_dims = self.embedding_group.group_feature_dims("sparse")
        sparse_dim = self.embedding_group.group_total_dim("sparse")
        self.per_sparse_dim = 0
        for feature_name, dim in sparse_feature_dims.items():
            self.per_sparse_dim = dim
            if "seq_encoder" in feature_name:
                raise Exception("sparse group not have sequence features.")

        self.sparse_num = len(sparse_feature_dims)
        sparse_dims = set(sparse_feature_dims.values())
        if len(sparse_dims) > 1:
            raise Exception(
                f"sparse group feature dims must be the same, but we find {sparse_dims}"
            )
        if self.per_sparse_dim != self.dense_mlp.output_dim():
            raise Exception(
                "dense mlp last hidden_unit must be the same sparse feature dim"
            )

        self.interaction = InteractionArch(self.sparse_num + 1)

        feature_dim = self.dense_mlp.output_dim() + self.interaction.output_dim()
        if self._model_config.arch_with_sparse:
            feature_dim += sparse_dim

        self.final_mlp = MLP(feature_dim, **config_to_kwargs(self._model_config.final))
        self.output_mlp = nn.Linear(self.final_mlp.output_dim(), self._num_class)

    def predict(self, batch: Batch) -> Dict[str, torch.Tensor]:
        """Forward the model.

        Args:
            batch (Batch): input batch data.

        Return:
            predictions (dict): a dict of predicted result.
        """
        grouped_features = self.build_input(batch)

        # dense
        dense_group_feat = grouped_features["dense"]
        dense_feat = self.dense_mlp(dense_group_feat)

        # sparse
        sparse_group_feat = grouped_features["sparse"]
        sparse_feat = sparse_group_feat.reshape(
            -1, self.sparse_num, self.per_sparse_dim
        )

        # interaction
        interaction_feat = self.interaction(dense_feat, sparse_feat)

        # final mlp
        all_feat = torch.cat([interaction_feat, dense_feat], dim=-1)
        if self._model_config.arch_with_sparse:
            all_feat = torch.cat([all_feat, sparse_group_feat], dim=-1)
        y_final = self.final_mlp(all_feat)

        # output
        y = self.output_mlp(y_final)
        return self._output_to_prediction(y)
