def _process_tokens()

in trl/trainer/bco_trainer.py [0:0]


def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = None, **kwargs) -> dict:
    """Process tokens of a BCO specific dataset.

    At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt +
    completion responses is/are too long. First we truncate the prompt; if we're still too long, we truncate the
    completion.

    We also create the labels for the completion responses, which are of length equal to the sum of the length of the
    prompt and the completion response, with label_pad_token_id for the prompt tokens.
    """
    prompt = example["prompt"]
    completion = example["completion"]

    batch = {
        f"{kwargs['prefix']}prompt": prompt,
        f"{kwargs['prefix']}completion": completion,
        f"{kwargs['prefix']}label": example["label"],
    }

    if not kwargs["is_encoder_decoder"]:
        # Check issues below for more details
        #  1. https://github.com/huggingface/trl/issues/907
        #  2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
        #  3. https://github.com/LianjiaTech/BELLE/issues/337

        if not isinstance(prompt, str):
            raise ValueError(f"prompt should be an str but got {type(prompt)}")

        if not isinstance(completion, str):
            raise ValueError(f"completion should be an str but got {type(completion)}")

        # keys of format prompt_* refers to just the prompt and answer_* refers to just the answer
        all_tokens = {
            "prompt_input_ids": example["prompt_input_ids"],
            "prompt_attention_mask": example["prompt_attention_mask"],
            "answer_input_ids": example["answer_input_ids"],
            "answer_attention_mask": example["answer_attention_mask"],
        }

        # calculate max length by checking if BOS/EOS is already there
        max_length = kwargs["max_length"]
        bos_token_id = kwargs["tokenizer"].bos_token_id
        eos_token_id = kwargs["tokenizer"].eos_token_id
        if bos_token_id != all_tokens["prompt_input_ids"][0]:
            max_length -= 1
        if eos_token_id != all_tokens["answer_input_ids"][-1]:
            max_length -= 1

        # if combined sequence is too long (> max_length - 1 for BOS token - 1 for EOS), truncate the prompt
        if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length:
            for k in ["prompt_input_ids", "prompt_attention_mask"]:
                if kwargs["truncation_mode"] == "keep_start":
                    all_tokens[k] = all_tokens[k][: kwargs["max_prompt_length"]]
                elif kwargs["truncation_mode"] == "keep_end":
                    all_tokens[k] = all_tokens[k][-kwargs["max_prompt_length"] :]
                else:
                    raise ValueError(f"Unknown truncation mode: {kwargs['truncation_mode']}")

        # if that's still too long, truncate the response
        if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length:
            for k in ["answer_input_ids", "answer_attention_mask"]:
                all_tokens[k] = all_tokens[k][: max_length - kwargs["max_prompt_length"]]

        # all input_ids and attention mask as is. We then check if we need to add BOS/EOS tokens
        batch[f"{kwargs['prefix']}prompt_input_ids"] = all_tokens["prompt_input_ids"]
        batch[f"{kwargs['prefix']}prompt_attention_mask"] = all_tokens["prompt_attention_mask"]
        batch[f"{kwargs['prefix']}completion_input_ids"] = (
            all_tokens["prompt_input_ids"] + all_tokens["answer_input_ids"]
        )
        batch[f"{kwargs['prefix']}completion_attention_mask"] = (
            all_tokens["prompt_attention_mask"] + all_tokens["answer_attention_mask"]
        )

        # add BOS, which affects both prompt and the full completion
        if bos_token_id is not None:
            if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]:
                batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[
                    f"{kwargs['prefix']}prompt_input_ids"
                ]
                batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[
                    f"{kwargs['prefix']}prompt_attention_mask"
                ]
                batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[
                    f"{kwargs['prefix']}completion_input_ids"
                ]
                batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[
                    f"{kwargs['prefix']}completion_attention_mask"
                ]
        # add EOS, which affects only the full completion
        if len(all_tokens["answer_input_ids"]) == 0 or eos_token_id != all_tokens["answer_input_ids"][-1]:
            batch[f"{kwargs['prefix']}completion_input_ids"] = batch[f"{kwargs['prefix']}completion_input_ids"] + [
                eos_token_id
            ]
            batch[f"{kwargs['prefix']}completion_attention_mask"] = batch[
                f"{kwargs['prefix']}completion_attention_mask"
            ] + [1]

        batch[f"{kwargs['prefix']}completion_labels"] = batch[f"{kwargs['prefix']}completion_input_ids"][:]
        batch[f"{kwargs['prefix']}completion_labels"][: len(batch[f"{kwargs['prefix']}prompt_input_ids"])] = [
            kwargs["label_pad_token_id"]
        ] * len(batch[f"{kwargs['prefix']}prompt_input_ids"])
    else:
        completion_tokens = kwargs["tokenizer"](
            completion, truncation=True, max_length=kwargs["max_completion_length"], add_special_tokens=True
        )
        prompt_tokens = kwargs["tokenizer"](
            prompt, truncation=True, max_length=kwargs["max_prompt_length"], add_special_tokens=True
        )

        batch[f"{kwargs['prefix']}prompt_input_ids"] = prompt_tokens["input_ids"]
        batch[f"{kwargs['prefix']}prompt_attention_mask"] = prompt_tokens["attention_mask"]

        batch[f"{kwargs['prefix']}completion_labels"] = completion_tokens["input_ids"]
        batch[f"{kwargs['prefix']}completion_attention_mask"] = completion_tokens["attention_mask"]
        if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
            batch[f"{kwargs['prefix']}completion_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
                labels=torch.tensor(batch["completion_labels"])
            )

    return batch