def iter_data_mpi()

in datasets.py [0:0]


def iter_data_mpi(*args, n_batch, log, shuffle=False, iters=None, seed=None, split_by_rank=True):
    'Take the tensors in *args and iterate through them across mpi ranks if split_by_rank, otherwise iter normally'
    if not args:
        raise ValueError
    size = args[0].shape[0]
    for idx in range(1, len(args)):
        if args[idx].shape[0] != size:
            raise ValueError(f'mismatch in arg {idx}, shape {args[idx].shape[0]} vs {size}')

    if seed:
        np.random.seed(seed)

    if shuffle:
        idxs = np.random.permutation(np.arange(size))
    else:
        idxs = np.arange(size)

    ms = mpisize
    mr = mpirank
    if not split_by_rank:
        ms = 1
        mr = 0

    # Truncate the data if it does not divide evenly
    sequences_per_batch = ms * n_batch
    length = (idxs.size // sequences_per_batch) * sequences_per_batch
    if length != idxs.size:
        log('Truncating {}/{} sequences'.format(idxs.size - length, idxs.size))
    idxs = idxs[:length]
    # Reshape starting indices to K*mpi_size*n_batch
    idxs = idxs.reshape([-1, ms, n_batch])
    log(f'Number of minibatches in this dataset: {len(idxs)}')
    for mb_idx in range(len(idxs)):
        indices = idxs[mb_idx, mr]
        vals = [t[indices] for t in args]
        yield vals
        if iters and mb_idx > iters:
            break