in src/hyperpod_nemo_adapter/collections/data/hf_data_module.py [0:0]
def mm_collate_fn(examples):
lis = list(examples[0].keys())
batch = {}
for k in lis:
if k == "pixel_values":
batch[k] = torch.concat([torch.as_tensor(sample[k]) for sample in examples], dim=0)
else:
batch[k] = torch.stack([torch.as_tensor(sample[k]) for sample in examples], dim=0)
return batch