in dataloading.py [0:0]
def subsample(data, num_samples, random=True):
"""
Subsamples the specified `data` to contain `num_samples` samples. Set
`random` to `False` to not select samples randomly but only pick top ones.
"""
# assertions:
assert isinstance(data, dict), "data must be a dict"
assert "targets" in data, "data dict does not have targets field"
dataset_size = data["targets"].nelement()
assert num_samples > 0, "num_samples must be positive integer value"
assert num_samples <= dataset_size, "num_samples cannot exceed data size"
# subsample data:
if random:
permutation = torch.randperm(dataset_size)
for key, value in data.items():
if random:
data[key] = value.index_select(0, permutation[:num_samples])
else:
data[key] = value.narrow(0, 0, num_samples).contiguous()
return data