in lm_eval/models/nemo_lm.py [0:0]
def generate_until(self, requests):
if not requests:
return []
res = []
def get_until(req_args):
until = req_args.get("until", [])
until = deepcopy(until) # prevent from modifying req_args for cache_key
if self.tokenizer.ids_to_tokens([self.eot_token_id])[0] not in until:
until.append(self.tokenizer.ids_to_tokens([self.eot_token_id])[0])
return until
def _collate(x):
toks = self.tok_encode(x[0])
return len(toks), x[0]
re_ords = Collator(
[reg.args for reg in requests], sort_fn=_collate, group_by="gen_kwargs"
)
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk)
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
req_args = all_gen_kwargs[0]
# unpack our keyword arguments.
until = get_until(req_args)
max_gen_toks = req_args.get("max_gen_toks", self.max_gen_toks)
remaining_length = self.max_length - max_gen_toks
contexts = []
for context, _ in chunk:
encoded_context = self.tok_encode(context)
encoded_context = encoded_context[-remaining_length:]
contexts.append(self.tok_decode(encoded_context))
output = self.generate(
self.model,
inputs=contexts,
tokens_to_generate=max_gen_toks,
end_strings=until,
greedy=True,
)
answers = output["sentences"]
continuations = []
for context, answer in zip(contexts, answers):
continuations.append(answer[len(context) :])
for term in until:
continuations = [answer.split(term)[0] for answer in continuations]
for request, answer in zip(chunk, continuations):
self.cache_hook.add_partial("greedy_until", request, answer)
res.append(answer)
return re_ords.get_original(res)