def load_datasampler()

in dataloading.py [0:0]


def load_datasampler(dataset, batch_size=1, shuffle=True, transform=None):
    """
    Returns a data sampler that yields samples of the specified `dataset` with the
    given `batch_size`. An optional `transform` for samples can also be given.
    If `shuffle` is `True` (default), samples are shuffled.
    """
    assert dataset["features"].size(0) == dataset["targets"].size(0), \
        "number of feature vectors and targets must match"
    if transform is not None:
        assert callable(transform), "transform must be callable if specified"
    N = dataset["features"].size(0)

    # define simple dataset sampler:
    def sampler():
        idx = 0
        perm = torch.randperm(N) if shuffle else torch.range(0, N).long()
        while idx < N:

            # get batch:
            start = idx
            end = min(idx + batch_size, N)
            batch = dataset["features"][perm[start:end], :]

            # apply transform:
            if transform is not None:
                transformed_batch = [
                    transform(batch[n, :]) for n in range(batch.size(0))
                ]
                batch = torch.stack(transformed_batch, dim=0)

            # return sample:
            yield {"features": batch, "targets": dataset["targets"][perm[start:end]]}
            idx += batch_size

    # return sampler:
    return sampler