in scripts/adapet/ADAPET/src/data/Batcher.py [0:0]
def my_collate_fn(batch):
dict_batch = {}
dict_batch["input"] = {}
dict_batch["output"] = {}
for datapoint in batch:
for (k, v) in datapoint["input"].items():
if k in dict_batch["input"]:
dict_batch["input"][k].append(v)
else:
dict_batch["input"][k] = [v]
for (k, v) in datapoint["output"].items():
if k in dict_batch["output"]:
dict_batch["output"][k].append(v)
else:
dict_batch["output"][k] = [v]
for (k, list_v) in dict_batch["input"].items():
if isinstance(list_v[0], int):
dict_batch["input"][k] = torch.tensor(list_v)
for (k, list_v) in dict_batch["output"].items():
if isinstance(list_v[0], int):
dict_batch["output"][k] = torch.tensor(list_v)
return dict_batch