in datasets.py [0:0]
def iter_data_mpi(*args, n_batch, log, shuffle=False, iters=None, seed=None, split_by_rank=True):
'Take the tensors in *args and iterate through them across mpi ranks if split_by_rank, otherwise iter normally'
if not args:
raise ValueError
size = args[0].shape[0]
for idx in range(1, len(args)):
if args[idx].shape[0] != size:
raise ValueError(f'mismatch in arg {idx}, shape {args[idx].shape[0]} vs {size}')
if seed:
np.random.seed(seed)
if shuffle:
idxs = np.random.permutation(np.arange(size))
else:
idxs = np.arange(size)
ms = mpisize
mr = mpirank
if not split_by_rank:
ms = 1
mr = 0
# Truncate the data if it does not divide evenly
sequences_per_batch = ms * n_batch
length = (idxs.size // sequences_per_batch) * sequences_per_batch
if length != idxs.size:
log('Truncating {}/{} sequences'.format(idxs.size - length, idxs.size))
idxs = idxs[:length]
# Reshape starting indices to K*mpi_size*n_batch
idxs = idxs.reshape([-1, ms, n_batch])
log(f'Number of minibatches in this dataset: {len(idxs)}')
for mb_idx in range(len(idxs)):
indices = idxs[mb_idx, mr]
vals = [t[indices] for t in args]
yield vals
if iters and mb_idx > iters:
break