in src/nanotron/generation/sampler.py [0:0]
def __call__(self, sharded_logits: torch.Tensor) -> torch.Tensor:
batch_size, vocab_per_shard = sharded_logits.shape
# Find local top-k logits and their indices
local_top_k_values, local_top_k_indices = torch.topk(sharded_logits, self.k, dim=-1)
# Add offset to the indices
local_top_k_indices = local_top_k_indices + (dist.get_rank(self.pg) * vocab_per_shard)
# Split local_top_k_values into a list of tensors along batch
# We have: [min_shard_batch_size + 1] * nb_shard_containing_extra_one + [min_shard_batch_size] * (self.pg.size() - nb_shard_containing_extra_one)
min_shard_batch_size = batch_size // self.pg.size()
nb_shard_containing_extra_one = batch_size % self.pg.size()
in_split = tuple(
min_shard_batch_size + 1 if rank < nb_shard_containing_extra_one else min_shard_batch_size
for rank in range(self.pg.size())
)
# out_split should be all equal to be able to concat at last dimension
out_split = (in_split[dist.get_rank(self.pg)],) * self.pg.size()
total_out_size = in_split[dist.get_rank(self.pg)] * self.pg.size()
# The last shard could be smaller than shard_batch_size
local_top_k_values_in = list(torch.split(local_top_k_values, in_split, dim=0))
local_tok_k_indices_in = list(torch.split(local_top_k_indices, in_split, dim=0))
# Prepare tensors for all-to-all operation
# Gather top-k logits and their indices from all shards, tp_rank first
top_k_values_out_mat = torch.empty(
(total_out_size,) + local_top_k_values.shape[1:],
dtype=local_top_k_values.dtype,
device=local_top_k_values.device,
)
top_k_indices_out_mat = torch.empty(
(total_out_size,) + local_top_k_indices.shape[1:],
dtype=local_top_k_indices.dtype,
device=local_top_k_indices.device,
)
local_top_k_values_out = list(torch.split(top_k_values_out_mat, out_split, dim=0))
local_top_k_indices_out = list(torch.split(top_k_indices_out_mat, out_split, dim=0))
dist.all_to_all(local_top_k_values_out, local_top_k_values_in, group=self.pg)
dist.all_to_all(local_top_k_indices_out, local_tok_k_indices_in, group=self.pg)
# Concat assumes that the primary dimension is the same across all shards
sharded_local_top_k_values = torch.cat(local_top_k_values_out, dim=-1) # [sharded_batch_size, k * num_shards]
sharded_local_top_k_indices = torch.cat(
local_top_k_indices_out, dim=-1
) # [sharded_batch_size, k * num_shards]
# Select global top-k from the gathered top-k, now the top-k is across all vocab, batch_size is sharded
sharded_top_k_values, sharded_top_k_indices = torch.topk(
sharded_local_top_k_values, self.k, dim=-1
) # [sharded_batch_size, k]
# Select corresponding indices from the gathered indices
sharded_top_k_indices = sharded_local_top_k_indices.gather(
-1, sharded_top_k_indices
) # [sharded_batch_size, k]
# Apply temperature and compute softmax probabilities
probs = torch.softmax(sharded_top_k_values.to(dtype=torch.float) / self.temperature, dim=-1)
# Sample from the probabilities
sampled_indices = torch.multinomial(probs, num_samples=1) # [sharded_batch_size]
# Select the corresponding token index from the global top-k indices
new_decoder_input_ids = sharded_top_k_indices.gather(-1, sampled_indices) # [sharded_batch_size]
# All gather the new decoder input ids along batch dimension
gathered_new_decoder_input_ids = all_gather_batches(new_decoder_input_ids, in_split, group=self.pg)
return gathered_new_decoder_input_ids