tzrec/ops/pytorch/pt_jagged_tensors.py (214 lines of code) (raw):
# Copyright (c) 2025, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# We use the jagged_tensors ops from generative-recommenders a starting point.
# https://github.com/facebookresearch/generative-recommenders
# thanks to their public work.
from typing import Optional, Tuple
import torch
from tzrec.utils.fx_util import fx_arange
torch.fx.wrap(fx_arange)
def _concat_2D_jagged_jagged(
values_left: torch.Tensor,
values_right: torch.Tensor,
max_len_left: int,
max_len_right: int,
offsets_left: torch.Tensor,
offsets_right: torch.Tensor,
) -> torch.Tensor:
max_seq_len = max_len_left + max_len_right
lengths_left = offsets_left[1:] - offsets_left[:-1]
lengths_right = offsets_right[1:] - offsets_right[:-1]
padded_left = torch.ops.fbgemm.jagged_to_padded_dense(
values=values_left,
offsets=[offsets_left],
max_lengths=[max_len_left],
padding_value=0.0,
)
padded_right = torch.ops.fbgemm.jagged_to_padded_dense(
values=values_right,
offsets=[offsets_right],
max_lengths=[max_len_right],
padding_value=0.0,
)
concatted_dense = torch.cat([padded_left, padded_right], dim=1)
mask = fx_arange(max_seq_len, device=offsets_left.device).view(1, -1)
mask = torch.logical_or(
mask < lengths_left.view(-1, 1),
torch.logical_and(
mask >= max_len_left,
mask < max_len_left + lengths_right.view(-1, 1),
),
)
return concatted_dense.flatten(0, 1)[mask.view(-1), :]
@torch.fx.wrap
def pytorch_concat_2D_jagged(
values_left: torch.Tensor,
values_right: torch.Tensor,
max_len_left: int,
max_len_right: int,
offsets_left: Optional[torch.Tensor],
offsets_right: Optional[torch.Tensor],
) -> torch.Tensor:
if offsets_left is None:
B = values_left.shape[0] // max_len_left
offsets_left_non_optional = max_len_left * torch.arange(
B + 1, device=values_left.device
)
else:
offsets_left_non_optional = offsets_left
if offsets_right is None:
B = values_right.shape[0] // max_len_right
offsets_right_non_optional = max_len_right * torch.arange(
B + 1, device=values_left.device
)
else:
offsets_right_non_optional = offsets_right
return _concat_2D_jagged_jagged(
values_left=values_left,
values_right=values_right,
max_len_left=max_len_left,
max_len_right=max_len_right,
offsets_left=offsets_left_non_optional,
offsets_right=offsets_right_non_optional,
)
def _split_2D_jagged_jagged(
max_seq_len: int,
values: torch.Tensor,
offsets_left: torch.Tensor,
offsets_right: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
offsets = offsets_left + offsets_right
padded_values = torch.ops.fbgemm.jagged_to_padded_dense(
values=values,
offsets=[offsets],
max_lengths=[max_seq_len],
padding_value=0.0,
).flatten(0, 1)
lengths_left = offsets_left[1:] - offsets_left[:-1]
lengths_right = offsets_right[1:] - offsets_right[:-1]
mask = fx_arange(max_seq_len, device=values.device).view(1, -1)
mask_left = mask < lengths_left.view(-1, 1)
mask_right = torch.logical_and(
mask >= lengths_left.view(-1, 1),
mask < (lengths_left + lengths_right).view(-1, 1),
)
return padded_values[mask_left.view(-1), :], padded_values[mask_right.view(-1), :]
@torch.fx.wrap
def pytorch_split_2D_jagged(
max_seq_len: int,
values: torch.Tensor,
max_len_left: Optional[int],
max_len_right: Optional[int],
offsets_left: Optional[torch.Tensor],
offsets_right: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if offsets_left is None:
assert max_len_left is not None
assert offsets_right is not None
offsets_left_non_optional = max_len_left * torch.arange(
offsets_right.shape[0], device=values.device
)
else:
offsets_left_non_optional = offsets_left
if offsets_right is None:
assert max_len_right is not None
assert offsets_left is not None
offsets_right_non_optional = max_len_right * torch.arange(
offsets_left.shape[0], device=values.device
)
else:
offsets_right_non_optional = offsets_right
return _split_2D_jagged_jagged(
max_seq_len=max_seq_len,
values=values,
offsets_left=offsets_left_non_optional,
offsets_right=offsets_right_non_optional,
)
def pytorch_hstu_split_l2_embeddings(
max_seq_len: int,
x: torch.Tensor,
minus_l2_offsets: torch.Tensor,
l2_offsets: torch.Tensor,
contextual_seq_len: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
x_offsets = minus_l2_offsets + l2_offsets
x_lengths = x_offsets[1:] - x_offsets[:-1]
padded_x = torch.ops.fbgemm.jagged_to_padded_dense(
values=x,
offsets=[x_offsets],
max_lengths=[max_seq_len],
padding_value=0.0,
).flatten(0, 1)
minus_l2_lengths = minus_l2_offsets[1:] - minus_l2_offsets[:-1]
mask = fx_arange(max_seq_len, device=x_offsets.device).view(1, -1)
mask_minus_l2 = torch.logical_and(
mask >= contextual_seq_len,
mask < minus_l2_lengths.view(-1, 1) + contextual_seq_len,
)
mask_l2 = torch.logical_or(
mask < contextual_seq_len,
torch.logical_and(
mask >= minus_l2_lengths.view(-1, 1) + contextual_seq_len,
mask < x_lengths.view(-1, 1),
),
)
return padded_x[mask_minus_l2.view(-1), :], padded_x[mask_l2.view(-1), :]
def pytorch_hstu_concat_l2_embeddings(
max_minus_l2_len: int,
minus_l2_x: torch.Tensor,
minus_l2_offsets: torch.Tensor,
max_l2_len: int,
l2_x: torch.Tensor,
l2_offsets: torch.Tensor,
contextual_seq_len: int,
) -> torch.Tensor:
padded_minus_l2_x = torch.ops.fbgemm.jagged_to_padded_dense(
values=minus_l2_x,
offsets=[minus_l2_offsets],
max_lengths=[max_minus_l2_len],
padding_value=0.0,
)
padded_l2_x = torch.ops.fbgemm.jagged_to_padded_dense(
values=l2_x,
offsets=[l2_offsets],
max_lengths=[max_l2_len],
padding_value=0.0,
)
padded_x = torch.cat(
[
padded_l2_x[:, 0:contextual_seq_len, :],
padded_minus_l2_x,
padded_l2_x[:, contextual_seq_len:, :],
],
dim=1,
)
mask = fx_arange(max_minus_l2_len + max_l2_len, device=minus_l2_x.device).view(
1, -1
)
minus_l2_lengths = minus_l2_offsets[1:] - minus_l2_offsets[:-1]
l2_lengths = l2_offsets[1:] - l2_offsets[:-1]
mask = torch.logical_or(
mask < minus_l2_lengths.view(-1, 1) + contextual_seq_len,
torch.logical_and(
mask >= max_minus_l2_len + contextual_seq_len,
mask < max_minus_l2_len + l2_lengths.view(-1, 1),
),
)
return padded_x.flatten(0, 1)[mask.view(-1), :]
def pytorch_jagged_dense_bmm_broadcast_add(
max_seq_len: int,
seq_offsets: torch.Tensor,
jagged: torch.Tensor,
dense: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
dtype = jagged.dtype
jagged = jagged.to(torch.float32)
dense = dense.to(torch.float32)
padded_jagged = torch.ops.fbgemm.jagged_to_padded_dense(
values=jagged,
offsets=[seq_offsets],
max_lengths=[max_seq_len],
padding_value=0.0,
)
bmm_out = torch.bmm(padded_jagged, dense)
jagged_out = torch.ops.fbgemm.dense_to_jagged(
bmm_out + bias.unsqueeze(1), [seq_offsets], total_L=jagged.shape[0]
)[0]
jagged_out = jagged_out.to(dtype)
return jagged_out