optimum/neuron/models/inference/backend/modules/generation/sampling.py (220 lines of code) (raw):
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/aws-neuron/neuronx-distributed-inference/blob/9993358ce052fd7a1bb4a7497a6318aac36ed95c/src/neuronx_distributed_inference/modules/generation/sampling.py
import logging
from typing import Optional, Union
import torch
from neuronx_distributed.operators.argmax import argmax as nxd_argmax
from neuronx_distributed.operators.topk import topk as nxd_topk
from neuronx_distributed.parallel_layers import parallel_state
from neuronx_distributed.utils.utils import hardware
from neuronxcc.nki._private_kernels.cumsum import cumsum as nki_cumsum
from torch_neuronx.utils import get_platform_target
from torch_neuronx.xla_impl.ops import nki_jit, xla_hlo_call
from ...config import NxDNeuronConfig
logger = logging.getLogger("Neuron")
def mask_padded_logits(logits, rank_id, world_size, pad_size=None):
if pad_size is None or pad_size == 0:
return logits
# invalid if rank_id == tp_degree - 1
last_rank_mask = torch.eq(
torch.full(logits.shape, world_size - 1, device=logits.device, dtype=torch.int32),
rank_id.broadcast_to(logits.shape),
)
# and index >= logits.shape[-1] - pad
on_pad_mask = torch.ge(
torch.arange(logits.shape[-1], device=logits.device, dtype=torch.int32).broadcast_to(logits.shape),
torch.full(logits.shape, logits.shape[-1] - pad_size, device=logits.device, dtype=torch.int32),
)
invalid_mask = last_rank_mask * on_pad_mask
logits = torch.where(invalid_mask, torch.full_like(logits, torch.finfo(logits.dtype).min), logits)
return logits
def cumsum(tensor_in, dim, on_cpu: bool = False):
if on_cpu:
logger.debug("On CPU, using torch cumsum")
return torch.cumsum(tensor_in, dim=dim)
init_shape_len = len(tensor_in.shape)
cumsum_dim = dim % init_shape_len
last_dim = init_shape_len - 1
is_transposed = False
if cumsum_dim != last_dim:
tensor_in = torch.transpose(tensor_in, cumsum_dim, last_dim)
is_transposed = True
init_shape = tensor_in.shape
cumsum_len = init_shape[last_dim]
# Prioritize nki kernel for float dtype, then matmul cumsum if not input is not float
if torch.is_floating_point(tensor_in):
logger.debug("Using NKI cumsum")
tensor_in = tensor_in.view(-1, cumsum_len)
nki_cumsum_func = nki_jit()(nki_cumsum)
output = torch.zeros_like(tensor_in, device=tensor_in.device, dtype=tensor_in.dtype)
nki_cumsum_func(tensor_in, output, axis=1)
output = output.view(init_shape)
if is_transposed:
output = torch.transpose(output, cumsum_dim, last_dim)
return output
else:
logger.debug("Using matmul cumsum")
triu = torch.triu(
torch.ones(
cumsum_len,
cumsum_len,
dtype=tensor_in.dtype,
device=tensor_in.device,
)
)
output = tensor_in @ triu
if is_transposed:
output = torch.transpose(output, cumsum_dim, last_dim)
return output
@xla_hlo_call
def rand_like(tensor):
dtype = tensor.dtype
shape = tensor.sizes
minimum = dtype.Constant(constant_value=0)
maximum = dtype.Constant(constant_value=1)
return dtype[shape].Rng(minimum, maximum, distribution=1) # Uniform distribution
def validate_sampling_params(params: torch.Tensor, max_topk: int) -> None:
"""
Validates sampling parameters for language models.
Args:
params (torch.Tensor): Tensor of shape (batch_size, 3) containing sampling parameters
in the order: top-k, top-p, temperature.
max_topk (int): The maximum number of top tokens to sample from.
Raises:
ValueError: If any of the parameters are invalid.
"""
if params.shape[1] != 3:
raise ValueError(f"Expected tensor of shape (batch_size, 3), but got {params.shape}")
# autocast params tensor to float32
params = params.to(torch.float32)
# Unpack parameters
top_k, top_p, temperature = params[:, 0], params[:, 1], params[:, 2]
# Validate top-k value range
valid_top_k = (top_k == -1) | ((top_k > 0) & (top_k <= max_topk))
if not torch.all(valid_top_k):
raise ValueError(
f"Invalid top-k values found. top-k must be -1 or greater than 0 but less than or equal to {max_topk}. Found {top_k}."
)
# checks if top-k values can be represented as integers
if not torch.equal(top_k, top_k.floor()):
raise ValueError(
f"Invalid top-k values found. top-k values should be able to be represented as integer values, but found decimal parts. Found {top_k=}."
)
# Validate top-p
valid_top_p = (top_p > 0.0) & (top_p <= 1.0)
if not torch.all(valid_top_p):
raise ValueError(f"Invalid top-p values found. top-p must be in the range (0.0, 1.0]. Found {top_p=}.")
# Validate temperature
valid_temp = temperature > 0.0
if not torch.all(valid_temp):
raise ValueError(
f"Invalid temperature values found. Temperature must be strictly greater than 0.0. Found {temperature=}."
)
def prepare_sampling_params(batch_size, top_k=[1], top_p=[1.0], temperature=[1.0]):
top_k = prepare_tensor(top_k)
top_p = prepare_tensor(top_p)
temperature = prepare_tensor(temperature)
assert top_k.shape[0] == top_p.shape[0] == temperature.shape[0], (
f"sampling params shapes don't match. \
Got top_k shape: {top_k.shape}, top_p shape: {top_p.shape}, temperature shape: {temperature.shape}"
)
if top_k.shape[0] == 1:
top_k = top_k.broadcast_to(batch_size)
top_p = top_p.broadcast_to(batch_size)
temperature = temperature.broadcast_to(batch_size)
stacked = torch.stack([top_k, top_p, temperature], dim=1)
return stacked
def prepare_tensor(val: Union[torch.Tensor, list, float]):
if not torch.is_tensor(val):
if not isinstance(val, list):
val = [val]
val = torch.tensor(val)
return val
class Sampler(torch.nn.Module):
"""Add sampling code to the model graph.
The sampling method is set when compiling the model, and cannot be changed at runtime.
If the model was compiled for multinomial sampling, it is still possible
to perform greedy sampling by passing top_k=1 and top_p=1.0.
On the other hand, if the model was compiled for greedy sampling, it is not
possible to perform multinomial sampling at runtime.
For that reason, multinomial sampling is the default sampling method.
Args:
do_sample(`Optional[bool]`): whether to use sampling or not. If False, argmax sampling is used, whatever sampling parameters are passed at runtime.
max_topk(`Optional[int]`): the maximum number of top tokens to sample from. It is used to optimize calculations
by performing a single topk operation on all logits in a batch then apply a mask by sequence instead of
applying top_k on each sequence in the batch individually. Defaults to 0, which means no optimization.
on_cpu(`Optional[bool]`): whether to run on CPU or not
"""
def __init__(
self, neuron_config: NxDNeuronConfig, do_sample: Optional[bool] = True, on_cpu: Optional[bool] = False
):
super().__init__()
if not do_sample:
logger.warning("Greedy sampling is used. Sampling parameters will be ignored at runtime.")
self.neuron_config = neuron_config
self.do_sample = do_sample
if self.neuron_config.max_topk < 0:
logger.warning("max_topk optimization is disabled: this can lead to extremely long compilation times.")
self.IGNORED_LOGITS_VALUE = -3000 # large negative values will be transformed to ~0 in softmax, this is to ignore tokens that are beyond topk range
self.on_cpu = on_cpu
if on_cpu:
self.process_group = None
else:
self.process_group = parallel_state.get_tensor_model_parallel_group()
def _soft_max(self, logits, dim):
return torch.nn.functional.softmax(input=logits, dim=dim)
def _get_top_k_num_stages(self):
hardware_type = hardware(get_platform_target())
if (
hardware_type == hardware.TRN2
and self.neuron_config.tp_degree == self.neuron_config.world_size == 64
and self.neuron_config.logical_nc_config == 2
):
return 3
elif hardware_type == hardware.TRN1 and self.neuron_config.tp_degree == self.neuron_config.world_size == 32:
return 2
else:
return 1
def _top_k_masked(self, logits, top_k, dim, rank_id):
if self.neuron_config.max_topk > 0:
if self.on_cpu:
sorted_logits, indeces = torch.topk(input=logits, k=self.neuron_config.max_topk, dim=dim)
else:
sorted_logits, indeces = nxd_topk(
tensor=logits,
k=self.neuron_config.max_topk,
dim=dim,
gather_dim=dim,
process_group=self.process_group,
stages=self._get_top_k_num_stages(),
rank_id=rank_id,
)
else:
sorted_logits, indeces = torch.sort(input=logits, dim=dim, descending=True)
vocab_size = sorted_logits.shape[-1]
mask = torch.arange(vocab_size, device=logits.device)
mask = mask.broadcast_to(*sorted_logits.shape)
mask = torch.greater_equal(mask, top_k)
sorted_logits = sorted_logits.masked_fill_(mask, self.IGNORED_LOGITS_VALUE)
return sorted_logits, indeces
def _top_p(self, top_k_logits_values, probs_cumsum, top_p, dim):
top_p_mask = torch.greater(probs_cumsum, top_p)
top_k_logits_values = top_k_logits_values.masked_fill_(top_p_mask, self.IGNORED_LOGITS_VALUE)
probs_soft_max = self._soft_max(top_k_logits_values, dim) # custom call
probs_cumsum = cumsum(tensor_in=probs_soft_max, dim=dim, on_cpu=self.on_cpu)
return probs_cumsum
def _rand_selector(self, probs_cumsum, num_samples=1):
zeros = torch.zeros((probs_cumsum.shape[0], num_samples), device=probs_cumsum.device, dtype=probs_cumsum.dtype)
return torch.rand_like(zeros) if self.on_cpu else rand_like(zeros)
def _multinomial(self, probs, dim, num_samples=1):
probs_cumsum = cumsum(tensor_in=probs, dim=dim, on_cpu=self.on_cpu)
rand_selector = self._rand_selector(probs_cumsum, num_samples)
greater_than_rand = torch.greater(rand_selector, probs_cumsum)
counts = torch.sum(greater_than_rand, dim=dim).unsqueeze(dim)
return counts
def _argmax_sample(self, token_logits, return_values, dim):
if self.on_cpu:
return torch.argmax(token_logits, dim=dim)
else:
# distributed argmax
tokens = nxd_argmax(
tensor=token_logits,
dim=dim,
gather_dim=dim,
keepdim=False,
process_group=self.process_group,
)
values = torch.ones(tokens.shape, dtype=token_logits.dtype, device=tokens.device)
if return_values:
return tokens, values
return tokens
def _multinomial_sample(self, token_logits, sampling_params, return_values, dim, rank_id):
batch_size = token_logits.shape[0]
top_k = sampling_params[:, 0].reshape(batch_size, 1)
top_p = sampling_params[:, 1].reshape(batch_size, 1)
temperature = sampling_params[:, 2].reshape(batch_size, 1)
# Apply top_k first
top_k_logits_values, top_k_logits_indices = self._top_k_masked(token_logits, top_k, dim, rank_id)
# Apply temperature
top_k_logits_values = torch.divide(top_k_logits_values, temperature)
# Apply top_p
probs_soft_max = self._soft_max(top_k_logits_values, dim)
probs_cumsum = cumsum(tensor_in=probs_soft_max, dim=dim, on_cpu=self.on_cpu)
top_p = torch.max(torch.min(probs_cumsum), top_p)
top_p_mask = torch.greater(probs_cumsum, top_p).index_fill_(
dim, torch.tensor([0], device=top_p.device), False
) # need to keep at least one token
top_k_logits_values = top_k_logits_values.masked_fill_(top_p_mask, self.IGNORED_LOGITS_VALUE)
probs_soft_max = self._soft_max(top_k_logits_values, dim) # custom call
if return_values:
return top_k_logits_indices, probs_soft_max
counts = self._multinomial(probs_soft_max, dim)
return torch.gather(input=top_k_logits_indices, dim=dim, index=counts).flatten()
def forward(self, token_logits, sampling_params, return_values=False, rank_id=None):
"""
forward to perform topk, topp, temperature and multinomial sampling.
This method is only used when compiling the model, which means that the
decision to use multinomial sampling cannot be made at runtime.
If the model was compiled for multinomial sampling, it is still possible
to perform greedy sampling by passing top_k=1 and top_p=1.0.
On the other hand, if the model was compiled for greedy sampling, it is not
possible to perform multinomial sampling at runtime.
For that reason, multinomial sampling is the default sampling method.
Inputs:
token_logits: tensor whose first dimension is Batch Size
and whose final dimension is Vocabulary Size
sampling_params: a 2D tensor of size (Batch Size, 3)
containing the following sampling params:
* top_k: value to use for top_k sampling
* top_p: value to use for top_p sampling
* temperature: value to use for temperature sampling
Output:
Tensor containing 1 sampled token id per batch size.
Output size is (1, Batch Size)
Note: Using torch.multinomial on device causes trace to hang.
This is because torch.multinomial performs a number of distribution
validation steps, which is content dependent. Hence we implement multinomial
distribution here instead.
"""
dim = len(token_logits.shape) - 1 # vocab_size dimension
if self.do_sample:
return self._multinomial_sample(token_logits, sampling_params, return_values, dim, rank_id)
else:
return self._argmax_sample(token_logits, return_values, dim)