tzrec/ops/utils.py (41 lines of code) (raw):
# Copyright (c) 2025, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import List
import torch
import triton
def switch_to_contiguous_if_needed(x: torch.Tensor) -> torch.Tensor:
if not torch.jit.is_scripting() and torch.compiler.is_compiling():
# Tell Dynamo this data-dependent value is in the range (0, 10**9)
torch._check(x.size(0) > 0)
torch._check(x.size(0) < 10**9)
if x.stride(-1) == 1:
return x
return x.contiguous()
@torch.fx.wrap
def prev_power_of_2(x: int) -> int:
if torch.compiler.is_compiling():
# Re-write to make Dynamo happy
x_tensor = torch.scalar_tensor(x, dtype=torch.int64) # type: ignore[arg-type]
x_tensor_orig = x_tensor.clone()
out = triton.next_power_of_2(x_tensor) # type: ignore[arg-type]
return int(torch.where(torch.lt(x_tensor_orig, out), out // 2, out).item()) # type: ignore[return-value]
else:
out = triton.next_power_of_2(x)
return out // 2 if out > x else out
STATIC_MAX_SEQ_LENS: List[int] = []
USE_RUNTIME_MAX_SEQ_LEN: bool = False
def set_static_max_seq_lens(max_seq_lens: List[int]) -> None:
global STATIC_MAX_SEQ_LENS
STATIC_MAX_SEQ_LENS = copy.deepcopy(max_seq_lens)
STATIC_MAX_SEQ_LENS.sort()
def set_use_runtime_max_seq_len(use_runtime_max_seq_len: bool) -> None:
global USE_RUNTIME_MAX_SEQ_LEN
USE_RUNTIME_MAX_SEQ_LEN = use_runtime_max_seq_len
def autotune_max_seq_len(runtime_max_seq_len: int) -> int:
global USE_RUNTIME_MAX_SEQ_LEN
if USE_RUNTIME_MAX_SEQ_LEN:
return prev_power_of_2(runtime_max_seq_len)
else:
if STATIC_MAX_SEQ_LENS == []:
return 1
for max_len in STATIC_MAX_SEQ_LENS:
if max_len >= runtime_max_seq_len:
return max_len
return STATIC_MAX_SEQ_LENS[-1]