def get_image_and_text_features()

in mmf/models/lxmert.py [0:0]


    def get_image_and_text_features(self, sample_list, device):
        # bert input
        bert_input_ids = sample_list.input_ids
        bert_input_mask = sample_list.input_mask
        bert_input_type_ids = sample_list.segment_ids
        masked_lm_labels = sample_list.lm_label_ids

        # image input
        image_info = getattr(sample_list, "image_info_0", {})
        image_dim_variable = getattr(image_info, "max_features", None)
        image_feature_variable = getattr(sample_list, "image_feature_0", None)
        max_features = torch.tensor(
            image_feature_variable.shape[1], dtype=torch.int
        ).to(device)
        image_location_variable = getattr(image_info, "bbox", None)
        image_location_variable = image_location_variable[:, : max_features.item(), :4]

        # aux data
        image_label_variable = getattr(sample_list, "image_labels", None)
        if image_label_variable is not None:
            image_label_variable = image_label_variable[:, : max_features.item(), None]
            image_label_variable = image_label_variable.unsqueeze(-1).to(device)
        cls_prob = getattr(image_info, "cls_prob", None)
        if cls_prob is not None:
            cls_prob = torch.tensor(cls_prob)[:, : max_features.item(), None].to(device)
        answers = getattr(sample_list, "targets", None)
        if answers is None:
            answers = getattr(sample_list, "answers", None)
        if answers is not None:
            if not isinstance(answers, torch.Tensor):
                answers = torch.tensor(answers)
            answers = answers.to(device)
        is_correct = getattr(sample_list, "is_correct", None)
        if is_correct is not None:
            if isinstance(is_correct, torch.Tensor):
                is_correct = is_correct.to(device)
            else:
                is_correct = torch.tensor(is_correct).to(device)

        return {
            "input_ids": bert_input_ids,
            "token_type_ids": bert_input_mask,
            "attention_mask": bert_input_type_ids,
            "masked_lm_labels": masked_lm_labels,
            "visual_feats": image_feature_variable,
            "pos": image_location_variable,
            "masked_image_labels": image_label_variable,
            "obj_labels": cls_prob,
            "matched_label": is_correct,
            "ans": answers,
            "image_dim": image_dim_variable,
            "max_features": max_features,
            "dataset_name": str(sample_list.dataset_name),
        }