tzrec/ops/position.py (112 lines of code) (raw):

# Copyright (c) 2025, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # We use the position ecnoder ops from generative-recommenders a starting point. # https://github.com/facebookresearch/generative-recommenders # thanks to their public work. from typing import Optional import torch from torch.fx._symbolic_trace import is_fx_tracing from tzrec.ops import Kernel from tzrec.ops.pytorch.pt_position import ( pytorch_add_position_embeddings, pytorch_add_timestamp_positional_embeddings, ) from tzrec.ops.triton.triton_position import ( triton_add_position_embeddings, triton_add_timestamp_positional_embeddings, ) @torch.fx.wrap def _get_high_inds( high_inds: torch.Tensor, position_embeddings_weight: torch.Tensor, num_targets: Optional[torch.Tensor], interleave_targets: bool, ) -> torch.Tensor: max_pos_ind = position_embeddings_weight.size(0) if num_targets is not None: if interleave_targets: high_inds = high_inds - num_targets * 2 else: high_inds = high_inds - num_targets high_inds = torch.clamp(high_inds, max=max_pos_ind - 1) return high_inds def add_positional_embeddings( alpha: float, max_seq_len: int, position_embeddings_weight: torch.Tensor, seq_offsets: torch.Tensor, seq_lengths: torch.Tensor, seq_embeddings: torch.Tensor, num_targets: Optional[torch.Tensor], interleave_targets: bool, kernel: Kernel = Kernel.PYTORCH, ) -> torch.Tensor: high_inds = _get_high_inds( seq_lengths, position_embeddings_weight, num_targets, interleave_targets ) if not is_fx_tracing(): _, D = seq_embeddings.shape torch._assert( seq_offsets.size(0) - 1 == high_inds.size(0), "wrong jagged_offsets shape[0]", ) _, D2 = position_embeddings_weight.shape torch._assert(D2 == D, "wrong dense shape[1]") if kernel == Kernel.TRITON: return triton_add_position_embeddings( jagged=seq_embeddings, jagged_offsets=seq_offsets, high_inds=high_inds, max_seq_len=max_seq_len, dense=position_embeddings_weight, scale=alpha, ) else: return pytorch_add_position_embeddings( jagged=seq_embeddings, jagged_offsets=seq_offsets, high_inds=high_inds, max_seq_len=max_seq_len, dense=position_embeddings_weight, scale=alpha, ) def add_timestamp_positional_embeddings( alpha: float, max_seq_len: int, max_contextual_seq_len: int, position_embeddings_weight: torch.Tensor, timestamp_embeddings_weight: torch.Tensor, seq_offsets: torch.Tensor, seq_lengths: torch.Tensor, seq_embeddings: torch.Tensor, timestamps: torch.Tensor, num_targets: Optional[torch.Tensor], interleave_targets: bool, time_bucket_fn: str = "sqrt", kernel: Kernel = Kernel.PYTORCH, ) -> torch.Tensor: assert time_bucket_fn in ["sqrt", "log"] seq_embeddings = seq_embeddings * alpha if kernel == Kernel.TRITON: return triton_add_timestamp_positional_embeddings( seq_embeddings=seq_embeddings, seq_offsets=seq_offsets, pos_embeddings=position_embeddings_weight, ts_embeddings=timestamp_embeddings_weight, timestamps=timestamps, max_seq_len=max_seq_len, max_contextual_seq_len=max_contextual_seq_len, seq_lengths=seq_lengths, num_targets=num_targets, interleave_targets=interleave_targets, time_bucket_fn=time_bucket_fn, ) else: return pytorch_add_timestamp_positional_embeddings( seq_embeddings=seq_embeddings, seq_offsets=seq_offsets, pos_embeddings=position_embeddings_weight, ts_embeddings=timestamp_embeddings_weight, timestamps=timestamps, max_seq_len=max_seq_len, max_contextual_seq_len=max_contextual_seq_len, seq_lengths=seq_lengths, num_targets=num_targets, interleave_targets=interleave_targets, time_bucket_fn=time_bucket_fn, )