def batchify()

in train.py [0:0]


def batchify(fn, chunk, detailed_output=False):
    """Constructs a version of 'fn' that applies to smaller batches."""
    if chunk is None:
        return fn

    def ret(inputs):
        if detailed_output:
            outputs, details_lists = zip(
                *[
                    fn(inputs[i : i + chunk], detailed_output=detailed_output)
                    for i in range(0, inputs.shape[0], chunk)
                ]
            )
            outputs = torch.cat(outputs, 0)
            details = {}
            for key in details_lists[0].keys():
                details[key] = torch.cat([details[key] for details in details_lists], 0)
            return outputs, details
        else:
            return torch.cat(
                [
                    fn(inputs[i : i + chunk], detailed_output=detailed_output)
                    for i in range(0, inputs.shape[0], chunk)
                ],
                0,
            )

    return ret