in data/advanced_datasets.py [0:0]
def _pack_one_group(self, group_indices, batch, max_len):
ids, lbl, am, ims = [], [], [], []
for i in group_indices:
ids.extend(batch[i]["input_ids"])
lbl.extend(batch[i]["labels"])
am.extend(batch[i]["attention_mask"])
ims.extend(batch[i]["images"])
# safety: assert we never overflow
if len(ids) > max_len:
raise ValueError(f"Packed length {len(ids)} > max_len {max_len}")
return torch.stack(ids), torch.stack(lbl), torch.stack(am), ims