def __init__()

in ppuda/deepnets1m/loader.py [0:0]


    def __init__(self,
                 split='train',
                 nets_dir='./data',
                 virtual_edges=1,
                 num_ch=(32, 128),
                 fc_dim=(64, 512),
                 num_nets=None,
                 arch=None,
                 large_images=False):
        super(DeepNets1M, self).__init__()

        self.split = split
        assert self.split in ['train', 'val', 'test', 'search',
                              'wide', 'deep', 'dense', 'bnfree', 'predefined'],\
            ('invalid split', self.split)
        self.is_train = self.split == 'train'

        self.virtual_edges = virtual_edges
        assert self.virtual_edges >= 1, virtual_edges

        if self.is_train:
            # During training we will randomly sample values from this range
            self.num_ch = torch.arange(num_ch[0], num_ch[1] + 1, 16)
            self.fc_dim = torch.arange(fc_dim[0], fc_dim[1] + 1, 64)

        self.large_images = large_images  # this affects some network parameters

        # Load one of the splits
        print('\nloading %s nets...' % self.split.upper())

        if self.split == 'predefined':
            self.nets = self._get_predefined()
            n_all = len(self.nets)
            self.nodes = torch.tensor([net.n_nodes for net in self.nets])
        else:
            self.h5_data = None
            self.h5_file = os.path.join(nets_dir, 'deepnets1m_%s.hdf5' % (split if split in ['train', 'search'] else 'eval'))

            self.primitives_dict = {op: i for i, op in enumerate(PRIMITIVES_DEEPNETS1M)}
            assert os.path.exists(self.h5_file), ('%s not found' % self.h5_file)

            # Load meta data to convert dataset files to graphs later in the _init_graph function
            to_int_dict = lambda d: { int(k): v for k, v in d.items() }
            with open(self.h5_file.replace('.hdf5', '_meta.json'), 'r') as f:
                meta = json.load(f)[split]
                n_all = len(meta['nets'])
                self.nets = meta['nets'][:n_all if num_nets is None else num_nets]
                self.primitives_ext =  to_int_dict(meta['meta']['primitives_ext'])
                self.op_names_net = to_int_dict(meta['meta']['unique_op_names'])
            self.h5_idx = [ arch ] if arch is not None else None
            self.nodes = torch.tensor([net['num_nodes'] for net in self.nets])

        if arch is not None:
            arch = int(arch)
            assert arch >= 0 and arch < len(self.nets), \
                'architecture with index={} is not available in the {} split with {} architectures in total'.format(arch, split, len(self.nets))
            self.nets = [self.nets[arch]]
            self.nodes = torch.tensor([self.nodes[arch]])

        print('loaded {}/{} nets with {}-{} nodes (mean\u00B1std: {:.1f}\u00B1{:.1f})'.
              format(len(self.nets),n_all,
                     self.nodes.min().item(),
                     self.nodes.max().item(),
                     self.nodes.float().mean().item(),
                     self.nodes.float().std().item()))