in src/lighteval/models/abstract_model.py [0:0]
def tok_encode_pair(self, context, continuations: list[str], pairwise: bool = False):
"""Encodes a context with a list of continuations by taking care of the spaces in between.
Args:
context (str): The context string to be encoded.
continuation (list[str]): List of continuation strings to be encoded.
pairwise (bool):
If True, encode context and continuations separately.
If False, encode them together and then split.
Returns:
Tuple[TokenSequence, list[TokenSequence]]:
A tuple containing the encoded context and a list of encoded continuations.
The advantage of pairwise is:
1) It better aligns with how LLM predicts tokens
2) Works in case len(tok(context,cont)) != len(tok(context)) + len(tok(continuation)).
E.g this can happen for chinese if no space is used between context/continuation
"""
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuations = [context[-n_spaces:] + cont for cont in continuations]
context = context[:-n_spaces]
if pairwise:
# We don't add special tokens to the continuation as if bos is added
# models tend to to completely ignore a context
context_enc = self.tok_encode(context, add_special_tokens=self.add_special_tokens)
continuation_enc = [self.tok_encode(cont, add_special_tokens=False) for cont in continuations]
# In theory the context_enc can be ended with eos token, this would again
# cause the model to ignore the context. We thus strip the eos token from context_enc
if len(context_enc) > 0 and context_enc[-1] == self.tokenizer.eos_token_id:
context_enc = context_enc[:-1]
context_encs = [context_enc] * len(continuation_enc)
return context_encs, continuation_enc
# Handle list of continuations
context_enc = self.tok_encode(context)
context_encs = []
continuations_encs = []
for cont in continuations:
whole_enc = self.tok_encode(context + cont)
context_enc_len = len(context_enc)
if len(context_enc) == len(whole_enc):
context_enc_len = len(context_enc) - 1
continuations_encs.append(whole_enc[context_enc_len:])
context_encs.append(whole_enc[:context_enc_len])
return context_encs, continuations_encs