in src/data.py [0:0]
def custom_collate(batch):
"""Collate function to deal with variable length input """
batch_size = len(batch)
max_len = max([sample["text_len"] for sample in batch])
# IMPORTANT: Enforce padding token to be 0
padded_input = torch.zeros((batch_size, max_len))
padded_output = torch.zeros((batch_size, max_len))
text_len = []
md, md_len = defaultdict(list), defaultdict(list)
for idx, sample in enumerate(batch):
curr_len = sample["text_len"]
text_len.append(curr_len)
padded_input[idx, :curr_len] = sample["input"]
padded_output[idx, :curr_len] = sample["output"]
sample_md = sample["md"]
sample_md_len = sample["md_len"]
if sample_md is None:
md = None
md_len = None
continue
for curr_md_transform, curr_md in sample_md.items():
md[curr_md_transform].append(curr_md)
for curr_md_transform, curr_md_len in sample_md_len.items():
md_len[curr_md_transform].append(curr_md_len)
text_len = torch.stack(text_len)
if md:
for curr_md_transform in md.keys():
md[curr_md_transform] = torch.stack(md[curr_md_transform])
for curr_md_transform in md.keys():
md_len[curr_md_transform] = torch.stack(md_len[curr_md_transform])
processed_batch = {"input": padded_input,
"output": padded_output,
"md": md,
"text_len": text_len,
"md_len": md_len}
return processed_batch