tzrec/models/model.py (152 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Alibaba, Inc. and its affiliates.
from collections import OrderedDict
from itertools import chain
from queue import Queue
from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
import torchmetrics
from torch import nn
from torchrec.modules.embedding_modules import (
EmbeddingBagCollectionInterface,
EmbeddingCollectionInterface,
)
from tzrec.datasets.data_parser import DataParser
from tzrec.datasets.utils import Batch
from tzrec.features.feature import BaseFeature
from tzrec.modules.utils import BaseModule
from tzrec.protos.loss_pb2 import LossConfig
from tzrec.protos.model_pb2 import FeatureGroupConfig, ModelConfig
from tzrec.utils.load_class import get_register_class_meta
_MODEL_CLASS_MAP = {}
_meta_cls = get_register_class_meta(_MODEL_CLASS_MAP)
class BaseModel(BaseModule, metaclass=_meta_cls):
"""TorchEasyRec base model.
Args:
model_config (ModelConfig): an instance of ModelConfig.
features (list): list of features.
labels (list): list of label names.
sample_weights (list): sample weight names.
"""
def __init__(
self,
model_config: ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self._base_model_config = model_config
self._model_type = model_config.WhichOneof("model")
self._features = features
self._labels = labels
self._model_config = (
getattr(model_config, self._model_type) if self._model_type else None
)
self._metric_modules = nn.ModuleDict()
self._loss_modules = nn.ModuleDict()
if sample_weights:
self._sample_weights = sample_weights
def predict(self, batch: Batch) -> Dict[str, torch.Tensor]:
"""Predict the model.
Args:
batch (Batch): input batch data.
Return:
predictions (dict): a dict of predicted result.
"""
raise NotImplementedError
def init_loss(self) -> None:
"""Initialize loss modules."""
raise NotImplementedError
def loss(
self, predictions: Dict[str, torch.Tensor], batch: Batch
) -> Dict[str, torch.Tensor]:
"""Compute loss of the model.
Args:
predictions (dict): a dict of predicted result.
batch (Batch): input batch data.
Return:
losses (dict): a dict of loss tensor.
"""
raise NotImplementedError
def init_metric(self) -> None:
"""Initialize metric modules."""
raise NotImplementedError
def update_metric(
self,
predictions: Dict[str, torch.Tensor],
batch: Batch,
losses: Optional[Dict[str, torch.Tensor]] = None,
) -> None:
"""Update metric state.
Args:
predictions (dict): a dict of predicted result.
batch (Batch): input batch data.
losses (dict, optional): a dict of loss.
"""
raise NotImplementedError
def compute_metric(self) -> Dict[str, torch.Tensor]:
"""Compute metric.
Return:
metric_result (dict): a dict of metric result tensor.
"""
metric_results = {}
for metric_name, metric in self._metric_modules.items():
metric_results[metric_name] = metric.compute()
metric.reset()
return metric_results
def sparse_parameters(self) -> Iterable[nn.Parameter]:
"""Get an iterator over sparse parameters of the module."""
q = Queue()
q.put(self)
parameters_list = []
while not q.empty():
m = q.get()
if isinstance(m, EmbeddingBagCollectionInterface) or isinstance(
m, EmbeddingCollectionInterface
):
parameters_list.append(m.parameters())
else:
for child in m.children():
q.put(child)
return chain.from_iterable(parameters_list)
def forward(self, batch: Batch) -> Dict[str, torch.Tensor]:
"""Predict the model."""
return self.predict(batch)
def _init_loss_metric_impl(self, loss_cfg: LossConfig, suffix: str = "") -> None:
loss_type = loss_cfg.WhichOneof("loss")
loss_name = loss_type + suffix
self._metric_modules[loss_name] = torchmetrics.MeanMetric()
def _update_loss_metric_impl(
self,
losses: Dict[str, torch.Tensor],
batch: Batch,
label: torch.Tensor,
loss_cfg: LossConfig,
suffix: str = "",
) -> None:
loss_type = loss_cfg.WhichOneof("loss")
loss_name = loss_type + suffix
loss = losses[loss_name]
self._metric_modules[loss_name].update(loss, loss.new_tensor(label.size(0)))
def get_features_in_feature_groups(
self, feature_groups: List[FeatureGroupConfig]
) -> List[BaseFeature]:
"""Select features order by feature groups."""
name_to_feature = {x.name: x for x in self._features}
grouped_features = OrderedDict()
for feature_group in feature_groups:
for x in feature_group.feature_names:
grouped_features[x] = name_to_feature[x]
for sequence_group in feature_group.sequence_groups:
for x in sequence_group.feature_names:
grouped_features[x] = name_to_feature[x]
return list(grouped_features.values())
TRAIN_OUT_TYPE = Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Batch]
TRAIN_FWD_TYPE = Tuple[torch.Tensor, TRAIN_OUT_TYPE]
class TrainWrapper(BaseModule):
"""Model train wrapper for pipeline."""
def __init__(self, module: nn.Module) -> None:
super().__init__()
self.model = module
self.model.init_loss()
self.model.init_metric()
def forward(self, batch: Batch) -> TRAIN_FWD_TYPE:
"""Predict and compute loss.
Args:
batch (Batch): input batch data.
Return:
total_loss (Tensor): total loss.
losses (dict): a dict of loss tensor.
predictions (dict): a dict of predicted result.
batch (Batch): input batch data.
"""
predictions = self.model.predict(batch)
losses = self.model.loss(predictions, batch)
total_loss = torch.stack(list(losses.values())).sum()
losses = {k: v.detach() for k, v in losses.items()}
predictions = {k: v.detach() for k, v in predictions.items()}
return total_loss, (losses, predictions, batch)
class ScriptWrapper(BaseModule):
"""Model inference wrapper for jit.script."""
def __init__(self, module: nn.Module) -> None:
super().__init__()
self.model = module
self._features = self.model._features
self._data_parser = DataParser(self._features)
def get_batch(
self,
data: Dict[str, torch.Tensor],
# pyre-ignore [9]
device: torch.device = "cpu",
) -> Batch:
"""Get batch."""
batch = self._data_parser.to_batch(data)
batch = batch.to(device, non_blocking=True)
return batch
def forward(
self,
data: Dict[str, torch.Tensor],
# pyre-ignore [9]
device: torch.device = "cpu",
) -> Dict[str, torch.Tensor]:
"""Predict the model.
Args:
data (dict): a dict of input data for Batch.
device (torch.device): inference device.
Return:
predictions (dict): a dict of predicted result.
"""
batch = self.get_batch(data, device)
return self.model.predict(batch)
class CudaExportWrapper(ScriptWrapper):
"""Model inference wrapper for cuda export(aot/trt)."""
# pyre-ignore [14]
def forward(
self,
data: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
"""Predict the model.
Args:
data (dict): a dict of input data for Batch.
Return:
predictions (dict): a dict of predicted result.
"""
batch = self._data_parser.to_batch(data)
batch = batch.to(torch.device("cuda"), non_blocking=True)
return self.model.predict(batch)