in src/peft/peft_model.py [0:0]
def _cpt_forward(self, input_ids, inputs_embeds, peft_config, task_ids, batch_size, **kwargs):
# Extract labels from kwargs
labels = kwargs.pop("labels")
device = [i.device for i in [input_ids, inputs_embeds, labels] if i is not None][0]
# Extract input_type_mask from kwargs and move it to the same device as labels
if "input_type_mask" in kwargs.keys():
input_type_mask = kwargs.pop("input_type_mask").to(device)
else:
if input_ids is None:
N_tokens = inputs_embeds.shape[1]
else:
N_tokens = input_ids.shape[1]
input_type_mask = torch.ones((batch_size, N_tokens)).to(device) * 4
cpt_token_ids = peft_config.cpt_token_ids
cpt_tokens_type_mask = peft_config.cpt_tokens_type_mask
# Generate embeddings if not provided
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
# Get prompt and concatenate with input embeddings
prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids)
prompts = prompts.to(inputs_embeds.dtype)
inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
# If labels are provided, generate prefix labels and type mask
cpt_labels = None
if labels is not None:
# Generate prefix labels and concatenate with the input labels
prefix_labels = torch.Tensor(cpt_token_ids).long().view(1, -1)
prefix_labels = prefix_labels.repeat(batch_size, 1).to(labels.device)
cpt_labels = torch.cat((prefix_labels, labels), dim=1)
# Generate prefix type mask and shift input type mask values to avoid conflicts
prefix_type_mask = torch.Tensor(cpt_tokens_type_mask).long().view(1, -1)
prefix_type_mask = prefix_type_mask.repeat(batch_size, 1).to(labels.device)
adjusted_input_type_mask = input_type_mask
adjusted_input_type_mask[adjusted_input_type_mask > 0] += prefix_type_mask.max()
# Concatenate prefix and shifted input type masks
cpt_type_mask = torch.cat((prefix_type_mask, adjusted_input_type_mask), dim=1)
# Identify valid label positions and mask invalid ones with -100
labels_idx = (cpt_type_mask > 0) & (cpt_type_mask % 4 == 0)
cpt_labels[~labels_idx] = -100
# Update kwargs with the modified labels
kwargs["labels"] = cpt_labels
# Pass the modified inputs to the base model
base_model_output = self.base_model(inputs_embeds=inputs_embeds, **kwargs)
if labels is None:
return base_model_output
else:
# Calculate the loss using the custom CPT loss function
cpt_embedding = PEFT_TYPE_TO_TUNER_MAPPING[peft_config.peft_type]
base_model_output = cpt_embedding.calculate_loss(
base_model_output, cpt_labels, cpt_type_mask, self.peft_config["default"]
)
return base_model_output