tzrec/modules/dense_embedding_collection.py (195 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from collections import OrderedDict
from enum import Enum
from math import sqrt
from typing import Dict, List, Optional
import torch
from torch import Tensor, nn
from torchrec.sparse.jagged_tensor import KeyedTensor
class DenseEmbeddingType(Enum):
"""Dense Embedding Type."""
MLP = 0
AUTO_DIS = 1
class DenseEmbeddingConfig:
"""DenseEmbeddingConfig base class."""
def __init__(
self,
embedding_dim: int,
feature_names: List[str],
embedding_type: DenseEmbeddingType,
) -> None:
self.embedding_dim = embedding_dim
self.feature_names = feature_names
self.embedding_type = embedding_type
@property
def group_key(self) -> str:
"""Config group key."""
raise NotImplementedError(
"Subclasses of DenseEmbeddingConfig should implement this."
)
class MLPDenseEmbeddingConfig(DenseEmbeddingConfig):
"""MLPDenseEmbeddingConfig class."""
def __init__(self, embedding_dim: int, feature_names: List[str]) -> None:
super().__init__(embedding_dim, feature_names, DenseEmbeddingType.MLP)
@property
def group_key(self) -> str:
"""Config group key."""
return f"mlp#{self.embedding_dim}"
class AutoDisEmbeddingConfig(DenseEmbeddingConfig):
"""AutoDisEmbeddingConfig class."""
def __init__(
self,
embedding_dim: int,
n_channels: int,
temperature: float,
keep_prob: float,
feature_names: List[str],
) -> None:
super().__init__(embedding_dim, feature_names, DenseEmbeddingType.AUTO_DIS)
self.n_channels = n_channels
self.temperature = temperature
self.keep_prob = keep_prob
@property
def group_key(self) -> str:
"""Config group key."""
return (
f"autodis#{self.embedding_dim}#{self.n_channels}#{self.keep_prob:.6f}"
f"#{self.temperature:.6f}".replace(".", "_")
)
class AutoDisEmbedding(nn.Module):
"""An Embedding Learning Framework for Numerical Features in CTR Prediction.
https://arxiv.org/pdf/2012.08986
"""
def __init__(
self,
num_dense_feature: int,
embedding_dim: int,
num_channels: int,
temperature: float = 0.1,
keep_prob: float = 0.8,
device: Optional[torch.device] = None,
) -> None:
super().__init__()
self.num_dense_feature = num_dense_feature
self.embedding_dim = embedding_dim
self.keep_prob = keep_prob
self.temperature = temperature
self.num_channels = num_channels
self.meta_emb = nn.Parameter(
torch.randn(num_dense_feature, num_channels, embedding_dim, device=device)
)
# glorot normal initialization, std = sqrt(2 /(1+c))
self.proj_w = nn.Parameter(
torch.randn(num_dense_feature, num_channels, device=device)
* sqrt(2 / (1 + num_channels))
)
# glorot normal initialization, std = sqrt(2 /(c+c))
self.proj_m = nn.Parameter(
torch.randn(num_dense_feature, num_channels, num_channels, device=device)
* sqrt(1 / num_channels)
)
self.leaky_relu = nn.LeakyReLU()
self.softmax = nn.Softmax(dim=-1)
self.reset_parameters()
def reset_parameters(
self,
) -> None:
"""Reset the parameters."""
nn.init.normal_(self.meta_emb, 0, 1.0)
nn.init.normal_(self.proj_w, 0, sqrt(2 / (1 + self.num_channels)))
nn.init.normal_(self.proj_m, 0, sqrt(1 / self.num_channels))
def forward(self, dense_input: Tensor) -> Tensor:
"""Forward the module.
Args:
dense_input (Tensor): dense input feature, shape = [b, n],
where b is batch_size, n is the number of dense features
Returns:
atde (Tensor): Tensor of autodis embedding.
"""
hidden = self.leaky_relu(
torch.einsum("nc,bn->bnc", self.proj_w, dense_input)
) # shape [b, n, c]
x_bar = (
torch.einsum("nij,bnj->bni", self.proj_m, hidden) + self.keep_prob * hidden
) # shape [b, n, c]
x_hat = self.softmax(x_bar / self.temperature) # shape = [b, n, c]
emb = torch.einsum("ncd,bnc->bnd", self.meta_emb, x_hat) # shape = [b, n, d]
output = emb.reshape(
(-1, self.num_dense_feature * self.embedding_dim)
) # shape = [b, n * d]
return output
class MLPEmbedding(nn.Module):
"""MLP embedding for dense features."""
def __init__(
self,
num_dense_feature: int,
embedding_dim: int,
device: Optional[torch.device] = None,
) -> None:
super().__init__()
self.num_dense_feature = num_dense_feature
self.embedding_dim = embedding_dim
self.proj_w = nn.Parameter(
torch.randn(num_dense_feature, embedding_dim)
* sqrt(2 / (1 + embedding_dim)) # glorot normal initialization
)
def forward(self, input: Tensor) -> Tensor:
"""Forward the module.
Args:
input (Tensor): dense input feature, shape = [b, n],
where b is batch_size, n is the number of dense features.
Returns:
output (Tensor): Tensor of mlp embedding, shape = [b, n * d],
where d is the embedding_dim.
"""
return torch.einsum("ni,bn->bni", self.proj_w, input).reshape(
(-1, self.num_dense_feature * self.embedding_dim)
)
def merge_same_config_features(
conf_list: List[DenseEmbeddingConfig],
) -> List[DenseEmbeddingConfig]:
"""Merge features with same group_key configs.
For example:
conf_list: List[DenseEmbeddingConfig]
= [
{'embedding_dim': 128, 'n_channels': 32, 'keep_prob': 0.5,
'temperature': 1.0, 'feature_names': ['f1']},
{'embedding_dim': 128, 'n_channels': 32, 'keep_prob': 0.5,
'temperature': 1.0, 'feature_names': ['f2', 'f3']},
{'embedding_dim': 256, 'n_channels': 64, 'keep_prob': 0.5,
'temperature': 0.8, 'feature_names': ['f4']}
]
will be merged as:
[
{'embedding_dim': 128, 'n_channels': 32, 'keep_prob': 0.5,
'temperature': 1.0, 'feature_names': ['f1', 'f2', 'f3']},
{'embedding_dim': 256, 'n_channels': 64, 'keep_prob': 0.5,
'temperature': 0.8, 'feature_names': ['f4']}
]
"""
unique_dict = {}
for conf in conf_list:
if conf.group_key in unique_dict:
unique_dict[conf.group_key].feature_names.extend(conf.feature_names)
else:
unique_dict[conf.group_key] = copy.copy(conf)
for key in unique_dict:
unique_dict[key].feature_names = sorted(
list(set(unique_dict[key].feature_names))
)
unique_list = list(unique_dict.values())
return unique_list
class DenseEmbeddingCollection(nn.Module):
"""DenseEmbeddingCollection module.
Args:
emb_dense_configs (list): list of DenseEmbeddingConfig.
device (torch.device): embedding device, default is meta.
raw_dense_feature_to_dim (dict): a feature_name to feature dim dict for
raw dense features do not need to do embedding. If specified,
the returned keyed tensor will also include these features.
"""
def __init__(
self,
emb_dense_configs: List[DenseEmbeddingConfig],
device: Optional[torch.device] = None,
raw_dense_feature_to_dim: Optional[Dict[str, int]] = None,
) -> None:
super(DenseEmbeddingCollection, self).__init__()
self.emb_dense_configs = emb_dense_configs
self._raw_dense_feature_to_dim = raw_dense_feature_to_dim
self.grouped_configs = merge_same_config_features(emb_dense_configs)
self.all_dense_names = []
self.all_dense_dims = []
self._group_to_feature_names = OrderedDict()
self.dense_embs = nn.ModuleDict()
for conf in self.grouped_configs:
feature_names = conf.feature_names
embedding_dim = conf.embedding_dim
if conf.embedding_type == DenseEmbeddingType.MLP:
self.dense_embs[conf.group_key] = MLPEmbedding(
num_dense_feature=len(feature_names),
embedding_dim=embedding_dim,
device=device,
)
elif conf.embedding_type == DenseEmbeddingType.AUTO_DIS:
self.dense_embs[conf.group_key] = AutoDisEmbedding(
num_dense_feature=len(feature_names),
embedding_dim=embedding_dim,
num_channels=conf.n_channels,
temperature=conf.temperature,
keep_prob=conf.keep_prob,
device=device,
)
self.all_dense_names.extend(feature_names)
self.all_dense_dims.extend([embedding_dim] * len(feature_names))
self._group_to_feature_names[conf.group_key] = feature_names
if raw_dense_feature_to_dim is not None and len(raw_dense_feature_to_dim) > 0:
feature_names, feature_dims = (
list(raw_dense_feature_to_dim.keys()),
list(raw_dense_feature_to_dim.values()),
)
self._group_to_feature_names["__raw_dense_group__"] = feature_names
self.all_dense_names.extend(feature_names)
self.all_dense_dims.extend(feature_dims)
def forward(self, dense_feature: KeyedTensor) -> KeyedTensor:
"""Forward the module."""
grouped_features = KeyedTensor.regroup_as_dict(
[dense_feature],
list(self._group_to_feature_names.values()),
list(self._group_to_feature_names.keys()),
)
emb_list = []
for group_key, emb_module in self.dense_embs.items():
emb_list.append(emb_module(grouped_features[group_key]))
if self._raw_dense_feature_to_dim:
emb_list.append(grouped_features["__raw_dense_group__"])
kt_dense_emb = KeyedTensor(
keys=self.all_dense_names,
length_per_key=self.all_dense_dims,
values=torch.cat(emb_list, dim=1),
key_dim=1,
)
return kt_dense_emb