in ppuda/deepnets1m/loader.py [0:0]
def __init__(self,
split='train',
nets_dir='./data',
virtual_edges=1,
num_ch=(32, 128),
fc_dim=(64, 512),
num_nets=None,
arch=None,
large_images=False):
super(DeepNets1M, self).__init__()
self.split = split
assert self.split in ['train', 'val', 'test', 'search',
'wide', 'deep', 'dense', 'bnfree', 'predefined'],\
('invalid split', self.split)
self.is_train = self.split == 'train'
self.virtual_edges = virtual_edges
assert self.virtual_edges >= 1, virtual_edges
if self.is_train:
# During training we will randomly sample values from this range
self.num_ch = torch.arange(num_ch[0], num_ch[1] + 1, 16)
self.fc_dim = torch.arange(fc_dim[0], fc_dim[1] + 1, 64)
self.large_images = large_images # this affects some network parameters
# Load one of the splits
print('\nloading %s nets...' % self.split.upper())
if self.split == 'predefined':
self.nets = self._get_predefined()
n_all = len(self.nets)
self.nodes = torch.tensor([net.n_nodes for net in self.nets])
else:
self.h5_data = None
self.h5_file = os.path.join(nets_dir, 'deepnets1m_%s.hdf5' % (split if split in ['train', 'search'] else 'eval'))
self.primitives_dict = {op: i for i, op in enumerate(PRIMITIVES_DEEPNETS1M)}
assert os.path.exists(self.h5_file), ('%s not found' % self.h5_file)
# Load meta data to convert dataset files to graphs later in the _init_graph function
to_int_dict = lambda d: { int(k): v for k, v in d.items() }
with open(self.h5_file.replace('.hdf5', '_meta.json'), 'r') as f:
meta = json.load(f)[split]
n_all = len(meta['nets'])
self.nets = meta['nets'][:n_all if num_nets is None else num_nets]
self.primitives_ext = to_int_dict(meta['meta']['primitives_ext'])
self.op_names_net = to_int_dict(meta['meta']['unique_op_names'])
self.h5_idx = [ arch ] if arch is not None else None
self.nodes = torch.tensor([net['num_nodes'] for net in self.nets])
if arch is not None:
arch = int(arch)
assert arch >= 0 and arch < len(self.nets), \
'architecture with index={} is not available in the {} split with {} architectures in total'.format(arch, split, len(self.nets))
self.nets = [self.nets[arch]]
self.nodes = torch.tensor([self.nodes[arch]])
print('loaded {}/{} nets with {}-{} nodes (mean\u00B1std: {:.1f}\u00B1{:.1f})'.
format(len(self.nets),n_all,
self.nodes.min().item(),
self.nodes.max().item(),
self.nodes.float().mean().item(),
self.nodes.float().std().item()))