in src/data.py [0:0]
def custom_collate(batch):
    """Collate function to deal with variable length input """
    batch_size = len(batch)
    max_len = max([sample["text_len"] for sample in batch])
    # IMPORTANT: Enforce padding token to be 0
    padded_input = torch.zeros((batch_size, max_len))
    padded_output = torch.zeros((batch_size, max_len))
    text_len = []
    md, md_len = defaultdict(list), defaultdict(list)
    for idx, sample in enumerate(batch):
        curr_len = sample["text_len"]
        text_len.append(curr_len)
        padded_input[idx, :curr_len] = sample["input"]
        padded_output[idx, :curr_len] = sample["output"]
        sample_md = sample["md"]
        sample_md_len = sample["md_len"]
        if sample_md is None:
            md = None
            md_len = None
            continue
        for curr_md_transform, curr_md in sample_md.items():
            md[curr_md_transform].append(curr_md)
        for curr_md_transform, curr_md_len in sample_md_len.items():
            md_len[curr_md_transform].append(curr_md_len)
    text_len = torch.stack(text_len)
    if md:
        for curr_md_transform in md.keys():
            md[curr_md_transform] = torch.stack(md[curr_md_transform])
        for curr_md_transform in md.keys():
            md_len[curr_md_transform] = torch.stack(md_len[curr_md_transform])
    processed_batch = {"input": padded_input,
                       "output": padded_output,
                       "md": md,
                       "text_len": text_len,
                       "md_len": md_len}
    return processed_batch