tzrec/ops/jagged_tensors.py (187 lines of code) (raw):
# Copyright (c) 2025, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# We use the jagged_tensors ops from generative-recommenders a starting point.
# https://github.com/facebookresearch/generative-recommenders
# thanks to their public work.
from typing import Optional, Tuple
import torch
from torch.fx._symbolic_trace import is_fx_tracing
from tzrec.ops import Kernel
from tzrec.ops.pytorch.pt_jagged_tensors import (
pytorch_concat_2D_jagged,
pytorch_hstu_concat_l2_embeddings,
pytorch_hstu_split_l2_embeddings,
pytorch_jagged_dense_bmm_broadcast_add,
pytorch_split_2D_jagged,
)
from tzrec.ops.triton.triton_jagged_tensors import (
triton_concat_2D_jagged,
triton_jagged_dense_bmm_broadcast_add,
triton_split_2D_jagged,
)
torch.fx.wrap("triton_concat_2D_jagged")
torch.fx.wrap("triton_split_2D_jagged")
def concat_2D_jagged(
values_left: torch.Tensor,
values_right: torch.Tensor,
max_len_left: int,
max_len_right: int,
offsets_left: Optional[torch.Tensor],
offsets_right: Optional[torch.Tensor],
kernel: Kernel = Kernel.PYTORCH,
) -> torch.Tensor:
if not is_fx_tracing():
torch._assert(values_left.dim() == 2, "values_left must be 2D")
torch._assert(values_right.dim() == 2, "values_right must be 2D")
torch._assert(
values_right.shape[1] == values_left.shape[1],
f"values_left shape[1] must be equal to values_right shape[1] {values_left.shape[1]} vs {values_right.shape[1]}", # NOQA
)
if kernel == Kernel.TRITON:
return triton_concat_2D_jagged(
values_left=values_left,
values_right=values_right,
max_len_left=max_len_left,
max_len_right=max_len_right,
offsets_left=offsets_left,
offsets_right=offsets_right,
)
else:
return pytorch_concat_2D_jagged(
values_left=values_left,
values_right=values_right,
max_len_left=max_len_left,
max_len_right=max_len_right,
offsets_left=offsets_left,
offsets_right=offsets_right,
)
def split_2D_jagged(
max_seq_len: int,
values: torch.Tensor,
max_len_left: Optional[int] = None,
max_len_right: Optional[int] = None,
offsets_left: Optional[torch.Tensor] = None,
offsets_right: Optional[torch.Tensor] = None,
kernel: Kernel = Kernel.PYTORCH,
) -> Tuple[torch.Tensor, torch.Tensor]:
if not is_fx_tracing():
torch._assert(values.dim() == 2, "values must be 2D")
torch._assert(
offsets_left is not None or offsets_right is not None,
"offsets_left and offsets_right cannot be None at the same time",
)
if offsets_left is None:
torch._assert(
max_len_left is not None,
"max_len_left must be provided when offsets_left is None",
)
if offsets_right is None:
torch._assert(
max_len_right is not None,
"max_len_right must be provided when offsets_right is None",
)
if offsets_left is not None and offsets_right is not None:
torch._assert(
offsets_left.shape[0] == offsets_right.shape[0],
"offsets_left shape[0] must be equal to offsets_right shape[0]",
)
if kernel == Kernel.TRITON:
return triton_split_2D_jagged(
max_seq_len=max_seq_len,
values=values,
max_len_left=max_len_left,
max_len_right=max_len_right,
offsets_left=offsets_left,
offsets_right=offsets_right,
)
else:
return pytorch_split_2D_jagged(
max_seq_len=max_seq_len,
values=values,
max_len_left=max_len_left,
max_len_right=max_len_right,
offsets_left=offsets_left,
offsets_right=offsets_right,
)
def hstu_split_l2_embeddings(
max_seq_len: int,
x: torch.Tensor,
minus_l2_offsets: torch.Tensor,
l2_offsets: torch.Tensor,
contextual_seq_len: int,
kernel: Kernel = Kernel.PYTORCH,
) -> Tuple[torch.Tensor, torch.Tensor]:
if kernel == Kernel.TRITON:
return triton_split_2D_jagged(
max_seq_len=max_seq_len,
values=x,
max_len_left=None,
max_len_right=None,
offsets_left=minus_l2_offsets,
offsets_right=l2_offsets,
n_prefix_to_right=contextual_seq_len,
)
else:
return pytorch_hstu_split_l2_embeddings(
max_seq_len=max_seq_len,
x=x,
minus_l2_offsets=minus_l2_offsets,
l2_offsets=l2_offsets,
contextual_seq_len=contextual_seq_len,
)
def hstu_concat_l2_embeddings(
max_minus_l2_len: int,
minus_l2_x: torch.Tensor,
minus_l2_offsets: torch.Tensor,
max_l2_len: int,
l2_x: torch.Tensor,
l2_offsets: torch.Tensor,
contextual_seq_len: int,
kernel: Kernel = Kernel.PYTORCH,
) -> torch.Tensor:
if kernel == Kernel.TRITON:
return triton_concat_2D_jagged(
values_left=minus_l2_x,
values_right=l2_x,
max_len_left=max_minus_l2_len,
max_len_right=max_l2_len,
offsets_left=minus_l2_offsets,
offsets_right=l2_offsets,
n_prefix_from_right=contextual_seq_len,
)
else:
return pytorch_hstu_concat_l2_embeddings(
contextual_seq_len=contextual_seq_len,
max_minus_l2_len=max_minus_l2_len,
minus_l2_x=minus_l2_x,
minus_l2_offsets=minus_l2_offsets,
max_l2_len=max_l2_len,
l2_x=l2_x,
l2_offsets=l2_offsets,
)
def jagged_dense_bmm_broadcast_add(
max_seq_len: int,
seq_offsets: torch.Tensor,
jagged: torch.Tensor,
dense: torch.Tensor,
bias: torch.Tensor,
kernel: Kernel = Kernel.PYTORCH,
) -> torch.Tensor:
"""Computing out = jagged x dense + bias.
jagged has shape (sum_B(M_i), K), dense has shape (B, K, N), and bias has
shape (B, N), out has shape (sum_B(M_i), N)
"""
if not is_fx_tracing():
_, K = jagged.shape
B, _, N = dense.shape
torch._assert(dense.shape[1] == K, "wrong dense shape[1]")
torch._assert(seq_offsets.shape[0] == B + 1, "wrong seq_offsets shape[0]")
torch._assert(bias.shape[0] == B, "wrong bias shape[0]")
torch._assert(bias.shape[1] == N, "wrong bias shape[1]")
if kernel == Kernel.TRITON:
return triton_jagged_dense_bmm_broadcast_add(
max_seq_len=max_seq_len,
seq_offsets=seq_offsets,
jagged=jagged,
dense=dense,
bias=bias,
)
else:
return pytorch_jagged_dense_bmm_broadcast_add(
max_seq_len=max_seq_len,
seq_offsets=seq_offsets,
jagged=jagged,
dense=dense,
bias=bias,
)