chatlearn/models/vllm/hooks/vllm_0_3_0/sampler.py (17 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hooks of vllm-0.3.0 sampler to allgather logits of all ranks."""
import inspect
# pylint: disable=unused-import,wildcard-import
from vllm.model_executor.layers import sampler
source = inspect.getsource(sampler.Sampler._get_logits)
if 'tensor_model_parallel_all_gather' not in source:
import torch
from typing import Dict, List, Optional, Tuple
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
from vllm.model_executor.parallel_utils.communication_op import tensor_model_parallel_all_gather # pylint: disable=import-outside-toplevel
logits = tensor_model_parallel_all_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
return logits
sampler.Sampler._get_logits = _get_logits