optimum/neuron/models/inference/backend/modules/autobucketing.py (69 lines of code) (raw):
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/aws-neuron/neuronx-distributed-inference/blob/9993358ce052fd7a1bb4a7497a6318aac36ed95c/src/neuronx_distributed_inference/modules/autobucketing.py
from math import log2
from typing import List
import torch
def generate_buckets(min_length: int, max_length: int):
if min_length == max_length:
return [max_length]
min_bound = int(log2(min_length))
max_bound = round(log2(max_length)) # we use round because it creates optimal bucket spacing
# NOTE: because range operates on [a,b), and we rounded the log2 result
# we won't get 2**i results close to the max_length.
# ex. we won't see bucket spacing of [128,256,512,513] or [128,256,510,512]
buckets = [2**i for i in range(min_bound, max_bound)] + [max_length]
return buckets
@torch.jit.script
def generation_model_bk(
tensors: List[torch.Tensor], buckets: torch.Tensor, padding_side: str, speculation_length: int
):
"""
The Bucket Kernel for Token Generation Models.
1) tensors: A list of torch tensors after running through the flattener
2) buckets: A torch.tensor of the bucket sizes
3) padding_side: A string specifying padding side, must be "left" or "right"
"""
# assume tensors[1] is either pos id or attention mask (seq dim == 1 => pos id)
item = tensors[1]
attention_mask_is_removed = item.shape[1] == 1 # indicates item is position Id
if attention_mask_is_removed:
position_ids = tensors[1]
max_position_id = (
position_ids[:, -1] + speculation_length
if (position_ids[:, -1] + speculation_length).all() <= buckets[-1]
else position_ids[:, -1]
)
bucket_mask = (buckets <= (max_position_id).unsqueeze(1)).to(torch.int)
bucket_idx = torch.max(torch.argmin(bucket_mask, dim=1))
else:
attention_mask = tensors[1]
position_ids = tensors[2]
max_position_id = (
position_ids[:, -1] + speculation_length
if (position_ids[:, -1] + speculation_length).all() <= buckets[-1]
else position_ids[:, -1]
)
bucket_mask = (buckets <= (max_position_id).unsqueeze(1)).to(torch.int)
bucket_idx = torch.max(torch.argmin(bucket_mask, dim=1))
bucket = buckets[bucket_idx]
# slice the attention mask based on the selected bucket size
if padding_side == "right":
tensors[1] = torch.ops.aten.slice(attention_mask, dim=1, start=0, end=bucket)
else:
tensors[1] = torch.ops.aten.slice(attention_mask, dim=1, start=buckets[-1] - bucket, end=buckets[-1])
return tensors, bucket_idx.to(torch.int)
def get_generation_model_bk():
return generation_model_bk
@torch.jit.script
def context_encoder_bk(tensors: List[torch.Tensor], buckets, padding_side: str, pad_token: int):
"""
The Bucket Kernel for Context Encoding Models.
1) tensors: A list of torch tensors after running through the flattener
2) buckets: A torch.tensor of the bucket sizes
3) padding_side: A string specifying padding side, must be "left" or "right"
4) pad_token: An integer representing the pad token id. Typically this is 0.
"""
input_ids = tensors[0]
# -----Remarks for calculating position_idx-----
# finds the number of non pad tokens and that is the active sequence_length
# The resulting tensor is of shape (batch_size,)
#
# NOTE: We derive position_ids from input_ids because
# position_ids is eliminated from the flattener for context encoding models.
# ----------------------------------------------
position_idx = (input_ids != pad_token).sum(dim=1)
position_idx = position_idx[:, None] # shape (batch_size, 1)
buckets = buckets[None, :] # shape (1, seq_len)
# -----Remarks for choosing the bucket_idx-----
# 1. (buckets < position_idx) produces a bucket_mask where invalid buckets are 0
# 2. We convert the boolean tensor to int because argmin doesn't support
# boolean tensors
# 3. We choose the minimum valid bucket, which is the first 1 value
# 4. From the minimum valid buckets, we choose the largest bucket, otherwise
# we'd be truncating generated tokens from longer sequences.
# 5. DO NOT USE argmax since we monkeypatch it,
# causing issues with torch.jit.script
# ---------------------------------------------
bucket_mask = (buckets < position_idx).to(torch.int) # shape (batch_size, seq_len)
bucket_idx = torch.max(torch.argmin(bucket_mask, dim=1))
# select the chosen bucket after squeezing back to original form
bucket = buckets.squeeze(0)[bucket_idx]
new_tensors = []
# ---------Remarks on handling padding sides-------
# 1. slice from the opposite side for padding
# 2. Identify seq_id tensors by shape and don't slice it
# -------------------------------------------------
if padding_side == "right":
for i, tens in enumerate(tensors):
# identifies the seq_ids, which don't need to be sliced
if len(tens.shape) == 1:
new_tensors.append(tens)
else: # all other tensors are of shape (batch_size,seq_len) so we slice on seq_len
new_tensors.append(torch.ops.aten.slice(tens, dim=1, start=0, end=bucket))
else:
max_idx = buckets[-1][-1]
for i, tens in enumerate(tensors):
# identifies the seq_ids, which don't need to be sliced
if len(tens.shape) == 1:
new_tensors.append(tens)
else:
new_tensors.append(torch.ops.aten.slice(tens, dim=1, start=max_idx - bucket, end=max_idx))
return new_tensors, bucket_idx.to(torch.int)
def get_context_encoder_bk():
return context_encoder_bk