tzrec/ops/triton/triton_hstu_preprocess_and_attention.py (307 lines of code) (raw):

# Copyright (c) 2025, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # We use the hstu_attention ops from generative-recommenders a starting point. # https://github.com/facebookresearch/generative-recommenders # thanks to their public work. from typing import Optional, Tuple import torch from torch.nn import functional as F from tzrec.ops.triton.triton_addmm import ( triton_addmm_bwd, triton_addmm_fwd, ) from tzrec.ops.triton.triton_hstu_attention import ( triton_hstu_attention_bwd, triton_hstu_attention_fwd, ) from tzrec.ops.triton.triton_layer_norm import ( triton_weighted_layer_norm_bwd, triton_weighted_layer_norm_fwd, ) class _HSTUPreprocessAndAttentionFunction(torch.autograd.Function): @staticmethod # pyre-ignore [14] def forward( ctx, # pyre-ignore [2] x: torch.Tensor, norm_weight: torch.Tensor, norm_bias: torch.Tensor, norm_eps: float, num_heads: int, attn_dim: int, hidden_dim: int, uvqk_weight: torch.Tensor, uvqk_bias: torch.Tensor, max_seq_len: int, seq_offsets: torch.Tensor, attn_alpha: float, causal: bool, num_targets: Optional[torch.Tensor], max_attn_len: int, contextual_seq_len: int, recompute_uvqk_in_backward: bool, recompute_normed_x_in_backward: bool, sort_by_length: bool, ) -> Tuple[torch.Tensor, torch.Tensor]: normed_x, x_mean, x_rstd, BLOCK_D, num_warps = triton_weighted_layer_norm_fwd( x=x, weight=norm_weight, bias=norm_bias, eps=norm_eps, ) uvqk = triton_addmm_fwd(x=normed_x, w=uvqk_weight, y=uvqk_bias).contiguous() u, v, q, k = uvqk.split( [ hidden_dim * num_heads, hidden_dim * num_heads, attn_dim * num_heads, attn_dim * num_heads, ], dim=1, ) q = q.view(-1, num_heads, attn_dim) k = k.view(-1, num_heads, attn_dim) v = v.view(-1, num_heads, hidden_dim) silu_u = F.silu(u) sort_by_length_indices = None if sort_by_length: seq_lengths = seq_offsets[1:] - seq_offsets[:-1] _, sort_by_length_indices = torch.sort( seq_lengths, descending=True, stable=False ) out = triton_hstu_attention_fwd( N=max_seq_len, alpha=attn_alpha, q=q, k=k, v=v, seq_offsets=seq_offsets, causal=causal, num_targets=num_targets, max_attn_len=max_attn_len, contextual_seq_len=contextual_seq_len, sort_by_length_indices=sort_by_length_indices, ) # update ctx saved_tensors = [ x, norm_weight, norm_bias, x_mean, x_rstd, uvqk_weight, seq_offsets, ] if num_targets is not None: saved_tensors.append(num_targets) if not recompute_normed_x_in_backward: saved_tensors.append(normed_x) if recompute_uvqk_in_backward: saved_tensors.append(uvqk_bias) else: saved_tensors.append(uvqk) if sort_by_length: saved_tensors.append(sort_by_length_indices) ctx.save_for_backward(*saved_tensors) ctx.attn_alpha = attn_alpha ctx.causal = causal ctx.has_multiple_targets = num_targets is not None ctx.max_seq_len = max_seq_len ctx.max_attn_len = max_attn_len ctx.recompute_normed_x_in_backward = recompute_normed_x_in_backward ctx.recompute_uvqk_in_backward = recompute_uvqk_in_backward ctx.hidden_dim = hidden_dim ctx.attn_dim = attn_dim ctx.num_heads = num_heads ctx.uvqk_bias_1d = uvqk_bias.dim() == 1 ctx.norm_eps = norm_eps ctx.norm_BLOCK_D = BLOCK_D ctx.norm_num_warps = num_warps ctx.contextual_seq_len = contextual_seq_len ctx.sort_by_length = sort_by_length return silu_u, out @staticmethod # pyre-ignore[14] def backward( ctx, # pyre-ignore[2] dsilu_u: torch.Tensor, dout: torch.Tensor, ) -> Tuple[ torch.Tensor, # d_x torch.Tensor, # d_norm_weight torch.Tensor, # d_norm_bias None, None, None, None, torch.Tensor, # d_uvqk_weight torch.Tensor, # d_uvqk_bias None, None, None, None, None, None, None, None, None, None, ]: x, norm_weight, norm_bias, x_mean, x_rstd, uvqk_weight, seq_offsets = ( ctx.saved_tensors[:7] ) idx = 7 if ctx.has_multiple_targets: num_targets = ctx.saved_tensors[idx] idx += 1 else: num_targets = None if ctx.recompute_normed_x_in_backward: normed_x, _, _, _, _ = triton_weighted_layer_norm_fwd( x=x, weight=norm_weight, bias=norm_bias, eps=ctx.norm_eps, mean=x_mean, rstd=x_rstd, ) else: normed_x = ctx.saved_tensors[idx] idx += 1 if ctx.recompute_uvqk_in_backward: uvqk_bias = ctx.saved_tensors[idx] uvqk = triton_addmm_fwd(x=normed_x, w=uvqk_weight, y=uvqk_bias) idx += 1 else: uvqk = ctx.saved_tensors[idx] idx += 1 if ctx.sort_by_length: sort_by_length_indices = ctx.saved_tensors[idx] else: sort_by_length_indices = None duvqk = torch.empty_like(uvqk) du, dv, dq, dk = duvqk.split( [ ctx.hidden_dim * ctx.num_heads, ctx.hidden_dim * ctx.num_heads, ctx.attn_dim * ctx.num_heads, ctx.attn_dim * ctx.num_heads, ], dim=1, ) u, v, q, k = uvqk.split( [ ctx.hidden_dim * ctx.num_heads, ctx.hidden_dim * ctx.num_heads, ctx.attn_dim * ctx.num_heads, ctx.attn_dim * ctx.num_heads, ], dim=1, ) q = q.view(-1, ctx.num_heads, ctx.attn_dim) k = k.view(-1, ctx.num_heads, ctx.attn_dim) v = v.view(-1, ctx.num_heads, ctx.hidden_dim) dq = dq.view(-1, ctx.num_heads, ctx.attn_dim) dk = dk.view(-1, ctx.num_heads, ctx.attn_dim) dv = dv.view(-1, ctx.num_heads, ctx.hidden_dim) # Note: the two operations below update duvqk in place ( _dq, _dk, _dv, ) = triton_hstu_attention_bwd( dout=dout, q=q, k=k, v=v, dq=dq, dk=dk, dv=dv, seq_offsets=seq_offsets, num_targets=num_targets, N=ctx.max_seq_len, max_attn_len=ctx.max_attn_len, alpha=ctx.attn_alpha, causal=ctx.causal, contextual_seq_len=ctx.contextual_seq_len, sort_by_length_indices=sort_by_length_indices, ) if dq.data_ptr() != _dq.data_ptr(): dq.copy_(_dq) if dk.data_ptr() != _dk.data_ptr(): dk.copy_(_dk) if dv.data_ptr() != _dv.data_ptr(): dv.copy_(_dv) torch.ops.aten.silu_backward(dsilu_u, u, grad_input=du) d_normed_x, d_uvqk_weight, d_uvqk_bias = triton_addmm_bwd( x=normed_x, w=uvqk_weight, dz=duvqk, is_y_1d=ctx.uvqk_bias_1d, ) d_x, d_norm_weight, d_norm_bias = triton_weighted_layer_norm_bwd( dy=d_normed_x, x=x, weight=norm_weight, bias=norm_bias, mean=x_mean, rstd=x_rstd, learnable=True, eps=ctx.norm_eps, BLOCK_D=ctx.norm_BLOCK_D, num_warps=ctx.norm_num_warps, ) # pyre-ignore[7] return ( d_x, d_norm_weight, d_norm_bias, None, None, None, None, d_uvqk_weight, d_uvqk_bias, None, None, None, None, None, None, None, None, None, None, ) def triton_hstu_preprocess_and_attention( x: torch.Tensor, norm_weight: torch.Tensor, norm_bias: torch.Tensor, norm_eps: float, num_heads: int, attn_dim: int, hidden_dim: int, uvqk_weight: torch.Tensor, uvqk_bias: torch.Tensor, max_seq_len: int, seq_offsets: torch.Tensor, attn_alpha: float, causal: bool, num_targets: Optional[torch.Tensor], max_attn_len: int = 0, contextual_seq_len: int = 0, recompute_uvqk_in_backward: bool = False, recompute_normed_x_in_backward: bool = False, sort_by_length: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: return _HSTUPreprocessAndAttentionFunction.apply( x, norm_weight, norm_bias, norm_eps, num_heads, attn_dim, hidden_dim, uvqk_weight, uvqk_bias, max_seq_len, seq_offsets, attn_alpha, causal, num_targets, max_attn_len, contextual_seq_len, recompute_uvqk_in_backward, recompute_normed_x_in_backward, sort_by_length, )