in mmf/models/lxmert.py [0:0]
def forward(self, sample_list):
device = registry.get("config").training.device
params = self.get_image_and_text_features(sample_list, device)
if params["visual_feats"] is not None and params["image_dim"] is not None:
device = params["visual_feats"].device
image_mask = (
torch.arange(params["visual_feats"].size(-2))
.expand(*params["visual_feats"].size()[:-1])
.to(device)
)
if len(params["image_dim"].size()) < len(image_mask.size()):
params["image_dim"] = params["image_dim"].unsqueeze(-1)
assert len(params["image_dim"].size()) == len(image_mask.size())
image_mask = image_mask < params["image_dim"]
params["image_attention_mask"] = image_mask.long()
else:
params["image_attention_mask"] = None
if self.config.training_head_type == "pretraining":
output_dict = self.model(
input_ids=params["input_ids"],
token_type_ids=params["token_type_ids"],
attention_mask=params["attention_mask"],
visual_feats=params["visual_feats"],
visual_pos=params["pos"],
visual_attention_mask=params["image_attention_mask"],
masked_lm_labels=params["masked_lm_labels"],
masked_image_labels=params["masked_image_labels"],
obj_labels=params["obj_labels"],
matched_label=params["matched_label"],
ans=params["ans"],
num_features=params["max_features"],
name=params["dataset_name"],
)
loss_key = "{}/{}".format(
sample_list.dataset_name, sample_list.dataset_type
)
output_dict["losses"] = {}
if "masked_lm_loss" in output_dict.keys():
output_dict["losses"][loss_key + "/masked_lm_loss"] = output_dict.pop(
"masked_lm_loss"
)
if "matched_loss" in output_dict.keys():
output_dict["losses"][loss_key + "/matched_loss"] = output_dict.pop(
"matched_loss"
)
if "visn_loss" in output_dict.keys():
output_dict["losses"][loss_key + "/visn_loss"] = output_dict.pop(
"visn_loss"
)
if "answer_loss" in output_dict.keys():
output_dict["losses"][loss_key + "/answer_loss"] = output_dict.pop(
"answer_loss"
)
else:
output_dict = self.model(
input_ids=params["input_ids"],
token_type_ids=params["token_type_ids"],
attention_mask=params["attention_mask"],
visual_feats=params["visual_feats"],
visual_pos=params["pos"],
visual_attention_mask=params["image_attention_mask"],
)
return output_dict