optimum/neuron/models/inference/backend/modules/autobucketing.py (69 lines of code) (raw):

# coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Adapted from https://github.com/aws-neuron/neuronx-distributed-inference/blob/9993358ce052fd7a1bb4a7497a6318aac36ed95c/src/neuronx_distributed_inference/modules/autobucketing.py from math import log2 from typing import List import torch def generate_buckets(min_length: int, max_length: int): if min_length == max_length: return [max_length] min_bound = int(log2(min_length)) max_bound = round(log2(max_length)) # we use round because it creates optimal bucket spacing # NOTE: because range operates on [a,b), and we rounded the log2 result # we won't get 2**i results close to the max_length. # ex. we won't see bucket spacing of [128,256,512,513] or [128,256,510,512] buckets = [2**i for i in range(min_bound, max_bound)] + [max_length] return buckets @torch.jit.script def generation_model_bk( tensors: List[torch.Tensor], buckets: torch.Tensor, padding_side: str, speculation_length: int ): """ The Bucket Kernel for Token Generation Models. 1) tensors: A list of torch tensors after running through the flattener 2) buckets: A torch.tensor of the bucket sizes 3) padding_side: A string specifying padding side, must be "left" or "right" """ # assume tensors[1] is either pos id or attention mask (seq dim == 1 => pos id) item = tensors[1] attention_mask_is_removed = item.shape[1] == 1 # indicates item is position Id if attention_mask_is_removed: position_ids = tensors[1] max_position_id = ( position_ids[:, -1] + speculation_length if (position_ids[:, -1] + speculation_length).all() <= buckets[-1] else position_ids[:, -1] ) bucket_mask = (buckets <= (max_position_id).unsqueeze(1)).to(torch.int) bucket_idx = torch.max(torch.argmin(bucket_mask, dim=1)) else: attention_mask = tensors[1] position_ids = tensors[2] max_position_id = ( position_ids[:, -1] + speculation_length if (position_ids[:, -1] + speculation_length).all() <= buckets[-1] else position_ids[:, -1] ) bucket_mask = (buckets <= (max_position_id).unsqueeze(1)).to(torch.int) bucket_idx = torch.max(torch.argmin(bucket_mask, dim=1)) bucket = buckets[bucket_idx] # slice the attention mask based on the selected bucket size if padding_side == "right": tensors[1] = torch.ops.aten.slice(attention_mask, dim=1, start=0, end=bucket) else: tensors[1] = torch.ops.aten.slice(attention_mask, dim=1, start=buckets[-1] - bucket, end=buckets[-1]) return tensors, bucket_idx.to(torch.int) def get_generation_model_bk(): return generation_model_bk @torch.jit.script def context_encoder_bk(tensors: List[torch.Tensor], buckets, padding_side: str, pad_token: int): """ The Bucket Kernel for Context Encoding Models. 1) tensors: A list of torch tensors after running through the flattener 2) buckets: A torch.tensor of the bucket sizes 3) padding_side: A string specifying padding side, must be "left" or "right" 4) pad_token: An integer representing the pad token id. Typically this is 0. """ input_ids = tensors[0] # -----Remarks for calculating position_idx----- # finds the number of non pad tokens and that is the active sequence_length # The resulting tensor is of shape (batch_size,) # # NOTE: We derive position_ids from input_ids because # position_ids is eliminated from the flattener for context encoding models. # ---------------------------------------------- position_idx = (input_ids != pad_token).sum(dim=1) position_idx = position_idx[:, None] # shape (batch_size, 1) buckets = buckets[None, :] # shape (1, seq_len) # -----Remarks for choosing the bucket_idx----- # 1. (buckets < position_idx) produces a bucket_mask where invalid buckets are 0 # 2. We convert the boolean tensor to int because argmin doesn't support # boolean tensors # 3. We choose the minimum valid bucket, which is the first 1 value # 4. From the minimum valid buckets, we choose the largest bucket, otherwise # we'd be truncating generated tokens from longer sequences. # 5. DO NOT USE argmax since we monkeypatch it, # causing issues with torch.jit.script # --------------------------------------------- bucket_mask = (buckets < position_idx).to(torch.int) # shape (batch_size, seq_len) bucket_idx = torch.max(torch.argmin(bucket_mask, dim=1)) # select the chosen bucket after squeezing back to original form bucket = buckets.squeeze(0)[bucket_idx] new_tensors = [] # ---------Remarks on handling padding sides------- # 1. slice from the opposite side for padding # 2. Identify seq_id tensors by shape and don't slice it # ------------------------------------------------- if padding_side == "right": for i, tens in enumerate(tensors): # identifies the seq_ids, which don't need to be sliced if len(tens.shape) == 1: new_tensors.append(tens) else: # all other tensors are of shape (batch_size,seq_len) so we slice on seq_len new_tensors.append(torch.ops.aten.slice(tens, dim=1, start=0, end=bucket)) else: max_idx = buckets[-1][-1] for i, tens in enumerate(tensors): # identifies the seq_ids, which don't need to be sliced if len(tens.shape) == 1: new_tensors.append(tens) else: new_tensors.append(torch.ops.aten.slice(tens, dim=1, start=max_idx - bucket, end=max_idx)) return new_tensors, bucket_idx.to(torch.int) def get_context_encoder_bk(): return context_encoder_bk