def __call__()

in vision/smolvlm2/smolvlm/datasets/builder.py [0:0]


    def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        ################################################################
        # PART A: Pad the text data (input_ids, attention_mask, labels)
        ################################################################
        
        attention_masks_list = []
        for ex in examples:
            # If "attention_mask" is missing, we generate it on the fly
            if "attention_mask" in ex:
                attention_masks_list.append(ex["attention_mask"])
            else:
                am = (ex["input_ids"] != self.pad_token_id).long()
                attention_masks_list.append(am)


        input_ids = self.func_pad_sequence(
            [ex["input_ids"] for ex in examples],
            batch_first=True,
            padding_value=self.pad_token_id
        )
        attention_mask = self.func_pad_sequence(
            attention_masks_list,
            batch_first=True,
            padding_value=0
        )
        labels = self.func_pad_sequence(
            [ex["labels"] for ex in examples],
            batch_first=True,
            padding_value=self.ignore_index
        )

        # Optional: truncate if model_max_length is specified
        if self.model_max_length and input_ids.size(1) > self.model_max_length:
            input_ids = input_ids[:, :self.model_max_length]
            attention_mask = attention_mask[:, :self.model_max_length]
            labels = labels[:, :self.model_max_length]

        out = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }

        ################################################################
        # PART B: Handle pixel data (pixel_values) + pixel_attention_mask
        ################################################################
        # Step 1: figure out maximum frames, height, width across the batch
        pvs = [ex["pixel_values"] for ex in examples if "pixel_values" in ex]
        if pvs:  # there is at least one non-None pixel_values
            max_frames = max(pv.shape[0] for pv in pvs)
            max_h = max(pv.shape[-2] for pv in pvs)
            max_w = max(pv.shape[-1] for pv in pvs)
        else:
            max_h = max_w = self.image_size
            max_frames = 1 #TODO: verify this is good default

        # Step 2: create padded pixel_values and pixel_attention_mask for each example
        padded_pixel_values_list = []
        padded_pixel_mask_list = []

        for ex in examples:
            pv = ex.get("pixel_values", None)
            pm = ex.get("pixel_attention_mask", None)  # shape (f, h, w) if provided

            if pv is None:
                # text-only => fill pixel data + mask with zeros
                shape_pv = (max_frames, 3, max_h, max_w)
                shape_pm = (max_frames, max_h, max_w)
                padded_pv = torch.zeros(shape_pv, dtype=torch.float32)
                padded_pm = torch.zeros(shape_pm, dtype=torch.long)
            else:
                f, c, h, w = pv.shape
                # Prepare final storage
                padded_pv = torch.zeros(
                    (max_frames, c, max_h, max_w),
                    dtype=pv.dtype,
                    device=pv.device
                )
                padded_pm = torch.zeros(
                    (max_frames, max_h, max_w),
                    dtype=torch.long,
                    device=pv.device
                )

                padded_pv[:f, :, :h, :w] = pv
                # Copy or fill the pixel attention mask
                if pm is not None:
                    padded_pm[:f, :h, :w] = pm
                else:
                    # Mark valid region as 1
                    padded_pm[:f, :h, :w] = 1

            padded_pixel_values_list.append(padded_pv)
            padded_pixel_mask_list.append(padded_pm)

        # Finally, stack along batch dimension
        ## try not outputting pixel_values in text-only sample
        #if any("pixel_values" in ex for ex in examples):
        out["pixel_values"] = torch.stack(padded_pixel_values_list, dim=0)
        return out