in train.py [0:0]
def batchify(fn, chunk, detailed_output=False):
"""Constructs a version of 'fn' that applies to smaller batches."""
if chunk is None:
return fn
def ret(inputs):
if detailed_output:
outputs, details_lists = zip(
*[
fn(inputs[i : i + chunk], detailed_output=detailed_output)
for i in range(0, inputs.shape[0], chunk)
]
)
outputs = torch.cat(outputs, 0)
details = {}
for key in details_lists[0].keys():
details[key] = torch.cat([details[key] for details in details_lists], 0)
return outputs, details
else:
return torch.cat(
[
fn(inputs[i : i + chunk], detailed_output=detailed_output)
for i in range(0, inputs.shape[0], chunk)
],
0,
)
return ret