tzrec/ops/hstu_attention.py (141 lines of code) (raw):
# Copyright (c) 2025, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# We use the hstu_attention ops from generative-recommenders a starting point.
# https://github.com/facebookresearch/generative-recommenders
# thanks to their public work.
from typing import Optional
import torch
from torch.fx._symbolic_trace import is_fx_tracing
from tzrec.ops import Kernel
from tzrec.ops.pytorch.pt_hstu_attention import (
pytorch_cached_hstu_mha,
pytorch_hstu_mha,
)
from tzrec.ops.triton.triton_hstu_attention import (
triton_cached_hstu_mha,
triton_hstu_mha,
)
from tzrec.ops.utils import switch_to_contiguous_if_needed
def hstu_mha(
max_seq_len: int,
alpha: float,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_offsets: torch.Tensor,
causal: bool = True,
dropout_pr: float = 0.0,
training: bool = True,
num_targets: Optional[torch.Tensor] = None,
max_attn_len: int = 0,
contextual_seq_len: int = 0,
min_full_attn_seq_len: int = 0,
sort_by_length: bool = False,
kernel: Kernel = Kernel.PYTORCH,
) -> torch.Tensor:
_, H, _ = q.shape
if not is_fx_tracing():
torch._assert(max_seq_len > 0, "max_seq_len must be larger than 0")
torch._assert(q.dim() == 3, "q must be 3-D")
torch._assert(k.shape == q.shape, "k must be the same shape as q")
torch._assert(v.dim() == 3, "v must be 3-D")
torch._assert(v.shape[0] == q.shape[0], "wrong v shape[0]")
torch._assert(v.shape[1] == H, "wrong v shape[1]")
torch._assert(causal, "only support causal attention")
if kernel in [Kernel.TRITON]:
if not is_fx_tracing() and kernel == Kernel.TRITON:
torch._assert(q.is_cuda, "q must be CUDA tensor")
torch._assert(k.is_cuda, "k must be CUDA tensor")
torch._assert(v.is_cuda, "v must be CUDA tensor")
torch._assert(seq_offsets.is_cuda, "seq_offsets must be CUDA tensor")
torch._assert(dropout_pr < 1e-6, "dropout for triton path not implemented")
torch._assert(
min_full_attn_seq_len == 0, "min_full_attn_seq_len not implemented"
)
q = switch_to_contiguous_if_needed(q)
k = switch_to_contiguous_if_needed(k)
v = switch_to_contiguous_if_needed(v)
seq_offsets = seq_offsets.contiguous()
if kernel == Kernel.TRITON:
return triton_hstu_mha(
N=max_seq_len,
alpha=alpha,
q=q,
k=k,
v=v,
seq_offsets=seq_offsets,
causal=causal,
num_targets=num_targets,
max_attn_len=max_attn_len,
contextual_seq_len=contextual_seq_len,
sort_by_length=sort_by_length,
)
else:
return pytorch_hstu_mha(
max_seq_len=max_seq_len,
alpha=alpha,
q=q,
k=k,
v=v,
seq_offsets=seq_offsets,
causal=causal,
dropout_pr=dropout_pr,
training=training,
num_targets=num_targets,
max_attn_len=max_attn_len,
contextual_seq_len=contextual_seq_len,
min_full_attn_seq_len=min_full_attn_seq_len,
)
def delta_hstu_mha(
max_seq_len: int,
alpha: float,
delta_q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_offsets: torch.Tensor,
num_targets: Optional[torch.Tensor] = None,
max_attn_len: int = 0,
contextual_seq_len: int = 0,
kernel: Kernel = Kernel.PYTORCH,
) -> torch.Tensor:
L, H, D = delta_q.shape
B = seq_offsets.size(0) - 1
DeltaSize = L // B # NOQA
if not is_fx_tracing():
torch._assert(max_seq_len > 0, "max_seq_len must be larger than 0")
torch._assert(delta_q.dim() == 3, "delta_q must be 3-D")
torch._assert(L % B == 0, "delta_q must be padded")
torch._assert(k.dim() == 3, "k must be 3-D")
torch._assert(k.shape[1] == H, "wrong k shape[1]")
torch._assert(k.shape[2] == D, "wrong k shape[2]")
torch._assert(v.dim() == 3, "v must be 3-D")
torch._assert(v.shape[1] == H, "wrong v shape[1]")
if kernel in [Kernel.TRITON]:
if not is_fx_tracing() and kernel == Kernel.TRITON:
torch._assert(delta_q.is_cuda, "q must be CUDA tensor")
torch._assert(seq_offsets.is_cuda, "seq_offsets must be CUDA tensor")
if num_targets is not None:
torch._assert(num_targets.is_cuda, "num_targets must be CUDA tensor")
seq_offsets = seq_offsets.contiguous()
delta_q = switch_to_contiguous_if_needed(delta_q)
k = switch_to_contiguous_if_needed(k)
v = switch_to_contiguous_if_needed(v)
if kernel == Kernel.TRITON:
return triton_cached_hstu_mha(
N=max_seq_len,
alpha=alpha,
delta_q=delta_q,
k=k,
v=v,
seq_offsets=seq_offsets,
num_targets=num_targets,
max_attn_len=max_attn_len,
contextual_seq_len=contextual_seq_len,
)
else:
return pytorch_cached_hstu_mha(
max_seq_len=max_seq_len,
alpha=alpha,
delta_q=delta_q,
k=k,
v=v,
seq_offsets=seq_offsets,
num_targets=num_targets,
max_attn_len=max_attn_len,
contextual_seq_len=contextual_seq_len,
)