in mmf/models/mmf_transformer.py [0:0]
def preprocess_sample(self, sample_list: Dict[str, Any]) -> BaseTransformerInput:
"""Preprocess the sample list elements and form a BaseTransformerInput
type object. This object standardizes how we represent multiple modalities.
Check the definition of this dataclass in BaseTransformer.
"""
# Input IDs (or text tokens/image features)
input_ids: Dict[str, Tensor] = {}
for idx, modality in enumerate(self.config.modalities):
if modality.type == "text":
if sample_list.input_ids.dim() > 2:
input_ids[modality.key] = sample_list.input_ids[:, idx]
else:
input_ids[modality.key] = sample_list.input_ids
elif modality.type == "image":
if "image" in sample_list:
image_modal = sample_list.image
else:
image_modal = sample_list.image_feature_0
input_ids[modality.key] = self.image_encoder(image_modal)
# Position IDs
position_ids: Dict[str, Tensor] = {}
for modality in self.config.modalities:
position_ids[modality.key] = (
torch.arange(
0,
input_ids[modality.key].size(1),
dtype=torch.long,
device=input_ids[modality.key].device,
)
.unsqueeze(0)
.expand(input_ids[modality.key].size()[:2])
)
# Segment IDs
segment_ids: Dict[str, Tensor] = {}
for idx, modality in enumerate(self.config.modalities):
if modality.type == "text" and hasattr(sample_list, "segment_ids"):
if sample_list.segment_ids.dim() > 2:
segment_ids[modality.key] = sample_list.segment_ids[:, idx]
else:
segment_ids[modality.key] = sample_list.segment_ids
elif hasattr(modality, "segment_id"):
segment_ids[modality.key] = torch.zeros(
input_ids[modality.key].size()[:2],
dtype=torch.long,
device=input_ids[modality.key].device,
).fill_(modality.segment_id)
# Masks
masks: Dict[str, Tensor] = {}
for idx, modality in enumerate(self.config.modalities):
if modality.type == "text":
if sample_list.input_mask.dim() > 2:
masks[modality.key] = sample_list.input_mask[:, idx]
else:
masks[modality.key] = sample_list.input_mask
elif modality.type == "image":
if "image_mask" in sample_list:
masks[modality.key] = sample_list.image_mask
else:
masks[modality.key] = torch.ones(
input_ids[modality.key].size()[:-1],
dtype=torch.long,
device=input_ids[modality.key].device,
)
return BaseTransformerInput(input_ids, position_ids, segment_ids, masks)