def forward()

in mmf/models/mmf_bert.py [0:0]


    def forward(self, sample_list):
        # bert text input
        input_ids = sample_list.input_ids
        input_mask = sample_list.input_mask
        input_type_ids = sample_list.segment_ids
        input_ids = transform_to_batch_sequence(input_ids)
        input_type_ids = transform_to_batch_sequence(input_type_ids)
        input_mask = transform_to_batch_sequence(input_mask)

        if input_mask is None:
            input_mask = torch.ones_like(input_ids)
        if input_type_ids is None:
            input_type_ids = torch.zeros_like(input_ids)
        attention_mask = input_mask.unsqueeze(1).unsqueeze(2)

        # pretraining labels
        masked_lm_labels = getattr(sample_list, "lm_label_ids", None)
        masked_lm_labels = transform_to_batch_sequence(masked_lm_labels)
        # pretraining labels
        # is_random_next = getattr(sample_list, "is_correct", None)
        # TODO(aps): Fix later on dataset side
        is_random_next = None

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        # fp16 compatibility
        attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype)
        attention_mask = (1.0 - attention_mask) * -10000.0

        text_embedding = self.word_embedding(input_ids, input_type_ids)
        text_embedding_total = self.process_text_embedding(
            text_embedding, input_mask == 0
        )

        image_embedding_total, _ = self.process_feature_embedding(
            "image", sample_list, text_embedding_total
        )

        if self.inter_model is not None:
            image_embedding_total = self.inter_model(image_embedding_total)

        # image_embedding_total = image_embedding_total *
        # input_mask.unsqueeze(-1).float()
        # text_embedding_total = text_embedding_total *
        # input_mask.unsqueeze(-1).float()

        if self.config.combine_embeddings:
            joint_embedding = self.combine_embeddings(
                ["image", "text"], [image_embedding_total, text_embedding_total]
            )
        else:
            joint_embedding = image_embedding_total

        output_dict = {}

        pooled_output = self.pooler(joint_embedding)

        if "pretraining" in self.config.training_head_type:
            prediction_scores, seq_relationship_score = self.classifier(
                joint_embedding, pooled_output
            )
            output_dict["logits"] = prediction_scores

            if masked_lm_labels is not None:
                loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
                masked_lm_loss = loss_fct(
                    prediction_scores.contiguous().view(
                        -1, self.bert_config.vocab_size
                    ),
                    masked_lm_labels.contiguous().view(-1),
                )
                # print(seq_relationship_score.argmax(dim=1), is_random_next)
                loss_key = "{}/{}".format(
                    sample_list.dataset_name, sample_list.dataset_type
                )

                output_dict["losses"] = {}
                output_dict["losses"][loss_key + "/masked_lm_loss"] = masked_lm_loss

                if is_random_next is not None:
                    output_dict["seq_relationship_score"] = seq_relationship_score

                    next_sentence_loss = loss_fct(
                        seq_relationship_score.contiguous().view(-1, 2),
                        is_random_next.contiguous().view(-1),
                    )
                    output_dict["losses"][
                        loss_key + "/next_sentence_loss"
                    ] = next_sentence_loss
            return output_dict
        elif (
            "vqa" in self.config.training_head_type
            or self.config.training_head_type == "vizwiz"
        ):
            index_to_gather = input_mask.sum(1) - 2

            pooled_output = torch.gather(
                joint_embedding,
                1,
                index_to_gather.unsqueeze(-1)
                .unsqueeze(-1)
                .expand(index_to_gather.size(0), 1, joint_embedding.size(-1)),
            )

            pooled_output = self.dropout(pooled_output)
            logits = self.classifier(pooled_output)
            reshaped_logits = logits.contiguous().view(-1, self.answer_space_size)

            output_dict["scores"] = reshaped_logits
            return output_dict
        elif (
            self.config.training_head_type == "nlvr2"
            or self.config.training_head_type == "visual_entailment"
        ):
            pooled_output = self.dropout(pooled_output)
            logits = self.classifier(pooled_output)
            output_dict["scores"] = logits
            return output_dict

        return output_dict