tzrec/ops/pytorch/pt_position.py (117 lines of code) (raw):

# Copyright (c) 2025, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # We use the position ecnoder ops from generative-recommenders a starting point. # https://github.com/facebookresearch/generative-recommenders # thanks to their public work. from typing import Optional import torch from tzrec.utils.fx_util import fx_arange, fx_unwrap_optional_tensor torch.fx.wrap(fx_arange) torch.fx.wrap(fx_unwrap_optional_tensor) def pytorch_add_position_embeddings( jagged: torch.Tensor, jagged_offsets: torch.Tensor, high_inds: torch.Tensor, max_seq_len: int, dense: torch.Tensor, scale: float = 1.0, ) -> torch.Tensor: jagged = jagged * scale B = high_inds.size(0) col_indices = fx_arange(max_seq_len, device=high_inds.device).expand(B, max_seq_len) col_indices = torch.clamp(col_indices, max=high_inds.view(-1, 1)) dense_values = torch.index_select(dense, 0, col_indices.reshape(-1)).view( B, max_seq_len, -1 ) return torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output( jagged, [jagged_offsets], dense_values, )[0] @torch.fx.wrap def _get_col_indices( max_seq_len: int, max_contextual_seq_len: int, max_pos_ind: int, seq_lengths: torch.Tensor, num_targets: Optional[torch.Tensor], interleave_targets: bool, ) -> torch.Tensor: B = seq_lengths.size(0) col_indices = torch.arange(max_seq_len, device=seq_lengths.device).expand( B, max_seq_len ) if num_targets is not None: if interleave_targets: high_inds = seq_lengths - fx_unwrap_optional_tensor(num_targets) * 2 else: high_inds = seq_lengths - fx_unwrap_optional_tensor(num_targets) col_indices = torch.clamp(col_indices, max=high_inds.view(-1, 1)) col_indices = high_inds.view(-1, 1) - col_indices else: col_indices = seq_lengths.view(-1, 1) - col_indices col_indices = col_indices + max_contextual_seq_len col_indices = torch.clamp(col_indices, max=max_pos_ind - 1) if max_contextual_seq_len > 0: col_indices[:, :max_contextual_seq_len] = torch.arange( 0, max_contextual_seq_len, device=col_indices.device, dtype=col_indices.dtype, ).view(1, -1) return col_indices def pytorch_add_timestamp_positional_embeddings( seq_embeddings: torch.Tensor, seq_offsets: torch.Tensor, pos_embeddings: torch.Tensor, ts_embeddings: torch.Tensor, timestamps: torch.Tensor, max_seq_len: int, max_contextual_seq_len: int, seq_lengths: torch.Tensor, num_targets: Optional[torch.Tensor], interleave_targets: bool, time_bucket_fn: str, ) -> torch.Tensor: max_pos_ind = pos_embeddings.size(0) # position encoding pos_inds = _get_col_indices( max_seq_len=max_seq_len, max_contextual_seq_len=max_contextual_seq_len, max_pos_ind=max_pos_ind, seq_lengths=seq_lengths, num_targets=num_targets, interleave_targets=interleave_targets, ) B, _ = pos_inds.shape # timestamp encoding num_time_buckets = 2048 time_bucket_increments = 60.0 time_bucket_divisor = 1.0 time_delta = 0 timestamps = torch.ops.fbgemm.jagged_to_padded_dense( values=timestamps.unsqueeze(-1), offsets=[seq_offsets], max_lengths=[max_seq_len], padding_value=0.0, ).squeeze(-1) query_time = torch.gather( timestamps, dim=1, index=(seq_lengths - 1).unsqueeze(1).clamp(min=0) ) ts = query_time - timestamps ts = ts + time_delta ts = ts.clamp(min=1e-6) / time_bucket_increments if time_bucket_fn == "log": ts = torch.log(ts) else: ts = torch.sqrt(ts) ts = (ts / time_bucket_divisor).clamp(min=0).int() ts = torch.clamp( ts, min=0, max=num_time_buckets, ) position_embeddings = torch.index_select( pos_embeddings, 0, pos_inds.reshape(-1) ).view(B, max_seq_len, -1) time_embeddings = torch.index_select(ts_embeddings, 0, ts.reshape(-1)).view( B, max_seq_len, -1 ) return torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output( seq_embeddings, [seq_offsets], (time_embeddings + position_embeddings).to(seq_embeddings.dtype), )[0]