def __call__()

in src/nanotron/generation/sampler.py [0:0]


    def __call__(self, sharded_logits: torch.Tensor) -> torch.Tensor:
        batch_size, vocab_per_shard = sharded_logits.shape

        # Find local top-k logits and their indices
        local_top_k_values, local_top_k_indices = torch.topk(sharded_logits, self.k, dim=-1)

        # Add offset to the indices
        local_top_k_indices = local_top_k_indices + (dist.get_rank(self.pg) * vocab_per_shard)

        # Split local_top_k_values into a list of tensors along batch
        # We have: [min_shard_batch_size + 1] * nb_shard_containing_extra_one + [min_shard_batch_size] * (self.pg.size() - nb_shard_containing_extra_one)
        min_shard_batch_size = batch_size // self.pg.size()
        nb_shard_containing_extra_one = batch_size % self.pg.size()
        in_split = tuple(
            min_shard_batch_size + 1 if rank < nb_shard_containing_extra_one else min_shard_batch_size
            for rank in range(self.pg.size())
        )

        # out_split should be all equal to be able to concat at last dimension
        out_split = (in_split[dist.get_rank(self.pg)],) * self.pg.size()
        total_out_size = in_split[dist.get_rank(self.pg)] * self.pg.size()

        # The last shard could be smaller than shard_batch_size
        local_top_k_values_in = list(torch.split(local_top_k_values, in_split, dim=0))
        local_tok_k_indices_in = list(torch.split(local_top_k_indices, in_split, dim=0))
        # Prepare tensors for all-to-all operation
        # Gather top-k logits and their indices from all shards, tp_rank first
        top_k_values_out_mat = torch.empty(
            (total_out_size,) + local_top_k_values.shape[1:],
            dtype=local_top_k_values.dtype,
            device=local_top_k_values.device,
        )
        top_k_indices_out_mat = torch.empty(
            (total_out_size,) + local_top_k_indices.shape[1:],
            dtype=local_top_k_indices.dtype,
            device=local_top_k_indices.device,
        )
        local_top_k_values_out = list(torch.split(top_k_values_out_mat, out_split, dim=0))
        local_top_k_indices_out = list(torch.split(top_k_indices_out_mat, out_split, dim=0))

        dist.all_to_all(local_top_k_values_out, local_top_k_values_in, group=self.pg)
        dist.all_to_all(local_top_k_indices_out, local_tok_k_indices_in, group=self.pg)

        # Concat assumes that the primary dimension is the same across all shards
        sharded_local_top_k_values = torch.cat(local_top_k_values_out, dim=-1)  # [sharded_batch_size, k * num_shards]
        sharded_local_top_k_indices = torch.cat(
            local_top_k_indices_out, dim=-1
        )  # [sharded_batch_size, k * num_shards]

        # Select global top-k from the gathered top-k, now the top-k is across all vocab, batch_size is sharded
        sharded_top_k_values, sharded_top_k_indices = torch.topk(
            sharded_local_top_k_values, self.k, dim=-1
        )  # [sharded_batch_size, k]

        # Select corresponding indices from the gathered indices
        sharded_top_k_indices = sharded_local_top_k_indices.gather(
            -1, sharded_top_k_indices
        )  # [sharded_batch_size, k]

        # Apply temperature and compute softmax probabilities
        probs = torch.softmax(sharded_top_k_values.to(dtype=torch.float) / self.temperature, dim=-1)

        # Sample from the probabilities
        sampled_indices = torch.multinomial(probs, num_samples=1)  # [sharded_batch_size]

        # Select the corresponding token index from the global top-k indices
        new_decoder_input_ids = sharded_top_k_indices.gather(-1, sampled_indices)  # [sharded_batch_size]

        # All gather the new decoder input ids along batch dimension
        gathered_new_decoder_input_ids = all_gather_batches(new_decoder_input_ids, in_split, group=self.pg)

        return gathered_new_decoder_input_ids