tzrec/models/multi_task_rank.py (136 lines of code) (raw):

# Copyright (c) 2024, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, List, Optional, Union import torch from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature from tzrec.models.rank_model import RankModel from tzrec.modules.utils import div_no_nan from tzrec.protos.model_pb2 import ModelConfig from tzrec.protos.tower_pb2 import BayesTaskTower, InterventionTaskTower, TaskTower class MultiTaskRank(RankModel): """Multi task model for ranking. Args: model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. sample_weights (list): sample weight names. """ def __init__( self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: Optional[List[str]] = None, **kwargs: Any, ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) self._task_tower_cfgs = list(self._model_config.task_towers) def _multi_task_output_to_prediction( self, output: Dict[str, torch.Tensor] ) -> Dict[str, torch.Tensor]: predictions = {} for task_tower_cfg in self._task_tower_cfgs: tower_name = task_tower_cfg.tower_name for loss_cfg in task_tower_cfg.losses: predictions.update( self._output_to_prediction_impl( output[tower_name], loss_cfg, num_class=task_tower_cfg.num_class, suffix=f"_{tower_name}", ) ) return predictions def has_weight( self, task_cfg: Union[TaskTower, BayesTaskTower, InterventionTaskTower] ) -> bool: """Task cfg should have weight.""" if ( task_cfg.HasField("sample_weight_name") or task_cfg.HasField("weight") or task_cfg.HasField("task_space_indicator_label") ): return True else: return False def init_loss(self) -> None: """Initialize loss modules.""" for task_tower_cfg in self._task_tower_cfgs: tower_name = task_tower_cfg.tower_name reduction = "none" if self.has_weight(task_tower_cfg) else "mean" for loss_cfg in task_tower_cfg.losses: self._init_loss_impl( loss_cfg, num_class=task_tower_cfg.num_class, reduction=reduction, suffix=f"_{tower_name}", ) def loss( self, predictions: Dict[str, torch.Tensor], batch: Batch ) -> Dict[str, torch.Tensor]: """Compute loss of the model.""" losses = {} for task_tower_cfg in self._task_tower_cfgs: tower_name = task_tower_cfg.tower_name label_name = task_tower_cfg.label_name if self.has_weight(task_tower_cfg): if task_tower_cfg.sample_weight_name: sample_weight = task_tower_cfg.sample_weight_name loss_weight = batch.sample_weights[sample_weight] else: loss_weight = torch.Tensor([1.0]).to( batch.labels[label_name].device ) if task_tower_cfg.HasField("task_space_indicator_label"): in_task_space = ( batch.labels[task_tower_cfg.task_space_indicator_label] > 0 ).float() loss_weight = loss_weight * ( task_tower_cfg.in_task_space_weight * in_task_space + task_tower_cfg.out_task_space_weight * (1 - in_task_space) ) loss_weight = div_no_nan(loss_weight, torch.mean(loss_weight)) loss_weight *= task_tower_cfg.weight else: loss_weight = None for loss_cfg in task_tower_cfg.losses: losses.update( self._loss_impl( predictions, batch, batch.labels[label_name], loss_weight, loss_cfg, num_class=task_tower_cfg.num_class, suffix=f"_{tower_name}", ) ) losses.update(self._loss_collection) return losses def init_metric(self) -> None: """Initialize metric modules.""" for task_tower_cfg in self._task_tower_cfgs: tower_name = task_tower_cfg.tower_name for metric_cfg in task_tower_cfg.metrics: self._init_metric_impl( metric_cfg, num_class=task_tower_cfg.num_class, suffix=f"_{tower_name}", ) for loss_cfg in task_tower_cfg.losses: self._init_loss_metric_impl(loss_cfg, suffix=f"_{tower_name}") def update_metric( self, predictions: Dict[str, torch.Tensor], batch: Batch, losses: Optional[Dict[str, torch.Tensor]] = None, ) -> None: """Update metric state. Args: predictions (dict): a dict of predicted result. batch (Batch): input batch data. losses (dict, optional): a dict of loss. """ for task_tower_cfg in self._task_tower_cfgs: tower_name = task_tower_cfg.tower_name label_name = task_tower_cfg.label_name for metric_cfg in task_tower_cfg.metrics: self._update_metric_impl( predictions, batch, batch.labels[label_name], metric_cfg, num_class=task_tower_cfg.num_class, suffix=f"_{tower_name}", ) if losses is not None: for loss_cfg in task_tower_cfg.losses: self._update_loss_metric_impl( losses, batch, batch.labels[label_name], loss_cfg, suffix=f"_{tower_name}", )