megatron_patch/lm_evaluate.py (139 lines of code) (raw):
# Copyright (c) 2023 Alibaba PAI Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import types
import numpy as np
from typing import List, Optional, Union
import torch
import torch.nn.functional as F
import transformers
from tqdm import tqdm
from megatron import get_args
from megatron.checkpointing import load_checkpoint
from megatron.core.enums import ModelType
from megatron.core import mpu
from megatron.core.tensor_parallel.mappings import gather_from_tensor_model_parallel_region
from megatron.core.pipeline_parallel.p2p_communication import recv_forward
from megatron.core.pipeline_parallel.p2p_communication import send_forward
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
from megatron.arguments import core_transformer_config_from_args
from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.models.huggingface import HFLM, eval_logger
from megatron_patch.training import get_model
from megatron_patch.tokenizer import build_tokenizer, get_tokenizer
class EvalHarnessAdaptor(HFLM):
def __init__(
self,
pretrained: Optional[Union[str, transformers.PreTrainedModel]] = "gpt2",
max_length: Optional[int] = None,
batch_size: Optional[Union[int, str]] = 1,
trust_remote_code: Optional[bool] = False,
**kwargs,
) -> None:
self.args = get_args()
build_tokenizer(self.args)
self.tokenizer = get_tokenizer()
self.is_main = torch.distributed.get_rank() == 0
self.adaptive_seq_len = self.args.adaptive_seq_len
self.model_provider = kwargs['model_provider']
super().__init__(pretrained=pretrained,
batch_size=batch_size,
trust_remote_code=trust_remote_code,
max_length=max_length,
tokenizer=self.tokenizer)
def _create_model(
self,
pretrained: str,
**kwargs,
) -> None:
model_list = get_model(self.model_provider,
model_type=ModelType.encoder_or_decoder,
wrap_with_ddp=False)
if pretrained is not None:
load_checkpoint(model_list, None, None)
self._model = model_list[0]
def tie_weights(self):
pass
self._model.tie_weights = types.MethodType(tie_weights, self._model)
return None
def create_model_inputs(self, tokens):
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
self.eot_token_id,
self.args.reset_position_ids,
self.args.reset_attention_mask,
self.args.eod_mask_loss)
return (tokens, position_ids, attention_mask), (tokens, loss_mask)
def _model_call(self, inps, attn_mask=None, labels=None):
args = get_args()
# Since the shape of the micro-batch will change
# We need set the correct shapes here
# So that latter pipeline stages knows which shapes to expect.
# Otherwise we will deadlock.
args.micro_batch_size = len(inps)
args.seq_length = len(inps[0])
config = core_transformer_config_from_args(args)
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
input_tensor = recv_forward(tensor_shape, config)
# Forward pass through the model.
unwrapped_model = unwrap_model(self.model)
unwrapped_model.set_input_tensor(input_tensor)
output = self.model(*self.create_model_inputs(inps)[0])
send_forward(output, config)
if mpu.is_pipeline_last_stage():
return gather_from_tensor_model_parallel_region(output)[..., :self.tokenizer.vocab_size]
else:
return None
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
# TODO: Implement caching once we've confirmed the perplexity implementation
# TODO: automatic batch size detection for vectorization
loglikelihoods = []
with torch.no_grad():
for string, in tqdm(requests):
rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows(
token_list=self.tokenizer_encode(string),
prefix_token=self.eot_token_id,
max_seq_len=self.max_length,
context_len=1,
)))
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for that
string_nll = self._loglikelihood_tokens(rolling_token_windows, disable_tqdm=True)
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
disable_tqdm = disable_tqdm if self.is_main else True
res = []
res_len = 0 # storing the result length for later
self.model.eval()
with torch.no_grad():
def _collate(x):
toks = x[1] + x[2]
return (-len(toks), tuple(toks))
reord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size):
inps, contlens, inplens, padding_length = [], [], [], None
for _, context_enc, continuation_enc in chunk:
# when too long to fit in context, truncate from the left
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length + 1):][:-1]
, dtype=torch.long).to(self.device)
inplen, = inp.shape
cont = continuation_enc
# since in _collate we make sure length is descending, the longest is always the first one.
padding_length = padding_length if padding_length is not None else inplen
if not self.adaptive_seq_len:
padding_length = self.max_length
# pad to length
inp = torch.cat([
inp, # [seq]
torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq]
], dim=0)
inps.append(inp.unsqueeze(0))
contlens.append(cont)
inplens.append(inplen)
logits = self._model_call(torch.cat(inps, dim=0))
res_len += len(chunk)
if logits is not None:
multi_logits = F.log_softmax(logits, dim=-1).cpu() # [batch, seq, vocab]
for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(chunk, multi_logits, inps, inplens, contlens):
contlen = len(cont_toks)
logits = logits[inplen - contlen:inplen].unsqueeze(0) # [1, seq, vocab]
greedy_tokens = logits.argmax(dim=-1)
# cont_toks :: [1, seq]
cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0)
max_equal = (greedy_tokens == cont_toks).all()
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq]
answer = (float(logits.sum()), bool(max_equal))
# partial caching
res.append(answer)
if not mpu.is_pipeline_last_stage():
# @HACK: To make the eval harness happy on threads that don't have access to the results.
# We just randomly generate some data.
res = [(np.random.rand(), np.random.rand()>0.5) for _ in requests]
return reord.get_original(res)