def forward()

in mmf/models/lxmert.py [0:0]


    def forward(self, sample_list):
        device = registry.get("config").training.device
        params = self.get_image_and_text_features(sample_list, device)
        if params["visual_feats"] is not None and params["image_dim"] is not None:
            device = params["visual_feats"].device
            image_mask = (
                torch.arange(params["visual_feats"].size(-2))
                .expand(*params["visual_feats"].size()[:-1])
                .to(device)
            )
            if len(params["image_dim"].size()) < len(image_mask.size()):
                params["image_dim"] = params["image_dim"].unsqueeze(-1)
                assert len(params["image_dim"].size()) == len(image_mask.size())
            image_mask = image_mask < params["image_dim"]
            params["image_attention_mask"] = image_mask.long()
        else:
            params["image_attention_mask"] = None
        if self.config.training_head_type == "pretraining":
            output_dict = self.model(
                input_ids=params["input_ids"],
                token_type_ids=params["token_type_ids"],
                attention_mask=params["attention_mask"],
                visual_feats=params["visual_feats"],
                visual_pos=params["pos"],
                visual_attention_mask=params["image_attention_mask"],
                masked_lm_labels=params["masked_lm_labels"],
                masked_image_labels=params["masked_image_labels"],
                obj_labels=params["obj_labels"],
                matched_label=params["matched_label"],
                ans=params["ans"],
                num_features=params["max_features"],
                name=params["dataset_name"],
            )
            loss_key = "{}/{}".format(
                sample_list.dataset_name, sample_list.dataset_type
            )
            output_dict["losses"] = {}
            if "masked_lm_loss" in output_dict.keys():
                output_dict["losses"][loss_key + "/masked_lm_loss"] = output_dict.pop(
                    "masked_lm_loss"
                )
            if "matched_loss" in output_dict.keys():
                output_dict["losses"][loss_key + "/matched_loss"] = output_dict.pop(
                    "matched_loss"
                )
            if "visn_loss" in output_dict.keys():
                output_dict["losses"][loss_key + "/visn_loss"] = output_dict.pop(
                    "visn_loss"
                )
            if "answer_loss" in output_dict.keys():
                output_dict["losses"][loss_key + "/answer_loss"] = output_dict.pop(
                    "answer_loss"
                )
        else:
            output_dict = self.model(
                input_ids=params["input_ids"],
                token_type_ids=params["token_type_ids"],
                attention_mask=params["attention_mask"],
                visual_feats=params["visual_feats"],
                visual_pos=params["pos"],
                visual_attention_mask=params["image_attention_mask"],
            )
        return output_dict