in mmf/models/lxmert.py [0:0]
def get_image_and_text_features(self, sample_list, device):
# bert input
bert_input_ids = sample_list.input_ids
bert_input_mask = sample_list.input_mask
bert_input_type_ids = sample_list.segment_ids
masked_lm_labels = sample_list.lm_label_ids
# image input
image_info = getattr(sample_list, "image_info_0", {})
image_dim_variable = getattr(image_info, "max_features", None)
image_feature_variable = getattr(sample_list, "image_feature_0", None)
max_features = torch.tensor(
image_feature_variable.shape[1], dtype=torch.int
).to(device)
image_location_variable = getattr(image_info, "bbox", None)
image_location_variable = image_location_variable[:, : max_features.item(), :4]
# aux data
image_label_variable = getattr(sample_list, "image_labels", None)
if image_label_variable is not None:
image_label_variable = image_label_variable[:, : max_features.item(), None]
image_label_variable = image_label_variable.unsqueeze(-1).to(device)
cls_prob = getattr(image_info, "cls_prob", None)
if cls_prob is not None:
cls_prob = torch.tensor(cls_prob)[:, : max_features.item(), None].to(device)
answers = getattr(sample_list, "targets", None)
if answers is None:
answers = getattr(sample_list, "answers", None)
if answers is not None:
if not isinstance(answers, torch.Tensor):
answers = torch.tensor(answers)
answers = answers.to(device)
is_correct = getattr(sample_list, "is_correct", None)
if is_correct is not None:
if isinstance(is_correct, torch.Tensor):
is_correct = is_correct.to(device)
else:
is_correct = torch.tensor(is_correct).to(device)
return {
"input_ids": bert_input_ids,
"token_type_ids": bert_input_mask,
"attention_mask": bert_input_type_ids,
"masked_lm_labels": masked_lm_labels,
"visual_feats": image_feature_variable,
"pos": image_location_variable,
"masked_image_labels": image_label_variable,
"obj_labels": cls_prob,
"matched_label": is_correct,
"ans": answers,
"image_dim": image_dim_variable,
"max_features": max_features,
"dataset_name": str(sample_list.dataset_name),
}