in dataloading.py [0:0]
def load_datasampler(dataset, batch_size=1, shuffle=True, transform=None):
"""
Returns a data sampler that yields samples of the specified `dataset` with the
given `batch_size`. An optional `transform` for samples can also be given.
If `shuffle` is `True` (default), samples are shuffled.
"""
assert dataset["features"].size(0) == dataset["targets"].size(0), \
"number of feature vectors and targets must match"
if transform is not None:
assert callable(transform), "transform must be callable if specified"
N = dataset["features"].size(0)
# define simple dataset sampler:
def sampler():
idx = 0
perm = torch.randperm(N) if shuffle else torch.range(0, N).long()
while idx < N:
# get batch:
start = idx
end = min(idx + batch_size, N)
batch = dataset["features"][perm[start:end], :]
# apply transform:
if transform is not None:
transformed_batch = [
transform(batch[n, :]) for n in range(batch.size(0))
]
batch = torch.stack(transformed_batch, dim=0)
# return sample:
yield {"features": batch, "targets": dataset["targets"][perm[start:end]]}
idx += batch_size
# return sampler:
return sampler