tzrec/ops/hstu_compute.py (195 lines of code) (raw):
# Copyright (c) 2025, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# We use the hstu_compute ops from generative-recommenders a starting point.
# https://github.com/facebookresearch/generative-recommenders
# thanks to their public work.
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from torch.fx._symbolic_trace import is_fx_tracing
from tzrec.ops import Kernel
from tzrec.ops.hstu_attention import hstu_mha
from tzrec.ops.layer_norm import layer_norm
from tzrec.ops.mm import addmm
from tzrec.ops.pytorch.pt_hstu_linear import (
pytorch_hstu_compute_output,
)
from tzrec.ops.triton.triton_hstu_linear import (
triton_hstu_compute_output,
)
from tzrec.ops.triton.triton_hstu_preprocess_and_attention import (
triton_hstu_preprocess_and_attention,
)
def hstu_compute_uqvk(
x: torch.Tensor,
norm_weight: torch.Tensor,
norm_bias: torch.Tensor,
norm_eps: float,
num_heads: int,
attn_dim: int,
hidden_dim: int,
uvqk_weight: torch.Tensor,
uvqk_bias: torch.Tensor,
kernel: Kernel = Kernel.PYTORCH,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
normed_x = layer_norm(
x,
weight=norm_weight,
bias=norm_bias,
eps=norm_eps,
kernel=kernel,
)
# NOTE: for AMD training, we go with torch.addmm instead of the triton
# version before Triton on AMD achieves on-par perf with NV GPU.
if torch.version.hip and kernel == Kernel.TRITON:
uvqk = torch.addmm(uvqk_bias, normed_x, uvqk_weight)
else:
uvqk = addmm(uvqk_bias, normed_x, uvqk_weight, kernel)
u, v, q, k = torch.split(
uvqk,
[
hidden_dim * num_heads,
hidden_dim * num_heads,
attn_dim * num_heads,
attn_dim * num_heads,
],
dim=1,
)
u = F.silu(u)
q = q.view(-1, num_heads, attn_dim)
k = k.view(-1, num_heads, attn_dim)
v = v.view(-1, num_heads, hidden_dim)
return u, q, k, v
def hstu_compute_output(
attn: torch.Tensor,
u: torch.Tensor,
x: torch.Tensor,
norm_weight: torch.Tensor,
norm_bias: torch.Tensor,
norm_eps: float,
output_weight: torch.Tensor,
num_heads: int,
linear_dim: int,
dropout_ratio: float,
training: bool,
concat_ux: bool,
group_norm: bool,
recompute_y_in_backward: bool,
kernel: Kernel = Kernel.PYTORCH,
) -> torch.Tensor:
if kernel == Kernel.TRITON:
return triton_hstu_compute_output(
attn=attn,
u=u,
x=x,
norm_weight=norm_weight,
norm_bias=norm_bias,
output_weight=output_weight,
eps=norm_eps,
dropout_ratio=dropout_ratio,
training=training,
concat_ux=concat_ux,
group_norm=group_norm,
num_heads=num_heads,
linear_dim=linear_dim,
seed=None,
recompute_y_in_backward=recompute_y_in_backward,
)
else:
return pytorch_hstu_compute_output(
attn=attn,
u=u,
x=x,
norm_weight=norm_weight,
norm_bias=norm_bias,
output_weight=output_weight,
eps=norm_eps,
dropout_ratio=dropout_ratio,
training=training,
concat_ux=concat_ux,
group_norm=group_norm,
num_heads=num_heads,
linear_dim=linear_dim,
)
def hstu_preprocess_and_attention(
x: torch.Tensor,
norm_weight: torch.Tensor,
norm_bias: torch.Tensor,
norm_eps: float,
num_heads: int,
attn_dim: int,
hidden_dim: int,
uvqk_weight: torch.Tensor,
uvqk_bias: torch.Tensor,
max_seq_len: int,
seq_offsets: torch.Tensor,
attn_alpha: float,
causal: bool,
num_targets: Optional[torch.Tensor],
max_attn_len: int,
contextual_seq_len: int,
recompute_uvqk_in_backward: bool,
recompute_normed_x_in_backward: bool,
sort_by_length: bool,
prefill: bool = False,
kernel: Kernel = Kernel.PYTORCH,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
if not is_fx_tracing():
torch._assert(max_seq_len > 0, "max_seq_len must be larger than 0")
torch._assert(x.dim() == 2, "x must be 2-D")
torch._assert(
x.shape[1] == uvqk_weight.shape[0],
"x.shape[1] must equal uvqk_weight.shape[0]",
)
torch._assert(
uvqk_weight.shape[1] == 2 * num_heads * (hidden_dim + attn_dim),
"uvqk_weight.shape[1] must equal 2 * num_heads * (hidden_dim + attn_dim)",
)
if kernel == Kernel.TRITON and prefill is False:
u, attn_output = triton_hstu_preprocess_and_attention(
x=x,
norm_weight=norm_weight,
norm_bias=norm_bias,
norm_eps=norm_eps,
num_heads=num_heads,
attn_dim=attn_dim,
hidden_dim=hidden_dim,
uvqk_weight=uvqk_weight,
uvqk_bias=uvqk_bias,
max_seq_len=max_seq_len,
seq_offsets=seq_offsets,
attn_alpha=attn_alpha,
causal=causal,
num_targets=num_targets,
max_attn_len=max_attn_len,
contextual_seq_len=contextual_seq_len,
recompute_uvqk_in_backward=recompute_uvqk_in_backward,
recompute_normed_x_in_backward=recompute_normed_x_in_backward,
sort_by_length=sort_by_length,
)
attn_output = attn_output.view(-1, hidden_dim * num_heads)
k = None
v = None
else:
u, q, k, v = hstu_compute_uqvk(
x=x,
norm_weight=norm_weight,
norm_bias=norm_bias,
norm_eps=norm_eps,
num_heads=num_heads,
attn_dim=attn_dim,
hidden_dim=hidden_dim,
uvqk_weight=uvqk_weight,
uvqk_bias=uvqk_bias,
kernel=kernel,
)
attn_output = hstu_mha(
max_seq_len=max_seq_len,
alpha=attn_alpha,
q=q,
k=k,
v=v,
seq_offsets=seq_offsets,
causal=causal,
dropout_pr=0.0,
training=False,
num_targets=num_targets,
max_attn_len=max_attn_len,
contextual_seq_len=contextual_seq_len,
sort_by_length=sort_by_length,
kernel=kernel,
).view(-1, hidden_dim * num_heads)
return u, attn_output, k, v