tzrec/models/rank_model.py (311 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional
import torch
import torchmetrics
from torch import nn
from tzrec.datasets.utils import BASE_DATA_GROUP, Batch
from tzrec.features.feature import BaseFeature
from tzrec.loss.jrc_loss import JRCLoss
from tzrec.metrics.grouped_auc import GroupedAUC
from tzrec.models.model import BaseModel
from tzrec.modules.embedding import EmbeddingGroup
from tzrec.modules.utils import div_no_nan
from tzrec.modules.variational_dropout import VariationalDropout
from tzrec.protos import model_pb2
from tzrec.protos.loss_pb2 import LossConfig
from tzrec.protos.metric_pb2 import MetricConfig
from tzrec.utils.config_util import config_to_kwargs
@torch.fx.wrap
def _update_tensor_dict(
tensor_dict: Dict[str, torch.Tensor], new_tensor: torch.Tensor, key: str
) -> None:
tensor_dict[key] = new_tensor
class RankModel(BaseModel):
"""Base model for ranking.
Args:
model_config (ModelConfig): an instance of ModelConfig.
features (list): list of features.
labels (list): list of label names.
sample_weights (list): sample weight names.
"""
def __init__(
self,
model_config: model_pb2.ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
super().__init__(model_config, features, labels, sample_weights, **kwargs)
self._num_class = model_config.num_class
self._label_name = labels[0]
self._sample_weight_name = (
sample_weights[0] if sample_weights else sample_weights
)
self._loss_collection = {}
self.embedding_group = None
self.group_variational_dropouts = None
def init_input(self) -> None:
"""Build embedding group and group variational dropout."""
self.embedding_group = EmbeddingGroup(
self._features, list(self._base_model_config.feature_groups)
)
if self._base_model_config.HasField("variational_dropout"):
self.group_variational_dropouts = nn.ModuleDict()
variational_dropout_config = self._base_model_config.variational_dropout
variational_dropout_config_dict = config_to_kwargs(
variational_dropout_config
)
for feature_group in list(self._base_model_config.feature_groups):
group_name = feature_group.group_name
if feature_group.group_type != model_pb2.SEQUENCE:
feature_dim = self.embedding_group.group_feature_dims(group_name)
if len(feature_dim) > 1:
variational_dropout = VariationalDropout(
feature_dim, group_name, **variational_dropout_config_dict
)
self.group_variational_dropouts[group_name] = (
variational_dropout
)
def build_input(self, batch: Batch) -> Dict[str, torch.Tensor]:
"""Build input feature."""
feature_dict = self.embedding_group(batch)
if self.group_variational_dropouts is not None:
for (
group_name,
variational_dropout,
) in self.group_variational_dropouts.items():
feature, variational_dropout_loss = variational_dropout(
feature_dict[group_name]
)
_update_tensor_dict(feature_dict, feature, group_name)
_update_tensor_dict(
self._loss_collection,
variational_dropout_loss,
group_name + "_feature_p_loss",
)
return feature_dict
def _output_to_prediction_impl(
self,
output: torch.Tensor,
loss_cfg: LossConfig,
num_class: int = 1,
suffix: str = "",
) -> Dict[str, torch.Tensor]:
predictions = {}
loss_type = loss_cfg.WhichOneof("loss")
if loss_type == "binary_cross_entropy":
assert num_class == 1, f"num_class must be 1 when loss type is {loss_type}"
output = torch.squeeze(output, dim=1)
predictions["logits" + suffix] = output
predictions["probs" + suffix] = torch.sigmoid(output)
elif loss_type == "softmax_cross_entropy":
assert num_class > 1, (
f"num_class must be greater than 1 when loss type is {loss_type}"
)
probs = torch.softmax(output, dim=1)
predictions["logits" + suffix] = output
predictions["probs" + suffix] = probs
if num_class == 2:
predictions["probs1" + suffix] = probs[:, 1]
elif loss_type == "jrc_loss":
assert num_class == 2, f"num_class must be 2 when loss type is {loss_type}"
probs = torch.softmax(output, dim=1)
predictions["logits" + suffix] = output
predictions["probs" + suffix] = probs
predictions["probs1" + suffix] = probs[:, 1]
elif loss_type == "l2_loss":
output = torch.squeeze(output, dim=1)
predictions["y" + suffix] = output
else:
raise NotImplementedError
return predictions
def _output_to_prediction(
self, output: torch.Tensor, suffix: str = ""
) -> Dict[str, torch.Tensor]:
predictions = {}
for loss_cfg in self._base_model_config.losses:
predictions.update(
self._output_to_prediction_impl(
output, loss_cfg, num_class=self._num_class, suffix=suffix
)
)
return predictions
def _init_loss_impl(
self,
loss_cfg: LossConfig,
num_class: int = 1,
reduction: str = "none",
suffix: str = "",
) -> None:
loss_type = loss_cfg.WhichOneof("loss")
loss_name = loss_type + suffix
if loss_type == "binary_cross_entropy":
self._loss_modules[loss_name] = nn.BCEWithLogitsLoss(reduction=reduction)
elif loss_type == "softmax_cross_entropy":
self._loss_modules[loss_name] = nn.CrossEntropyLoss(reduction=reduction)
elif loss_type == "jrc_loss":
assert num_class == 2, f"num_class must be 2 when loss type is {loss_type}"
self._loss_modules[loss_name] = JRCLoss(
alpha=loss_cfg.jrc_loss.alpha, reduction=reduction
)
elif loss_type == "l2_loss":
self._loss_modules[loss_name] = nn.MSELoss(reduction=reduction)
else:
raise ValueError(f"loss[{loss_type}] is not supported yet.")
def init_loss(self) -> None:
"""Initialize loss modules."""
for loss_cfg in self._base_model_config.losses:
reduction = "none" if self._sample_weight_name else "mean"
self._init_loss_impl(loss_cfg, self._num_class, reduction=reduction)
def _loss_impl(
self,
predictions: Dict[str, torch.Tensor],
batch: Batch,
label: torch.Tensor,
loss_weight: Optional[torch.Tensor],
loss_cfg: LossConfig,
num_class: int = 1,
suffix: str = "",
) -> Dict[str, torch.Tensor]:
losses = {}
loss_type = loss_cfg.WhichOneof("loss")
loss_name = loss_type + suffix
if loss_type == "binary_cross_entropy":
pred = predictions["logits" + suffix]
label = label.to(torch.float32)
losses[loss_name] = self._loss_modules[loss_name](pred, label)
elif loss_type == "softmax_cross_entropy":
pred = predictions["logits" + suffix]
losses[loss_name] = self._loss_modules[loss_name](pred, label)
elif loss_type == "jrc_loss":
assert num_class == 2, f"num_class must be 2 when loss type is {loss_type}"
pred = predictions["logits" + suffix]
session_id = batch.sparse_features[BASE_DATA_GROUP][
loss_cfg.jrc_loss.session_name
].values()
losses[loss_name] = self._loss_modules[loss_name](pred, label, session_id)
elif loss_type == "l2_loss":
pred = predictions["y" + suffix]
losses[loss_name] = self._loss_modules[loss_name](pred, label)
else:
raise ValueError(f"loss[{loss_type}] is not supported yet.")
if loss_weight is not None:
losses[loss_name] = torch.mean(losses[loss_name] * loss_weight)
return losses
def loss(
self, predictions: Dict[str, torch.Tensor], batch: Batch
) -> Dict[str, torch.Tensor]:
"""Compute loss of the model."""
losses = {}
if self._sample_weight_name:
loss_weight = batch.sample_weights[self._sample_weight_name]
loss_weight = div_no_nan(loss_weight, torch.mean(loss_weight))
else:
loss_weight = None
for loss_cfg in self._base_model_config.losses:
losses.update(
self._loss_impl(
predictions,
batch,
batch.labels[self._label_name],
loss_weight,
loss_cfg,
num_class=self._num_class,
)
)
losses.update(self._loss_collection)
return losses
def _init_metric_impl(
self, metric_cfg: MetricConfig, num_class: int = 1, suffix: str = ""
) -> None:
metric_type = metric_cfg.WhichOneof("metric")
oneof_metric_cfg = getattr(metric_cfg, metric_type)
metric_kwargs = config_to_kwargs(oneof_metric_cfg)
metric_name = metric_type + suffix
if metric_type == "auc":
assert num_class <= 2, (
f"num_class must less than 2 when metric type is {metric_type}"
)
self._metric_modules[metric_name] = torchmetrics.AUROC(
task="binary", **metric_kwargs
)
elif metric_type == "multiclass_auc":
self._metric_modules[metric_name] = torchmetrics.AUROC(
task="multiclass", num_classes=num_class, **metric_kwargs
)
elif metric_type == "mean_absolute_error":
self._metric_modules[metric_name] = torchmetrics.MeanAbsoluteError()
elif metric_type == "mean_squared_error":
self._metric_modules[metric_name] = torchmetrics.MeanSquaredError()
elif metric_type == "accuracy":
self._metric_modules[metric_name] = torchmetrics.Accuracy(
task="multiclass" if num_class > 1 else "binary",
num_classes=num_class,
**metric_kwargs,
)
elif metric_type == "grouped_auc":
assert num_class <= 2, (
f"num_class must less than 2 when metric type is {metric_type}"
)
self._metric_modules[metric_name] = GroupedAUC()
else:
raise ValueError(f"{metric_type} is not supported for this model")
def init_metric(self) -> None:
"""Initialize metric modules."""
for metric_cfg in self._base_model_config.metrics:
self._init_metric_impl(metric_cfg, self._num_class)
for loss_cfg in self._base_model_config.losses:
self._init_loss_metric_impl(loss_cfg)
def _update_metric_impl(
self,
predictions: Dict[str, torch.Tensor],
batch: Batch,
label: torch.Tensor,
metric_cfg: MetricConfig,
num_class: int = 1,
suffix: str = "",
) -> None:
metric_type = metric_cfg.WhichOneof("metric")
oneof_metric_cfg = getattr(metric_cfg, metric_type)
metric_name = metric_type + suffix
base_sparse_feat = None
if metric_type in ["grouped_auc"]:
base_sparse_feat = batch.sparse_features[BASE_DATA_GROUP].to_dict()
if metric_type == "auc":
pred = (
predictions["probs" + suffix]
if num_class == 1
else predictions["probs1" + suffix]
)
self._metric_modules[metric_name].update(pred, label)
elif metric_type == "multiclass_auc":
pred = predictions["probs" + suffix]
self._metric_modules[metric_name].update(pred, label)
elif metric_type == "mean_absolute_error":
pred = predictions["y" + suffix]
self._metric_modules[metric_name].update(pred, label)
elif metric_type == "mean_squared_error":
pred = predictions["y" + suffix]
self._metric_modules[metric_name].update(pred, label)
elif metric_type == "accuracy":
pred = predictions["probs" + suffix]
self._metric_modules[metric_name].update(pred, label)
elif metric_type == "grouped_auc":
pred = (
predictions["probs" + suffix]
if num_class == 1
else predictions["probs1" + suffix]
)
# pyre-ignore [16]
grouping_key = base_sparse_feat[
oneof_metric_cfg.grouping_key
].to_padded_dense(1)[:, 0]
self._metric_modules[metric_name].update(pred, label, grouping_key)
else:
raise ValueError(f"{metric_type} is not supported for this model")
def update_metric(
self,
predictions: Dict[str, torch.Tensor],
batch: Batch,
losses: Optional[Dict[str, torch.Tensor]] = None,
) -> None:
"""Update metric state.
Args:
predictions (dict): a dict of predicted result.
batch (Batch): input batch data.
losses (dict, optional): a dict of loss.
"""
for metric_cfg in self._base_model_config.metrics:
self._update_metric_impl(
predictions,
batch,
batch.labels[self._label_name],
metric_cfg,
num_class=self._num_class,
)
if losses is not None:
for loss_cfg in self._base_model_config.losses:
self._update_loss_metric_impl(
losses, batch, batch.labels[self._label_name], loss_cfg
)