tzrec/modules/hstu.py (338 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import math
from typing import Callable, Optional, Tuple
import torch
import torch.nn.functional as F
TIMESTAMPS_KEY = "timestamps"
class RelativeAttentionBiasModule(torch.nn.Module):
"""Relative Attention Bias Module for transformer-based architectures.
This module computes relative positional biases for attention mechanisms,
allowing the model to consider relative positions between tokens in the sequence.
Implements learnable relative position embeddings that can be
added to attention scores.
Inherits from:
torch.nn.Module: Base PyTorch Module class
Note:
The relative attention bias is typically added to the attention scores
before the softmax operation in the attention mechanism.
"""
@abc.abstractmethod
def forward(
self,
all_timestamps: torch.Tensor,
) -> torch.Tensor:
"""Calculate bias with timestamps.
Args:
all_timestamps: [B, N] x int64
Returns:
torch.float tensor broadcastable to [B, N, N]
"""
pass
class RelativePositionalBias(RelativeAttentionBiasModule):
"""Implements relative positional bias for attention mechanisms.
This class provides learnable position-based attention biases based on the relative
positions of elements in a sequence, up to a maximum sequence length.
Args:
max_seq_len (int): Maximum sequence length supported by this bias module.
"""
def __init__(self, max_seq_len: int) -> None:
super().__init__()
self._max_seq_len: int = max_seq_len
self._w = torch.nn.Parameter(
torch.empty(2 * max_seq_len - 1).normal_(mean=0, std=0.02),
)
def forward(
self,
all_timestamps: torch.Tensor,
) -> torch.Tensor:
"""Computes relative positional biases for attention.
This method generates position-based attention biases based on
relative positions, ignoring the actual timestamps provided
(as this implementation only cares about
relative positions, not temporal information).
Args:
all_timestamps: Tensor of shape [B, N] containing int64 timestamps
(unused in this implementation)
Returns:
torch.Tensor: Attention bias tensor broadcastable to shape [B, N, N]
"""
del all_timestamps
n: int = self._max_seq_len
t = F.pad(self._w[: 2 * n - 1], [0, n]).repeat(n)
t = t[..., :-n].reshape(1, n, 3 * n - 2)
r = (2 * n - 1) // 2
return t[..., r:-r]
class RelativeBucketedTimeAndPositionBasedBias(RelativeAttentionBiasModule):
"""Bucketizes timespans based on ts(next-item) - ts(current-item)."""
def __init__(
self,
max_seq_len: int,
num_buckets: int,
bucketization_fn: Callable[[torch.Tensor], torch.Tensor],
) -> None:
super().__init__()
self._max_seq_len: int = max_seq_len
self._ts_w = torch.nn.Parameter(
torch.empty(num_buckets + 1).normal_(mean=0, std=0.02),
)
self._pos_w = torch.nn.Parameter(
torch.empty(2 * max_seq_len - 1).normal_(mean=0, std=0.02),
)
self._num_buckets: int = num_buckets
self._bucketization_fn: Callable[[torch.Tensor], torch.Tensor] = (
bucketization_fn
)
def forward(
self,
all_timestamps: torch.Tensor,
) -> torch.Tensor:
"""Forward function.
Args:
all_timestamps: (B, N).
Returns:
(B, N, N).
"""
B = all_timestamps.size(0)
N = self._max_seq_len
t = F.pad(self._pos_w[: 2 * N - 1], [0, N]).repeat(N)
t = t[..., :-N].reshape(1, N, 3 * N - 2)
r = (2 * N - 1) // 2
# [B, N + 1] to simplify tensor manipulations.
ext_timestamps = torch.cat(
[all_timestamps, all_timestamps[:, N - 1 : N]], dim=1
)
# causal masking. Otherwise [:, :-1] - [:, 1:] works
bucketed_timestamps = torch.clamp(
self._bucketization_fn(
ext_timestamps[:, 1:].unsqueeze(2) - ext_timestamps[:, :-1].unsqueeze(1)
),
min=0,
max=self._num_buckets,
).detach()
rel_pos_bias = t[:, :, r:-r]
rel_ts_bias = torch.index_select(
self._ts_w, dim=0, index=bucketed_timestamps.view(-1)
).view(B, N, N)
return rel_pos_bias + rel_ts_bias
HSTUCacheState = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
def _hstu_attention_maybe_from_cache(
num_heads: int,
attention_dim: int,
linear_dim: int,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cached_q: Optional[torch.Tensor],
cached_k: Optional[torch.Tensor],
delta_x_offsets: Optional[Tuple[torch.Tensor, torch.Tensor]],
x_offsets: torch.Tensor,
all_timestamps: Optional[torch.Tensor],
invalid_attn_mask: torch.Tensor,
rel_attn_bias: RelativeAttentionBiasModule,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
B: int = x_offsets.size(0) - 1
n: int = invalid_attn_mask.size(-1)
if delta_x_offsets is not None:
padded_q, padded_k = cached_q, cached_k
flattened_offsets = delta_x_offsets[1] + torch.arange(
start=0,
end=B * n,
step=n,
device=delta_x_offsets[1].device,
dtype=delta_x_offsets[1].dtype,
)
assert isinstance(padded_q, torch.Tensor)
assert isinstance(padded_k, torch.Tensor)
padded_q = (
padded_q.view(B * n, -1)
.index_copy_(
dim=0,
index=flattened_offsets,
source=q,
)
.view(B, n, -1)
)
padded_k = (
padded_k.view(B * n, -1)
.index_copy_(
dim=0,
index=flattened_offsets,
source=k,
)
.view(B, n, -1)
)
else:
padded_q = torch.ops.fbgemm.jagged_to_padded_dense(
values=q, offsets=[x_offsets], max_lengths=[n], padding_value=0.0
)
padded_k = torch.ops.fbgemm.jagged_to_padded_dense(
values=k, offsets=[x_offsets], max_lengths=[n], padding_value=0.0
)
qk_attn = torch.einsum(
"bnhd,bmhd->bhnm",
padded_q.view(B, n, num_heads, attention_dim),
padded_k.view(B, n, num_heads, attention_dim),
)
if all_timestamps is not None:
qk_attn = qk_attn + rel_attn_bias(all_timestamps).unsqueeze(1)
qk_attn = F.silu(qk_attn) / n
qk_attn = qk_attn * invalid_attn_mask.unsqueeze(0).unsqueeze(0)
attn_output = torch.ops.fbgemm.dense_to_jagged(
torch.einsum(
"bhnm,bmhd->bnhd",
qk_attn,
torch.ops.fbgemm.jagged_to_padded_dense(v, [x_offsets], [n]).reshape(
B, n, num_heads, linear_dim
),
).reshape(B, n, num_heads * linear_dim),
[x_offsets],
)[0]
return attn_output, padded_q, padded_k
class SequentialTransductionUnitJagged(torch.nn.Module):
"""A jagged sequential transduction unit for variable-length sequences.
This module processes jagged (variable-length) sequences using a
combination of attention mechanisms and linear transformations.
It supports various normalization strategies and attention bias configurations.
Args:
embedding_dim (int): Dimension of input embeddings
linear_hidden_dim (int): Dimension of hidden linear layers
attention_dim (int): Dimension of attention mechanism
dropout_ratio (float): Dropout probability for linear layers
attn_dropout_ratio (float): Dropout probability for attention
num_heads (int): Number of attention heads
linear_activation (str):
Activation function for linear layers ('silu' or 'none')
relative_attention_bias_module (Optional[RelativeAttentionBiasModule]):
Module for relative position biases
normalization (str, optional):
Normalization strategy. Defaults to "rel_bias".
Options: "rel_bias", "hstu_rel_bias", "softmax_rel_bias"
linear_config (str, optional):
Linear layer configuration. Defaults to "uvqk".
concat_ua (bool, optional):
Whether to concatenate u and a in output. Defaults to False.
epsilon (float, optional):
Small constant for numerical stability. Defaults to 1e-6.
max_length (Optional[int], optional):
Maximum sequence length. Defaults to None.
Attributes:
_embedding_dim (int): Dimension of input embeddings
_linear_dim (int): Dimension of hidden linear layers
_attention_dim (int): Dimension of attention mechanism
_num_heads (int): Number of attention heads
_rel_attn_bias (Optional[RelativeAttentionBiasModule]):
Module for relative position biases
_normalization (str): Normalization strategy
_linear_config (str): Linear layer configuration
_concat_ua (bool): Whether to concatenate u and a in output
_eps (float): Small constant for numerical stability
Note:
This implementation supports caching for efficient sequential processing and
handles jagged sequences through FBGEMM operations for dense-jagged conversions.
"""
def __init__(
self,
embedding_dim: int,
linear_hidden_dim: int,
attention_dim: int,
dropout_ratio: float,
attn_dropout_ratio: float,
num_heads: int,
linear_activation: str,
relative_attention_bias_module: Optional[RelativeAttentionBiasModule] = None,
normalization: str = "rel_bias",
linear_config: str = "uvqk",
concat_ua: bool = False,
epsilon: float = 1e-6,
max_length: Optional[int] = None,
) -> None:
super().__init__()
self._embedding_dim: int = embedding_dim
self._linear_dim: int = linear_hidden_dim
self._attention_dim: int = attention_dim
self._dropout_ratio: float = dropout_ratio
self._attn_dropout_ratio: float = attn_dropout_ratio
self._num_heads: int = num_heads
self._rel_attn_bias: Optional[RelativeAttentionBiasModule] = (
relative_attention_bias_module
)
self._normalization: str = normalization
self._linear_config: str = linear_config
if self._linear_config == "uvqk":
self._uvqk: torch.nn.Parameter = torch.nn.Parameter(
torch.empty(
(
embedding_dim,
linear_hidden_dim * 2 * num_heads
+ attention_dim * num_heads * 2,
)
).normal_(mean=0, std=0.02),
)
else:
raise ValueError(f"Unknown linear_config {self._linear_config}")
self._linear_activation: str = linear_activation
self._concat_ua: bool = concat_ua
self._o = torch.nn.Linear(
in_features=linear_hidden_dim * num_heads * (3 if concat_ua else 1),
out_features=embedding_dim,
)
torch.nn.init.xavier_uniform_(self._o.weight)
self._eps: float = epsilon
def _norm_input(self, x: torch.Tensor) -> torch.Tensor:
return F.layer_norm(x, normalized_shape=[self._embedding_dim], eps=self._eps)
def _norm_attn_output(self, x: torch.Tensor) -> torch.Tensor:
return F.layer_norm(
x, normalized_shape=[self._linear_dim * self._num_heads], eps=self._eps
)
def forward( # pyre-ignore [3]
self,
x: torch.Tensor,
x_offsets: torch.Tensor,
all_timestamps: Optional[torch.Tensor],
invalid_attn_mask: torch.Tensor,
delta_x_offsets: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
cache: Optional[HSTUCacheState] = None,
return_cache_states: bool = False,
):
r"""Forward function.
Args:
x: (\sum_i N_i, D) x float.
x_offsets: (B + 1) x int32.
all_timestamps: optional (B, N) x int64.
invalid_attn_mask: (B, N, N) x float, each element in {0, 1}.
delta_x_offsets: optional 2-tuple ((B,) x int32, (B,) x int32).
For the 1st element in the tuple, each element is in [0, x_offsets[-1]).
For the 2nd element in the tuple, each element is in [0, N).
cache: Optional 4-tuple of (v, padded_q, padded_k, output) from prior runs,
where all except padded_q, padded_k are jagged.
return_cache_states: Return cache status or not.
Returns:
x' = f(x), (\sum_i N_i, D) x float.
"""
n: int = invalid_attn_mask.size(-1)
cached_q = None
cached_k = None
if delta_x_offsets is not None:
# In this case, for all the following code, x, u, v, q, k
# become restricted to [delta_x_offsets[0], :].
assert cache is not None
x = x[delta_x_offsets[0], :]
cached_v, cached_q, cached_k, cached_outputs = cache
normed_x = self._norm_input(x)
if self._linear_config == "uvqk":
batched_mm_output = torch.mm(normed_x, self._uvqk)
if self._linear_activation == "silu":
batched_mm_output = F.silu(batched_mm_output)
elif self._linear_activation == "none":
batched_mm_output = batched_mm_output
u, v, q, k = torch.split(
batched_mm_output,
[
self._linear_dim * self._num_heads,
self._linear_dim * self._num_heads,
self._attention_dim * self._num_heads,
self._attention_dim * self._num_heads,
],
dim=1,
)
else:
raise ValueError(f"Unknown self._linear_config {self._linear_config}")
if delta_x_offsets is not None:
v = cached_v.index_copy_(dim=0, index=delta_x_offsets[0], source=v)
B: int = x_offsets.size(0) - 1
if self._normalization == "rel_bias" or self._normalization == "hstu_rel_bias":
assert self._rel_attn_bias is not None
attn_output, padded_q, padded_k = _hstu_attention_maybe_from_cache(
num_heads=self._num_heads,
attention_dim=self._attention_dim,
linear_dim=self._linear_dim,
q=q,
k=k,
v=v,
cached_q=cached_q,
cached_k=cached_k,
delta_x_offsets=delta_x_offsets,
x_offsets=x_offsets,
all_timestamps=all_timestamps,
invalid_attn_mask=invalid_attn_mask,
rel_attn_bias=self._rel_attn_bias,
)
elif self._normalization == "softmax_rel_bias":
if delta_x_offsets is not None:
B = x_offsets.size(0) - 1
padded_q, padded_k = cached_q, cached_k
flattened_offsets = delta_x_offsets[1] + torch.arange(
start=0,
end=B * n,
step=n,
device=delta_x_offsets[1].device,
dtype=delta_x_offsets[1].dtype,
)
assert padded_q is not None
assert padded_k is not None
padded_q = (
padded_q.view(B * n, -1)
.index_copy_(
dim=0,
index=flattened_offsets,
source=q,
)
.view(B, n, -1)
)
padded_k = (
padded_k.view(B * n, -1)
.index_copy_(
dim=0,
index=flattened_offsets,
source=k,
)
.view(B, n, -1)
)
else:
padded_q = torch.ops.fbgemm.jagged_to_padded_dense(
values=q, offsets=[x_offsets], max_lengths=[n], padding_value=0.0
)
padded_k = torch.ops.fbgemm.jagged_to_padded_dense(
values=k, offsets=[x_offsets], max_lengths=[n], padding_value=0.0
)
qk_attn = torch.einsum("bnd,bmd->bnm", padded_q, padded_k)
if self._rel_attn_bias is not None:
qk_attn = qk_attn + self._rel_attn_bias(all_timestamps)
qk_attn = F.softmax(qk_attn / math.sqrt(self._attention_dim), dim=-1)
qk_attn = qk_attn * invalid_attn_mask
attn_output = torch.ops.fbgemm.dense_to_jagged(
torch.bmm(
qk_attn,
torch.ops.fbgemm.jagged_to_padded_dense(v, [x_offsets], [n]),
),
[x_offsets],
)[0]
else:
raise ValueError(f"Unknown normalization method {self._normalization}")
attn_output = (
attn_output
if delta_x_offsets is None
else attn_output[delta_x_offsets[0], :]
)
if self._concat_ua:
a = self._norm_attn_output(attn_output)
o_input = torch.cat([u, a, u * a], dim=-1)
else:
o_input = u * self._norm_attn_output(attn_output)
new_outputs = (
self._o(
F.dropout(
o_input,
p=self._dropout_ratio,
training=self.training,
)
)
+ x
)
if delta_x_offsets is not None:
new_outputs = cached_outputs.index_copy_(
dim=0, index=delta_x_offsets[0], source=new_outputs
)
if return_cache_states and delta_x_offsets is None:
v = v.contiguous()
return new_outputs, (v, padded_q, padded_k, new_outputs)