tzrec/ops/pytorch/pt_hstu_attention.py (200 lines of code) (raw):

# Copyright (c) 2025, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # We use the hstu_attention ops from generative-recommenders a starting point. # https://github.com/facebookresearch/generative-recommenders # thanks to their public work. from typing import Optional, Tuple import torch import torch.nn.functional as F @torch.fx.wrap def _get_valid_attn_mask( device: torch.device, causal: bool, N: int, seq_lengths: torch.Tensor, num_targets: Optional[torch.Tensor] = None, max_attn_len: int = 0, contextual_seq_len: int = 0, min_full_attn_seq_len: int = 0, ) -> torch.Tensor: ids = torch.arange(0, N, device=device).view(1, N) max_ids = seq_lengths.view(-1, 1, 1) if contextual_seq_len > 0: ids = ids - contextual_seq_len + 1 ids = torch.clamp(ids, min=0) max_ids = max_ids - contextual_seq_len + 1 if num_targets is not None: max_ids = max_ids - num_targets.view(-1, 1, 1) ids = torch.clamp( ids, max=max_ids, ) row_ids = ids.view(-1, N, 1).expand(-1, N, N) col_ids = ids.view(-1, 1, N).expand(-1, N, N) else: row_ids = ids.view(N, 1).expand(N, N) col_ids = row_ids.t() row_ids = row_ids.view(1, N, N) col_ids = col_ids.view(1, N, N) row_col_dist = row_ids - col_ids valid_attn_mask = torch.eye(N, device=device, dtype=torch.bool).view(1, N, N) if not causal: row_col_dist = torch.where(row_col_dist > 0, row_col_dist, -row_col_dist) valid_attn_mask = torch.logical_or(valid_attn_mask, row_col_dist > 0) if max_attn_len > 0: if min_full_attn_seq_len > 0: valid_attn_mask = torch.logical_and( valid_attn_mask, torch.logical_or( row_col_dist <= max_attn_len, row_ids >= max_ids - min_full_attn_seq_len, ), ) else: valid_attn_mask = torch.logical_and( valid_attn_mask, row_col_dist <= max_attn_len ) if contextual_seq_len > 0: valid_attn_mask = torch.logical_or( valid_attn_mask, torch.logical_and(row_ids == 0, col_ids < max_ids) ) return valid_attn_mask def _pad_qkv( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_offsets: torch.Tensor, N: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: L, H, D = q.shape V = v.shape[2] padded_q = ( torch.ops.fbgemm.jagged_to_padded_dense( values=q.reshape(L, H * D), offsets=[seq_offsets], max_lengths=[N], padding_value=0.0, ) .view(-1, N, H, D) .transpose(1, 2) ) # [B, H, N, A] padded_k = ( torch.ops.fbgemm.jagged_to_padded_dense( values=k.reshape(L, H * D), offsets=[seq_offsets], max_lengths=[N], padding_value=0.0, ) .view(-1, N, H, D) .transpose(1, 2) ) # [B, H, N, A] padded_v = ( torch.ops.fbgemm.jagged_to_padded_dense( values=v.reshape(L, H * V), offsets=[seq_offsets], max_lengths=[N], padding_value=0.0, ) .view(-1, N, H, V) .transpose(1, 2) ) # [B, H, N, D] return padded_q, padded_k, padded_v @torch.fx.wrap def pytorch_hstu_mha( max_seq_len: int, alpha: float, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_offsets: torch.Tensor, causal: bool = True, dropout_pr: float = 0.0, training: bool = True, num_targets: Optional[torch.Tensor] = None, max_attn_len: int = 0, contextual_seq_len: int = 0, min_full_attn_seq_len: int = 0, ) -> torch.Tensor: L, H, _ = q.shape V = v.shape[2] q, k, v = _pad_qkv( q, k, v, seq_offsets, max_seq_len ) # [B, H, N, D) and [B, H, N, V] qk_attn = torch.einsum("bhxa,bhya->bhxy", q, k) * alpha qk_attn = F.silu(qk_attn) / max_seq_len valid_attn_mask = _get_valid_attn_mask( device=q.device, causal=causal, N=max_seq_len, seq_lengths=seq_offsets[1:] - seq_offsets[:-1], num_targets=num_targets, max_attn_len=max_attn_len, contextual_seq_len=contextual_seq_len, min_full_attn_seq_len=min_full_attn_seq_len, ) # raise NotImplementedError(valid_attn_mask[0, :, :].to(torch.int32)) qk_attn = qk_attn * valid_attn_mask.unsqueeze(1) if dropout_pr > 0.0: qk_attn = F.dropout(qk_attn, p=dropout_pr, training=training) attn_dense = torch.einsum("bhxd,bhdv->bhxv", qk_attn, v) # [B, H, N, V] return torch.ops.fbgemm.dense_to_jagged( attn_dense.transpose(1, 2).flatten(2, 3), # [B, N, H, V]->[B, N, H * V] [seq_offsets], L, )[0].view(L, H, V) @torch.fx.wrap def pytorch_cached_hstu_mha( max_seq_len: int, alpha: float, delta_q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_offsets: torch.Tensor, num_targets: Optional[torch.Tensor] = None, max_attn_len: int = 0, contextual_seq_len: int = 0, ) -> torch.Tensor: L, H, D = delta_q.shape _, _, V = v.shape B = seq_offsets.size(0) - 1 delta_size = L // B delta_q = delta_q.view(B, -1, H, D).transpose(1, 2) full_k = ( torch.ops.fbgemm.jagged_to_padded_dense( values=k.reshape(-1, H * D), offsets=[seq_offsets], max_lengths=[max_seq_len], padding_value=0.0, ) .view(B, -1, H, D) .transpose(1, 2) ) full_v = ( torch.ops.fbgemm.jagged_to_padded_dense( values=v.reshape(-1, H * V), offsets=[seq_offsets], max_lengths=[max_seq_len], padding_value=0.0, ) .view(B, -1, H, V) .transpose(1, 2) ) qk_attn = torch.einsum("bhxa,bhya->bhxy", delta_q, full_k) * alpha qk_attn = F.silu(qk_attn) / max_seq_len full_valid_attn_mask = _get_valid_attn_mask( device=delta_q.device, causal=True, N=max_seq_len, seq_lengths=seq_offsets[1:] - seq_offsets[:-1], num_targets=num_targets, max_attn_len=max_attn_len, contextual_seq_len=contextual_seq_len, ) seq_lengths = seq_offsets[1:] - seq_offsets[:-1] mask = torch.arange(max_seq_len, device=delta_q.device).view(1, -1) mask = torch.logical_and( mask >= (seq_lengths - delta_size).view(-1, 1), mask < seq_lengths.view(-1, 1), ) valid_attn_mask = ( full_valid_attn_mask.expand(B, -1, -1) .flatten(0, 1)[mask.view(-1), :] .view(-1, delta_size, max_seq_len) ) qk_attn = qk_attn * valid_attn_mask.unsqueeze(1) attn_output = torch.einsum("bhxd,bhdv->bhxv", qk_attn, full_v) return attn_output.transpose(1, 2).reshape(-1, H, V)