src/nanotron/generation/sampler.py [38:67]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    def __call__(self, sharded_logits: torch.Tensor) -> torch.Tensor:
        batch_size, vocab_per_shard = sharded_logits.shape

        # Split max_values/max_indices into a list of tensors along batch
        # We have: [min_shard_batch_size + 1] * nb_shard_containing_extra_one + [min_shard_batch_size] * (self.pg.size() - nb_shard_containing_extra_one)
        min_shard_batch_size = batch_size // self.pg.size()
        nb_shard_containing_extra_one = batch_size % self.pg.size()
        in_split = tuple(
            min_shard_batch_size + 1 if rank < nb_shard_containing_extra_one else min_shard_batch_size
            for rank in range(self.pg.size())
        )

        # out_split should be all equal to be able to concat at last dimension
        out_split = (in_split[dist.get_rank(self.pg)],) * self.pg.size()
        total_out_size = in_split[dist.get_rank(self.pg)] * self.pg.size()

        # Prepare tensors for all-to-all operation
        # Gather logits from all vocab shards but shard on batch, tp_rank first
        sharded_logits_out = torch.empty(
            (total_out_size, vocab_per_shard),
            dtype=sharded_logits.dtype,
            device=sharded_logits.device,
        )  # [pg_size * sharded_batch_size, vocab_per_shard]

        local_sharded_logits_in = list(torch.split(sharded_logits, in_split, dim=0))
        local_sharded_logits_out = list(torch.split(sharded_logits_out, out_split, dim=0))

        dist.all_to_all(local_sharded_logits_out, local_sharded_logits_in, group=self.pg)

        logits = torch.cat(local_sharded_logits_out, dim=-1)  # [sharded_batch_size, vocab_size]
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



src/nanotron/generation/sampler.py [245:276]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    def __call__(self, sharded_logits: torch.Tensor) -> torch.Tensor:
        # We will cross batch and vocab shards to sample from the full vocab and a part of the batch
        # (right now logits are sharded on vocab and batch, so we need to do all-to-all)
        batch_size, vocab_per_shard = sharded_logits.shape

        # Split max_values/max_indices into a list of tensors along batch
        # We have: [min_shard_batch_size + 1] * nb_shard_containing_extra_one + [min_shard_batch_size] * (self.pg.size() - nb_shard_containing_extra_one)
        min_shard_batch_size = batch_size // self.pg.size()
        nb_shard_containing_extra_one = batch_size % self.pg.size()
        in_split = tuple(
            min_shard_batch_size + 1 if rank < nb_shard_containing_extra_one else min_shard_batch_size
            for rank in range(self.pg.size())
        )

        # out_split should be all equal to be able to concat at last dimension
        out_split = (in_split[dist.get_rank(self.pg)],) * self.pg.size()
        total_out_size = in_split[dist.get_rank(self.pg)] * self.pg.size()

        # Prepare tensors for all-to-all operation
        # Gather logits from all vocab shards but shard on batch, tp_rank first
        sharded_logits_out = torch.empty(
            (total_out_size, vocab_per_shard),
            dtype=sharded_logits.dtype,
            device=sharded_logits.device,
        )  # [pg_size * sharded_batch_size, vocab_per_shard]

        local_sharded_logits_in = list(torch.split(sharded_logits, in_split, dim=0))
        local_sharded_logits_out = list(torch.split(sharded_logits_out, out_split, dim=0))

        dist.all_to_all(local_sharded_logits_out, local_sharded_logits_in, group=self.pg)

        logits = torch.cat(local_sharded_logits_out, dim=-1)  # [sharded_batch_size, vocab_size]
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



