tzrec/ops/pytorch/pt_hstu_linear.py (72 lines of code) (raw):
# Copyright (c) 2025, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# We use the hstu_linear ops from generative-recommenders a starting point.
# https://github.com/facebookresearch/generative-recommenders
# thanks to their public work.
import torch
import torch.nn.functional as F
def pytorch_norm_mul_dropout(
x: torch.Tensor,
u: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
dropout_ratio: float,
training: bool,
concat_ux: bool = False,
group_norm: bool = False,
num_heads: int = 1,
linear_dim: int = -1,
) -> torch.Tensor:
dtype = x.dtype
x = x.to(torch.float32)
u = u.to(torch.float32)
if group_norm:
y = u * F.group_norm(
x.view(-1, num_heads, linear_dim),
num_groups=num_heads,
weight=weight.to(torch.float32),
bias=bias.to(torch.float32),
eps=eps,
).view(-1, num_heads * linear_dim)
else:
y = u * F.layer_norm(
x,
normalized_shape=(x.shape[-1],),
weight=weight.to(torch.float32),
bias=bias.to(torch.float32),
eps=eps,
)
if concat_ux:
y = torch.cat([u, x, y], dim=1)
y = F.dropout(
y,
p=dropout_ratio,
training=training,
)
return y.to(dtype)
def pytorch_hstu_compute_output(
attn: torch.Tensor,
u: torch.Tensor,
x: torch.Tensor,
norm_weight: torch.Tensor,
norm_bias: torch.Tensor,
output_weight: torch.Tensor,
eps: float,
dropout_ratio: float,
training: bool,
concat_ux: bool = False,
group_norm: bool = False,
num_heads: int = 1,
linear_dim: int = -1,
) -> torch.Tensor:
dtype = x.dtype
y = pytorch_norm_mul_dropout(
x=attn,
u=u,
weight=norm_weight,
bias=norm_bias,
eps=eps,
dropout_ratio=dropout_ratio,
training=training,
concat_ux=concat_ux,
group_norm=group_norm,
num_heads=num_heads,
linear_dim=linear_dim,
)
return torch.addmm(x, y, output_weight.to(x.dtype)).to(dtype)