in mmf/models/mmf_bert.py [0:0]
def forward(self, sample_list):
# bert text input
input_ids = sample_list.input_ids
input_mask = sample_list.input_mask
input_type_ids = sample_list.segment_ids
input_ids = transform_to_batch_sequence(input_ids)
input_type_ids = transform_to_batch_sequence(input_type_ids)
input_mask = transform_to_batch_sequence(input_mask)
if input_mask is None:
input_mask = torch.ones_like(input_ids)
if input_type_ids is None:
input_type_ids = torch.zeros_like(input_ids)
attention_mask = input_mask.unsqueeze(1).unsqueeze(2)
# pretraining labels
masked_lm_labels = getattr(sample_list, "lm_label_ids", None)
masked_lm_labels = transform_to_batch_sequence(masked_lm_labels)
# pretraining labels
# is_random_next = getattr(sample_list, "is_correct", None)
# TODO(aps): Fix later on dataset side
is_random_next = None
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# fp16 compatibility
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype)
attention_mask = (1.0 - attention_mask) * -10000.0
text_embedding = self.word_embedding(input_ids, input_type_ids)
text_embedding_total = self.process_text_embedding(
text_embedding, input_mask == 0
)
image_embedding_total, _ = self.process_feature_embedding(
"image", sample_list, text_embedding_total
)
if self.inter_model is not None:
image_embedding_total = self.inter_model(image_embedding_total)
# image_embedding_total = image_embedding_total *
# input_mask.unsqueeze(-1).float()
# text_embedding_total = text_embedding_total *
# input_mask.unsqueeze(-1).float()
if self.config.combine_embeddings:
joint_embedding = self.combine_embeddings(
["image", "text"], [image_embedding_total, text_embedding_total]
)
else:
joint_embedding = image_embedding_total
output_dict = {}
pooled_output = self.pooler(joint_embedding)
if "pretraining" in self.config.training_head_type:
prediction_scores, seq_relationship_score = self.classifier(
joint_embedding, pooled_output
)
output_dict["logits"] = prediction_scores
if masked_lm_labels is not None:
loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(
prediction_scores.contiguous().view(
-1, self.bert_config.vocab_size
),
masked_lm_labels.contiguous().view(-1),
)
# print(seq_relationship_score.argmax(dim=1), is_random_next)
loss_key = "{}/{}".format(
sample_list.dataset_name, sample_list.dataset_type
)
output_dict["losses"] = {}
output_dict["losses"][loss_key + "/masked_lm_loss"] = masked_lm_loss
if is_random_next is not None:
output_dict["seq_relationship_score"] = seq_relationship_score
next_sentence_loss = loss_fct(
seq_relationship_score.contiguous().view(-1, 2),
is_random_next.contiguous().view(-1),
)
output_dict["losses"][
loss_key + "/next_sentence_loss"
] = next_sentence_loss
return output_dict
elif (
"vqa" in self.config.training_head_type
or self.config.training_head_type == "vizwiz"
):
index_to_gather = input_mask.sum(1) - 2
pooled_output = torch.gather(
joint_embedding,
1,
index_to_gather.unsqueeze(-1)
.unsqueeze(-1)
.expand(index_to_gather.size(0), 1, joint_embedding.size(-1)),
)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.contiguous().view(-1, self.answer_space_size)
output_dict["scores"] = reshaped_logits
return output_dict
elif (
self.config.training_head_type == "nlvr2"
or self.config.training_head_type == "visual_entailment"
):
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
output_dict["scores"] = logits
return output_dict
return output_dict