def get_sharded_to_full_mapping()

in arctic_inference/vllm/spec_dec/vocab_parallel_embedding.py [0:0]


    def get_sharded_to_full_mapping(self) -> Optional[List[int]]:
        """Get a mapping that can be used to reindex the gathered
        logits for sampling.
        
        During sampling, we gather logits from all ranks. The relationship
        of index->token_id will follow the same format as outlined in the class
        docstring. However, after the gather, we want to reindex the final
        logits tensor to map index->token_id one-to-one (the index is always
        equal the token_id it corresponds to). The indices returned by this
        method allow us to do that.
        """
        if self.tp_size < 2:
            return None

        base_embeddings: List[int] = []
        added_embeddings: List[int] = []
        padding: List[int] = []
        for tp_rank in range(self.tp_size):
            shard_indices = self._get_indices(self.num_embeddings_padded,
                                              self.org_vocab_size_padded,
                                              self.num_embeddings,
                                              self.org_vocab_size, tp_rank,
                                              self.tp_size)
            range_start = self.num_embeddings_per_partition * tp_rank
            range_end = self.num_embeddings_per_partition * (tp_rank + 1)
            base_embeddings.extend(
                range(range_start,
                      range_start + shard_indices.num_org_elements))
            padding.extend(
                range(range_start + shard_indices.num_org_elements,
                      range_start + shard_indices.num_org_elements_padded))
            added_embeddings.extend(
                range(
                    range_start + shard_indices.num_org_elements_padded,
                    range_start + shard_indices.num_org_elements_padded +
                    shard_indices.num_added_elements))
            padding.extend(
                range(
                    range_start + shard_indices.num_org_elements_padded +
                    shard_indices.num_added_elements,
                    range_start + shard_indices.num_org_elements_padded +
                    shard_indices.num_added_elements_padded))
            assert (range_start + shard_indices.num_org_elements_padded +
                    shard_indices.num_added_elements_padded == range_end)
        ret = base_embeddings + added_embeddings + padding
        assert len(ret) == self.num_embeddings_padded
        return ret