def _cpt_forward()

in src/peft/peft_model.py [0:0]


    def _cpt_forward(self, input_ids, inputs_embeds, peft_config, task_ids, batch_size, **kwargs):
        # Extract labels from kwargs
        labels = kwargs.pop("labels")
        device = [i.device for i in [input_ids, inputs_embeds, labels] if i is not None][0]
        # Extract input_type_mask from kwargs and move it to the same device as labels
        if "input_type_mask" in kwargs.keys():
            input_type_mask = kwargs.pop("input_type_mask").to(device)
        else:
            if input_ids is None:
                N_tokens = inputs_embeds.shape[1]
            else:
                N_tokens = input_ids.shape[1]
            input_type_mask = torch.ones((batch_size, N_tokens)).to(device) * 4

        cpt_token_ids = peft_config.cpt_token_ids
        cpt_tokens_type_mask = peft_config.cpt_tokens_type_mask

        # Generate embeddings if not provided
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        # Get prompt and concatenate with input embeddings
        prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids)
        prompts = prompts.to(inputs_embeds.dtype)
        inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
        # If labels are provided, generate prefix labels and type mask
        cpt_labels = None
        if labels is not None:
            # Generate prefix labels and concatenate with the input labels
            prefix_labels = torch.Tensor(cpt_token_ids).long().view(1, -1)
            prefix_labels = prefix_labels.repeat(batch_size, 1).to(labels.device)
            cpt_labels = torch.cat((prefix_labels, labels), dim=1)
            # Generate prefix type mask and shift input type mask values to avoid conflicts
            prefix_type_mask = torch.Tensor(cpt_tokens_type_mask).long().view(1, -1)
            prefix_type_mask = prefix_type_mask.repeat(batch_size, 1).to(labels.device)
            adjusted_input_type_mask = input_type_mask
            adjusted_input_type_mask[adjusted_input_type_mask > 0] += prefix_type_mask.max()
            # Concatenate prefix and shifted input type masks
            cpt_type_mask = torch.cat((prefix_type_mask, adjusted_input_type_mask), dim=1)
            # Identify valid label positions and mask invalid ones with -100
            labels_idx = (cpt_type_mask > 0) & (cpt_type_mask % 4 == 0)
            cpt_labels[~labels_idx] = -100
            # Update kwargs with the modified labels

        kwargs["labels"] = cpt_labels
        # Pass the modified inputs to the base model
        base_model_output = self.base_model(inputs_embeds=inputs_embeds, **kwargs)
        if labels is None:
            return base_model_output
        else:
            # Calculate the loss using the custom CPT loss function
            cpt_embedding = PEFT_TYPE_TO_TUNER_MAPPING[peft_config.peft_type]
            base_model_output = cpt_embedding.calculate_loss(
                base_model_output, cpt_labels, cpt_type_mask, self.peft_config["default"]
            )
            return base_model_output