tzrec/ops/triton/triton_jagged_tensors.py (631 lines of code) (raw):

# Copyright (c) 2025, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # We use the jagged_tensors ops from generative-recommenders a starting point. # https://github.com/facebookresearch/generative-recommenders # thanks to their public work. from typing import List, Optional, Tuple import torch import triton import triton.language as tl from triton.runtime.autotuner import autotune as triton_autotune from tzrec.ops.utils import autotune_max_seq_len, switch_to_contiguous_if_needed def _get_bmm_configs() -> List[triton.Config]: configs = [] for BLOCK_M in [64, 128]: for BLOCK_N in [64, 128]: for BLOCK_K in [32, 64]: for num_stages in [2, 3]: for num_warps in [4, 8]: configs.append( triton.Config( { "BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "BLOCK_K": BLOCK_K, }, num_stages=num_stages, num_warps=num_warps, ) ) return configs @triton.jit def _concat_2D_jagged( ValuesA, ValuesB, OffsetsA, OffsetsB, MaxLenA, MaxLenB, Out, D, stride_ad, stride_bd, stride_od, n_prefix_from_B, IS_DENSE_A: tl.constexpr, IS_DENSE_B: tl.constexpr, BLOCK_D: tl.constexpr, ): off_z = tl.program_id(1) off_n = tl.program_id(0) if IS_DENSE_A: seq_start_a = off_z * MaxLenA seq_len_a = MaxLenA else: seq_start_a = tl.load(OffsetsA + off_z) seq_end_a = tl.load(OffsetsA + off_z + 1) seq_len_a = seq_end_a - seq_start_a if IS_DENSE_B: seq_start_b = off_z * MaxLenB seq_len_b = MaxLenB else: seq_start_b = tl.load(OffsetsB + off_z) seq_end_b = tl.load(OffsetsB + off_z + 1) seq_len_b = seq_end_b - seq_start_b seq_len = seq_len_a + seq_len_b if off_n >= seq_len: return offs_d = tl.arange(0, BLOCK_D) out_seq_start = seq_start_a + seq_start_b + off_n out_ptrs = Out + out_seq_start.to(tl.int64) * stride_od + offs_d if off_n < n_prefix_from_B: in_ptrs = ValuesB + (off_n + seq_start_b).to(tl.int64) * stride_bd + offs_d elif off_n < seq_len_a + n_prefix_from_B: in_ptrs = ( ValuesA + (off_n - n_prefix_from_B + seq_start_a).to(tl.int64) * stride_ad + offs_d ) else: in_ptrs = ( ValuesB + (off_n - seq_len_a + seq_start_b).to(tl.int64) * stride_bd + offs_d ) v = tl.load(in_ptrs, mask=offs_d < D) tl.store(out_ptrs, v, mask=offs_d < D) @triton.jit def _split_2D_jagged( JaggedIn, OffsetsA, OffsetsB, MaxLenA, MaxLenB, OutA, OutB, D, stride_id, stride_ad, stride_bd, n_prefix_to_B, IS_DENSE_A: tl.constexpr, IS_DENSE_B: tl.constexpr, BLOCK_D: tl.constexpr, ): off_z = tl.program_id(1) off_n = tl.program_id(0) if IS_DENSE_A: seq_start_a = off_z * MaxLenA seq_len_a = MaxLenA else: seq_start_a = tl.load(OffsetsA + off_z) seq_end_a = tl.load(OffsetsA + off_z + 1) seq_len_a = seq_end_a - seq_start_a if IS_DENSE_B: seq_start_b = off_z * MaxLenB seq_len_b = MaxLenB else: seq_start_b = tl.load(OffsetsB + off_z) seq_end_b = tl.load(OffsetsB + off_z + 1) seq_len_b = seq_end_b - seq_start_b seq_len = seq_len_a + seq_len_b if off_n >= seq_len: return seq_start = seq_start_a + seq_start_b offs_d = tl.arange(0, BLOCK_D) in_ptrs = JaggedIn + (seq_start + off_n).to(tl.int64) * stride_id + offs_d if off_n < n_prefix_to_B: out_ptrs = OutB + (off_n + seq_start_b).to(tl.int64) * stride_bd + offs_d elif off_n < seq_len_a + n_prefix_to_B: out_ptrs = ( OutA + (off_n - n_prefix_to_B + seq_start_a).to(tl.int64) * stride_ad + offs_d ) else: out_ptrs = ( OutB + (off_n - seq_len_a + seq_start_b).to(tl.int64) * stride_bd + offs_d ) v = tl.load(in_ptrs, mask=offs_d < D) tl.store(out_ptrs, v, mask=offs_d < D) @triton_autotune( configs=_get_bmm_configs(), key=["AUTOTUNE_MAX_SEQ_LEN", "N", "K"], ) @triton.jit def jagged_dense_bmm_broadcast_add_kernel( seq_offsets, Jagged, Dense, Bias, Out, AUTOTUNE_MAX_SEQ_LEN, N, K, stride_jm, stride_db, stride_dk, stride_dn, stride_bias_b, stride_om, HAS_BIAS: tl.constexpr, ALLOW_TF32: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): """Computing bmm Out = Jagged x Dense + Bias. M is the jagged dimension Jagged has shape (sum_B(M_i), K), Dense has shape (B, K, N), Bias has shape (B, N), and Out has shape (sum_B(M_i), N) """ off_n = tl.program_id(0) off_m = tl.program_id(1) off_b = tl.program_id(2) seq_start = tl.load(seq_offsets + off_b).to(tl.int64) seq_end = tl.load(seq_offsets + off_b + 1) seq_len = seq_end - seq_start start_m = off_m * BLOCK_M start_n = off_n * BLOCK_N if start_m >= seq_len: return Jagged += seq_start * stride_jm Dense += off_b.to(tl.int64) * stride_db Out += seq_start * stride_om offs_m = start_m + tl.arange(0, BLOCK_M) offs_n = start_n + tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_K) jg_ptrs = Jagged + offs_m[:, None] * stride_jm + offs_k[None, :] dn_ptrs = Dense + offs_k[:, None] * stride_dk + offs_n[None, :] * stride_dn accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(0, K, BLOCK_K): jg = tl.load( jg_ptrs, # pyre-fixme[16]: `int` has no attribute `__getitem__`. mask=(offs_m[:, None] < seq_len) and ((k + offs_k)[None, :] < K), other=0.0, ) dn = tl.load( dn_ptrs, mask=((k + offs_k)[:, None] < K) and (offs_n[None, :] < N), other=0.0, ) accumulator += tl.dot(jg, dn, allow_tf32=ALLOW_TF32) jg_ptrs += BLOCK_K dn_ptrs += BLOCK_K * stride_dk if HAS_BIAS: bias_ptrs = Bias + off_b * stride_bias_b + offs_n bias = tl.load(bias_ptrs, mask=offs_n < N) accumulator += bias[None, :].to(tl.float32) out = accumulator.to(Out.dtype.element_ty) offs_m = start_m + tl.arange(0, BLOCK_M) offs_n = start_n + tl.arange(0, BLOCK_N) out_ptrs = Out + offs_m[:, None] * stride_om + offs_n[None, :] tl.store(out_ptrs, out, mask=(offs_m[:, None] < seq_len) & (offs_n[None, :] < N)) @triton_autotune( configs=_get_bmm_configs(), key=["M", "N", "AUTOTUNE_MAX_SEQ_LEN"], ) @triton.jit def _jagged_jagged_bmm_reduce_sum( seq_offsets, JaggedA, JaggedB, Out, ReduceOut, M, N, AUTOTUNE_MAX_SEQ_LEN, stride_ak, stride_bk, stride_ob, stride_om, stride_on, stride_orb, stride_orn, REDUCE_JAGGEDB: tl.constexpr, ALLOW_TF32: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): """Computing bmm Out = Jagged x Jagged. K is the jagged dimension JaggedA has shape (sum_B(K_i), M), JaggedB has shape (sum_B(K_i), N), and Out has shape (B, M, N) """ off_b = tl.program_id(0) off_m = tl.program_id(1) off_n = tl.program_id(2) seq_start = tl.load(seq_offsets + off_b).to(tl.int64) seq_end = tl.load(seq_offsets + off_b + 1) seq_len = seq_end - seq_start start_m = off_m * BLOCK_M start_n = off_n * BLOCK_N accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) Out += off_b.to(tl.int64) * stride_ob offs_m = start_m + tl.arange(0, BLOCK_M) offs_n = start_n + tl.arange(0, BLOCK_N) out_ptrs = Out + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on if REDUCE_JAGGEDB: out_reduce_ptrs = ReduceOut + off_b * stride_orb + offs_n * stride_orn acc_reduce = tl.zeros((BLOCK_N,), dtype=tl.float32) if seq_len == 0: out = accumulator.to(Out.dtype.element_ty) tl.store(out_ptrs, out, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N)) if REDUCE_JAGGEDB: if off_m == 0: tl.store( out_reduce_ptrs, # pyre-ignore [61] acc_reduce.to(ReduceOut.dtype.element_ty), mask=(offs_n < N), ) return JaggedA += seq_start * stride_ak JaggedB += seq_start * stride_bk offs_k = tl.arange(0, BLOCK_K) jg_a_ptrs = JaggedA + offs_k[None, :] * stride_ak + offs_m[:, None] jg_b_ptrs = JaggedB + offs_k[:, None] * stride_bk + offs_n[None, :] for k in range(0, seq_len, BLOCK_K): jg_a = tl.load( jg_a_ptrs, # pyre-fixme[16]: `int` has no attribute `__getitem__`. mask=(offs_m[:, None] < M) and ((k + offs_k)[None, :] < seq_len), other=0.0, ) jg_b = tl.load( jg_b_ptrs, mask=(offs_n[None, :] < N) and ((k + offs_k)[:, None] < seq_len), other=0.0, ) accumulator += tl.dot(jg_a, jg_b, allow_tf32=ALLOW_TF32) if REDUCE_JAGGEDB: if off_m == 0: acc_reduce += tl.sum(jg_b.to(tl.float32), axis=0) jg_a_ptrs += BLOCK_K * stride_ak jg_b_ptrs += BLOCK_K * stride_bk out = accumulator.to(Out.dtype.element_ty) tl.store(out_ptrs, out, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N)) if REDUCE_JAGGEDB: if off_m == 0: tl.store( out_reduce_ptrs, # pyre-ignore [61] acc_reduce.to(ReduceOut.dtype.element_ty), mask=(offs_n < N), ) class _Concat2DJaggedFunction(torch.autograd.Function): @staticmethod # pyre-ignore[14] def forward( ctx, values_a: torch.Tensor, values_b: torch.Tensor, max_len_a: int, max_len_b: int, offsets_a: Optional[torch.Tensor], offsets_b: Optional[torch.Tensor], n_prefix_from_B: int, ): values_a = switch_to_contiguous_if_needed(values_a) values_b = switch_to_contiguous_if_needed(values_b) is_dense_a = offsets_a is None is_dense_b = offsets_b is None total_len_a, D = values_a.shape total_len_b, _ = values_b.shape if is_dense_a: B = total_len_a // max_len_a else: assert offsets_a is not None B = offsets_a.shape[0] - 1 if is_dense_b: B = total_len_b // max_len_b else: assert offsets_b is not None B = offsets_b.shape[0] - 1 total_seq_len = total_len_a + total_len_b max_seq_len = max_len_a + max_len_b BLOCK_D = triton.next_power_of_2(D) values_out = torch.empty( (total_seq_len, D), device=values_a.device, dtype=values_a.dtype ) _concat_2D_jagged[(max_seq_len, B)]( ValuesA=values_a, ValuesB=values_b, OffsetsA=offsets_a, OffsetsB=offsets_b, MaxLenA=max_len_a, MaxLenB=max_len_b, Out=values_out, D=D, stride_ad=values_a.stride(-2), stride_bd=values_b.stride(-2), stride_od=values_out.stride(-2), n_prefix_from_B=n_prefix_from_B, # pyre-ignore[6] IS_DENSE_A=is_dense_a, # pyre-ignore[6] IS_DENSE_B=is_dense_b, BLOCK_D=BLOCK_D, ) ctx.save_for_backward(offsets_a, offsets_b) ctx.max_seq_len = max_seq_len ctx.total_len_a = total_len_a ctx.total_len_b = total_len_b ctx.is_dense_a = is_dense_a ctx.is_dense_b = is_dense_b ctx.max_len_a = max_len_a ctx.max_len_b = max_len_b ctx.B = B ctx.n_prefix_from_B = n_prefix_from_B return values_out @staticmethod # pyre-ignore[14] def backward( ctx, d_out: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, None, None, None, None, None]: offsets_a, offsets_b = ctx.saved_tensors _, D = d_out.shape BLOCK_D = triton.next_power_of_2(D) d_values_a = torch.zeros( (ctx.total_len_a, D), device=d_out.device, dtype=d_out.dtype ) d_values_b = torch.empty( (ctx.total_len_b, D), device=d_out.device, dtype=d_out.dtype ) _split_2D_jagged[(ctx.max_seq_len, ctx.B)]( JaggedIn=d_out, OffsetsA=offsets_a, OffsetsB=offsets_b, MaxLenA=ctx.max_len_a, MaxLenB=ctx.max_len_b, OutA=d_values_a, OutB=d_values_b, D=D, stride_id=d_out.stride(-2), stride_ad=d_values_a.stride(-2), stride_bd=d_values_b.stride(-2), n_prefix_to_B=ctx.n_prefix_from_B, BLOCK_D=BLOCK_D, IS_DENSE_A=ctx.is_dense_a, IS_DENSE_B=ctx.is_dense_b, ) return d_values_a, d_values_b, None, None, None, None, None class _Split2DJaggedFunction(torch.autograd.Function): @staticmethod # pyre-ignore[14] def forward( ctx, max_seq_len: int, values: torch.Tensor, max_len_a: Optional[int], max_len_b: Optional[int], offsets_a: Optional[torch.Tensor], offsets_b: Optional[torch.Tensor], n_prefix_to_B: int, ) -> Tuple[torch.Tensor, torch.Tensor]: values = switch_to_contiguous_if_needed(values) is_dense_a: bool = offsets_a is None is_dense_b: bool = offsets_b is None total_seq_len, D = values.shape if is_dense_a: assert is_dense_b is False assert offsets_b is not None assert max_len_a is not None B = offsets_b.shape[0] - 1 total_len_a = max_len_a * B total_len_b = total_seq_len - total_len_a elif is_dense_b: assert is_dense_a is False assert offsets_a is not None assert max_len_b is not None B = offsets_a.shape[0] - 1 total_len_b = max_len_b * B total_len_a = total_seq_len - total_len_b else: assert offsets_a is not None and offsets_b is not None B = offsets_a.shape[0] - 1 total_len_a = int(offsets_a[-1].item()) total_len_b = int(offsets_b[-1].item()) _, D = values.shape BLOCK_D = triton.next_power_of_2(D) values_a = torch.empty( (total_len_a, D), device=values.device, dtype=values.dtype ) values_b = torch.empty( (total_len_b, D), device=values.device, dtype=values.dtype ) _split_2D_jagged[(max_seq_len, B)]( JaggedIn=values, OffsetsA=offsets_a, OffsetsB=offsets_b, MaxLenA=max_len_a, MaxLenB=max_len_b, OutA=values_a, OutB=values_b, D=D, stride_id=values.stride(0), stride_ad=values_a.stride(0), stride_bd=values_b.stride(0), n_prefix_to_B=n_prefix_to_B, # pyre-ignore[6] IS_DENSE_A=is_dense_a, # pyre-ignore[6] IS_DENSE_B=is_dense_b, BLOCK_D=BLOCK_D, ) ctx.save_for_backward(offsets_a, offsets_b) ctx.max_seq_len = max_seq_len ctx.total_seq_len = total_seq_len ctx.max_len_a = max_len_a ctx.max_len_b = max_len_b ctx.is_dense_a = is_dense_a ctx.is_dense_b = is_dense_b ctx.B = B ctx.D = D ctx.n_prefix_to_B = n_prefix_to_B return values_a, values_b @staticmethod def backward( ctx, *d_values ) -> Tuple[None, torch.Tensor, None, None, None, None, None]: offsets_a, offsets_b = ctx.saved_tensors d_values_a, d_values_b = d_values BLOCK_D = triton.next_power_of_2(ctx.D) d_jagged_in = torch.empty( (ctx.total_seq_len, ctx.D), device=d_values_a.device, dtype=d_values_a.dtype, ) _concat_2D_jagged[(ctx.max_seq_len, ctx.B)]( ValuesA=d_values_a, ValuesB=d_values_b, OffsetsA=offsets_a, OffsetsB=offsets_b, MaxLenA=ctx.max_len_a, MaxLenB=ctx.max_len_b, Out=d_jagged_in, D=ctx.D, stride_ad=d_values_a.stride(-2), stride_bd=d_values_b.stride(-2), stride_od=d_jagged_in.stride(-2), n_prefix_from_B=ctx.n_prefix_to_B, IS_DENSE_A=ctx.is_dense_a, IS_DENSE_B=ctx.is_dense_b, BLOCK_D=BLOCK_D, ) return None, d_jagged_in, None, None, None, None, None class _JaggedDenseBmmBroadcastAddFunction(torch.autograd.Function): @staticmethod # pyre-ignore[14] def forward( ctx, max_seq_len: int, seq_offsets: torch.Tensor, jagged: torch.Tensor, dense: torch.Tensor, bias: torch.Tensor, ): jagged = switch_to_contiguous_if_needed(jagged) bias = switch_to_contiguous_if_needed(bias) L, K = jagged.shape B, _, N = dense.shape out = torch.empty((L, N), dtype=jagged.dtype, device=jagged.device) grid = lambda meta: ( # noqa E731 triton.cdiv(N, meta["BLOCK_N"]), triton.cdiv(max_seq_len, meta["BLOCK_M"]), B, ) jagged_dense_bmm_broadcast_add_kernel[grid]( seq_offsets=seq_offsets, Jagged=jagged, Dense=dense, Bias=bias, Out=out, AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(max_seq_len), N=N, K=K, stride_jm=jagged.stride(0), stride_db=dense.stride(0), stride_dk=dense.stride(1), stride_dn=dense.stride(2), stride_bias_b=bias.stride(0), stride_om=out.stride(0), HAS_BIAS=True, ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, ) ctx.save_for_backward(seq_offsets, jagged, dense) ctx.B = B ctx.max_seq_len = max_seq_len ctx.K = K ctx.N = N return out @staticmethod # pyre-ignore[14] def backward( ctx, d_out: torch.Tensor ) -> Tuple[None, None, torch.Tensor, torch.Tensor, torch.Tensor]: seq_offsets, jagged, dense = ctx.saved_tensors d_jagged = torch.empty_like(jagged) d_dense = torch.empty_like(dense) d_bias = torch.empty((ctx.B, ctx.N), device=d_out.device, dtype=d_out.dtype) grid = lambda meta: ( # noqa E731 triton.cdiv(ctx.K, meta["BLOCK_N"]), triton.cdiv(ctx.max_seq_len, meta["BLOCK_M"]), ctx.B, ) jagged_dense_bmm_broadcast_add_kernel[grid]( seq_offsets=seq_offsets, Jagged=d_out, Dense=dense, Bias=None, Out=d_jagged, AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(ctx.max_seq_len), N=ctx.K, K=ctx.N, stride_jm=d_out.stride(0), stride_db=dense.stride(0), stride_dk=dense.stride(2), stride_dn=dense.stride(1), stride_bias_b=0, stride_om=d_jagged.stride(0), HAS_BIAS=False, ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, ) grid = lambda meta: ( # noqa E731 ctx.B, triton.cdiv(ctx.K, meta["BLOCK_M"]), triton.cdiv(ctx.N, meta["BLOCK_N"]), ) _jagged_jagged_bmm_reduce_sum[grid]( seq_offsets=seq_offsets, JaggedA=jagged, JaggedB=d_out, Out=d_dense, ReduceOut=d_bias, M=ctx.K, N=ctx.N, AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(ctx.max_seq_len), stride_ak=jagged.stride(0), stride_bk=d_out.stride(0), stride_ob=d_dense.stride(0), stride_om=d_dense.stride(1), stride_on=d_dense.stride(2), stride_orb=d_bias.stride(0), stride_orn=d_bias.stride(1), REDUCE_JAGGEDB=True, ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, ) return None, None, d_jagged, d_dense, d_bias @torch.fx.wrap def triton_concat_2D_jagged( values_left: torch.Tensor, values_right: torch.Tensor, max_len_left: int, max_len_right: int, offsets_left: Optional[torch.Tensor], offsets_right: Optional[torch.Tensor], n_prefix_from_right: int = 0, ) -> torch.Tensor: return _Concat2DJaggedFunction.apply( values_left, values_right, max_len_left, max_len_right, offsets_left, offsets_right, n_prefix_from_right, ) @torch.fx.wrap def triton_split_2D_jagged( max_seq_len: int, values: torch.Tensor, max_len_left: Optional[int], max_len_right: Optional[int], offsets_left: Optional[torch.Tensor], offsets_right: Optional[torch.Tensor], n_prefix_to_right: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor]: return _Split2DJaggedFunction.apply( max_seq_len, values, max_len_left, max_len_right, offsets_left, offsets_right, n_prefix_to_right, ) @torch.fx.wrap def triton_jagged_dense_bmm_broadcast_add( max_seq_len: int, seq_offsets: torch.Tensor, jagged: torch.Tensor, dense: torch.Tensor, bias: torch.Tensor, ) -> torch.Tensor: return _JaggedDenseBmmBroadcastAddFunction.apply( max_seq_len, seq_offsets, jagged, dense, bias )