tzrec/ops/jagged_tensors.py (187 lines of code) (raw):

# Copyright (c) 2025, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # We use the jagged_tensors ops from generative-recommenders a starting point. # https://github.com/facebookresearch/generative-recommenders # thanks to their public work. from typing import Optional, Tuple import torch from torch.fx._symbolic_trace import is_fx_tracing from tzrec.ops import Kernel from tzrec.ops.pytorch.pt_jagged_tensors import ( pytorch_concat_2D_jagged, pytorch_hstu_concat_l2_embeddings, pytorch_hstu_split_l2_embeddings, pytorch_jagged_dense_bmm_broadcast_add, pytorch_split_2D_jagged, ) from tzrec.ops.triton.triton_jagged_tensors import ( triton_concat_2D_jagged, triton_jagged_dense_bmm_broadcast_add, triton_split_2D_jagged, ) torch.fx.wrap("triton_concat_2D_jagged") torch.fx.wrap("triton_split_2D_jagged") def concat_2D_jagged( values_left: torch.Tensor, values_right: torch.Tensor, max_len_left: int, max_len_right: int, offsets_left: Optional[torch.Tensor], offsets_right: Optional[torch.Tensor], kernel: Kernel = Kernel.PYTORCH, ) -> torch.Tensor: if not is_fx_tracing(): torch._assert(values_left.dim() == 2, "values_left must be 2D") torch._assert(values_right.dim() == 2, "values_right must be 2D") torch._assert( values_right.shape[1] == values_left.shape[1], f"values_left shape[1] must be equal to values_right shape[1] {values_left.shape[1]} vs {values_right.shape[1]}", # NOQA ) if kernel == Kernel.TRITON: return triton_concat_2D_jagged( values_left=values_left, values_right=values_right, max_len_left=max_len_left, max_len_right=max_len_right, offsets_left=offsets_left, offsets_right=offsets_right, ) else: return pytorch_concat_2D_jagged( values_left=values_left, values_right=values_right, max_len_left=max_len_left, max_len_right=max_len_right, offsets_left=offsets_left, offsets_right=offsets_right, ) def split_2D_jagged( max_seq_len: int, values: torch.Tensor, max_len_left: Optional[int] = None, max_len_right: Optional[int] = None, offsets_left: Optional[torch.Tensor] = None, offsets_right: Optional[torch.Tensor] = None, kernel: Kernel = Kernel.PYTORCH, ) -> Tuple[torch.Tensor, torch.Tensor]: if not is_fx_tracing(): torch._assert(values.dim() == 2, "values must be 2D") torch._assert( offsets_left is not None or offsets_right is not None, "offsets_left and offsets_right cannot be None at the same time", ) if offsets_left is None: torch._assert( max_len_left is not None, "max_len_left must be provided when offsets_left is None", ) if offsets_right is None: torch._assert( max_len_right is not None, "max_len_right must be provided when offsets_right is None", ) if offsets_left is not None and offsets_right is not None: torch._assert( offsets_left.shape[0] == offsets_right.shape[0], "offsets_left shape[0] must be equal to offsets_right shape[0]", ) if kernel == Kernel.TRITON: return triton_split_2D_jagged( max_seq_len=max_seq_len, values=values, max_len_left=max_len_left, max_len_right=max_len_right, offsets_left=offsets_left, offsets_right=offsets_right, ) else: return pytorch_split_2D_jagged( max_seq_len=max_seq_len, values=values, max_len_left=max_len_left, max_len_right=max_len_right, offsets_left=offsets_left, offsets_right=offsets_right, ) def hstu_split_l2_embeddings( max_seq_len: int, x: torch.Tensor, minus_l2_offsets: torch.Tensor, l2_offsets: torch.Tensor, contextual_seq_len: int, kernel: Kernel = Kernel.PYTORCH, ) -> Tuple[torch.Tensor, torch.Tensor]: if kernel == Kernel.TRITON: return triton_split_2D_jagged( max_seq_len=max_seq_len, values=x, max_len_left=None, max_len_right=None, offsets_left=minus_l2_offsets, offsets_right=l2_offsets, n_prefix_to_right=contextual_seq_len, ) else: return pytorch_hstu_split_l2_embeddings( max_seq_len=max_seq_len, x=x, minus_l2_offsets=minus_l2_offsets, l2_offsets=l2_offsets, contextual_seq_len=contextual_seq_len, ) def hstu_concat_l2_embeddings( max_minus_l2_len: int, minus_l2_x: torch.Tensor, minus_l2_offsets: torch.Tensor, max_l2_len: int, l2_x: torch.Tensor, l2_offsets: torch.Tensor, contextual_seq_len: int, kernel: Kernel = Kernel.PYTORCH, ) -> torch.Tensor: if kernel == Kernel.TRITON: return triton_concat_2D_jagged( values_left=minus_l2_x, values_right=l2_x, max_len_left=max_minus_l2_len, max_len_right=max_l2_len, offsets_left=minus_l2_offsets, offsets_right=l2_offsets, n_prefix_from_right=contextual_seq_len, ) else: return pytorch_hstu_concat_l2_embeddings( contextual_seq_len=contextual_seq_len, max_minus_l2_len=max_minus_l2_len, minus_l2_x=minus_l2_x, minus_l2_offsets=minus_l2_offsets, max_l2_len=max_l2_len, l2_x=l2_x, l2_offsets=l2_offsets, ) def jagged_dense_bmm_broadcast_add( max_seq_len: int, seq_offsets: torch.Tensor, jagged: torch.Tensor, dense: torch.Tensor, bias: torch.Tensor, kernel: Kernel = Kernel.PYTORCH, ) -> torch.Tensor: """Computing out = jagged x dense + bias. jagged has shape (sum_B(M_i), K), dense has shape (B, K, N), and bias has shape (B, N), out has shape (sum_B(M_i), N) """ if not is_fx_tracing(): _, K = jagged.shape B, _, N = dense.shape torch._assert(dense.shape[1] == K, "wrong dense shape[1]") torch._assert(seq_offsets.shape[0] == B + 1, "wrong seq_offsets shape[0]") torch._assert(bias.shape[0] == B, "wrong bias shape[0]") torch._assert(bias.shape[1] == N, "wrong bias shape[1]") if kernel == Kernel.TRITON: return triton_jagged_dense_bmm_broadcast_add( max_seq_len=max_seq_len, seq_offsets=seq_offsets, jagged=jagged, dense=dense, bias=bias, ) else: return pytorch_jagged_dense_bmm_broadcast_add( max_seq_len=max_seq_len, seq_offsets=seq_offsets, jagged=jagged, dense=dense, bias=bias, )