def custom_collate()

in src/data.py [0:0]


def custom_collate(batch):
    """Collate function to deal with variable length input """
    batch_size = len(batch)
    max_len = max([sample["text_len"] for sample in batch])

    # IMPORTANT: Enforce padding token to be 0
    padded_input = torch.zeros((batch_size, max_len))
    padded_output = torch.zeros((batch_size, max_len))
    text_len = []

    md, md_len = defaultdict(list), defaultdict(list)

    for idx, sample in enumerate(batch):
        curr_len = sample["text_len"]
        text_len.append(curr_len)
        padded_input[idx, :curr_len] = sample["input"]
        padded_output[idx, :curr_len] = sample["output"]

        sample_md = sample["md"]
        sample_md_len = sample["md_len"]
        if sample_md is None:
            md = None
            md_len = None
            continue

        for curr_md_transform, curr_md in sample_md.items():
            md[curr_md_transform].append(curr_md)

        for curr_md_transform, curr_md_len in sample_md_len.items():
            md_len[curr_md_transform].append(curr_md_len)


    text_len = torch.stack(text_len)

    if md:
        for curr_md_transform in md.keys():
            md[curr_md_transform] = torch.stack(md[curr_md_transform])
        for curr_md_transform in md.keys():
            md_len[curr_md_transform] = torch.stack(md_len[curr_md_transform])

    processed_batch = {"input": padded_input,
                       "output": padded_output,
                       "md": md,
                       "text_len": text_len,
                       "md_len": md_len}

    return processed_batch