tzrec/models/rank_model.py (311 lines of code) (raw):

# Copyright (c) 2024, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, List, Optional import torch import torchmetrics from torch import nn from tzrec.datasets.utils import BASE_DATA_GROUP, Batch from tzrec.features.feature import BaseFeature from tzrec.loss.jrc_loss import JRCLoss from tzrec.metrics.grouped_auc import GroupedAUC from tzrec.models.model import BaseModel from tzrec.modules.embedding import EmbeddingGroup from tzrec.modules.utils import div_no_nan from tzrec.modules.variational_dropout import VariationalDropout from tzrec.protos import model_pb2 from tzrec.protos.loss_pb2 import LossConfig from tzrec.protos.metric_pb2 import MetricConfig from tzrec.utils.config_util import config_to_kwargs @torch.fx.wrap def _update_tensor_dict( tensor_dict: Dict[str, torch.Tensor], new_tensor: torch.Tensor, key: str ) -> None: tensor_dict[key] = new_tensor class RankModel(BaseModel): """Base model for ranking. Args: model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. sample_weights (list): sample weight names. """ def __init__( self, model_config: model_pb2.ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: Optional[List[str]] = None, **kwargs: Any, ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) self._num_class = model_config.num_class self._label_name = labels[0] self._sample_weight_name = ( sample_weights[0] if sample_weights else sample_weights ) self._loss_collection = {} self.embedding_group = None self.group_variational_dropouts = None def init_input(self) -> None: """Build embedding group and group variational dropout.""" self.embedding_group = EmbeddingGroup( self._features, list(self._base_model_config.feature_groups) ) if self._base_model_config.HasField("variational_dropout"): self.group_variational_dropouts = nn.ModuleDict() variational_dropout_config = self._base_model_config.variational_dropout variational_dropout_config_dict = config_to_kwargs( variational_dropout_config ) for feature_group in list(self._base_model_config.feature_groups): group_name = feature_group.group_name if feature_group.group_type != model_pb2.SEQUENCE: feature_dim = self.embedding_group.group_feature_dims(group_name) if len(feature_dim) > 1: variational_dropout = VariationalDropout( feature_dim, group_name, **variational_dropout_config_dict ) self.group_variational_dropouts[group_name] = ( variational_dropout ) def build_input(self, batch: Batch) -> Dict[str, torch.Tensor]: """Build input feature.""" feature_dict = self.embedding_group(batch) if self.group_variational_dropouts is not None: for ( group_name, variational_dropout, ) in self.group_variational_dropouts.items(): feature, variational_dropout_loss = variational_dropout( feature_dict[group_name] ) _update_tensor_dict(feature_dict, feature, group_name) _update_tensor_dict( self._loss_collection, variational_dropout_loss, group_name + "_feature_p_loss", ) return feature_dict def _output_to_prediction_impl( self, output: torch.Tensor, loss_cfg: LossConfig, num_class: int = 1, suffix: str = "", ) -> Dict[str, torch.Tensor]: predictions = {} loss_type = loss_cfg.WhichOneof("loss") if loss_type == "binary_cross_entropy": assert num_class == 1, f"num_class must be 1 when loss type is {loss_type}" output = torch.squeeze(output, dim=1) predictions["logits" + suffix] = output predictions["probs" + suffix] = torch.sigmoid(output) elif loss_type == "softmax_cross_entropy": assert num_class > 1, ( f"num_class must be greater than 1 when loss type is {loss_type}" ) probs = torch.softmax(output, dim=1) predictions["logits" + suffix] = output predictions["probs" + suffix] = probs if num_class == 2: predictions["probs1" + suffix] = probs[:, 1] elif loss_type == "jrc_loss": assert num_class == 2, f"num_class must be 2 when loss type is {loss_type}" probs = torch.softmax(output, dim=1) predictions["logits" + suffix] = output predictions["probs" + suffix] = probs predictions["probs1" + suffix] = probs[:, 1] elif loss_type == "l2_loss": output = torch.squeeze(output, dim=1) predictions["y" + suffix] = output else: raise NotImplementedError return predictions def _output_to_prediction( self, output: torch.Tensor, suffix: str = "" ) -> Dict[str, torch.Tensor]: predictions = {} for loss_cfg in self._base_model_config.losses: predictions.update( self._output_to_prediction_impl( output, loss_cfg, num_class=self._num_class, suffix=suffix ) ) return predictions def _init_loss_impl( self, loss_cfg: LossConfig, num_class: int = 1, reduction: str = "none", suffix: str = "", ) -> None: loss_type = loss_cfg.WhichOneof("loss") loss_name = loss_type + suffix if loss_type == "binary_cross_entropy": self._loss_modules[loss_name] = nn.BCEWithLogitsLoss(reduction=reduction) elif loss_type == "softmax_cross_entropy": self._loss_modules[loss_name] = nn.CrossEntropyLoss(reduction=reduction) elif loss_type == "jrc_loss": assert num_class == 2, f"num_class must be 2 when loss type is {loss_type}" self._loss_modules[loss_name] = JRCLoss( alpha=loss_cfg.jrc_loss.alpha, reduction=reduction ) elif loss_type == "l2_loss": self._loss_modules[loss_name] = nn.MSELoss(reduction=reduction) else: raise ValueError(f"loss[{loss_type}] is not supported yet.") def init_loss(self) -> None: """Initialize loss modules.""" for loss_cfg in self._base_model_config.losses: reduction = "none" if self._sample_weight_name else "mean" self._init_loss_impl(loss_cfg, self._num_class, reduction=reduction) def _loss_impl( self, predictions: Dict[str, torch.Tensor], batch: Batch, label: torch.Tensor, loss_weight: Optional[torch.Tensor], loss_cfg: LossConfig, num_class: int = 1, suffix: str = "", ) -> Dict[str, torch.Tensor]: losses = {} loss_type = loss_cfg.WhichOneof("loss") loss_name = loss_type + suffix if loss_type == "binary_cross_entropy": pred = predictions["logits" + suffix] label = label.to(torch.float32) losses[loss_name] = self._loss_modules[loss_name](pred, label) elif loss_type == "softmax_cross_entropy": pred = predictions["logits" + suffix] losses[loss_name] = self._loss_modules[loss_name](pred, label) elif loss_type == "jrc_loss": assert num_class == 2, f"num_class must be 2 when loss type is {loss_type}" pred = predictions["logits" + suffix] session_id = batch.sparse_features[BASE_DATA_GROUP][ loss_cfg.jrc_loss.session_name ].values() losses[loss_name] = self._loss_modules[loss_name](pred, label, session_id) elif loss_type == "l2_loss": pred = predictions["y" + suffix] losses[loss_name] = self._loss_modules[loss_name](pred, label) else: raise ValueError(f"loss[{loss_type}] is not supported yet.") if loss_weight is not None: losses[loss_name] = torch.mean(losses[loss_name] * loss_weight) return losses def loss( self, predictions: Dict[str, torch.Tensor], batch: Batch ) -> Dict[str, torch.Tensor]: """Compute loss of the model.""" losses = {} if self._sample_weight_name: loss_weight = batch.sample_weights[self._sample_weight_name] loss_weight = div_no_nan(loss_weight, torch.mean(loss_weight)) else: loss_weight = None for loss_cfg in self._base_model_config.losses: losses.update( self._loss_impl( predictions, batch, batch.labels[self._label_name], loss_weight, loss_cfg, num_class=self._num_class, ) ) losses.update(self._loss_collection) return losses def _init_metric_impl( self, metric_cfg: MetricConfig, num_class: int = 1, suffix: str = "" ) -> None: metric_type = metric_cfg.WhichOneof("metric") oneof_metric_cfg = getattr(metric_cfg, metric_type) metric_kwargs = config_to_kwargs(oneof_metric_cfg) metric_name = metric_type + suffix if metric_type == "auc": assert num_class <= 2, ( f"num_class must less than 2 when metric type is {metric_type}" ) self._metric_modules[metric_name] = torchmetrics.AUROC( task="binary", **metric_kwargs ) elif metric_type == "multiclass_auc": self._metric_modules[metric_name] = torchmetrics.AUROC( task="multiclass", num_classes=num_class, **metric_kwargs ) elif metric_type == "mean_absolute_error": self._metric_modules[metric_name] = torchmetrics.MeanAbsoluteError() elif metric_type == "mean_squared_error": self._metric_modules[metric_name] = torchmetrics.MeanSquaredError() elif metric_type == "accuracy": self._metric_modules[metric_name] = torchmetrics.Accuracy( task="multiclass" if num_class > 1 else "binary", num_classes=num_class, **metric_kwargs, ) elif metric_type == "grouped_auc": assert num_class <= 2, ( f"num_class must less than 2 when metric type is {metric_type}" ) self._metric_modules[metric_name] = GroupedAUC() else: raise ValueError(f"{metric_type} is not supported for this model") def init_metric(self) -> None: """Initialize metric modules.""" for metric_cfg in self._base_model_config.metrics: self._init_metric_impl(metric_cfg, self._num_class) for loss_cfg in self._base_model_config.losses: self._init_loss_metric_impl(loss_cfg) def _update_metric_impl( self, predictions: Dict[str, torch.Tensor], batch: Batch, label: torch.Tensor, metric_cfg: MetricConfig, num_class: int = 1, suffix: str = "", ) -> None: metric_type = metric_cfg.WhichOneof("metric") oneof_metric_cfg = getattr(metric_cfg, metric_type) metric_name = metric_type + suffix base_sparse_feat = None if metric_type in ["grouped_auc"]: base_sparse_feat = batch.sparse_features[BASE_DATA_GROUP].to_dict() if metric_type == "auc": pred = ( predictions["probs" + suffix] if num_class == 1 else predictions["probs1" + suffix] ) self._metric_modules[metric_name].update(pred, label) elif metric_type == "multiclass_auc": pred = predictions["probs" + suffix] self._metric_modules[metric_name].update(pred, label) elif metric_type == "mean_absolute_error": pred = predictions["y" + suffix] self._metric_modules[metric_name].update(pred, label) elif metric_type == "mean_squared_error": pred = predictions["y" + suffix] self._metric_modules[metric_name].update(pred, label) elif metric_type == "accuracy": pred = predictions["probs" + suffix] self._metric_modules[metric_name].update(pred, label) elif metric_type == "grouped_auc": pred = ( predictions["probs" + suffix] if num_class == 1 else predictions["probs1" + suffix] ) # pyre-ignore [16] grouping_key = base_sparse_feat[ oneof_metric_cfg.grouping_key ].to_padded_dense(1)[:, 0] self._metric_modules[metric_name].update(pred, label, grouping_key) else: raise ValueError(f"{metric_type} is not supported for this model") def update_metric( self, predictions: Dict[str, torch.Tensor], batch: Batch, losses: Optional[Dict[str, torch.Tensor]] = None, ) -> None: """Update metric state. Args: predictions (dict): a dict of predicted result. batch (Batch): input batch data. losses (dict, optional): a dict of loss. """ for metric_cfg in self._base_model_config.metrics: self._update_metric_impl( predictions, batch, batch.labels[self._label_name], metric_cfg, num_class=self._num_class, ) if losses is not None: for loss_cfg in self._base_model_config.losses: self._update_loss_metric_impl( losses, batch, batch.labels[self._label_name], loss_cfg )