def subsample()

in dataloading.py [0:0]


def subsample(data, num_samples, random=True):
    """
    Subsamples the specified `data` to contain `num_samples` samples. Set
    `random` to `False` to not select samples randomly but only pick top ones.
    """

    # assertions:
    assert isinstance(data, dict), "data must be a dict"
    assert "targets" in data, "data dict does not have targets field"
    dataset_size = data["targets"].nelement()
    assert num_samples > 0, "num_samples must be positive integer value"
    assert num_samples <= dataset_size, "num_samples cannot exceed data size"

    # subsample data:
    if random:
        permutation = torch.randperm(dataset_size)
    for key, value in data.items():
        if random:
            data[key] = value.index_select(0, permutation[:num_samples])
        else:
            data[key] = value.narrow(0, 0, num_samples).contiguous()
    return data