def preprocess_sample()

in mmf/models/mmf_transformer.py [0:0]


    def preprocess_sample(self, sample_list: Dict[str, Any]) -> BaseTransformerInput:
        """Preprocess the sample list elements and form a BaseTransformerInput
        type object. This object standardizes how we represent multiple modalities.
        Check the definition of this dataclass in BaseTransformer.
        """

        # Input IDs (or text tokens/image features)
        input_ids: Dict[str, Tensor] = {}
        for idx, modality in enumerate(self.config.modalities):
            if modality.type == "text":
                if sample_list.input_ids.dim() > 2:
                    input_ids[modality.key] = sample_list.input_ids[:, idx]
                else:
                    input_ids[modality.key] = sample_list.input_ids
            elif modality.type == "image":
                if "image" in sample_list:
                    image_modal = sample_list.image
                else:
                    image_modal = sample_list.image_feature_0
                input_ids[modality.key] = self.image_encoder(image_modal)

        # Position IDs
        position_ids: Dict[str, Tensor] = {}
        for modality in self.config.modalities:
            position_ids[modality.key] = (
                torch.arange(
                    0,
                    input_ids[modality.key].size(1),
                    dtype=torch.long,
                    device=input_ids[modality.key].device,
                )
                .unsqueeze(0)
                .expand(input_ids[modality.key].size()[:2])
            )

        # Segment IDs
        segment_ids: Dict[str, Tensor] = {}
        for idx, modality in enumerate(self.config.modalities):
            if modality.type == "text" and hasattr(sample_list, "segment_ids"):
                if sample_list.segment_ids.dim() > 2:
                    segment_ids[modality.key] = sample_list.segment_ids[:, idx]
                else:
                    segment_ids[modality.key] = sample_list.segment_ids
            elif hasattr(modality, "segment_id"):
                segment_ids[modality.key] = torch.zeros(
                    input_ids[modality.key].size()[:2],
                    dtype=torch.long,
                    device=input_ids[modality.key].device,
                ).fill_(modality.segment_id)

        # Masks
        masks: Dict[str, Tensor] = {}
        for idx, modality in enumerate(self.config.modalities):
            if modality.type == "text":
                if sample_list.input_mask.dim() > 2:
                    masks[modality.key] = sample_list.input_mask[:, idx]
                else:
                    masks[modality.key] = sample_list.input_mask

            elif modality.type == "image":
                if "image_mask" in sample_list:
                    masks[modality.key] = sample_list.image_mask
                else:
                    masks[modality.key] = torch.ones(
                        input_ids[modality.key].size()[:-1],
                        dtype=torch.long,
                        device=input_ids[modality.key].device,
                    )

        return BaseTransformerInput(input_ids, position_ids, segment_ids, masks)