in vision/smolvlm2/smolvlm/datasets/builder.py [0:0]
def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
################################################################
# PART A: Pad the text data (input_ids, attention_mask, labels)
################################################################
attention_masks_list = []
for ex in examples:
# If "attention_mask" is missing, we generate it on the fly
if "attention_mask" in ex:
attention_masks_list.append(ex["attention_mask"])
else:
am = (ex["input_ids"] != self.pad_token_id).long()
attention_masks_list.append(am)
input_ids = self.func_pad_sequence(
[ex["input_ids"] for ex in examples],
batch_first=True,
padding_value=self.pad_token_id
)
attention_mask = self.func_pad_sequence(
attention_masks_list,
batch_first=True,
padding_value=0
)
labels = self.func_pad_sequence(
[ex["labels"] for ex in examples],
batch_first=True,
padding_value=self.ignore_index
)
# Optional: truncate if model_max_length is specified
if self.model_max_length and input_ids.size(1) > self.model_max_length:
input_ids = input_ids[:, :self.model_max_length]
attention_mask = attention_mask[:, :self.model_max_length]
labels = labels[:, :self.model_max_length]
out = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
################################################################
# PART B: Handle pixel data (pixel_values) + pixel_attention_mask
################################################################
# Step 1: figure out maximum frames, height, width across the batch
pvs = [ex["pixel_values"] for ex in examples if "pixel_values" in ex]
if pvs: # there is at least one non-None pixel_values
max_frames = max(pv.shape[0] for pv in pvs)
max_h = max(pv.shape[-2] for pv in pvs)
max_w = max(pv.shape[-1] for pv in pvs)
else:
max_h = max_w = self.image_size
max_frames = 1 #TODO: verify this is good default
# Step 2: create padded pixel_values and pixel_attention_mask for each example
padded_pixel_values_list = []
padded_pixel_mask_list = []
for ex in examples:
pv = ex.get("pixel_values", None)
pm = ex.get("pixel_attention_mask", None) # shape (f, h, w) if provided
if pv is None:
# text-only => fill pixel data + mask with zeros
shape_pv = (max_frames, 3, max_h, max_w)
shape_pm = (max_frames, max_h, max_w)
padded_pv = torch.zeros(shape_pv, dtype=torch.float32)
padded_pm = torch.zeros(shape_pm, dtype=torch.long)
else:
f, c, h, w = pv.shape
# Prepare final storage
padded_pv = torch.zeros(
(max_frames, c, max_h, max_w),
dtype=pv.dtype,
device=pv.device
)
padded_pm = torch.zeros(
(max_frames, max_h, max_w),
dtype=torch.long,
device=pv.device
)
padded_pv[:f, :, :h, :w] = pv
# Copy or fill the pixel attention mask
if pm is not None:
padded_pm[:f, :h, :w] = pm
else:
# Mark valid region as 1
padded_pm[:f, :h, :w] = 1
padded_pixel_values_list.append(padded_pv)
padded_pixel_mask_list.append(padded_pm)
# Finally, stack along batch dimension
## try not outputting pixel_values in text-only sample
#if any("pixel_values" in ex for ex in examples):
out["pixel_values"] = torch.stack(padded_pixel_values_list, dim=0)
return out