tzrec/ops/triton/triton_jagged_tensors.py (631 lines of code) (raw):
# Copyright (c) 2025, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# We use the jagged_tensors ops from generative-recommenders a starting point.
# https://github.com/facebookresearch/generative-recommenders
# thanks to their public work.
from typing import List, Optional, Tuple
import torch
import triton
import triton.language as tl
from triton.runtime.autotuner import autotune as triton_autotune
from tzrec.ops.utils import autotune_max_seq_len, switch_to_contiguous_if_needed
def _get_bmm_configs() -> List[triton.Config]:
configs = []
for BLOCK_M in [64, 128]:
for BLOCK_N in [64, 128]:
for BLOCK_K in [32, 64]:
for num_stages in [2, 3]:
for num_warps in [4, 8]:
configs.append(
triton.Config(
{
"BLOCK_M": BLOCK_M,
"BLOCK_N": BLOCK_N,
"BLOCK_K": BLOCK_K,
},
num_stages=num_stages,
num_warps=num_warps,
)
)
return configs
@triton.jit
def _concat_2D_jagged(
ValuesA,
ValuesB,
OffsetsA,
OffsetsB,
MaxLenA,
MaxLenB,
Out,
D,
stride_ad,
stride_bd,
stride_od,
n_prefix_from_B,
IS_DENSE_A: tl.constexpr,
IS_DENSE_B: tl.constexpr,
BLOCK_D: tl.constexpr,
):
off_z = tl.program_id(1)
off_n = tl.program_id(0)
if IS_DENSE_A:
seq_start_a = off_z * MaxLenA
seq_len_a = MaxLenA
else:
seq_start_a = tl.load(OffsetsA + off_z)
seq_end_a = tl.load(OffsetsA + off_z + 1)
seq_len_a = seq_end_a - seq_start_a
if IS_DENSE_B:
seq_start_b = off_z * MaxLenB
seq_len_b = MaxLenB
else:
seq_start_b = tl.load(OffsetsB + off_z)
seq_end_b = tl.load(OffsetsB + off_z + 1)
seq_len_b = seq_end_b - seq_start_b
seq_len = seq_len_a + seq_len_b
if off_n >= seq_len:
return
offs_d = tl.arange(0, BLOCK_D)
out_seq_start = seq_start_a + seq_start_b + off_n
out_ptrs = Out + out_seq_start.to(tl.int64) * stride_od + offs_d
if off_n < n_prefix_from_B:
in_ptrs = ValuesB + (off_n + seq_start_b).to(tl.int64) * stride_bd + offs_d
elif off_n < seq_len_a + n_prefix_from_B:
in_ptrs = (
ValuesA
+ (off_n - n_prefix_from_B + seq_start_a).to(tl.int64) * stride_ad
+ offs_d
)
else:
in_ptrs = (
ValuesB
+ (off_n - seq_len_a + seq_start_b).to(tl.int64) * stride_bd
+ offs_d
)
v = tl.load(in_ptrs, mask=offs_d < D)
tl.store(out_ptrs, v, mask=offs_d < D)
@triton.jit
def _split_2D_jagged(
JaggedIn,
OffsetsA,
OffsetsB,
MaxLenA,
MaxLenB,
OutA,
OutB,
D,
stride_id,
stride_ad,
stride_bd,
n_prefix_to_B,
IS_DENSE_A: tl.constexpr,
IS_DENSE_B: tl.constexpr,
BLOCK_D: tl.constexpr,
):
off_z = tl.program_id(1)
off_n = tl.program_id(0)
if IS_DENSE_A:
seq_start_a = off_z * MaxLenA
seq_len_a = MaxLenA
else:
seq_start_a = tl.load(OffsetsA + off_z)
seq_end_a = tl.load(OffsetsA + off_z + 1)
seq_len_a = seq_end_a - seq_start_a
if IS_DENSE_B:
seq_start_b = off_z * MaxLenB
seq_len_b = MaxLenB
else:
seq_start_b = tl.load(OffsetsB + off_z)
seq_end_b = tl.load(OffsetsB + off_z + 1)
seq_len_b = seq_end_b - seq_start_b
seq_len = seq_len_a + seq_len_b
if off_n >= seq_len:
return
seq_start = seq_start_a + seq_start_b
offs_d = tl.arange(0, BLOCK_D)
in_ptrs = JaggedIn + (seq_start + off_n).to(tl.int64) * stride_id + offs_d
if off_n < n_prefix_to_B:
out_ptrs = OutB + (off_n + seq_start_b).to(tl.int64) * stride_bd + offs_d
elif off_n < seq_len_a + n_prefix_to_B:
out_ptrs = (
OutA
+ (off_n - n_prefix_to_B + seq_start_a).to(tl.int64) * stride_ad
+ offs_d
)
else:
out_ptrs = (
OutB + (off_n - seq_len_a + seq_start_b).to(tl.int64) * stride_bd + offs_d
)
v = tl.load(in_ptrs, mask=offs_d < D)
tl.store(out_ptrs, v, mask=offs_d < D)
@triton_autotune(
configs=_get_bmm_configs(),
key=["AUTOTUNE_MAX_SEQ_LEN", "N", "K"],
)
@triton.jit
def jagged_dense_bmm_broadcast_add_kernel(
seq_offsets,
Jagged,
Dense,
Bias,
Out,
AUTOTUNE_MAX_SEQ_LEN,
N,
K,
stride_jm,
stride_db,
stride_dk,
stride_dn,
stride_bias_b,
stride_om,
HAS_BIAS: tl.constexpr,
ALLOW_TF32: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
"""Computing bmm Out = Jagged x Dense + Bias.
M is the jagged dimension
Jagged has shape (sum_B(M_i), K), Dense has shape (B, K, N), Bias has shape (B, N),
and Out has shape (sum_B(M_i), N)
"""
off_n = tl.program_id(0)
off_m = tl.program_id(1)
off_b = tl.program_id(2)
seq_start = tl.load(seq_offsets + off_b).to(tl.int64)
seq_end = tl.load(seq_offsets + off_b + 1)
seq_len = seq_end - seq_start
start_m = off_m * BLOCK_M
start_n = off_n * BLOCK_N
if start_m >= seq_len:
return
Jagged += seq_start * stride_jm
Dense += off_b.to(tl.int64) * stride_db
Out += seq_start * stride_om
offs_m = start_m + tl.arange(0, BLOCK_M)
offs_n = start_n + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
jg_ptrs = Jagged + offs_m[:, None] * stride_jm + offs_k[None, :]
dn_ptrs = Dense + offs_k[:, None] * stride_dk + offs_n[None, :] * stride_dn
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
jg = tl.load(
jg_ptrs,
# pyre-fixme[16]: `int` has no attribute `__getitem__`.
mask=(offs_m[:, None] < seq_len) and ((k + offs_k)[None, :] < K),
other=0.0,
)
dn = tl.load(
dn_ptrs,
mask=((k + offs_k)[:, None] < K) and (offs_n[None, :] < N),
other=0.0,
)
accumulator += tl.dot(jg, dn, allow_tf32=ALLOW_TF32)
jg_ptrs += BLOCK_K
dn_ptrs += BLOCK_K * stride_dk
if HAS_BIAS:
bias_ptrs = Bias + off_b * stride_bias_b + offs_n
bias = tl.load(bias_ptrs, mask=offs_n < N)
accumulator += bias[None, :].to(tl.float32)
out = accumulator.to(Out.dtype.element_ty)
offs_m = start_m + tl.arange(0, BLOCK_M)
offs_n = start_n + tl.arange(0, BLOCK_N)
out_ptrs = Out + offs_m[:, None] * stride_om + offs_n[None, :]
tl.store(out_ptrs, out, mask=(offs_m[:, None] < seq_len) & (offs_n[None, :] < N))
@triton_autotune(
configs=_get_bmm_configs(),
key=["M", "N", "AUTOTUNE_MAX_SEQ_LEN"],
)
@triton.jit
def _jagged_jagged_bmm_reduce_sum(
seq_offsets,
JaggedA,
JaggedB,
Out,
ReduceOut,
M,
N,
AUTOTUNE_MAX_SEQ_LEN,
stride_ak,
stride_bk,
stride_ob,
stride_om,
stride_on,
stride_orb,
stride_orn,
REDUCE_JAGGEDB: tl.constexpr,
ALLOW_TF32: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
"""Computing bmm Out = Jagged x Jagged.
K is the jagged dimension
JaggedA has shape (sum_B(K_i), M), JaggedB has shape (sum_B(K_i), N),
and Out has shape (B, M, N)
"""
off_b = tl.program_id(0)
off_m = tl.program_id(1)
off_n = tl.program_id(2)
seq_start = tl.load(seq_offsets + off_b).to(tl.int64)
seq_end = tl.load(seq_offsets + off_b + 1)
seq_len = seq_end - seq_start
start_m = off_m * BLOCK_M
start_n = off_n * BLOCK_N
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
Out += off_b.to(tl.int64) * stride_ob
offs_m = start_m + tl.arange(0, BLOCK_M)
offs_n = start_n + tl.arange(0, BLOCK_N)
out_ptrs = Out + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
if REDUCE_JAGGEDB:
out_reduce_ptrs = ReduceOut + off_b * stride_orb + offs_n * stride_orn
acc_reduce = tl.zeros((BLOCK_N,), dtype=tl.float32)
if seq_len == 0:
out = accumulator.to(Out.dtype.element_ty)
tl.store(out_ptrs, out, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
if REDUCE_JAGGEDB:
if off_m == 0:
tl.store(
out_reduce_ptrs, # pyre-ignore [61]
acc_reduce.to(ReduceOut.dtype.element_ty),
mask=(offs_n < N),
)
return
JaggedA += seq_start * stride_ak
JaggedB += seq_start * stride_bk
offs_k = tl.arange(0, BLOCK_K)
jg_a_ptrs = JaggedA + offs_k[None, :] * stride_ak + offs_m[:, None]
jg_b_ptrs = JaggedB + offs_k[:, None] * stride_bk + offs_n[None, :]
for k in range(0, seq_len, BLOCK_K):
jg_a = tl.load(
jg_a_ptrs,
# pyre-fixme[16]: `int` has no attribute `__getitem__`.
mask=(offs_m[:, None] < M) and ((k + offs_k)[None, :] < seq_len),
other=0.0,
)
jg_b = tl.load(
jg_b_ptrs,
mask=(offs_n[None, :] < N) and ((k + offs_k)[:, None] < seq_len),
other=0.0,
)
accumulator += tl.dot(jg_a, jg_b, allow_tf32=ALLOW_TF32)
if REDUCE_JAGGEDB:
if off_m == 0:
acc_reduce += tl.sum(jg_b.to(tl.float32), axis=0)
jg_a_ptrs += BLOCK_K * stride_ak
jg_b_ptrs += BLOCK_K * stride_bk
out = accumulator.to(Out.dtype.element_ty)
tl.store(out_ptrs, out, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
if REDUCE_JAGGEDB:
if off_m == 0:
tl.store(
out_reduce_ptrs, # pyre-ignore [61]
acc_reduce.to(ReduceOut.dtype.element_ty),
mask=(offs_n < N),
)
class _Concat2DJaggedFunction(torch.autograd.Function):
@staticmethod
# pyre-ignore[14]
def forward(
ctx,
values_a: torch.Tensor,
values_b: torch.Tensor,
max_len_a: int,
max_len_b: int,
offsets_a: Optional[torch.Tensor],
offsets_b: Optional[torch.Tensor],
n_prefix_from_B: int,
):
values_a = switch_to_contiguous_if_needed(values_a)
values_b = switch_to_contiguous_if_needed(values_b)
is_dense_a = offsets_a is None
is_dense_b = offsets_b is None
total_len_a, D = values_a.shape
total_len_b, _ = values_b.shape
if is_dense_a:
B = total_len_a // max_len_a
else:
assert offsets_a is not None
B = offsets_a.shape[0] - 1
if is_dense_b:
B = total_len_b // max_len_b
else:
assert offsets_b is not None
B = offsets_b.shape[0] - 1
total_seq_len = total_len_a + total_len_b
max_seq_len = max_len_a + max_len_b
BLOCK_D = triton.next_power_of_2(D)
values_out = torch.empty(
(total_seq_len, D), device=values_a.device, dtype=values_a.dtype
)
_concat_2D_jagged[(max_seq_len, B)](
ValuesA=values_a,
ValuesB=values_b,
OffsetsA=offsets_a,
OffsetsB=offsets_b,
MaxLenA=max_len_a,
MaxLenB=max_len_b,
Out=values_out,
D=D,
stride_ad=values_a.stride(-2),
stride_bd=values_b.stride(-2),
stride_od=values_out.stride(-2),
n_prefix_from_B=n_prefix_from_B,
# pyre-ignore[6]
IS_DENSE_A=is_dense_a,
# pyre-ignore[6]
IS_DENSE_B=is_dense_b,
BLOCK_D=BLOCK_D,
)
ctx.save_for_backward(offsets_a, offsets_b)
ctx.max_seq_len = max_seq_len
ctx.total_len_a = total_len_a
ctx.total_len_b = total_len_b
ctx.is_dense_a = is_dense_a
ctx.is_dense_b = is_dense_b
ctx.max_len_a = max_len_a
ctx.max_len_b = max_len_b
ctx.B = B
ctx.n_prefix_from_B = n_prefix_from_B
return values_out
@staticmethod
# pyre-ignore[14]
def backward(
ctx, d_out: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, None, None, None, None, None]:
offsets_a, offsets_b = ctx.saved_tensors
_, D = d_out.shape
BLOCK_D = triton.next_power_of_2(D)
d_values_a = torch.zeros(
(ctx.total_len_a, D), device=d_out.device, dtype=d_out.dtype
)
d_values_b = torch.empty(
(ctx.total_len_b, D), device=d_out.device, dtype=d_out.dtype
)
_split_2D_jagged[(ctx.max_seq_len, ctx.B)](
JaggedIn=d_out,
OffsetsA=offsets_a,
OffsetsB=offsets_b,
MaxLenA=ctx.max_len_a,
MaxLenB=ctx.max_len_b,
OutA=d_values_a,
OutB=d_values_b,
D=D,
stride_id=d_out.stride(-2),
stride_ad=d_values_a.stride(-2),
stride_bd=d_values_b.stride(-2),
n_prefix_to_B=ctx.n_prefix_from_B,
BLOCK_D=BLOCK_D,
IS_DENSE_A=ctx.is_dense_a,
IS_DENSE_B=ctx.is_dense_b,
)
return d_values_a, d_values_b, None, None, None, None, None
class _Split2DJaggedFunction(torch.autograd.Function):
@staticmethod
# pyre-ignore[14]
def forward(
ctx,
max_seq_len: int,
values: torch.Tensor,
max_len_a: Optional[int],
max_len_b: Optional[int],
offsets_a: Optional[torch.Tensor],
offsets_b: Optional[torch.Tensor],
n_prefix_to_B: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
values = switch_to_contiguous_if_needed(values)
is_dense_a: bool = offsets_a is None
is_dense_b: bool = offsets_b is None
total_seq_len, D = values.shape
if is_dense_a:
assert is_dense_b is False
assert offsets_b is not None
assert max_len_a is not None
B = offsets_b.shape[0] - 1
total_len_a = max_len_a * B
total_len_b = total_seq_len - total_len_a
elif is_dense_b:
assert is_dense_a is False
assert offsets_a is not None
assert max_len_b is not None
B = offsets_a.shape[0] - 1
total_len_b = max_len_b * B
total_len_a = total_seq_len - total_len_b
else:
assert offsets_a is not None and offsets_b is not None
B = offsets_a.shape[0] - 1
total_len_a = int(offsets_a[-1].item())
total_len_b = int(offsets_b[-1].item())
_, D = values.shape
BLOCK_D = triton.next_power_of_2(D)
values_a = torch.empty(
(total_len_a, D), device=values.device, dtype=values.dtype
)
values_b = torch.empty(
(total_len_b, D), device=values.device, dtype=values.dtype
)
_split_2D_jagged[(max_seq_len, B)](
JaggedIn=values,
OffsetsA=offsets_a,
OffsetsB=offsets_b,
MaxLenA=max_len_a,
MaxLenB=max_len_b,
OutA=values_a,
OutB=values_b,
D=D,
stride_id=values.stride(0),
stride_ad=values_a.stride(0),
stride_bd=values_b.stride(0),
n_prefix_to_B=n_prefix_to_B,
# pyre-ignore[6]
IS_DENSE_A=is_dense_a,
# pyre-ignore[6]
IS_DENSE_B=is_dense_b,
BLOCK_D=BLOCK_D,
)
ctx.save_for_backward(offsets_a, offsets_b)
ctx.max_seq_len = max_seq_len
ctx.total_seq_len = total_seq_len
ctx.max_len_a = max_len_a
ctx.max_len_b = max_len_b
ctx.is_dense_a = is_dense_a
ctx.is_dense_b = is_dense_b
ctx.B = B
ctx.D = D
ctx.n_prefix_to_B = n_prefix_to_B
return values_a, values_b
@staticmethod
def backward(
ctx, *d_values
) -> Tuple[None, torch.Tensor, None, None, None, None, None]:
offsets_a, offsets_b = ctx.saved_tensors
d_values_a, d_values_b = d_values
BLOCK_D = triton.next_power_of_2(ctx.D)
d_jagged_in = torch.empty(
(ctx.total_seq_len, ctx.D),
device=d_values_a.device,
dtype=d_values_a.dtype,
)
_concat_2D_jagged[(ctx.max_seq_len, ctx.B)](
ValuesA=d_values_a,
ValuesB=d_values_b,
OffsetsA=offsets_a,
OffsetsB=offsets_b,
MaxLenA=ctx.max_len_a,
MaxLenB=ctx.max_len_b,
Out=d_jagged_in,
D=ctx.D,
stride_ad=d_values_a.stride(-2),
stride_bd=d_values_b.stride(-2),
stride_od=d_jagged_in.stride(-2),
n_prefix_from_B=ctx.n_prefix_to_B,
IS_DENSE_A=ctx.is_dense_a,
IS_DENSE_B=ctx.is_dense_b,
BLOCK_D=BLOCK_D,
)
return None, d_jagged_in, None, None, None, None, None
class _JaggedDenseBmmBroadcastAddFunction(torch.autograd.Function):
@staticmethod
# pyre-ignore[14]
def forward(
ctx,
max_seq_len: int,
seq_offsets: torch.Tensor,
jagged: torch.Tensor,
dense: torch.Tensor,
bias: torch.Tensor,
):
jagged = switch_to_contiguous_if_needed(jagged)
bias = switch_to_contiguous_if_needed(bias)
L, K = jagged.shape
B, _, N = dense.shape
out = torch.empty((L, N), dtype=jagged.dtype, device=jagged.device)
grid = lambda meta: ( # noqa E731
triton.cdiv(N, meta["BLOCK_N"]),
triton.cdiv(max_seq_len, meta["BLOCK_M"]),
B,
)
jagged_dense_bmm_broadcast_add_kernel[grid](
seq_offsets=seq_offsets,
Jagged=jagged,
Dense=dense,
Bias=bias,
Out=out,
AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(max_seq_len),
N=N,
K=K,
stride_jm=jagged.stride(0),
stride_db=dense.stride(0),
stride_dk=dense.stride(1),
stride_dn=dense.stride(2),
stride_bias_b=bias.stride(0),
stride_om=out.stride(0),
HAS_BIAS=True,
ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32,
)
ctx.save_for_backward(seq_offsets, jagged, dense)
ctx.B = B
ctx.max_seq_len = max_seq_len
ctx.K = K
ctx.N = N
return out
@staticmethod
# pyre-ignore[14]
def backward(
ctx, d_out: torch.Tensor
) -> Tuple[None, None, torch.Tensor, torch.Tensor, torch.Tensor]:
seq_offsets, jagged, dense = ctx.saved_tensors
d_jagged = torch.empty_like(jagged)
d_dense = torch.empty_like(dense)
d_bias = torch.empty((ctx.B, ctx.N), device=d_out.device, dtype=d_out.dtype)
grid = lambda meta: ( # noqa E731
triton.cdiv(ctx.K, meta["BLOCK_N"]),
triton.cdiv(ctx.max_seq_len, meta["BLOCK_M"]),
ctx.B,
)
jagged_dense_bmm_broadcast_add_kernel[grid](
seq_offsets=seq_offsets,
Jagged=d_out,
Dense=dense,
Bias=None,
Out=d_jagged,
AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(ctx.max_seq_len),
N=ctx.K,
K=ctx.N,
stride_jm=d_out.stride(0),
stride_db=dense.stride(0),
stride_dk=dense.stride(2),
stride_dn=dense.stride(1),
stride_bias_b=0,
stride_om=d_jagged.stride(0),
HAS_BIAS=False,
ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32,
)
grid = lambda meta: ( # noqa E731
ctx.B,
triton.cdiv(ctx.K, meta["BLOCK_M"]),
triton.cdiv(ctx.N, meta["BLOCK_N"]),
)
_jagged_jagged_bmm_reduce_sum[grid](
seq_offsets=seq_offsets,
JaggedA=jagged,
JaggedB=d_out,
Out=d_dense,
ReduceOut=d_bias,
M=ctx.K,
N=ctx.N,
AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(ctx.max_seq_len),
stride_ak=jagged.stride(0),
stride_bk=d_out.stride(0),
stride_ob=d_dense.stride(0),
stride_om=d_dense.stride(1),
stride_on=d_dense.stride(2),
stride_orb=d_bias.stride(0),
stride_orn=d_bias.stride(1),
REDUCE_JAGGEDB=True,
ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32,
)
return None, None, d_jagged, d_dense, d_bias
@torch.fx.wrap
def triton_concat_2D_jagged(
values_left: torch.Tensor,
values_right: torch.Tensor,
max_len_left: int,
max_len_right: int,
offsets_left: Optional[torch.Tensor],
offsets_right: Optional[torch.Tensor],
n_prefix_from_right: int = 0,
) -> torch.Tensor:
return _Concat2DJaggedFunction.apply(
values_left,
values_right,
max_len_left,
max_len_right,
offsets_left,
offsets_right,
n_prefix_from_right,
)
@torch.fx.wrap
def triton_split_2D_jagged(
max_seq_len: int,
values: torch.Tensor,
max_len_left: Optional[int],
max_len_right: Optional[int],
offsets_left: Optional[torch.Tensor],
offsets_right: Optional[torch.Tensor],
n_prefix_to_right: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
return _Split2DJaggedFunction.apply(
max_seq_len,
values,
max_len_left,
max_len_right,
offsets_left,
offsets_right,
n_prefix_to_right,
)
@torch.fx.wrap
def triton_jagged_dense_bmm_broadcast_add(
max_seq_len: int,
seq_offsets: torch.Tensor,
jagged: torch.Tensor,
dense: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
return _JaggedDenseBmmBroadcastAddFunction.apply(
max_seq_len, seq_offsets, jagged, dense, bias
)