tzrec/models/multi_task_rank.py (136 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Union
import torch
from tzrec.datasets.utils import Batch
from tzrec.features.feature import BaseFeature
from tzrec.models.rank_model import RankModel
from tzrec.modules.utils import div_no_nan
from tzrec.protos.model_pb2 import ModelConfig
from tzrec.protos.tower_pb2 import BayesTaskTower, InterventionTaskTower, TaskTower
class MultiTaskRank(RankModel):
"""Multi task model for ranking.
Args:
model_config (ModelConfig): an instance of ModelConfig.
features (list): list of features.
labels (list): list of label names.
sample_weights (list): sample weight names.
"""
def __init__(
self,
model_config: ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
super().__init__(model_config, features, labels, sample_weights, **kwargs)
self._task_tower_cfgs = list(self._model_config.task_towers)
def _multi_task_output_to_prediction(
self, output: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
predictions = {}
for task_tower_cfg in self._task_tower_cfgs:
tower_name = task_tower_cfg.tower_name
for loss_cfg in task_tower_cfg.losses:
predictions.update(
self._output_to_prediction_impl(
output[tower_name],
loss_cfg,
num_class=task_tower_cfg.num_class,
suffix=f"_{tower_name}",
)
)
return predictions
def has_weight(
self, task_cfg: Union[TaskTower, BayesTaskTower, InterventionTaskTower]
) -> bool:
"""Task cfg should have weight."""
if (
task_cfg.HasField("sample_weight_name")
or task_cfg.HasField("weight")
or task_cfg.HasField("task_space_indicator_label")
):
return True
else:
return False
def init_loss(self) -> None:
"""Initialize loss modules."""
for task_tower_cfg in self._task_tower_cfgs:
tower_name = task_tower_cfg.tower_name
reduction = "none" if self.has_weight(task_tower_cfg) else "mean"
for loss_cfg in task_tower_cfg.losses:
self._init_loss_impl(
loss_cfg,
num_class=task_tower_cfg.num_class,
reduction=reduction,
suffix=f"_{tower_name}",
)
def loss(
self, predictions: Dict[str, torch.Tensor], batch: Batch
) -> Dict[str, torch.Tensor]:
"""Compute loss of the model."""
losses = {}
for task_tower_cfg in self._task_tower_cfgs:
tower_name = task_tower_cfg.tower_name
label_name = task_tower_cfg.label_name
if self.has_weight(task_tower_cfg):
if task_tower_cfg.sample_weight_name:
sample_weight = task_tower_cfg.sample_weight_name
loss_weight = batch.sample_weights[sample_weight]
else:
loss_weight = torch.Tensor([1.0]).to(
batch.labels[label_name].device
)
if task_tower_cfg.HasField("task_space_indicator_label"):
in_task_space = (
batch.labels[task_tower_cfg.task_space_indicator_label] > 0
).float()
loss_weight = loss_weight * (
task_tower_cfg.in_task_space_weight * in_task_space
+ task_tower_cfg.out_task_space_weight * (1 - in_task_space)
)
loss_weight = div_no_nan(loss_weight, torch.mean(loss_weight))
loss_weight *= task_tower_cfg.weight
else:
loss_weight = None
for loss_cfg in task_tower_cfg.losses:
losses.update(
self._loss_impl(
predictions,
batch,
batch.labels[label_name],
loss_weight,
loss_cfg,
num_class=task_tower_cfg.num_class,
suffix=f"_{tower_name}",
)
)
losses.update(self._loss_collection)
return losses
def init_metric(self) -> None:
"""Initialize metric modules."""
for task_tower_cfg in self._task_tower_cfgs:
tower_name = task_tower_cfg.tower_name
for metric_cfg in task_tower_cfg.metrics:
self._init_metric_impl(
metric_cfg,
num_class=task_tower_cfg.num_class,
suffix=f"_{tower_name}",
)
for loss_cfg in task_tower_cfg.losses:
self._init_loss_metric_impl(loss_cfg, suffix=f"_{tower_name}")
def update_metric(
self,
predictions: Dict[str, torch.Tensor],
batch: Batch,
losses: Optional[Dict[str, torch.Tensor]] = None,
) -> None:
"""Update metric state.
Args:
predictions (dict): a dict of predicted result.
batch (Batch): input batch data.
losses (dict, optional): a dict of loss.
"""
for task_tower_cfg in self._task_tower_cfgs:
tower_name = task_tower_cfg.tower_name
label_name = task_tower_cfg.label_name
for metric_cfg in task_tower_cfg.metrics:
self._update_metric_impl(
predictions,
batch,
batch.labels[label_name],
metric_cfg,
num_class=task_tower_cfg.num_class,
suffix=f"_{tower_name}",
)
if losses is not None:
for loss_cfg in task_tower_cfg.losses:
self._update_loss_metric_impl(
losses,
batch,
batch.labels[label_name],
loss_cfg,
suffix=f"_{tower_name}",
)