tzrec/ops/utils.py (41 lines of code) (raw):

# Copyright (c) 2025, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy from typing import List import torch import triton def switch_to_contiguous_if_needed(x: torch.Tensor) -> torch.Tensor: if not torch.jit.is_scripting() and torch.compiler.is_compiling(): # Tell Dynamo this data-dependent value is in the range (0, 10**9) torch._check(x.size(0) > 0) torch._check(x.size(0) < 10**9) if x.stride(-1) == 1: return x return x.contiguous() @torch.fx.wrap def prev_power_of_2(x: int) -> int: if torch.compiler.is_compiling(): # Re-write to make Dynamo happy x_tensor = torch.scalar_tensor(x, dtype=torch.int64) # type: ignore[arg-type] x_tensor_orig = x_tensor.clone() out = triton.next_power_of_2(x_tensor) # type: ignore[arg-type] return int(torch.where(torch.lt(x_tensor_orig, out), out // 2, out).item()) # type: ignore[return-value] else: out = triton.next_power_of_2(x) return out // 2 if out > x else out STATIC_MAX_SEQ_LENS: List[int] = [] USE_RUNTIME_MAX_SEQ_LEN: bool = False def set_static_max_seq_lens(max_seq_lens: List[int]) -> None: global STATIC_MAX_SEQ_LENS STATIC_MAX_SEQ_LENS = copy.deepcopy(max_seq_lens) STATIC_MAX_SEQ_LENS.sort() def set_use_runtime_max_seq_len(use_runtime_max_seq_len: bool) -> None: global USE_RUNTIME_MAX_SEQ_LEN USE_RUNTIME_MAX_SEQ_LEN = use_runtime_max_seq_len def autotune_max_seq_len(runtime_max_seq_len: int) -> int: global USE_RUNTIME_MAX_SEQ_LEN if USE_RUNTIME_MAX_SEQ_LEN: return prev_power_of_2(runtime_max_seq_len) else: if STATIC_MAX_SEQ_LENS == []: return 1 for max_len in STATIC_MAX_SEQ_LENS: if max_len >= runtime_max_seq_len: return max_len return STATIC_MAX_SEQ_LENS[-1]