def _infer_input_ids()

in mmf/models/mmf_transformer.py [0:0]


    def _infer_input_ids(self, sample_list: Dict[str, Tensor]) -> Dict[str, Tensor]:
        # Input IDs (or text tokens/image features)
        input_ids: Dict[str, Tensor] = {}
        current_text_idx = 0
        for idx, encoder in enumerate(self.encoders.values()):
            modality = self.modality_keys[idx]
            if self.modality_type[idx] == "text":
                # First, check if standard input_ids corresponds to text
                # if not, check for modality key inside the sample list
                text_ids = self._check_keys_for_modality(
                    sample_list, ("input_ids", modality)
                )
                # This handles the case of more than one text modalities
                # with type text. The input_ids must be stacked in this case.
                # For example, if there are two text modalities, input ids should
                # look have shape: B X 2 X L where second dim points to stacked
                # text ids. Furthermore, make sure that the sequence of modalities
                # in config is same as the sequence in the stacked input ids.
                if text_ids.dim() > 2:
                    input_ids[modality] = text_ids[:, current_text_idx]
                    current_text_idx += 1
                else:
                    input_ids[modality] = text_ids
            elif self.modality_type[idx] == "image":
                # input_modal is originally used by MMBT, added for
                # cross-compatibility of interops and datasets.
                input_ids[modality] = self._check_keys_for_modality(
                    sample_list, (modality, "image", "input_modal", "image_feature_0")
                )
            else:
                # TODO: Later deliberate if missing modalities should
                # be supported in MMFT.
                input_ids[modality] = self._check_keys_for_modality(
                    sample_list, (modality,)
                )
            # In the other case feature will be skipped, as it is not present in
            # the sample list
            if encoder is not None:
                input_ids[modality] = encoder(input_ids[modality])

            # For a feature which is of shape B X D and
            # is not text (which is B X L converted later by embeddings to B X L X D)
            # We convert it to B X 1 X D to signify single position dim.
            if self.modality_type[idx] != "text" and input_ids[modality].dim() == 2:
                input_ids[modality] = input_ids[modality].unsqueeze(1)

        return input_ids