def _get_indices()

in arctic_inference/vllm/spec_dec/vocab_parallel_embedding.py [0:0]


    def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
                     vocab_size: int, org_vocab_size: int, tp_rank: int,
                     tp_size: int) -> VocabParallelEmbeddingShardIndices:
        """Get start and end indices for vocab parallel embedding, following the
        layout outlined in the class docstring, based on the given tp_rank and
        tp_size."""
        num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded
        padded_org_vocab_start_index, padded_org_vocab_end_index = (
            vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank,
                                               tp_size))
        padded_added_vocab_start_index, padded_added_vocab_end_index = (
            vocab_range_from_global_vocab_size(num_added_embeddings_padded,
                                               tp_rank,
                                               tp_size,
                                               offset=org_vocab_size))
        # remove padding
        org_vocab_start_index = min(padded_org_vocab_start_index,
                                    org_vocab_size)
        org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
        added_vocab_start_index = min(padded_added_vocab_start_index,
                                      vocab_size)
        added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
        return VocabParallelEmbeddingShardIndices(
            padded_org_vocab_start_index, padded_org_vocab_end_index,
            padded_added_vocab_start_index, padded_added_vocab_end_index,
            org_vocab_start_index, org_vocab_end_index,
            added_vocab_start_index, added_vocab_end_index)