optimum/neuron/models/inference/backend/modules/generation/sampling.py (220 lines of code) (raw):

# coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Adapted from https://github.com/aws-neuron/neuronx-distributed-inference/blob/9993358ce052fd7a1bb4a7497a6318aac36ed95c/src/neuronx_distributed_inference/modules/generation/sampling.py import logging from typing import Optional, Union import torch from neuronx_distributed.operators.argmax import argmax as nxd_argmax from neuronx_distributed.operators.topk import topk as nxd_topk from neuronx_distributed.parallel_layers import parallel_state from neuronx_distributed.utils.utils import hardware from neuronxcc.nki._private_kernels.cumsum import cumsum as nki_cumsum from torch_neuronx.utils import get_platform_target from torch_neuronx.xla_impl.ops import nki_jit, xla_hlo_call from ...config import NxDNeuronConfig logger = logging.getLogger("Neuron") def mask_padded_logits(logits, rank_id, world_size, pad_size=None): if pad_size is None or pad_size == 0: return logits # invalid if rank_id == tp_degree - 1 last_rank_mask = torch.eq( torch.full(logits.shape, world_size - 1, device=logits.device, dtype=torch.int32), rank_id.broadcast_to(logits.shape), ) # and index >= logits.shape[-1] - pad on_pad_mask = torch.ge( torch.arange(logits.shape[-1], device=logits.device, dtype=torch.int32).broadcast_to(logits.shape), torch.full(logits.shape, logits.shape[-1] - pad_size, device=logits.device, dtype=torch.int32), ) invalid_mask = last_rank_mask * on_pad_mask logits = torch.where(invalid_mask, torch.full_like(logits, torch.finfo(logits.dtype).min), logits) return logits def cumsum(tensor_in, dim, on_cpu: bool = False): if on_cpu: logger.debug("On CPU, using torch cumsum") return torch.cumsum(tensor_in, dim=dim) init_shape_len = len(tensor_in.shape) cumsum_dim = dim % init_shape_len last_dim = init_shape_len - 1 is_transposed = False if cumsum_dim != last_dim: tensor_in = torch.transpose(tensor_in, cumsum_dim, last_dim) is_transposed = True init_shape = tensor_in.shape cumsum_len = init_shape[last_dim] # Prioritize nki kernel for float dtype, then matmul cumsum if not input is not float if torch.is_floating_point(tensor_in): logger.debug("Using NKI cumsum") tensor_in = tensor_in.view(-1, cumsum_len) nki_cumsum_func = nki_jit()(nki_cumsum) output = torch.zeros_like(tensor_in, device=tensor_in.device, dtype=tensor_in.dtype) nki_cumsum_func(tensor_in, output, axis=1) output = output.view(init_shape) if is_transposed: output = torch.transpose(output, cumsum_dim, last_dim) return output else: logger.debug("Using matmul cumsum") triu = torch.triu( torch.ones( cumsum_len, cumsum_len, dtype=tensor_in.dtype, device=tensor_in.device, ) ) output = tensor_in @ triu if is_transposed: output = torch.transpose(output, cumsum_dim, last_dim) return output @xla_hlo_call def rand_like(tensor): dtype = tensor.dtype shape = tensor.sizes minimum = dtype.Constant(constant_value=0) maximum = dtype.Constant(constant_value=1) return dtype[shape].Rng(minimum, maximum, distribution=1) # Uniform distribution def validate_sampling_params(params: torch.Tensor, max_topk: int) -> None: """ Validates sampling parameters for language models. Args: params (torch.Tensor): Tensor of shape (batch_size, 3) containing sampling parameters in the order: top-k, top-p, temperature. max_topk (int): The maximum number of top tokens to sample from. Raises: ValueError: If any of the parameters are invalid. """ if params.shape[1] != 3: raise ValueError(f"Expected tensor of shape (batch_size, 3), but got {params.shape}") # autocast params tensor to float32 params = params.to(torch.float32) # Unpack parameters top_k, top_p, temperature = params[:, 0], params[:, 1], params[:, 2] # Validate top-k value range valid_top_k = (top_k == -1) | ((top_k > 0) & (top_k <= max_topk)) if not torch.all(valid_top_k): raise ValueError( f"Invalid top-k values found. top-k must be -1 or greater than 0 but less than or equal to {max_topk}. Found {top_k}." ) # checks if top-k values can be represented as integers if not torch.equal(top_k, top_k.floor()): raise ValueError( f"Invalid top-k values found. top-k values should be able to be represented as integer values, but found decimal parts. Found {top_k=}." ) # Validate top-p valid_top_p = (top_p > 0.0) & (top_p <= 1.0) if not torch.all(valid_top_p): raise ValueError(f"Invalid top-p values found. top-p must be in the range (0.0, 1.0]. Found {top_p=}.") # Validate temperature valid_temp = temperature > 0.0 if not torch.all(valid_temp): raise ValueError( f"Invalid temperature values found. Temperature must be strictly greater than 0.0. Found {temperature=}." ) def prepare_sampling_params(batch_size, top_k=[1], top_p=[1.0], temperature=[1.0]): top_k = prepare_tensor(top_k) top_p = prepare_tensor(top_p) temperature = prepare_tensor(temperature) assert top_k.shape[0] == top_p.shape[0] == temperature.shape[0], ( f"sampling params shapes don't match. \ Got top_k shape: {top_k.shape}, top_p shape: {top_p.shape}, temperature shape: {temperature.shape}" ) if top_k.shape[0] == 1: top_k = top_k.broadcast_to(batch_size) top_p = top_p.broadcast_to(batch_size) temperature = temperature.broadcast_to(batch_size) stacked = torch.stack([top_k, top_p, temperature], dim=1) return stacked def prepare_tensor(val: Union[torch.Tensor, list, float]): if not torch.is_tensor(val): if not isinstance(val, list): val = [val] val = torch.tensor(val) return val class Sampler(torch.nn.Module): """Add sampling code to the model graph. The sampling method is set when compiling the model, and cannot be changed at runtime. If the model was compiled for multinomial sampling, it is still possible to perform greedy sampling by passing top_k=1 and top_p=1.0. On the other hand, if the model was compiled for greedy sampling, it is not possible to perform multinomial sampling at runtime. For that reason, multinomial sampling is the default sampling method. Args: do_sample(`Optional[bool]`): whether to use sampling or not. If False, argmax sampling is used, whatever sampling parameters are passed at runtime. max_topk(`Optional[int]`): the maximum number of top tokens to sample from. It is used to optimize calculations by performing a single topk operation on all logits in a batch then apply a mask by sequence instead of applying top_k on each sequence in the batch individually. Defaults to 0, which means no optimization. on_cpu(`Optional[bool]`): whether to run on CPU or not """ def __init__( self, neuron_config: NxDNeuronConfig, do_sample: Optional[bool] = True, on_cpu: Optional[bool] = False ): super().__init__() if not do_sample: logger.warning("Greedy sampling is used. Sampling parameters will be ignored at runtime.") self.neuron_config = neuron_config self.do_sample = do_sample if self.neuron_config.max_topk < 0: logger.warning("max_topk optimization is disabled: this can lead to extremely long compilation times.") self.IGNORED_LOGITS_VALUE = -3000 # large negative values will be transformed to ~0 in softmax, this is to ignore tokens that are beyond topk range self.on_cpu = on_cpu if on_cpu: self.process_group = None else: self.process_group = parallel_state.get_tensor_model_parallel_group() def _soft_max(self, logits, dim): return torch.nn.functional.softmax(input=logits, dim=dim) def _get_top_k_num_stages(self): hardware_type = hardware(get_platform_target()) if ( hardware_type == hardware.TRN2 and self.neuron_config.tp_degree == self.neuron_config.world_size == 64 and self.neuron_config.logical_nc_config == 2 ): return 3 elif hardware_type == hardware.TRN1 and self.neuron_config.tp_degree == self.neuron_config.world_size == 32: return 2 else: return 1 def _top_k_masked(self, logits, top_k, dim, rank_id): if self.neuron_config.max_topk > 0: if self.on_cpu: sorted_logits, indeces = torch.topk(input=logits, k=self.neuron_config.max_topk, dim=dim) else: sorted_logits, indeces = nxd_topk( tensor=logits, k=self.neuron_config.max_topk, dim=dim, gather_dim=dim, process_group=self.process_group, stages=self._get_top_k_num_stages(), rank_id=rank_id, ) else: sorted_logits, indeces = torch.sort(input=logits, dim=dim, descending=True) vocab_size = sorted_logits.shape[-1] mask = torch.arange(vocab_size, device=logits.device) mask = mask.broadcast_to(*sorted_logits.shape) mask = torch.greater_equal(mask, top_k) sorted_logits = sorted_logits.masked_fill_(mask, self.IGNORED_LOGITS_VALUE) return sorted_logits, indeces def _top_p(self, top_k_logits_values, probs_cumsum, top_p, dim): top_p_mask = torch.greater(probs_cumsum, top_p) top_k_logits_values = top_k_logits_values.masked_fill_(top_p_mask, self.IGNORED_LOGITS_VALUE) probs_soft_max = self._soft_max(top_k_logits_values, dim) # custom call probs_cumsum = cumsum(tensor_in=probs_soft_max, dim=dim, on_cpu=self.on_cpu) return probs_cumsum def _rand_selector(self, probs_cumsum, num_samples=1): zeros = torch.zeros((probs_cumsum.shape[0], num_samples), device=probs_cumsum.device, dtype=probs_cumsum.dtype) return torch.rand_like(zeros) if self.on_cpu else rand_like(zeros) def _multinomial(self, probs, dim, num_samples=1): probs_cumsum = cumsum(tensor_in=probs, dim=dim, on_cpu=self.on_cpu) rand_selector = self._rand_selector(probs_cumsum, num_samples) greater_than_rand = torch.greater(rand_selector, probs_cumsum) counts = torch.sum(greater_than_rand, dim=dim).unsqueeze(dim) return counts def _argmax_sample(self, token_logits, return_values, dim): if self.on_cpu: return torch.argmax(token_logits, dim=dim) else: # distributed argmax tokens = nxd_argmax( tensor=token_logits, dim=dim, gather_dim=dim, keepdim=False, process_group=self.process_group, ) values = torch.ones(tokens.shape, dtype=token_logits.dtype, device=tokens.device) if return_values: return tokens, values return tokens def _multinomial_sample(self, token_logits, sampling_params, return_values, dim, rank_id): batch_size = token_logits.shape[0] top_k = sampling_params[:, 0].reshape(batch_size, 1) top_p = sampling_params[:, 1].reshape(batch_size, 1) temperature = sampling_params[:, 2].reshape(batch_size, 1) # Apply top_k first top_k_logits_values, top_k_logits_indices = self._top_k_masked(token_logits, top_k, dim, rank_id) # Apply temperature top_k_logits_values = torch.divide(top_k_logits_values, temperature) # Apply top_p probs_soft_max = self._soft_max(top_k_logits_values, dim) probs_cumsum = cumsum(tensor_in=probs_soft_max, dim=dim, on_cpu=self.on_cpu) top_p = torch.max(torch.min(probs_cumsum), top_p) top_p_mask = torch.greater(probs_cumsum, top_p).index_fill_( dim, torch.tensor([0], device=top_p.device), False ) # need to keep at least one token top_k_logits_values = top_k_logits_values.masked_fill_(top_p_mask, self.IGNORED_LOGITS_VALUE) probs_soft_max = self._soft_max(top_k_logits_values, dim) # custom call if return_values: return top_k_logits_indices, probs_soft_max counts = self._multinomial(probs_soft_max, dim) return torch.gather(input=top_k_logits_indices, dim=dim, index=counts).flatten() def forward(self, token_logits, sampling_params, return_values=False, rank_id=None): """ forward to perform topk, topp, temperature and multinomial sampling. This method is only used when compiling the model, which means that the decision to use multinomial sampling cannot be made at runtime. If the model was compiled for multinomial sampling, it is still possible to perform greedy sampling by passing top_k=1 and top_p=1.0. On the other hand, if the model was compiled for greedy sampling, it is not possible to perform multinomial sampling at runtime. For that reason, multinomial sampling is the default sampling method. Inputs: token_logits: tensor whose first dimension is Batch Size and whose final dimension is Vocabulary Size sampling_params: a 2D tensor of size (Batch Size, 3) containing the following sampling params: * top_k: value to use for top_k sampling * top_p: value to use for top_p sampling * temperature: value to use for temperature sampling Output: Tensor containing 1 sampled token id per batch size. Output size is (1, Batch Size) Note: Using torch.multinomial on device causes trace to hang. This is because torch.multinomial performs a number of distribution validation steps, which is content dependent. Hence we implement multinomial distribution here instead. """ dim = len(token_logits.shape) - 1 # vocab_size dimension if self.do_sample: return self._multinomial_sample(token_logits, sampling_params, return_values, dim, rank_id) else: return self._argmax_sample(token_logits, return_values, dim)