in data/collators.py [0:0]
def _discard_samples_that_are_too_long(self, batch, max_length):
filtered = [
(ids, label, attn, img)
for ids, label, attn, img in zip(batch["input_ids"], batch["labels"], batch["attention_mask"], batch["images"])
if len(ids) <= max_length
]
if not filtered:
return [], [], [], []
batch_token_ids, batch_labels, batch_attentions, batch_images = zip(*filtered)
return {"input_ids": list(batch_token_ids), "labels": list(batch_labels), "attention_mask": list(batch_attentions), "images": list(batch_images)}