tzrec/modules/sequence.py (359 lines of code) (raw):
# Copyright (c) 2024-2025, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# HSTUEncoder is from generative-recommenders,
# https://github.com/facebookresearch/generative-recommenders,
# thanks to their public work.
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from tzrec.modules.hstu import (
HSTUCacheState,
RelativeBucketedTimeAndPositionBasedBias,
SequentialTransductionUnitJagged,
)
from tzrec.modules.mlp import MLP
from tzrec.protos.seq_encoder_pb2 import SeqEncoderConfig
from tzrec.utils import config_util
from tzrec.utils.fx_util import fx_arange
from tzrec.utils.load_class import get_register_class_meta
torch.fx.wrap(fx_arange)
_SEQ_ENCODER_CLASS_MAP = {}
_meta_cls = get_register_class_meta(_SEQ_ENCODER_CLASS_MAP)
class SequenceEncoder(nn.Module, metaclass=_meta_cls):
"""Base module of sequence encoder."""
def __init__(self, input: str) -> None:
super().__init__()
self._input = input
def input(self) -> str:
"""Get sequence encoder input group name."""
return self._input
def output_dim(self) -> int:
"""Output dimension of the module."""
raise NotImplementedError
class DINEncoder(SequenceEncoder):
"""DIN sequence encoder.
Args:
sequence_dim (int): sequence tensor channel dimension.
query_dim (int): query tensor channel dimension.
input(str): input feature group name.
attn_mlp (dict): target attention MLP module parameters.
"""
def __init__(
self,
sequence_dim: int,
query_dim: int,
input: str,
attn_mlp: Dict[str, Any],
**kwargs: Optional[Dict[str, Any]],
) -> None:
super().__init__(input)
self._query_dim = query_dim
self._sequence_dim = sequence_dim
if self._query_dim > self._sequence_dim:
raise ValueError("query_dim > sequence_dim not supported yet.")
self.mlp = MLP(in_features=sequence_dim * 4, dim=3, **attn_mlp)
self.linear = nn.Linear(self.mlp.hidden_units[-1], 1)
self._query_name = f"{input}.query"
self._sequence_name = f"{input}.sequence"
self._sequence_length_name = f"{input}.sequence_length"
def output_dim(self) -> int:
"""Output dimension of the module."""
return self._sequence_dim
def forward(self, sequence_embedded: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Forward the module."""
query = sequence_embedded[self._query_name]
sequence = sequence_embedded[self._sequence_name]
sequence_length = sequence_embedded[self._sequence_length_name]
max_seq_length = sequence.size(1)
sequence_mask = fx_arange(
max_seq_length, device=sequence_length.device
).unsqueeze(0) < sequence_length.unsqueeze(1)
if self._query_dim < self._sequence_dim:
query = F.pad(query, (0, self._sequence_dim - self._query_dim))
queries = query.unsqueeze(1).expand(-1, max_seq_length, -1)
attn_input = torch.cat(
[queries, sequence, queries - sequence, queries * sequence], dim=-1
)
attn_output = self.mlp(attn_input)
attn_output = self.linear(attn_output)
attn_output = attn_output.transpose(1, 2)
padding = torch.ones_like(attn_output) * (-(2**32) + 1)
scores = torch.where(sequence_mask.unsqueeze(1), attn_output, padding)
scores = F.softmax(scores, dim=-1)
return torch.matmul(scores, sequence).squeeze(1)
class SimpleAttention(SequenceEncoder):
"""Simple attention encoder."""
def __init__(
self,
sequence_dim: int,
query_dim: int,
input: str,
**kwargs: Optional[Dict[str, Any]],
) -> None:
super().__init__(input)
self._sequence_dim = sequence_dim
self._query_dim = query_dim
self._query_name = f"{input}.query"
self._sequence_name = f"{input}.sequence"
self._sequence_length_name = f"{input}.sequence_length"
def output_dim(self) -> int:
"""Output dimension of the module."""
return self._sequence_dim
def forward(self, sequence_embedded: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Forward the module."""
query = sequence_embedded[self._query_name]
sequence = sequence_embedded[self._sequence_name]
sequence_length = sequence_embedded[self._sequence_length_name]
max_seq_length = sequence.size(1)
sequence_mask = fx_arange(max_seq_length, sequence_length.device).unsqueeze(
0
) < sequence_length.unsqueeze(1)
attn_output = torch.matmul(sequence, query.unsqueeze(2)).squeeze(2)
padding = torch.ones_like(attn_output) * (-(2**32) + 1)
scores = torch.where(sequence_mask, attn_output, padding)
scores = F.softmax(scores, dim=-1)
return torch.matmul(scores.unsqueeze(1), sequence).squeeze(1)
class PoolingEncoder(SequenceEncoder):
"""Mean/Sum pooling sequence encoder.
Args:
sequence_dim (int): sequence tensor channel dimension.
input (str): input feature group name.
pooling_type (str): pooling type, sum or mean.
"""
def __init__(
self,
sequence_dim: int,
input: str,
pooling_type: str = "mean",
**kwargs: Optional[Dict[str, Any]],
) -> None:
super().__init__(input)
self._sequence_dim = sequence_dim
self._pooling_type = pooling_type
assert self._pooling_type in [
"sum",
"mean",
], "only sum|mean pooling type supported now."
self._sequence_name = f"{input}.sequence"
self._sequence_length_name = f"{input}.sequence_length"
def output_dim(self) -> int:
"""Output dimension of the module."""
return self._sequence_dim
def forward(self, sequence_embedded: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Forward the module."""
sequence = sequence_embedded[self._sequence_name]
feature = torch.sum(sequence, dim=1)
if self._pooling_type == "mean":
sequence_length = sequence_embedded[self._sequence_length_name]
sequence_length = torch.max(
sequence_length, torch.ones_like(sequence_length)
)
feature = feature / sequence_length.unsqueeze(1)
return feature
class MultiWindowDINEncoder(SequenceEncoder):
"""Multi Window DIN module.
Args:
sequence_dim (int): sequence tensor channel dimension.
query_dim (int): query tensor channel dimension.
input(str): input feature group name.
windows_len (list): time windows len.
attn_mlp (dict): target attention MLP module parameters.
"""
def __init__(
self,
sequence_dim: int,
query_dim: int,
input: str,
windows_len: List[int],
attn_mlp: Dict[str, Any],
**kwargs: Optional[Dict[str, Any]],
) -> None:
super().__init__(input)
self._query_dim = query_dim
self._sequence_dim = sequence_dim
self._windows_len = windows_len
if self._query_dim > self._sequence_dim:
raise ValueError("query_dim > sequence_dim not supported yet.")
self.register_buffer("windows_len", torch.tensor(windows_len))
self.register_buffer(
"cumsum_windows_len", torch.tensor(np.cumsum([0] + list(windows_len)[:-1]))
)
self._sum_windows_len = sum(windows_len)
self.mlp = MLP(in_features=sequence_dim * 3, dim=3, **attn_mlp)
self.linear = nn.Linear(self.mlp.hidden_units[-1], 1)
self.active = nn.PReLU()
self._query_name = f"{input}.query"
self._sequence_name = f"{input}.sequence"
self._sequence_length_name = f"{input}.sequence_length"
def output_dim(self) -> int:
"""Output dimension of the module."""
return self._sequence_dim * (len(self._windows_len) + 1)
def forward(self, sequence_embedded: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Forward the module."""
query = sequence_embedded[self._query_name]
sequence = sequence_embedded[self._sequence_name]
sequence_length = sequence_embedded[self._sequence_length_name]
max_seq_length = sequence.size(1)
sequence_mask = fx_arange(
max_seq_length, device=sequence_length.device
).unsqueeze(0) < sequence_length.unsqueeze(1)
if self._query_dim < self._sequence_dim:
query = F.pad(query, (0, self._sequence_dim - self._query_dim))
queries = query.unsqueeze(1).expand(-1, max_seq_length, -1) # [B, T, C]
attn_input = torch.cat([sequence, queries * sequence, queries], dim=-1)
attn_output = self.mlp(attn_input)
attn_output = self.linear(attn_output)
attn_output = self.active(attn_output) # [B, T, 1]
att_sequences = attn_output * sequence_mask.unsqueeze(2) * sequence
pad = (0, 0, 0, self._sum_windows_len - max_seq_length)
pad_att_sequences = F.pad(att_sequences, pad).transpose(0, 1)
result = torch.segment_reduce(
pad_att_sequences, reduce="sum", lengths=self.windows_len, axis=0
).transpose(0, 1) # [B, L, C]
segment_length = torch.min(
sequence_length.unsqueeze(1) - self.cumsum_windows_len.unsqueeze(0),
self.windows_len,
)
result = result / torch.max(
segment_length, torch.ones_like(segment_length)
).unsqueeze(2)
return torch.cat([result, query.unsqueeze(1)], dim=1).reshape(
result.shape[0], -1
) # [B, (L+1)*C]
class HSTUEncoder(SequenceEncoder):
"""HSTU sequence encoder.
Args:
sequence_dim (int): sequence tensor channel dimension.
query_dim (int): query tensor channel dimension.
input(str): input feature group name.
attn_mlp (dict): target attention MLP module parameters.
"""
def __init__(
self,
sequence_dim: int,
attn_dim: int,
linear_dim: int,
input: str,
max_seq_length: int,
pos_dropout_rate: float = 0.2,
linear_dropout_rate: float = 0.2,
attn_dropout_rate: float = 0.0,
normalization: str = "rel_bias",
linear_activation: str = "silu",
linear_config: str = "uvqk",
num_heads: int = 1,
num_blocks: int = 2,
max_output_len: int = 10,
time_bucket_size: int = 128,
**kwargs: Optional[Dict[str, Any]],
) -> None:
super().__init__(input)
self._sequence_dim = sequence_dim
self._attn_dim = attn_dim
self._linear_dim = linear_dim
self._max_seq_length = max_seq_length
self._query_name = f"{input}.query"
self._sequence_name = f"{input}.sequence"
self._sequence_length_name = f"{input}.sequence_length"
max_output_len = max_output_len + 1 # for target
self.position_embed = nn.Embedding(
self._max_seq_length + max_output_len, self._sequence_dim
)
self.dropout_rate = pos_dropout_rate
self.enable_relative_attention_bias = True
self.autocast_dtype = None
self._attention_layers: nn.ModuleList = nn.ModuleList(
modules=[
SequentialTransductionUnitJagged(
embedding_dim=self._sequence_dim,
linear_hidden_dim=self._linear_dim,
attention_dim=self._attn_dim,
normalization=normalization,
linear_config=linear_config,
linear_activation=linear_activation,
num_heads=num_heads,
relative_attention_bias_module=(
RelativeBucketedTimeAndPositionBasedBias(
max_seq_len=max_seq_length + max_output_len,
num_buckets=time_bucket_size,
bucketization_fn=lambda x: (
torch.log(torch.abs(x).clamp(min=1)) / 0.301
).long(),
)
if self.enable_relative_attention_bias
else None
),
dropout_ratio=linear_dropout_rate,
attn_dropout_ratio=attn_dropout_rate,
concat_ua=False,
)
for _ in range(num_blocks)
]
)
self.register_buffer(
"_attn_mask",
torch.triu(
torch.ones(
(
self._max_seq_length + max_output_len,
self._max_seq_length + max_output_len,
),
dtype=torch.bool,
),
diagonal=1,
),
)
self._autocast_dtype = None
def output_dim(self) -> int:
"""Output dimension of the module."""
return self._sequence_dim
def forward(self, sequence_embedded: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Forward the module."""
sequence = sequence_embedded[self._sequence_name] # B, N, E
sequence_length = sequence_embedded[self._sequence_length_name] # N
# max_seq_length = sequence.size(1)
float_dtype = sequence.dtype
# Add positional embeddings and apply dropout
positions = (
fx_arange(sequence.size(1), device=sequence.device)
.unsqueeze(0)
.expand(sequence.size(0), -1)
)
sequence = sequence * (self._sequence_dim**0.5) + self.position_embed(positions)
sequence = F.dropout(sequence, p=self.dropout_rate, training=self.training)
sequence_mask = fx_arange(
sequence.size(1), device=sequence_length.device
).unsqueeze(0) < sequence_length.unsqueeze(1)
sequence = sequence * sequence_mask.unsqueeze(-1).to(float_dtype)
invalid_attn_mask = 1.0 - self._attn_mask.to(float_dtype)
sequence_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(
sequence_length
)
sequence = torch.ops.fbgemm.dense_to_jagged(sequence, [sequence_offsets])[0]
all_timestamps = None
jagged_x, cache_states = self.jagged_forward(
x=sequence,
x_offsets=sequence_offsets,
all_timestamps=all_timestamps,
invalid_attn_mask=invalid_attn_mask,
delta_x_offsets=None,
cache=None,
return_cache_states=False,
)
# post processing: L2 Normalization
output_embeddings = jagged_x
output_embeddings = output_embeddings[..., : self._sequence_dim]
output_embeddings = output_embeddings / torch.clamp(
torch.linalg.norm(output_embeddings, ord=None, dim=-1, keepdim=True),
min=1e-6,
)
if not self.training:
output_embeddings = self.get_current_embeddings(
sequence_length, output_embeddings
)
return output_embeddings
def jagged_forward(
self,
x: torch.Tensor,
x_offsets: torch.Tensor,
all_timestamps: Optional[torch.Tensor],
invalid_attn_mask: torch.Tensor,
delta_x_offsets: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
cache: Optional[List[HSTUCacheState]] = None,
return_cache_states: bool = False,
) -> Tuple[torch.Tensor, List[HSTUCacheState]]:
r"""Jagged forward.
Args:
x: (\sum_i N_i, D) x float
x_offsets: (B + 1) x int32
all_timestamps: (B, 1 + N) x int64
invalid_attn_mask: (B, N, N) x float, each element in {0, 1}
delta_x_offsets: offsets for x
cache: cache contents
return_cache_states: bool. True if we should return cache states.
Returns:
x' = f(x), (\sum_i N_i, D) x float
"""
cache_states: List[HSTUCacheState] = []
with torch.autocast(
"cuda",
enabled=self._autocast_dtype is not None,
dtype=self._autocast_dtype or torch.float16,
):
for i, layer in enumerate(self._attention_layers):
x, cache_states_i = layer(
x=x,
x_offsets=x_offsets,
all_timestamps=all_timestamps,
invalid_attn_mask=invalid_attn_mask,
delta_x_offsets=delta_x_offsets,
cache=cache[i] if cache is not None else None,
return_cache_states=return_cache_states,
)
if return_cache_states:
cache_states.append(cache_states_i)
return x, cache_states
def get_current_embeddings(
self,
lengths: torch.Tensor,
encoded_embeddings: torch.Tensor,
) -> torch.Tensor:
"""Get the embeddings of the last past_id as the current embeds.
Args:
lengths: (B,) x int
encoded_embeddings: (B, N, D,) x float
Returns:
(B, D,) x float, where [i, :] == encoded_embeddings[i, lengths[i] - 1, :]
"""
offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
indices = offsets[1:] - 1
return encoded_embeddings[indices]
def create_seq_encoder(
seq_encoder_config: SeqEncoderConfig, group_total_dim: Dict[str, int]
) -> SequenceEncoder:
"""Build seq encoder model..
Args:
seq_encoder_config: a SeqEncoderConfig.group_total_dim.
group_total_dim: a dict contain all seq group dim info.
Return:
model: a SequenceEncoder cls.
"""
model_cls_name = config_util.which_msg(seq_encoder_config, "seq_module")
# pyre-ignore [16]
model_cls = SequenceEncoder.create_class(model_cls_name)
seq_type = seq_encoder_config.WhichOneof("seq_module")
seq_type_config = getattr(seq_encoder_config, seq_type)
input_name = seq_type_config.input
query_dim = group_total_dim[f"{input_name}.query"]
sequence_dim = group_total_dim[f"{input_name}.sequence"]
seq_config_dict = config_util.config_to_kwargs(seq_type_config)
seq_config_dict["sequence_dim"] = sequence_dim
seq_config_dict["query_dim"] = query_dim
seq_encoder = model_cls(**seq_config_dict)
return seq_encoder