in torchnet/dataset/splitdataset.py [0:0]
def __init__(self, dataset, partitions, initial_partition=None):
super(SplitDataset, self).__init__()
self.dataset = dataset
self.partitions = partitions
# A few assertions
assert isinstance(partitions, dict), 'partitions must be a dict'
assert len(partitions) >= 2, \
'SplitDataset should have at least two partitions'
assert min(partitions.values()) >= 0, \
'partition sizes cannot be negative'
assert max(partitions.values()) > 0, 'all partitions cannot be empty'
self.partition_names = sorted(list(self.partitions.keys()))
self.partition_index = {partition: i for i, partition in
enumerate(self.partition_names)}
self.partition_sizes = [self.partitions[parition] for parition in
self.partition_names]
# if partition sizes are fractions, convert to sizes:
if sum(self.partition_sizes) <= 1:
self.partition_sizes = [round(x * len(dataset)) for x in
self.partition_sizes]
else:
for x in self.partition_sizes:
assert x == int(x), ('partition sizes should be integer'
' numbers, or sum up to <= 1 ')
self.partition_cum_sizes = np.cumsum(self.partition_sizes)
if initial_partition is not None:
self.select(initial_partition)