in mmf/models/mmf_transformer.py [0:0]
def _infer_input_ids(self, sample_list: Dict[str, Tensor]) -> Dict[str, Tensor]:
# Input IDs (or text tokens/image features)
input_ids: Dict[str, Tensor] = {}
current_text_idx = 0
for idx, encoder in enumerate(self.encoders.values()):
modality = self.modality_keys[idx]
if self.modality_type[idx] == "text":
# First, check if standard input_ids corresponds to text
# if not, check for modality key inside the sample list
text_ids = self._check_keys_for_modality(
sample_list, ("input_ids", modality)
)
# This handles the case of more than one text modalities
# with type text. The input_ids must be stacked in this case.
# For example, if there are two text modalities, input ids should
# look have shape: B X 2 X L where second dim points to stacked
# text ids. Furthermore, make sure that the sequence of modalities
# in config is same as the sequence in the stacked input ids.
if text_ids.dim() > 2:
input_ids[modality] = text_ids[:, current_text_idx]
current_text_idx += 1
else:
input_ids[modality] = text_ids
elif self.modality_type[idx] == "image":
# input_modal is originally used by MMBT, added for
# cross-compatibility of interops and datasets.
input_ids[modality] = self._check_keys_for_modality(
sample_list, (modality, "image", "input_modal", "image_feature_0")
)
else:
# TODO: Later deliberate if missing modalities should
# be supported in MMFT.
input_ids[modality] = self._check_keys_for_modality(
sample_list, (modality,)
)
# In the other case feature will be skipped, as it is not present in
# the sample list
if encoder is not None:
input_ids[modality] = encoder(input_ids[modality])
# For a feature which is of shape B X D and
# is not text (which is B X L converted later by embeddings to B X L X D)
# We convert it to B X 1 X D to signify single position dim.
if self.modality_type[idx] != "text" and input_ids[modality].dim() == 2:
input_ids[modality] = input_ids[modality].unsqueeze(1)
return input_ids