def tok_encode_pair()

in src/lighteval/models/abstract_model.py [0:0]


    def tok_encode_pair(self, context, continuations: list[str], pairwise: bool = False):
        """Encodes a context with a list of continuations by taking care of the spaces in between.
        Args:
            context (str): The context string to be encoded.
            continuation (list[str]): List of continuation strings to be encoded.
            pairwise (bool):
                If True, encode context and continuations separately.
                If False, encode them together and then split.

        Returns:
            Tuple[TokenSequence, list[TokenSequence]]:
                A tuple containing the encoded context and a list of encoded continuations.

        The advantage of pairwise is:
        1) It better aligns with how LLM predicts tokens
        2) Works in case len(tok(context,cont)) != len(tok(context)) + len(tok(continuation)).
        E.g this can happen for chinese if no space is used between context/continuation
        """
        n_spaces = len(context) - len(context.rstrip())
        if n_spaces > 0:
            continuations = [context[-n_spaces:] + cont for cont in continuations]
            context = context[:-n_spaces]

        if pairwise:
            # We don't add special tokens to the continuation as if bos is added
            # models tend to to completely ignore a context
            context_enc = self.tok_encode(context, add_special_tokens=self.add_special_tokens)
            continuation_enc = [self.tok_encode(cont, add_special_tokens=False) for cont in continuations]

            # In theory the context_enc can be ended with eos token, this would again
            # cause the model to ignore the context. We thus strip the eos token from context_enc
            if len(context_enc) > 0 and context_enc[-1] == self.tokenizer.eos_token_id:
                context_enc = context_enc[:-1]

            context_encs = [context_enc] * len(continuation_enc)

            return context_encs, continuation_enc

        # Handle list of continuations
        context_enc = self.tok_encode(context)
        context_encs = []
        continuations_encs = []
        for cont in continuations:
            whole_enc = self.tok_encode(context + cont)
            context_enc_len = len(context_enc)
            if len(context_enc) == len(whole_enc):
                context_enc_len = len(context_enc) - 1
            continuations_encs.append(whole_enc[context_enc_len:])
            context_encs.append(whole_enc[:context_enc_len])

        return context_encs, continuations_encs