tzrec/ops/triton/triton_hstu_attention.py (1,640 lines of code) (raw):

# Copyright (c) 2025, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple import torch import triton import triton.language as tl from triton.runtime.autotuner import autotune as triton_autotune from tzrec.ops.utils import ( autotune_max_seq_len, prev_power_of_2, switch_to_contiguous_if_needed, ) try: from triton.language.extra.libdevice import fast_dividef # @manual=//triton:triton except ImportError: try: # @manual=//triton:triton from triton.language.extra.cuda.libdevice import fast_dividef except ImportError: # pyre-ignore[21] from triton.language.math import fast_dividef # @manual=//triton:triton torch.fx.wrap(prev_power_of_2) def _get_fw_configs() -> List[triton.Config]: # noqa: C901 configs = [] if torch.version.hip: for BLOCK_M in [32, 64, 128]: for BLOCK_N in [32, 64]: for num_stages in [1, 2]: for num_warps in [4, 8]: for matrix_instr_nonkdim in [16, 32]: configs.append( triton.Config( { "BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "matrix_instr_nonkdim": matrix_instr_nonkdim, "waves_per_eu": 0, "kpack": 2, }, num_stages=num_stages, num_warps=num_warps, ) ) else: configs = [ triton.Config( {"BLOCK_M": 16, "BLOCK_N": 32}, num_stages=2, num_warps=2, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 32}, num_stages=2, num_warps=2, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 32}, num_stages=4, num_warps=2, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 32}, num_stages=2, num_warps=4, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 32}, num_stages=4, num_warps=4, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 64}, num_stages=2, num_warps=4, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 64}, num_stages=4, num_warps=4, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 64}, num_stages=4, num_warps=8, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 128}, num_stages=2, num_warps=4, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 128}, num_stages=2, num_warps=8, ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 32}, num_stages=4, num_warps=2, ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 32}, num_stages=2, num_warps=4, ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 32}, num_stages=4, num_warps=4, ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 32}, num_stages=2, num_warps=8, ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 64}, num_stages=2, num_warps=2, ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 64}, num_stages=2, num_warps=4, ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 64}, num_stages=4, num_warps=4, ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 64}, num_stages=4, num_warps=8, ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 32}, num_stages=2, num_warps=2, ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 32}, num_stages=4, num_warps=2, ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 32}, num_stages=2, num_warps=4, ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 32}, num_stages=4, num_warps=4, ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 32}, num_stages=2, num_warps=8, ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 32}, num_stages=4, num_warps=8, ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=2, num_warps=4, ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=2, num_warps=8, ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=4, num_warps=8, ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 128}, num_stages=4, num_warps=4, ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 128}, num_stages=2, num_warps=8, ), ] return configs @triton.jit def _hstu_attn_fwd_one_block( # noqa: C901 start_n, seq_len, offs_m, offs_n, q, K_block_ptr, V_block_ptr, n_targets, alpha, MAX_SEQ_LEN, contextual_seq_len, max_attn_len, CAUSAL: tl.constexpr, HAS_MULTIPLE_TARGETS: tl.constexpr, HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, HAS_MAX_ATTN_LEN: tl.constexpr, ALLOW_TF32: tl.constexpr, BLOCK_N: tl.constexpr, ): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- k = tl.load(K_block_ptr, boundary_check=(1,), padding_option="zero") qk = tl.dot(q, k, allow_tf32=ALLOW_TF32) * alpha invalid_mask = offs_m[:, None] == offs_n[None, :] max_ids = seq_len if HAS_CONTEXTUAL_SEQ_LEN: offs_m = offs_m - contextual_seq_len + 1 offs_m = tl.where( offs_m > 0, offs_m, 0, ) offs_n = offs_n - contextual_seq_len + 1 offs_n = tl.where( offs_n > 0, offs_n, 0, ) max_ids = max_ids - contextual_seq_len + 1 if HAS_MULTIPLE_TARGETS: max_ids = max_ids - n_targets offs_m = tl.where( offs_m < max_ids, offs_m, max_ids, ) offs_n = tl.where( offs_n < max_ids, offs_n, max_ids, ) offs_m_minus_n = offs_m[:, None] - offs_n[None, :] if not CAUSAL: offs_m_minus_n = tl.where(offs_m_minus_n > 0, offs_m_minus_n, -offs_m_minus_n) invalid_mask = invalid_mask or (offs_m_minus_n > 0) if HAS_MAX_ATTN_LEN: invalid_mask = invalid_mask and offs_m_minus_n <= max_attn_len if HAS_CONTEXTUAL_SEQ_LEN: invalid_mask = invalid_mask or ( offs_m[:, None] == 0 and offs_n[None, :] < max_ids ) # pyre-fixme[16]: Module `math` has no attribute `fast_dividef`. silu = fast_dividef(qk, 1.0 + tl.exp(-qk)) * (1.0 / MAX_SEQ_LEN) silu = tl.where(invalid_mask, silu, 0) v = tl.load(V_block_ptr, boundary_check=(0,), padding_option="zero") silu = silu.to(v.dtype) return tl.dot(silu, v, allow_tf32=ALLOW_TF32) @triton.jit def _hstu_attn_fwd_compute( # noqa C901 Q, K, V, seq_offsets, num_targets, Out, stride_qm, stride_qh, stride_kn, stride_kh, stride_vn, stride_vh, stride_om, stride_oh, alpha, MAX_SEQ_LEN, DeltaSize, contextual_seq_len, max_attn_len, off_z, off_h, pid, CAUSAL: tl.constexpr, HAS_MULTIPLE_TARGETS: tl.constexpr, IS_DELTA_Q: tl.constexpr, ALLOW_TF32: tl.constexpr, BLOCK_D_Q: tl.constexpr, BLOCK_D_V: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, HAS_MAX_ATTN_LEN: tl.constexpr, ): seq_start = tl.load(seq_offsets + off_z).to(tl.int64) off_h = off_h.to(tl.int64) off_z = off_z.to(tl.int64) seq_end = tl.load(seq_offsets + off_z + 1) seq_len = (seq_end - seq_start).to(tl.int32) if IS_DELTA_Q: start_m_delta = pid * BLOCK_M start_m = (start_m_delta + seq_len - DeltaSize).to(tl.int32) else: start_m_delta = 0 start_m = pid * BLOCK_M if start_m < seq_len: if HAS_MULTIPLE_TARGETS: n_targets = tl.load(num_targets + off_z).to(tl.int32) else: n_targets = None # initialize offsets offs_m = start_m + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) if IS_DELTA_Q: Q_block_ptr = tl.make_block_ptr( base=Q + off_h * stride_qh + off_z * DeltaSize * stride_qm, shape=(DeltaSize, BLOCK_D_Q), strides=(stride_qm, 1), offsets=(start_m_delta, 0), block_shape=(BLOCK_M, BLOCK_D_Q), order=(1, 0), ) else: Q_block_ptr = tl.make_block_ptr( base=Q + off_h * stride_qh + seq_start * stride_qm, shape=(seq_len, BLOCK_D_Q), strides=(stride_qm, 1), offsets=(start_m, 0), block_shape=(BLOCK_M, BLOCK_D_Q), order=(1, 0), ) K_block_ptr = tl.make_block_ptr( base=K + off_h * stride_kh + seq_start * stride_kn, shape=(BLOCK_D_Q, seq_len), strides=(1, stride_kn), offsets=(0, 0), block_shape=(BLOCK_D_Q, BLOCK_N), order=(0, 1), ) V_block_ptr = tl.make_block_ptr( base=V + off_h * stride_vh + seq_start * stride_vn, shape=(seq_len, BLOCK_D_V), strides=(stride_vn, 1), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_D_V), order=(1, 0), ) q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option="zero") acc = tl.zeros([BLOCK_M, BLOCK_D_V], dtype=tl.float32) if CAUSAL: if HAS_MULTIPLE_TARGETS: uih_end = seq_len - n_targets else: uih_end = seq_len if HAS_CONTEXTUAL_SEQ_LEN is True and start_m < contextual_seq_len: # uih_end must be larger than start_m low = 0 high = seq_len else: low = 0 high = start_m + BLOCK_M if HAS_MAX_ATTN_LEN: if start_m > uih_end: low = uih_end - max_attn_len else: low = start_m - max_attn_len if HAS_CONTEXTUAL_SEQ_LEN: low = low if low > contextual_seq_len else 0 else: low = low if low > 0 else 0 if HAS_MULTIPLE_TARGETS: uih_end = (uih_end + BLOCK_N - 1) // BLOCK_N * BLOCK_N if uih_end < start_m: high = seq_len - n_targets else: low = 0 high = seq_len if low > 0: K_block_ptr = tl.advance(K_block_ptr, (0, low)) V_block_ptr = tl.advance(V_block_ptr, (low, 0)) end_n = low for start_n in range(low, high, BLOCK_N): acc += _hstu_attn_fwd_one_block( start_n=start_n, seq_len=seq_len, offs_m=offs_m, offs_n=offs_n + start_n, q=q, K_block_ptr=K_block_ptr, V_block_ptr=V_block_ptr, n_targets=n_targets if HAS_MULTIPLE_TARGETS else None, alpha=alpha, MAX_SEQ_LEN=MAX_SEQ_LEN, contextual_seq_len=contextual_seq_len, max_attn_len=max_attn_len, CAUSAL=CAUSAL, HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, ALLOW_TF32=ALLOW_TF32, BLOCK_N=BLOCK_N, ) K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) end_n += BLOCK_N if HAS_MULTIPLE_TARGETS and CAUSAL: # pyre-ignore[61] if uih_end < start_m: low_delta = start_m high_delta = start_m + BLOCK_M offset = (low_delta - end_n).to(tl.int32) K_block_ptr = tl.advance(K_block_ptr, (0, offset)) V_block_ptr = tl.advance(V_block_ptr, (offset, 0)) for start_delta in tl.range( low_delta, high_delta, BLOCK_N, num_stages=0 ): acc += _hstu_attn_fwd_one_block( start_n=start_delta, seq_len=seq_len, offs_m=offs_m, offs_n=offs_n + start_delta, q=q, K_block_ptr=K_block_ptr, V_block_ptr=V_block_ptr, n_targets=n_targets if HAS_MULTIPLE_TARGETS else None, alpha=alpha, MAX_SEQ_LEN=MAX_SEQ_LEN, contextual_seq_len=contextual_seq_len, max_attn_len=max_attn_len, CAUSAL=CAUSAL, HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, ALLOW_TF32=ALLOW_TF32, BLOCK_N=BLOCK_N, ) K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) if IS_DELTA_Q: start_m_delta = pid * BLOCK_M offs_m_delta = start_m_delta + tl.arange(0, BLOCK_M) offs_v_d = tl.arange(0, BLOCK_D_V) off_o = Out + off_z * DeltaSize * stride_om + off_h * stride_oh out_ptrs = off_o + offs_m_delta[:, None] * stride_om + offs_v_d[None, :] tl.store(out_ptrs, acc, mask=(offs_m_delta < DeltaSize)[:, None]) else: # rematerialize offsets to save registers start_m = pid * BLOCK_M offs_m = start_m + tl.arange(0, BLOCK_M) offs_v_d = tl.arange(0, BLOCK_D_V) off_o = Out + seq_start * stride_om + off_h * stride_oh out_ptrs = off_o + offs_m[:, None] * stride_om + offs_v_d[None, :] tl.store(out_ptrs, acc, mask=(offs_m < seq_len)[:, None]) @triton.autotune( configs=_get_fw_configs(), key=[ "AUTOTUNE_Z", "H", "AUTOTUNE_MAX_SEQ_LEN", "DimQ", "DimV", "DeltaSize", "IS_DELTA_Q", ], ) @triton.jit def _hstu_attn_fwd( # noqa C901 Q, K, V, sort_by_length_indices, seq_offsets, num_targets, Out, stride_qm, stride_qh, stride_kn, stride_kh, stride_vn, stride_vh, stride_om, stride_oh, alpha, Z, AUTOTUNE_Z, H, MAX_SEQ_LEN, AUTOTUNE_MAX_SEQ_LEN, # Quantized MAX_SEQ_LEN used as an autotuning key DimQ, DimV, DeltaSize, contextual_seq_len, max_attn_len, CAUSAL: tl.constexpr, HAS_MULTIPLE_TARGETS: tl.constexpr, IS_DELTA_Q: tl.constexpr, ALLOW_TF32: tl.constexpr, BLOCK_D_Q: tl.constexpr, BLOCK_D_V: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, HAS_MAX_ATTN_LEN: tl.constexpr, HAS_SORT_BY_LENGTH_INDICES: tl.constexpr, ): off_hz = tl.program_id(1) off_z = off_hz // H if HAS_SORT_BY_LENGTH_INDICES: off_z = tl.load(sort_by_length_indices + off_z) off_h = off_hz % H pid = tl.program_id(0) _hstu_attn_fwd_compute( Q=Q, K=K, V=V, seq_offsets=seq_offsets, num_targets=num_targets, Out=Out, stride_qm=stride_qm, stride_qh=stride_qh, stride_kn=stride_kn, stride_kh=stride_kh, stride_vn=stride_vn, stride_vh=stride_vh, stride_om=stride_om, stride_oh=stride_oh, alpha=alpha, MAX_SEQ_LEN=MAX_SEQ_LEN, DeltaSize=DeltaSize, contextual_seq_len=contextual_seq_len, max_attn_len=max_attn_len, off_z=off_z, off_h=off_h, pid=pid, CAUSAL=CAUSAL, HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, IS_DELTA_Q=IS_DELTA_Q, ALLOW_TF32=ALLOW_TF32, BLOCK_D_Q=BLOCK_D_Q, BLOCK_D_V=BLOCK_D_V, HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, ) @triton.autotune( configs=_get_fw_configs(), key=[ "AUTOTUNE_Z", "H", "AUTOTUNE_MAX_SEQ_LEN", "DimQ", "DimV", "DeltaSize", "IS_DELTA_Q", ], ) @triton.jit def _hstu_attn_fwd_persistent( # noqa C901 Q, K, V, sort_by_length_indices, seq_offsets, num_targets, Out, stride_qm, stride_qh, stride_kn, stride_kh, stride_vn, stride_vh, stride_om, stride_oh, alpha, Z, AUTOTUNE_Z, H, MAX_SEQ_LEN, AUTOTUNE_MAX_SEQ_LEN, # Quantized MAX_SEQ_LEN used as an autotuning key DimQ, DimV, DeltaSize, contextual_seq_len, max_attn_len, CAUSAL: tl.constexpr, HAS_MULTIPLE_TARGETS: tl.constexpr, IS_DELTA_Q: tl.constexpr, ALLOW_TF32: tl.constexpr, BLOCK_D_Q: tl.constexpr, BLOCK_D_V: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, HAS_MAX_ATTN_LEN: tl.constexpr, HAS_SORT_BY_LENGTH_INDICES: tl.constexpr, ): n_tile_num = tl.cdiv(MAX_SEQ_LEN, BLOCK_M) prog_id = tl.program_id(0) num_progs = tl.num_programs(0) total_tiles = n_tile_num * Z * H tiles_per_sm = total_tiles // num_progs if prog_id < total_tiles % num_progs: tiles_per_sm += 1 tile_idx = prog_id for _ in range(0, tiles_per_sm): pid = (total_tiles - tile_idx - 1) // (Z * H) off_hz = (total_tiles - tile_idx - 1) % (Z * H) off_z = off_hz // H off_h = off_hz % H _hstu_attn_fwd_compute( Q=Q, K=K, V=V, seq_offsets=seq_offsets, num_targets=num_targets, Out=Out, stride_qm=stride_qm, stride_qh=stride_qh, stride_kn=stride_kn, stride_kh=stride_kh, stride_vn=stride_vn, stride_vh=stride_vh, stride_om=stride_om, stride_oh=stride_oh, alpha=alpha, MAX_SEQ_LEN=MAX_SEQ_LEN, DeltaSize=DeltaSize, contextual_seq_len=contextual_seq_len, max_attn_len=max_attn_len, off_z=off_z, off_h=off_h, pid=pid, CAUSAL=CAUSAL, HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, IS_DELTA_Q=IS_DELTA_Q, ALLOW_TF32=ALLOW_TF32, BLOCK_D_Q=BLOCK_D_Q, BLOCK_D_V=BLOCK_D_V, HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, ) tile_idx += num_progs _hstu_attn_fwd = triton_autotune( configs=_get_fw_configs(), key=[ "AUTOTUNE_Z", "H", "AUTOTUNE_MAX_SEQ_LEN", "DimQ", "DimV", "DeltaSize", "IS_DELTA_Q", ], )(_hstu_attn_fwd.fn) _hstu_attn_fwd_persistent = triton_autotune( configs=_get_fw_configs(), key=[ "AUTOTUNE_Z", "H", "AUTOTUNE_MAX_SEQ_LEN", "DimQ", "DimV", "DeltaSize", "IS_DELTA_Q", ], )(_hstu_attn_fwd_persistent.fn) @triton.jit def _hstu_attn_bwd_one_block( # noqa C901 start_m, offs_n, offs_m, q_ptrs_trans, dq_ptrs_trans, mask_n, do_ptrs, dk, dv, k, v, pos_offs_n, seq_len, n_targets, max_ids, contextual_seq_len, max_attn_len, LOCK, stride_qm, stride_dom, stride_dqm, alpha, MAX_SEQ_LEN, CAUSAL: tl.constexpr, HAS_MULTIPLE_TARGETS: tl.constexpr, HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, HAS_MAX_ATTN_LEN: tl.constexpr, ALLOW_TF32: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ATOMIC_ADD: tl.constexpr, ): pos_offs_m = offs_m + start_m mask_m = pos_offs_m < seq_len invalid_mask_trans = pos_offs_m[None, :] == offs_n[:, None] # recompute qk and silu if HAS_CONTEXTUAL_SEQ_LEN: pos_offs_m = pos_offs_m - contextual_seq_len + 1 pos_offs_m = tl.where( pos_offs_m > 0, pos_offs_m, 0, ) if HAS_MULTIPLE_TARGETS: pos_offs_m = tl.where( pos_offs_m < max_ids, pos_offs_m, max_ids, ) q_trans = tl.load( q_ptrs_trans + start_m * stride_qm, mask=mask_m[None, :], other=0.0, ) qk_trans = tl.dot(k, q_trans, allow_tf32=ALLOW_TF32) * alpha # pyre-fixme[16]: Module `math` has no attribute `fast_dividef`. sig_trans = fast_dividef(1.0, 1.0 + tl.exp(-qk_trans)) silu_trans = qk_trans * sig_trans * (1.0 / MAX_SEQ_LEN) pos_offs_m_minus_n = pos_offs_m[None, :] - pos_offs_n[:, None] if not CAUSAL: pos_offs_m_minus_n = tl.where( pos_offs_m_minus_n > 0, pos_offs_m_minus_n, -pos_offs_m_minus_n ) invalid_mask_trans = invalid_mask_trans or (pos_offs_m_minus_n > 0) if HAS_MAX_ATTN_LEN: invalid_mask_trans = invalid_mask_trans and pos_offs_m_minus_n <= max_attn_len if HAS_CONTEXTUAL_SEQ_LEN: invalid_mask_trans = invalid_mask_trans or ( pos_offs_m[None, :] == 0 and pos_offs_n[:, None] < max_ids ) silu_trans = tl.where(invalid_mask_trans, silu_trans, 0) silu_trans = silu_trans.to(k.dtype) # compute dv do = tl.load( do_ptrs + start_m * stride_dom, mask=mask_m[:, None], other=0.0, ) dv += tl.dot(silu_trans, do, allow_tf32=ALLOW_TF32) # compute dk and dq dqk_trans = tl.dot(v, tl.trans(do), allow_tf32=ALLOW_TF32) dqk_trans = ( dqk_trans * sig_trans * (1 + qk_trans * (1 - sig_trans)) * (1.0 / MAX_SEQ_LEN) ) dqk_trans = tl.where(invalid_mask_trans, dqk_trans, 0) dqk_trans = dqk_trans.to(k.dtype) # Note: the factor `alpha` is delayed until the end of the function to reduce the cost # NOQA dk += tl.dot(dqk_trans, tl.trans(q_trans), allow_tf32=ALLOW_TF32) if ATOMIC_ADD: lock_id = start_m // BLOCK_M stride_lock = tl.cdiv(MAX_SEQ_LEN, BLOCK_M) lock = LOCK + tl.program_id(0) * stride_lock + lock_id tl.debug_barrier() # add a barrier to force sync while tl.atomic_cas(lock, 0, 1) == 1: pass dq_trans = tl.load( dq_ptrs_trans + start_m * stride_dqm, mask=mask_m[None, :], other=0.0, eviction_policy="evict_last", ) dq_trans += tl.dot(tl.trans(k), dqk_trans, allow_tf32=ALLOW_TF32) * alpha dq_trans = dq_trans.to(k.dtype) tl.store( dq_ptrs_trans + start_m * stride_dqm, dq_trans, mask=mask_m[None, :], eviction_policy="evict_last", ) if ATOMIC_ADD: tl.atomic_xchg(lock, 0) # pyre-ignore [61] return dk, dv @triton.jit def _hstu_attn_bwd_one_col_block( # noqa C901 start_n, seq_len, n_targets, contextual_seq_len, max_attn_len, Q, K, V, DOut, DQ, DK, DV, LOCK, stride_qm, stride_kn, stride_vn, stride_dom, stride_dqm, stride_dkn, stride_dvn, alpha, MAX_SEQ_LEN, CAUSAL: tl.constexpr, HAS_MULTIPLE_TARGETS: tl.constexpr, HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, HAS_MAX_ATTN_LEN: tl.constexpr, ALLOW_TF32: tl.constexpr, BLOCK_D_Q: tl.constexpr, BLOCK_D_V: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, UNROLL: tl.constexpr, ATOMIC_ADD: tl.constexpr, ): # Work on the subsequence dv[start_n, start_n + BLOCK_N, :] if CAUSAL: if HAS_MULTIPLE_TARGETS: low = start_n if HAS_MAX_ATTN_LEN: high = start_n + max_attn_len + BLOCK_N high = high if high + n_targets < seq_len else seq_len else: high = seq_len else: low = start_n if HAS_MAX_ATTN_LEN: high = start_n + max_attn_len + BLOCK_N high = high if high < seq_len else seq_len else: high = seq_len if HAS_CONTEXTUAL_SEQ_LEN: contextual_block_end = tl.cdiv(contextual_seq_len, BLOCK_M) * BLOCK_M if low < contextual_block_end: low = contextual_block_end else: low = 0 high = start_n + BLOCK_N # initialize row/col offsets offs_m = tl.arange(0, BLOCK_M) offs_qk_d = tl.arange(0, BLOCK_D_Q) offs_v_d = tl.arange(0, BLOCK_D_V) offs_n = start_n + tl.arange(0, BLOCK_N) # initialize pointers to value-like data q_ptrs_trans = Q + (offs_m[None, :] * stride_qm + offs_qk_d[:, None]) dq_ptrs_trans = DQ + (offs_m[None, :] * stride_dqm + offs_qk_d[:, None]) k_ptrs = K + (offs_n[:, None] * stride_kn + offs_qk_d[None, :]) v_ptrs = V + (offs_n[:, None] * stride_vn + offs_v_d[None, :]) mask_n = offs_n < seq_len do_ptrs = DOut + (offs_m[:, None] * stride_dom + offs_v_d[None, :]) # initialize dv and dk dv = tl.zeros([BLOCK_N, BLOCK_D_V], dtype=tl.float32) dk = tl.zeros([BLOCK_N, BLOCK_D_Q], dtype=tl.float32) # k and v stay in SRAM throughout k = tl.load(k_ptrs, mask=mask_n[:, None], other=0.0) v = tl.load(v_ptrs, mask=mask_n[:, None], other=0.0) max_ids = seq_len if HAS_CONTEXTUAL_SEQ_LEN: pos_offs_n = offs_n - contextual_seq_len + 1 pos_offs_n = tl.where( pos_offs_n > 0, pos_offs_n, 0, ) max_ids = max_ids - contextual_seq_len + 1 else: pos_offs_n = offs_n if HAS_MULTIPLE_TARGETS: max_ids = max_ids - n_targets pos_offs_n = tl.where( pos_offs_n < max_ids, pos_offs_n, max_ids, ) # loop over rows if HAS_CONTEXTUAL_SEQ_LEN and CAUSAL: for start_m in range(0, contextual_seq_len, BLOCK_M): start_m = tl.multiple_of(start_m, BLOCK_M) dk, dv = _hstu_attn_bwd_one_block( start_m=start_m, offs_n=offs_n, offs_m=offs_m, q_ptrs_trans=q_ptrs_trans, dq_ptrs_trans=dq_ptrs_trans, mask_n=mask_n, do_ptrs=do_ptrs, dk=dk, dv=dv, k=k, v=v, pos_offs_n=pos_offs_n, seq_len=seq_len, n_targets=n_targets, max_ids=max_ids, contextual_seq_len=contextual_seq_len, max_attn_len=max_attn_len, LOCK=LOCK, stride_qm=stride_qm, stride_dom=stride_dom, stride_dqm=stride_dqm, alpha=alpha, MAX_SEQ_LEN=MAX_SEQ_LEN, CAUSAL=CAUSAL, HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, ALLOW_TF32=ALLOW_TF32, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, ATOMIC_ADD=ATOMIC_ADD, ) for start_m in tl.range(low, high, BLOCK_M, loop_unroll_factor=UNROLL): start_m = tl.multiple_of(start_m, BLOCK_M) dk, dv = _hstu_attn_bwd_one_block( start_m=start_m, offs_n=offs_n, offs_m=offs_m, q_ptrs_trans=q_ptrs_trans, dq_ptrs_trans=dq_ptrs_trans, mask_n=mask_n, do_ptrs=do_ptrs, dk=dk, dv=dv, k=k, v=v, pos_offs_n=pos_offs_n, seq_len=seq_len, n_targets=n_targets, max_ids=max_ids, contextual_seq_len=contextual_seq_len, max_attn_len=max_attn_len, LOCK=LOCK, stride_qm=stride_qm, stride_dom=stride_dom, stride_dqm=stride_dqm, alpha=alpha, MAX_SEQ_LEN=MAX_SEQ_LEN, CAUSAL=CAUSAL, HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, ALLOW_TF32=ALLOW_TF32, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, ATOMIC_ADD=ATOMIC_ADD, ) # write-back dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_v_d[None, :]) dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_qk_d[None, :]) dk = dk * alpha tl.store(dv_ptrs, dv.to(k.dtype), mask=mask_n[:, None]) tl.store(dk_ptrs, dk.to(k.dtype), mask=mask_n[:, None]) def _bwd_pre_hook(nargs): nargs["DQ"].zero_() if nargs["SEQUENCE_PARALLEL"] is True: nargs["LOCK"].zero_() def _get_bw_configs() -> List[triton.Config]: if torch.version.hip: configs = [] for BLOCK_M in [32, 64]: for BLOCK_N in [32, 64]: for num_stages in [1, 2]: for num_warps in [4, 8]: for matrix_instr_nonkdim in [16, 32]: for waves_per_eu in [0, 2, 4]: for sp in [True, False]: configs.append( triton.Config( { "BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "matrix_instr_nonkdim": matrix_instr_nonkdim, # NOQA "waves_per_eu": waves_per_eu, "SEQUENCE_PARALLEL": sp, "UNROLL": 1, }, num_stages=num_stages, num_warps=num_warps, pre_hook=_bwd_pre_hook, ) ) return configs configs = [ triton.Config( {"BLOCK_M": 16, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, num_stages=2, num_warps=2, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 16, "BLOCK_N": 16, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, num_stages=2, num_warps=2, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 16, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, num_stages=2, num_warps=4, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 16, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, num_stages=1, num_warps=8, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 16, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, num_stages=1, num_warps=4, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, num_stages=1, num_warps=4, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, num_stages=2, num_warps=4, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, num_stages=1, num_warps=4, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, num_stages=2, num_warps=4, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, num_stages=1, num_warps=8, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, num_stages=2, num_warps=8, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, num_stages=1, num_warps=4, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, num_stages=2, num_warps=4, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, num_stages=1, num_warps=8, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, num_stages=2, num_warps=8, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, num_stages=2, num_warps=8, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, num_stages=3, num_warps=8, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False, "UNROLL": 2}, num_stages=2, num_warps=8, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False, "UNROLL": 4}, num_stages=2, num_warps=8, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 16, "BLOCK_N": 32, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, num_stages=2, num_warps=2, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 32, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, num_stages=1, num_warps=4, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 32, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, num_stages=2, num_warps=4, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, num_stages=1, num_warps=4, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, num_stages=2, num_warps=4, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, num_stages=1, num_warps=8, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, num_stages=1, num_warps=4, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, num_stages=2, num_warps=4, pre_hook=_bwd_pre_hook, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, num_stages=3, num_warps=8, pre_hook=_bwd_pre_hook, ), ] return configs @triton_autotune( configs=_get_bw_configs(), key=[ "AUTOTUNE_Z", "H", "AUTOTUNE_MAX_SEQ_LEN", "DimQ", "DimV", ], ) @triton.jit def _hstu_attn_bwd( # noqa C901 Q, K, V, sort_by_length_indices, seq_offsets, num_targets, DOut, DQ, DK, DV, LOCK, stride_qm, stride_qh, stride_kn, stride_kh, stride_vn, stride_vh, stride_dom, stride_doh, stride_dqm, stride_dqh, stride_dkn, stride_dkh, stride_dvn, stride_dvh, alpha, contextual_seq_len, max_attn_len, Z, AUTOTUNE_Z, H, MAX_SEQ_LEN, AUTOTUNE_MAX_SEQ_LEN, # Quantized MAX_SEQ_LEN used as an autotuning key DimQ, DimV, CAUSAL: tl.constexpr, HAS_MULTIPLE_TARGETS: tl.constexpr, HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, HAS_MAX_ATTN_LEN: tl.constexpr, ALLOW_TF32: tl.constexpr, BLOCK_D_Q: tl.constexpr, BLOCK_D_V: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, UNROLL: tl.constexpr, HAS_SORT_BY_LENGTH_INDICES: tl.constexpr, ): off_hz = tl.program_id(0) off_z = off_hz // H if HAS_SORT_BY_LENGTH_INDICES: off_z = tl.load(sort_by_length_indices + off_z) off_h = off_hz % H off_h = off_h.to(tl.int64) seq_start = tl.load(seq_offsets + off_z).to(tl.int64) seq_end = tl.load(seq_offsets + off_z + 1) seq_len = (seq_end - seq_start).to(tl.int32) if HAS_MULTIPLE_TARGETS: n_targets = tl.load(num_targets + off_z).to(tl.int32) else: n_targets = None # offset pointers for batch/head Q = Q + seq_start * stride_qm + off_h * stride_qh K = K + seq_start * stride_kn + off_h * stride_kh V = V + seq_start * stride_vn + off_h * stride_vh DOut = DOut + seq_start * stride_dom + off_h * stride_doh DQ = DQ + seq_start * stride_dqm + off_h * stride_dqh DK = DK + seq_start * stride_dkn + off_h * stride_dkh DV = DV + seq_start * stride_dvn + off_h * stride_dvh if SEQUENCE_PARALLEL: start_n = tl.program_id(1) * BLOCK_N if start_n >= seq_len: return _hstu_attn_bwd_one_col_block( start_n=start_n, seq_len=seq_len, n_targets=n_targets, contextual_seq_len=contextual_seq_len, max_attn_len=max_attn_len, Q=Q, K=K, V=V, DOut=DOut, DQ=DQ, DK=DK, DV=DV, LOCK=LOCK, stride_qm=stride_qm, stride_kn=stride_kn, stride_vn=stride_vn, stride_dom=stride_dom, stride_dqm=stride_dqm, stride_dkn=stride_dkn, stride_dvn=stride_dvn, alpha=alpha, MAX_SEQ_LEN=MAX_SEQ_LEN, CAUSAL=CAUSAL, HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, ALLOW_TF32=ALLOW_TF32, BLOCK_D_Q=BLOCK_D_Q, BLOCK_D_V=BLOCK_D_V, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, UNROLL=UNROLL, ATOMIC_ADD=True, ) else: for start_n in range(0, seq_len, BLOCK_N): _hstu_attn_bwd_one_col_block( start_n=start_n, seq_len=seq_len, n_targets=n_targets, contextual_seq_len=contextual_seq_len, max_attn_len=max_attn_len, Q=Q, K=K, V=V, DOut=DOut, DQ=DQ, DK=DK, DV=DV, LOCK=LOCK, stride_qm=stride_qm, stride_kn=stride_kn, stride_vn=stride_vn, stride_dom=stride_dom, stride_dqm=stride_dqm, stride_dkn=stride_dkn, stride_dvn=stride_dvn, alpha=alpha, MAX_SEQ_LEN=MAX_SEQ_LEN, CAUSAL=CAUSAL, HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, ALLOW_TF32=ALLOW_TF32, BLOCK_D_Q=BLOCK_D_Q, BLOCK_D_V=BLOCK_D_V, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, UNROLL=UNROLL, ATOMIC_ADD=False, ) def triton_hstu_attention_fwd( N: int, alpha: float, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_offsets: torch.Tensor, causal: bool, num_targets: Optional[torch.Tensor], max_attn_len: int, contextual_seq_len: int, sort_by_length_indices: Optional[torch.Tensor], ) -> torch.Tensor: Z = seq_offsets.numel() - 1 AUTOTUNE_Z = prev_power_of_2(Z) L, H, DimQ = q.shape _, _, DimV = v.shape out = torch.empty_like(v) has_multiple_targets = num_targets is not None has_contextual_seq_len = contextual_seq_len > 0 has_max_attn_len = max_attn_len > 0 has_sort_by_length_indices = sort_by_length_indices is not None if L == 0: return out grid = lambda meta: ( # noqa E731 triton.cdiv(N, meta["BLOCK_M"]), Z * H, ) _hstu_attn_fwd[grid]( Q=q, K=k, V=v, sort_by_length_indices=sort_by_length_indices, seq_offsets=seq_offsets, num_targets=num_targets, Out=out, stride_qm=q.stride(0), stride_qh=q.stride(1), stride_kn=k.stride(0), stride_kh=k.stride(1), stride_vn=v.stride(0), stride_vh=v.stride(1), stride_om=out.stride(0), stride_oh=out.stride(1), alpha=alpha, Z=Z, AUTOTUNE_Z=AUTOTUNE_Z, H=H, MAX_SEQ_LEN=N, AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(N), DimQ=DimQ, DimV=DimV, DeltaSize=0, contextual_seq_len=contextual_seq_len, max_attn_len=max_attn_len, CAUSAL=causal, HAS_MULTIPLE_TARGETS=has_multiple_targets, IS_DELTA_Q=False, ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, BLOCK_D_Q=DimQ, BLOCK_D_V=DimV, HAS_CONTEXTUAL_SEQ_LEN=has_contextual_seq_len, HAS_MAX_ATTN_LEN=has_max_attn_len, HAS_SORT_BY_LENGTH_INDICES=has_sort_by_length_indices, ) return out def triton_hstu_attention_bwd( dout: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dq: torch.Tensor, dk: torch.Tensor, dv: torch.Tensor, seq_offsets: torch.Tensor, num_targets: Optional[torch.Tensor], N: int, alpha: float, max_attn_len: int, causal: float, contextual_seq_len: int, sort_by_length_indices: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: dout = switch_to_contiguous_if_needed(dout) dq = switch_to_contiguous_if_needed(dq) dk = switch_to_contiguous_if_needed(dk) dv = switch_to_contiguous_if_needed(dv) if dout.shape[0] == 0: return torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v) Z = seq_offsets.numel() - 1 _, H, DimQ = q.shape _, _, DimV = v.shape grid = lambda meta: ( # noqa E731 Z * H, (triton.cdiv(N, meta["BLOCK_N"]) if meta["SEQUENCE_PARALLEL"] else 1), ) # The minimum size of BLOCK_M used in `_get_bw_configs`. # TODO (linjianma): avoid hardcoding the value. MIN_BLOCK_M = 16 lock = torch.empty( (Z * H, triton.cdiv(N, MIN_BLOCK_M)), dtype=torch.int32, device=q.device, ) AUTOTUNE_Z = prev_power_of_2(Z) _hstu_attn_bwd[grid]( Q=q, K=k, V=v, sort_by_length_indices=sort_by_length_indices, seq_offsets=seq_offsets, num_targets=num_targets, DOut=dout, DQ=dq, DK=dk, DV=dv, LOCK=lock, stride_qm=q.stride(0), stride_qh=q.stride(1), stride_kn=k.stride(0), stride_kh=k.stride(1), stride_vn=v.stride(0), stride_vh=v.stride(1), stride_dom=dout.stride(0), stride_doh=dout.stride(1), stride_dqm=dq.stride(0), stride_dqh=dq.stride(1), stride_dkn=dk.stride(0), stride_dkh=dk.stride(1), stride_dvn=dv.stride(0), stride_dvh=dv.stride(1), alpha=alpha, contextual_seq_len=contextual_seq_len, max_attn_len=max_attn_len, Z=Z, AUTOTUNE_Z=AUTOTUNE_Z, H=H, MAX_SEQ_LEN=N, AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(N), DimQ=DimQ, DimV=DimV, CAUSAL=causal, HAS_MULTIPLE_TARGETS=num_targets is not None, HAS_CONTEXTUAL_SEQ_LEN=contextual_seq_len > 0, HAS_MAX_ATTN_LEN=max_attn_len > 0, ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, BLOCK_D_Q=DimQ, BLOCK_D_V=DimV, HAS_SORT_BY_LENGTH_INDICES=sort_by_length_indices is not None, ) return dq, dk, dv class _AttentionFunction(torch.autograd.Function): @staticmethod # pyre-ignore[14] def forward( ctx, N: int, alpha: float, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_offsets: torch.Tensor, causal: bool, num_targets: Optional[torch.Tensor], max_attn_len: int, contextual_seq_len: int, sort_by_length: bool, ) -> torch.Tensor: sort_by_length_indices = None if sort_by_length: seq_lengths = seq_offsets[1:] - seq_offsets[:-1] _, sort_by_length_indices = torch.sort( seq_lengths, descending=True, stable=False ) saved_tensors = [q, k, v, seq_offsets] if num_targets is not None: saved_tensors.append(num_targets) if sort_by_length_indices is not None: saved_tensors.append(sort_by_length_indices) ctx.save_for_backward(*saved_tensors) ctx.alpha = alpha ctx.causal = causal ctx.has_multiple_targets = num_targets is not None ctx.max_attn_len = max_attn_len ctx.N = N ctx.contextual_seq_len = contextual_seq_len ctx.sort_by_length = sort_by_length return triton_hstu_attention_fwd( N=N, alpha=alpha, q=q, k=k, v=v, seq_offsets=seq_offsets, causal=causal, num_targets=num_targets, max_attn_len=max_attn_len, contextual_seq_len=contextual_seq_len, sort_by_length_indices=sort_by_length_indices, ) @staticmethod # pyre-ignore[14] def backward( ctx, dout: torch.Tensor ) -> Tuple[ None, None, torch.Tensor, torch.Tensor, torch.Tensor, None, None, None, None, None, None, ]: with torch.inference_mode(): q, k, v, seq_offsets = ctx.saved_tensors[:4] idx = 4 if ctx.has_multiple_targets: num_targets = ctx.saved_tensors[idx] idx += 1 else: num_targets = None if ctx.sort_by_length: sort_by_length_indices = ctx.saved_tensors[idx] else: sort_by_length_indices = None dq = torch.empty_like(q) dk = torch.empty_like(k) dv = torch.empty_like(v) dq, dk, dv = triton_hstu_attention_bwd( dout=dout, q=q, k=k, v=v, dq=dq, dk=dk, dv=dv, seq_offsets=seq_offsets, num_targets=num_targets, N=ctx.N, alpha=ctx.alpha, max_attn_len=ctx.max_attn_len, causal=ctx.causal, contextual_seq_len=ctx.contextual_seq_len, sort_by_length_indices=sort_by_length_indices, ) return ( None, None, dq, dk, dv, None, None, None, None, None, None, ) @torch.fx.wrap def triton_hstu_mha( N: int, alpha: float, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_offsets: torch.Tensor, causal: bool, num_targets: Optional[torch.Tensor] = None, max_attn_len: int = 0, contextual_seq_len: int = 0, sort_by_length: bool = False, ) -> torch.Tensor: return _AttentionFunction.apply( N, alpha, q, k, v, seq_offsets, causal, num_targets, max_attn_len, contextual_seq_len, sort_by_length, ) @torch.fx.wrap def triton_cached_hstu_mha( N: int, alpha: float, delta_q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_offsets: torch.Tensor, num_targets: Optional[torch.Tensor] = None, max_attn_len: int = 0, contextual_seq_len: int = 0, ) -> torch.Tensor: Z = seq_offsets.size(0) - 1 AUTOTUNE_Z = prev_power_of_2(Z) L, H, DimQ = delta_q.shape DeltaSize = L // Z _, _, DimV = v.shape out = torch.empty((L, H, DimV), dtype=delta_q.dtype, device=delta_q.device) grid = lambda meta: ( # noqa E731 triton.cdiv(DeltaSize, meta["BLOCK_M"]), Z * H, ) has_contextual_seq_len = contextual_seq_len > 0 has_max_attn_len = max_attn_len > 0 _hstu_attn_fwd[grid]( Q=delta_q, K=k, V=v, sort_by_length_indices=None, seq_offsets=seq_offsets, num_targets=num_targets, Out=out, stride_qm=delta_q.stride(0), stride_qh=delta_q.stride(1), stride_kn=k.stride(0), stride_kh=k.stride(1), stride_vn=v.stride(0), stride_vh=v.stride(1), stride_om=out.stride(0), stride_oh=out.stride(1), alpha=alpha, contextual_seq_len=contextual_seq_len, max_attn_len=max_attn_len, Z=Z, AUTOTUNE_Z=AUTOTUNE_Z, H=H, MAX_SEQ_LEN=N, AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(N), DimQ=DimQ, DimV=DimV, DeltaSize=DeltaSize, CAUSAL=True, HAS_MULTIPLE_TARGETS=num_targets is not None, IS_DELTA_Q=True, ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, BLOCK_D_Q=DimQ, BLOCK_D_V=DimV, HAS_CONTEXTUAL_SEQ_LEN=has_contextual_seq_len, HAS_MAX_ATTN_LEN=has_max_attn_len, HAS_SORT_BY_LENGTH_INDICES=False, ) return out