# Copyright (c) 2025, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#    http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import nn
from torch._tensor import Tensor

from tzrec.datasets.utils import Batch
from tzrec.features.feature import BaseFeature
from tzrec.models.match_model import MatchModel, MatchTower
from tzrec.modules.capsule import CapsuleLayer
from tzrec.modules.mlp import MLP
from tzrec.protos import model_pb2, simi_pb2, tower_pb2
from tzrec.utils.config_util import config_to_kwargs


class MINDUserTower(MatchTower):
    """MIND user tower.

    Args:
        tower_config (Tower): mind user tower config.
        output_dim (int): user output embedding dimension.
        similarity (Similarity): when use COSINE similarity,
            will norm the output embedding.
        user_feature_group (FeatureGroupConfig): user feature group config.
        hist_feature_group (FeatureGroupConfig): history sequence feature group config.
        user_features (list): list of user features.
        hist_features (list): list of history sequence features.
        model_config (ModelConfig): model config.
    """

    def __init__(
        self,
        tower_config: tower_pb2.MINDUserTower,
        output_dim: int,
        similarity: simi_pb2.Similarity,
        user_feature_group: model_pb2.FeatureGroupConfig,
        hist_feature_group: model_pb2.FeatureGroupConfig,
        user_features: List[BaseFeature],
        hist_features: List[BaseFeature],
        model_config: model_pb2.ModelConfig,
    ) -> None:
        super().__init__(
            tower_config,
            output_dim,
            similarity,
            [user_feature_group, hist_feature_group],
            user_features + hist_features,
            model_config,
        )

        self._hist_group_name = tower_config.history_input

        self.init_input()

        user_feature_in = self.embedding_group.group_total_dim(self._group_name)
        if len(tower_config.user_mlp.hidden_units) > 1:
            self.user_mlp = MLP(
                in_features=user_feature_in,
                hidden_units=tower_config.user_mlp.hidden_units[0:-1],
                activation=tower_config.user_mlp.activation,
                use_bn=tower_config.user_mlp.use_bn,
                dropout_ratio=tower_config.user_mlp.dropout_ratio[0]
                if tower_config.user_mlp.dropout_ratio
                else None,
            )
            self.user_mlp_out = nn.Linear(
                self.user_mlp.output_dim(), tower_config.user_mlp.hidden_units[-1]
            )

        else:
            self.user_mlp = nn.Linear(
                self.user_mlp.user_feature_in, tower_config.user_mlp.hidden_units[-1]
            )
            self.user_mlp_out = None

        hist_feature_dim = self.embedding_group.group_total_dim(
            self._hist_group_name + ".sequence"
        )

        if (
            tower_config.hist_seq_mlp
            and len(tower_config.hist_seq_mlp.hidden_units) > 1
        ):
            self._hist_seq_mlp = MLP(
                in_features=hist_feature_dim,
                dim=3,
                hidden_units=tower_config.hist_seq_mlp.hidden_units[0:-1],
                activation=tower_config.hist_seq_mlp.activation,
                use_bn=tower_config.hist_seq_mlp.use_bn,
                bias=False,
                dropout_ratio=tower_config.hist_seq_mlp.dropout_ratio[0]
                if tower_config.hist_seq_mlp.dropout_ratio
                else None,
            )
            self._hist_seq_mlp_out = nn.Linear(
                self._hist_seq_mlp.output_dim(),
                tower_config.hist_seq_mlp.hidden_units[-1],
                bias=False,
            )
            capsule_input_dim = tower_config.hist_seq_mlp.hidden_units[-1]
        elif (
            tower_config.hist_seq_mlp
            and len(tower_config.hist_seq_mlp.hidden_units) > 0
        ):
            self._hist_seq_mlp = nn.Linear(
                hist_feature_dim,
                tower_config.hist_seq_mlp.hidden_units[-1],
                bias=False,
            )
            self._hist_seq_mlp_out = None
            capsule_input_dim = tower_config.hist_seq_mlp.hidden_units[-1]
        else:
            self._hist_seq_mlp = None
            capsule_input_dim = hist_feature_dim

        self._capsule_layer = CapsuleLayer(
            capsule_config=tower_config.capsule_config,
            input_dim=capsule_input_dim,
        )
        self._concat_mlp = MLP(
            in_features=tower_config.user_mlp.hidden_units[-1]
            + tower_config.capsule_config.high_dim,
            dim=3,
            **config_to_kwargs(tower_config.concat_mlp),
        )
        if self._output_dim > 0:
            self.output = nn.Linear(
                self._concat_mlp.output_dim(), output_dim, bias=False
            )

    def forward(
        self, batch: Batch
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """Forward the tower.

        Args:
            batch (Batch): input batch data.

        Returns:
            user_interests (Tensor): user interests.
        """
        user_feature_dict = self.build_input(batch)
        grp_hist_seq = user_feature_dict[self._hist_group_name + ".sequence"]
        grp_hist_len = user_feature_dict[self._hist_group_name + ".sequence_length"]

        if self.user_mlp_out:
            user_feature = self.user_mlp_out(
                self.user_mlp(user_feature_dict[self._group_name])
            )
        else:
            user_feature = self.user_mlp(user_feature_dict[self._group_name])

        if self._hist_seq_mlp:
            if self._hist_seq_mlp_out:
                hist_seq_feas = self._hist_seq_mlp_out(self._hist_seq_mlp(grp_hist_seq))
            else:
                hist_seq_feas = self._hist_seq_mlp(grp_hist_seq)
        else:
            hist_seq_feas = grp_hist_seq

        high_capsules, high_capsules_mask = self._capsule_layer(
            hist_seq_feas, grp_hist_len
        )

        # concatenate user feature and high_capsules
        user_feature = torch.unsqueeze(user_feature, dim=1)
        user_feature_tile = torch.tile(user_feature, [1, high_capsules.shape[1], 1])
        user_interests = torch.cat([user_feature_tile, high_capsules], dim=-1)
        user_interests = user_interests * high_capsules_mask.unsqueeze(-1).float()
        user_interests = self._concat_mlp(user_interests)
        user_interests = user_interests * high_capsules_mask.unsqueeze(-1).float()

        if self._output_dim > 0:
            user_interests = self.output(user_interests)

        if self._similarity == simi_pb2.Similarity.COSINE:
            user_interests = F.normalize(user_interests, p=2.0, dim=-1)

        if self.is_inference:
            return user_interests
        else:
            return user_interests, high_capsules_mask


class MINDItemTower(MatchTower):
    """MIND item tower.

    Args:
        tower_config (Tower): mind item tower config.
        output_dim (int): item output embedding dimension.
        similarity (Similarity): when use COSINE similarity,
            will norm the output embedding.
        item_feature_group (FeatureGroupConfig): item feature group config.
        item_features (list): list of item features.
        model_config (ModelConfig): model config.

    """

    def __init__(
        self,
        tower_config: tower_pb2.MINDItemTower,
        output_dim: int,
        similarity: simi_pb2.Similarity,
        item_feature_group: model_pb2.FeatureGroupConfig,
        item_features: List[BaseFeature],
        model_config: model_pb2.ModelConfig,
    ) -> None:
        super().__init__(
            tower_config,
            output_dim,
            similarity,
            [item_feature_group],
            item_features,
            model_config,
        )
        self.init_input()
        tower_feature_in = self.embedding_group.group_total_dim(self._group_name)
        self.mlp = MLP(tower_feature_in, **config_to_kwargs(tower_config.mlp))
        if self._output_dim > 0:
            self.output = nn.Linear(self.mlp.output_dim(), output_dim)

    def forward(self, batch: Batch) -> torch.Tensor:
        """Forward the tower.

        Args:
            batch (Batch): input batch data.

        Return:
            item_emb (Tensor): item embedding.

        """
        grouped_features = self.build_input(batch)
        item_emb = self.mlp(grouped_features[self._group_name])

        if self._output_dim > 0:
            item_emb = self.output(item_emb)

        if self._similarity == simi_pb2.Similarity.COSINE:
            item_emb = F.normalize(item_emb, p=2.0, dim=1)
        return item_emb


class MIND(MatchModel):
    """MIND model.

    Args:
        model_config (ModelConfig): an instance of ModelConfig.
        features (list): list of features.
        labels (list): list of label names.
        sample_weights (list): list of sample weight names.

    """

    def __init__(
        self,
        model_config: model_pb2.ModelConfig,
        features: List[BaseFeature],
        labels: List[str],
        sample_weights: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> None:
        super().__init__(model_config, features, labels, sample_weights, **kwargs)
        name_to_feature_group = {x.group_name: x for x in model_config.feature_groups}
        user_group = name_to_feature_group[self._model_config.user_tower.input]
        item_group = name_to_feature_group[self._model_config.item_tower.input]
        hist_group = name_to_feature_group[self._model_config.user_tower.history_input]

        user_features = self.get_features_in_feature_groups([user_group])
        item_features = self.get_features_in_feature_groups([item_group])
        hist_features = self.get_features_in_feature_groups([hist_group])

        self.user_tower = MINDUserTower(
            self._model_config.user_tower,
            self._model_config.output_dim,
            self._model_config.similarity,
            user_group,
            hist_group,
            user_features,
            hist_features,
            model_config,
        )
        self.item_tower = MINDItemTower(
            self._model_config.item_tower,
            self._model_config.output_dim,
            self._model_config.similarity,
            item_group,
            item_features,
            model_config,
        )

    def label_aware_attention(
        self,
        user_interests: torch.Tensor,
        item_emb: torch.Tensor,
        interest_mask: torch.Tensor,
    ) -> torch.Tensor:
        """Compute label-aware attention for user interests.

        Args:
            user_interests (Tensor): user interests.
            item_emb (Tensor): item embedding.
            interest_mask (Tensor): interest mask.

        Returns:
            user_emb (Tensor): user embedding.
        """
        batch_size = user_interests.size(0)
        pos_item_emb = item_emb[:batch_size]

        simi_pow = self._model_config.simi_pow
        interest_weight = torch.einsum("bkd, bd->bk", user_interests, pos_item_emb)

        threshold = (interest_mask.float() * 2 - 1) * 1e32
        interest_weight = torch.minimum(interest_weight, threshold)

        interest_weight = interest_weight.unsqueeze(-1)
        interest_weight = interest_weight * simi_pow
        interest_weight = torch.nn.functional.softmax(interest_weight, dim=1)

        user_emb = torch.sum(torch.multiply(interest_weight, user_interests), dim=1)
        return user_emb

    def predict(self, batch: Batch) -> Dict[str, Tensor]:
        """Forward the model.

        Args:
            batch (Batch): input batch data.

        Returns:
            simi (dict): a dict of predicted result.

        """
        if self.is_inference:
            user_interests = self.user_tower(batch)
        else:
            user_interests, interest_mask = self.user_tower(batch)
        item_emb = self.item_tower(batch)

        user_emb = self.label_aware_attention(
            user_interests,
            item_emb,
            interest_mask,  # pyre-ignore [61]
        )
        # if self._model_config.similarity == simi_pb2.Similarity.COSINE:
        #     user_emb = F.normalize(user_emb, p=2.0, dim=1)

        ui_sim = self.sim(user_emb, item_emb) / self._model_config.temperature
        return {"similarity": ui_sim}
