tzrec/ops/triton/triton_position.py (586 lines of code) (raw):
# Copyright (c) 2025, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# We use the position ecnoder ops from generative-recommenders a starting point.
# https://github.com/facebookresearch/generative-recommenders
# thanks to their public work.
from typing import List, Optional, Tuple
import torch
import triton
import triton.language as tl
from triton.runtime.autotuner import autotune as triton_autotune
from tzrec.ops.utils import (
autotune_max_seq_len,
prev_power_of_2,
switch_to_contiguous_if_needed,
)
torch.fx.wrap(prev_power_of_2)
def _add_position_embeddings_configs() -> List[triton.Config]:
configs = []
for BLOCK_N in [16, 32, 64]:
for num_stages in [1, 2]:
for num_warps in [2, 4, 8]:
configs.append(
triton.Config(
{
"BLOCK_N": BLOCK_N,
},
num_stages=num_stages,
num_warps=num_warps,
)
)
return configs
@triton_autotune(
configs=_add_position_embeddings_configs(),
key=["AUTOTUNE_MAX_SEQ_LEN"],
)
@triton.jit
def _add_position_embeddings_kernel(
Jagged,
seq_offsets,
high_inds,
Dense,
Out,
AUTOTUNE_MAX_SEQ_LEN,
D,
scale,
stride_jn,
stride_dk,
stride_on,
SCALE_JAGGED: tl.constexpr,
BLOCK_D: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""Add position embedding.
Jagged has shape (sum_B(N_i), D),
Dense has shape (K, D),
Out has shape (sum_B(N_i), D)
"""
off_b = tl.program_id(0)
off_n = tl.program_id(1)
seq_start = tl.load(seq_offsets + off_b)
seq_end = tl.load(seq_offsets + off_b + 1)
max_ind = tl.load(high_inds + off_b)
seq_len = seq_end - seq_start
start_n = off_n * BLOCK_N
if start_n >= seq_len:
return
offs_n = start_n + tl.arange(0, BLOCK_N)
clamped_offs_n = tl.where(offs_n >= max_ind, max_ind, offs_n)
offs_d = tl.arange(0, BLOCK_D)
Jagged += seq_start.to(tl.int64) * stride_jn
jagged_ptr_offsets = offs_n[:, None] * stride_jn + offs_d[None, :]
Out += seq_start.to(tl.int64) * stride_on
out_ptrs = Out + offs_n[:, None] * stride_on + offs_d[None, :]
dense_ptrs = Dense + clamped_offs_n[:, None] * stride_dk + offs_d[None, :]
for _d in range(0, D, BLOCK_D):
mask = (offs_n[:, None] < seq_len) and offs_d[None, :] < D
jg = tl.load(Jagged + jagged_ptr_offsets, mask=mask)
if SCALE_JAGGED:
jg = jg * scale
dn = tl.load(dense_ptrs, mask=mask)
jg += dn
tl.store(out_ptrs, jg, mask=mask)
dense_ptrs += BLOCK_D
out_ptrs += BLOCK_D
offs_d += BLOCK_D
jagged_ptr_offsets += BLOCK_D
@triton.jit
def _add_position_embeddings_bwd_kernel(
Jagged,
seq_offsets,
high_inds,
DenseOut,
JaggedOut,
B,
D,
scale,
stride_jn,
stride_jon,
stride_don,
SCALE_JAGGED: tl.constexpr,
BLOCK_D: tl.constexpr,
):
off_k = tl.program_id(0)
offs_d = tl.arange(0, BLOCK_D)
accumulator = tl.zeros((BLOCK_D,), dtype=tl.float32)
for off_b in range(0, B):
max_ind = tl.load(high_inds + off_b)
if off_k < max_ind:
seq_start = tl.load(seq_offsets + off_b)
jagged_ptr = (
Jagged
+ seq_start.to(tl.int64) * stride_jn
+ off_k.to(tl.int64) * stride_jn
)
jagged_ptrs = jagged_ptr + offs_d
jg = tl.load(
jagged_ptrs,
mask=offs_d < D,
)
accumulator += jg
if SCALE_JAGGED:
out_jagged_ptr = (
JaggedOut
+ seq_start.to(tl.int64) * stride_jon
+ off_k.to(tl.int64) * stride_jon
)
out_jagged_ptrs = out_jagged_ptr + offs_d
tl.store(
out_jagged_ptrs,
jg * scale,
mask=offs_d < D,
)
elif off_k == max_ind:
seq_start = tl.load(seq_offsets + off_b).to(tl.int64)
seq_end = tl.load(seq_offsets + off_b + 1)
for k in range(seq_start + max_ind, seq_end):
jagged_ptr = Jagged + k * stride_jn
jagged_ptrs = jagged_ptr + offs_d
jg = tl.load(
jagged_ptrs,
mask=offs_d < D,
)
accumulator += jg
if SCALE_JAGGED:
out_jagged_ptr = JaggedOut + k * stride_jon
out_jagged_ptrs = out_jagged_ptr + offs_d
tl.store(
out_jagged_ptrs,
jg * scale,
mask=offs_d < D,
)
out = accumulator.to(DenseOut.dtype.element_ty)
out_ptrs = DenseOut + off_k * stride_don + offs_d
tl.store(
out_ptrs,
out,
mask=offs_d < D,
)
class _AddPositionEmbeddingsFunction(torch.autograd.Function):
@staticmethod
# pyre-ignore[14]
def forward(
ctx,
jagged: torch.Tensor,
jagged_offsets: torch.Tensor,
high_inds: torch.Tensor,
max_seq_len: int,
dense: torch.Tensor,
scale: float = 1.0,
):
jagged = switch_to_contiguous_if_needed(jagged)
dense = switch_to_contiguous_if_needed(dense)
L, D = jagged.shape
assert len(dense.shape) == 2
out = torch.empty_like(jagged)
B = high_inds.size(0)
grid = lambda meta: ( # noqa E731
B,
triton.cdiv(max_seq_len, meta["BLOCK_N"]),
)
BLOCK_D = triton.next_power_of_2(D) if D < 64 else 64
_add_position_embeddings_kernel[grid](
Jagged=jagged,
seq_offsets=jagged_offsets,
high_inds=high_inds,
Dense=dense,
Out=out,
AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(max_seq_len),
D=D,
scale=scale,
stride_jn=jagged.stride(0),
stride_dk=dense.stride(0),
stride_on=out.stride(0),
SCALE_JAGGED=scale != 1.0,
BLOCK_D=BLOCK_D,
)
ctx.save_for_backward(jagged_offsets, high_inds)
ctx.B = B
ctx.D = D
ctx.scale = scale
ctx.K = dense.size(0)
ctx.BLOCK_D = BLOCK_D
return out
@staticmethod
# pyre-ignore[14]
def backward(
ctx, d_out: torch.Tensor
) -> Tuple[torch.Tensor, None, None, None, torch.Tensor, None]:
jagged_offsets, high_inds = ctx.saved_tensors
d_dense = torch.empty((ctx.K, ctx.D), device=d_out.device, dtype=d_out.dtype)
scale_jagged = ctx.scale != 1.0
if scale_jagged:
d_jagged = torch.empty_like(d_out)
BLOCK_D = triton.next_power_of_2(ctx.D)
_add_position_embeddings_bwd_kernel[(ctx.K,)](
Jagged=d_out,
seq_offsets=jagged_offsets,
high_inds=high_inds,
DenseOut=d_dense,
JaggedOut=d_jagged if scale_jagged else None, # pyre-ignore[61]
B=ctx.B,
D=ctx.D,
scale=ctx.scale,
stride_jn=d_out.stride(0),
stride_jon=d_jagged.stride(0) if scale_jagged else 0,
stride_don=d_dense.stride(0),
SCALE_JAGGED=scale_jagged,
BLOCK_D=BLOCK_D,
)
# pyre-ignore[61]
return d_jagged if scale_jagged else d_out, None, None, None, d_dense, None
@triton_autotune(
configs=_add_position_embeddings_configs(),
key=["AUTOTUNE_MAX_SEQ_LEN"],
)
@triton.jit
def _add_timestamp_position_embeddings_kernel(
SeqEmb,
Offsets,
Lengths,
PosEmb,
TsEmb,
Out,
TS,
PosInds,
TsInds,
NumTargets,
AUTOTUNE_MAX_SEQ_LEN,
D,
num_time_buckets,
time_bucket_increments,
time_bucket_scale,
time_delta,
max_contextual_seq_len,
max_pos_ind,
stride_sn,
stride_pn,
stride_tn,
stride_on,
TRAINING: tl.constexpr,
HAS_MULTIPLE_TARGETS: tl.constexpr,
INTERLEAVE_TARGETS: tl.constexpr,
TIME_BUCKET_FN: tl.constexpr,
BLOCK_D: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""Add timestamp position embedding.
SeqEmb has shape (sum_B(N_i), D),
PosEmb has shape (N_p, D),
TsEmb has shape (N_t, D),
Out has shape (sum_B(N_i), D)
"""
off_b = tl.program_id(0)
off_n = tl.program_id(1)
seq_start = tl.load(Offsets + off_b)
seq_end = tl.load(Offsets + off_b + 1)
seq_len = seq_end - seq_start
start_n = off_n * BLOCK_N
if start_n >= seq_len:
return
offs_n = start_n + tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_D)
seq_emb_offsets = offs_n[:, None] * stride_sn + offs_d[None, :]
SeqEmb += seq_start.to(tl.int64) * stride_sn
mask_n = offs_n < seq_len
# position encoding
seq_len = tl.load(Lengths + off_b)
if HAS_MULTIPLE_TARGETS:
num_targets = tl.load(NumTargets + off_b)
if INTERLEAVE_TARGETS:
high_ind = seq_len - num_targets * 2
else:
high_ind = seq_len - num_targets
else:
high_ind = seq_len
pos_inds = tl.where(offs_n < high_ind, offs_n, high_ind)
pos_inds = high_ind - pos_inds + max_contextual_seq_len
pos_inds = tl.where(pos_inds < max_pos_ind - 1, pos_inds, max_pos_ind - 1)
pos_inds = tl.where(offs_n < max_contextual_seq_len, offs_n, pos_inds)
if TRAINING:
tl.store(PosInds + seq_start + offs_n, pos_inds, mask=mask_n)
pos_emb_offsets = pos_inds[:, None] * stride_pn + offs_d[None, :]
# timestamp encoding
ts = tl.load(TS + seq_start + offs_n, mask=mask_n)
query_time = tl.load(TS + seq_end - 1)
ts = query_time - ts + time_delta
ts = tl.where(ts > 1e-6, ts, 1e-6) / time_bucket_increments
if TIME_BUCKET_FN == "log":
ts = tl.log(ts)
else:
ts = tl.sqrt(ts)
ts = ts * time_bucket_scale
ts = ts.to(tl.int32)
ts = tl.where(ts > 0, ts, 0)
ts = tl.where(ts < num_time_buckets, ts, num_time_buckets)
if TRAINING:
tl.store(TsInds + seq_start + offs_n, ts, mask=mask_n)
ts_emb_offsets = ts[:, None] * stride_tn + offs_d[None, :]
Out += seq_start.to(tl.int64) * stride_on
out_offsets = Out + offs_n[:, None] * stride_on + offs_d[None, :]
for _d in range(0, D, BLOCK_D):
mask = (offs_n[:, None] < seq_len) and offs_d[None, :] < D
seq_emb = tl.load(SeqEmb + seq_emb_offsets, mask=mask)
pos_emb = tl.load(PosEmb + pos_emb_offsets, mask=mask)
ts_emb = tl.load(TsEmb + ts_emb_offsets, mask=mask)
tl.store(out_offsets, seq_emb + (pos_emb + ts_emb).to(seq_emb.dtype), mask=mask)
seq_emb_offsets += BLOCK_D
pos_emb_offsets += BLOCK_D
ts_emb_offsets += BLOCK_D
out_offsets += BLOCK_D
offs_d += BLOCK_D
def bwd_pre_hook(nargs):
nargs["Out"].zero_()
def _add_embeddings_bwd_configs() -> List[triton.Config]:
configs = []
for BLOCK in [32, 64, 128]:
for num_stages in [2, 3, 4]:
for num_warps in [2, 4, 8]:
configs.append(
triton.Config(
{
"BLOCK": BLOCK,
},
num_stages=num_stages,
num_warps=num_warps,
pre_hook=bwd_pre_hook,
)
)
return configs
@triton_autotune(
configs=_add_embeddings_bwd_configs(),
key=["AUTOTUNE_MAX_SEQ_LEN", "AUTOTUNE_B", "D"],
)
@triton.jit
def _add_embeddings_bwd_kernel(
In,
KeyInds,
ValueInds,
Out,
AUTOTUNE_MAX_SEQ_LEN,
B,
AUTOTUNE_B,
D,
jagged_size,
stride_in,
stride_on,
BLOCK_D: tl.constexpr,
BLOCK: tl.constexpr,
):
off_block = tl.program_id(0)
offs_d = tl.arange(0, BLOCK_D)
mask_d = offs_d < D
key_ind = -1
key_ind = key_ind.to(KeyInds.dtype.element_ty) # pyre-ignore[16]
accumulator = tl.zeros((BLOCK_D,), dtype=In.dtype.element_ty)
for off_i in range(0, BLOCK):
off = off_block * BLOCK + off_i
if off < jagged_size:
value_ind = tl.load(ValueInds + off)
in_offset = In + value_ind.to(tl.int64) * stride_in
jagged_in = tl.load(in_offset + offs_d, mask=mask_d)
key_ind_new = tl.load(KeyInds + off)
if key_ind == key_ind_new:
accumulator += jagged_in
else:
if key_ind >= 0:
out_offset = Out + key_ind.to(tl.int64) * stride_on
tl.atomic_add(
out_offset + offs_d,
accumulator.to(Out.dtype.element_ty),
mask=mask_d,
sem="relaxed",
)
key_ind = key_ind_new
accumulator = jagged_in
if key_ind >= 0:
out_offset = Out + key_ind.to(tl.int64) * stride_on
tl.atomic_add(
out_offset + offs_d,
accumulator.to(Out.dtype.element_ty),
mask=mask_d,
sem="relaxed",
)
class _AddTimestampPositionEmbeddingsFunction(torch.autograd.Function):
@staticmethod
# pyre-ignore[14]
def forward(
ctx,
seq_embeddings: torch.Tensor,
seq_offsets: torch.Tensor,
pos_embeddings: torch.Tensor,
ts_embeddings: torch.Tensor,
timestamps: torch.Tensor,
max_seq_len: int,
max_contextual_seq_len: int,
seq_lengths: torch.Tensor,
num_targets: Optional[torch.Tensor],
interleave_targets: bool,
time_bucket_fn: str,
):
seq_embeddings = switch_to_contiguous_if_needed(seq_embeddings)
pos_embeddings = switch_to_contiguous_if_needed(pos_embeddings)
ts_embeddings = switch_to_contiguous_if_needed(ts_embeddings)
max_pos_ind = pos_embeddings.shape[0]
B = seq_lengths.shape[0]
N, D = seq_embeddings.shape
assert len(pos_embeddings.shape) == 2
assert len(ts_embeddings.shape) == 2
assert pos_embeddings.shape[1] == D, (
"shape[1] of pos_embeddings much match seq_embeddings"
)
assert ts_embeddings.shape[1] == D, (
"shape[1] of ts_embeddings much match seq_embeddings"
)
out = torch.empty_like(seq_embeddings)
timestamps = switch_to_contiguous_if_needed(timestamps)
ts_inds = torch.empty_like(seq_embeddings[:, 0], dtype=torch.int32)
pos_inds = torch.empty_like(seq_embeddings[:, 0], dtype=torch.int32)
grid = lambda meta: ( # noqa E731
B,
triton.cdiv(max_seq_len, meta["BLOCK_N"]),
)
BLOCK_D = triton.next_power_of_2(D) if D < 64 else 64
_add_timestamp_position_embeddings_kernel[grid](
SeqEmb=seq_embeddings,
Offsets=seq_offsets,
Lengths=seq_lengths,
PosEmb=pos_embeddings,
TsEmb=ts_embeddings,
Out=out,
TS=timestamps,
PosInds=pos_inds,
TsInds=ts_inds,
NumTargets=num_targets,
AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(max_seq_len),
D=D,
num_time_buckets=2048,
time_bucket_increments=60.0,
time_bucket_scale=1.0,
time_delta=0,
max_contextual_seq_len=max_contextual_seq_len,
max_pos_ind=max_pos_ind,
stride_sn=seq_embeddings.stride(0),
stride_pn=pos_embeddings.stride(0),
stride_tn=ts_embeddings.stride(0),
stride_on=out.stride(0),
TRAINING=True,
HAS_MULTIPLE_TARGETS=num_targets is not None,
INTERLEAVE_TARGETS=interleave_targets,
TIME_BUCKET_FN=time_bucket_fn,
BLOCK_D=BLOCK_D,
)
try:
values = torch.arange(0, N, dtype=torch.int32, device=timestamps.device)
sorted_ts_key_inds, sorted_ts_value_inds = torch.ops.hammer.sort_kv_pairs(
ts_inds, values
)
sorted_pos_key_inds, sorted_pos_value_inds = torch.ops.hammer.sort_kv_pairs(
pos_inds, values
)
except Exception:
sorted_ts_key_inds, sorted_ts_value_inds = torch.sort(ts_inds)
sorted_pos_key_inds, sorted_pos_value_inds = torch.sort(pos_inds)
ctx.save_for_backward(
sorted_pos_key_inds,
sorted_pos_value_inds,
sorted_ts_key_inds,
sorted_ts_value_inds,
)
ctx.B = B
ctx.D = D
ctx.max_seq_len = max_seq_len
ctx.pos_emb_size = pos_embeddings.shape[0]
ctx.ts_emb_size = ts_embeddings.shape[0]
ctx.pos_dtype = pos_embeddings.dtype
ctx.ts_dtype = ts_embeddings.dtype
return out
@staticmethod
# pyre-ignore[14]
def backward(
ctx, d_out: torch.Tensor
) -> Tuple[
torch.Tensor,
None,
torch.Tensor,
torch.Tensor,
None,
None,
None,
None,
None,
None,
None,
]:
(
sorted_pos_key_inds,
sorted_pos_value_inds,
sorted_ts_key_inds,
sorted_ts_value_inds,
) = ctx.saved_tensors
d_pos_embeddings = torch.empty(
(ctx.pos_emb_size, ctx.D), device=d_out.device, dtype=torch.float32
)
d_ts_embeddings = torch.empty(
(ctx.ts_emb_size, ctx.D), device=d_out.device, dtype=torch.float32
)
grid = lambda meta: (triton.cdiv(d_out.shape[0], meta["BLOCK"]),) # noqa E731
B = ctx.B
AUTOTUNE_B = prev_power_of_2(B)
_add_embeddings_bwd_kernel[grid](
In=d_out,
KeyInds=sorted_pos_key_inds,
ValueInds=sorted_pos_value_inds,
Out=d_pos_embeddings,
AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(ctx.max_seq_len),
B=B,
AUTOTUNE_B=AUTOTUNE_B,
D=ctx.D,
jagged_size=d_out.shape[0],
stride_in=d_out.stride(0),
stride_on=d_pos_embeddings.stride(0),
BLOCK_D=triton.next_power_of_2(ctx.D),
)
_add_embeddings_bwd_kernel[grid](
In=d_out,
KeyInds=sorted_ts_key_inds,
ValueInds=sorted_ts_value_inds,
Out=d_ts_embeddings,
AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(ctx.max_seq_len),
B=B,
AUTOTUNE_B=AUTOTUNE_B,
D=ctx.D,
jagged_size=d_out.shape[0],
stride_in=d_out.stride(0),
stride_on=d_ts_embeddings.stride(0),
BLOCK_D=triton.next_power_of_2(ctx.D),
)
return (
d_out,
None,
d_pos_embeddings.to(ctx.pos_dtype),
d_ts_embeddings.to(ctx.ts_dtype),
None,
None,
None,
None,
None,
None,
None,
)
@torch.fx.wrap
def triton_add_position_embeddings(
jagged: torch.Tensor,
jagged_offsets: torch.Tensor,
high_inds: torch.Tensor,
max_seq_len: int,
dense: torch.Tensor,
scale: float = 1.0,
) -> torch.Tensor:
return _AddPositionEmbeddingsFunction.apply(
jagged, jagged_offsets, high_inds, max_seq_len, dense, scale
)
@torch.fx.wrap
def triton_add_timestamp_positional_embeddings(
seq_embeddings: torch.Tensor,
seq_offsets: torch.Tensor,
pos_embeddings: torch.Tensor,
ts_embeddings: torch.Tensor,
timestamps: torch.Tensor,
max_seq_len: int,
max_contextual_seq_len: int,
seq_lengths: torch.Tensor,
num_targets: Optional[torch.Tensor],
interleave_targets: bool,
time_bucket_fn: str,
) -> torch.Tensor:
return _AddTimestampPositionEmbeddingsFunction.apply(
seq_embeddings,
seq_offsets,
pos_embeddings,
ts_embeddings,
timestamps,
max_seq_len,
max_contextual_seq_len,
seq_lengths,
num_targets,
interleave_targets,
time_bucket_fn,
)