in experiments/net_generator.py [0:0]
def main():
try:
split = sys.argv[1].lower()
N = int(sys.argv[2])
data_dir = sys.argv[3]
except:
print('\nExample of usage: python deepnets1m/net_generator.py train 1000000 ./data\n')
raise
device = 'cpu' # no much benefit of using cuda
print(split, N, data_dir, device, flush=True)
if not os.path.exists(data_dir):
os.mkdir(data_dir)
set_seed(0 if split == 'val' else 1)
min_steps = 1
medium_steps = 2
max_steps = 4
min_layers = 4
deep_layers_all = np.arange(7, 11)
max_layers = 18
max_params = 10 ** 7
# for 'train', 'val', 'test' we have the same network generator
# for 'wide' we re-use the 'test' split and increase the number of channels when evaluate the model
# for 'bnfree' the generator is the same except that all nets have no BN
# 'predefined' is created on the fly in the deepnets1m.loader
if split == 'deep':
min_layers = 10
deep_layers_all = [18]
max_layers = 36
max_params = 10 ** 8
elif split == 'dense':
min_steps = 2
medium_steps = 6
max_steps = 10
max_params = 10 ** 8
elif split == 'search':
# allow a bit larger networks for search, since larger networks are more likely to have better final results
medium_steps = 3
max_steps = 6
min_layers = 6
deep_layers_all = [10]
max_layers = 20
else:
assert split in ['train', 'val', 'test', 'wide', 'bnfree'], ('unsupported split: %s' % split)
try:
gitcommit = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode('ascii').strip()
print('gitcommit:', gitcommit, flush=True)
except Exception as e:
print(e, flush=True)
start = time.time()
meta_data = {}
meta_data[split] = {'nets': [], 'meta': {}}
op_types, op_types_back, primitives, primitives_back = {}, {}, {}, {}
h5_file = join(data_dir, 'deepnets1m_%s.hdf5' % split)
meta_file = join(data_dir, 'deepnets1m_%s_meta.json' % split)
for f in [h5_file, meta_file]:
if os.path.exists(h5_file):
raise ValueError('File %s already exists. The script will exit now to avoid accidental overwriting of the file.' % f)
with h5py.File(h5_file, 'w') as h5_data:
h5_data.attrs['title'] = 'DeepNets-1M'
group = h5_data.create_group(split)
while len(meta_data[split]['nets']) < N:
layers = int(np.random.randint(min_layers, max_layers + 1)) # number of cells in total (convert to int to make it serializable)
deep_layers = np.random.choice(deep_layers_all) # a threshold to consider a deep network
steps = int(np.random.randint(min_steps, (max_steps if layers <= deep_layers else medium_steps) + 1)) # number of summation nodes in a cell
genotype = sample_genotype(steps=steps,
only_pool=bool(np.random.rand() > 0.5), # True means no trainable layers in the reduction cell
drop_concat=bool(np.random.rand() > 0.5) if steps > 1 else False, # drop some edges from the sum node to the final concat
allow_none=steps > 1, # none is the zero operation to allow sparse connections
allow_transformer=True) # allow to sample msa
ks = int(np.random.choice([3, 5, 7])) # kernel size of the first convolutional layer
is_vit = sum([n[0] == 'msa' for n in genotype.normal + genotype.reduce]) > 0 # Visual Transformer
is_cse = sum([n[0] == 'cse' for n in genotype.normal + genotype.reduce]) > 0 # Model with CSE
has_none = sum([n[0] == 'none' for n in genotype.normal + genotype.reduce]) > 0
is_cse2 = (sum([n[0] == 'cse' for n in genotype.normal]) > 1) or (
sum([n[0] == 'cse' for n in genotype.reduce]) > 1) # training GHNs on networks with CSE often leads to NaN losses, so we will avoid them
is_conv = sum([n[0].find('conv') >= 0 for n in genotype.normal + genotype.reduce]) > 0 # at least one simple conv op
is_conv_large = (sum([n[0] in ['conv_5x5', 'conv_7x7'] for n in genotype.normal]) > 1) or (
sum([n[0] in ['conv_5x5', 'conv_7x7'] for n in genotype.reduce]) > 1) # dense convolutions are memory consuming, so we will avoid them
if (is_cse and not is_conv) or is_cse2 or is_conv_large:
continue # avoid some networks that are difficult to train or too memory consuming
if not (is_cse or is_vit or is_conv):
# print('no lear layers', genotype, flush=True)
continue
C_mult = int(np.random.choice([1, 2]))
# Use 1x1 convolutional layers to match the channel dimensionality at the input of each cell
if steps > 1 or C_mult > 1:
preproc = True
else:
# allow some networks without those 1x1 conv layers for diversity
if split == 'search':
# not sure what's the logic was here, but keep for consistency
preproc = bool((not is_vit and np.random.rand() > 0.2) or (is_vit and np.random.rand() > 0.8))
else:
preproc = bool(not is_vit or np.random.rand() > 0.8)
# Use global pooling most of the time instead of VGG-style head
glob_avg = bool(is_vit or layers > deep_layers or np.random.rand() > 0.1)
if split == 'bnfree':
norm = None
elif split == 'search':
norm = 'bnorm'
else:
# Allow no BN in case of shallow networks and few ops
norm = np.random.choice(['bnorm', None]) if layers <= (min_layers + 1) and steps <= medium_steps else 'bnorm'
stem_type = int(np.random.choice([0, 1])) # style of the stem: simple or ImageNet-style from DARTS
net_args = {'stem_type': stem_type,
'stem_pool': bool(stem_type == 0 and np.random.rand() > 0.5), # add extra pooling layer in case of a simple cell
'norm': norm,
'preproc': preproc,
'fc_layers': int(np.random.randint(1, 3)), # number of fully connected layers before classification
'glob_avg': glob_avg,
'genotype': genotype,
'n_cells': layers,
'ks': ks,
'C_mult': C_mult,
'fc_dim': 256
}
skip = False
graph = None
num_params = {}
for dset_name in ['cifar10', 'imagenet']:
model = Network(C=32, # default number of channels
num_classes=10, # does not matter at this stage
is_imagenet_input=dset_name=='imagenet',
**net_args).to(device)
c, n = capacity(model)
num_params[dset_name] = n
if n > max_params:
print('too large architecture: %.2f M params \n' % (float(n) / 10 ** 6), flush=True)
skip = True
break
if dset_name == 'cifar10':
try:
graph = Graph(model, ve_cutoff=250, list_all_nodes=True)
except Exception as e:
print('\n%d: unable to construct the graph: it is likely to be disconnected' % len(meta_data[split]['nets']),
'has_none={}, genotype={}'.
format(has_none, net_args['genotype']), flush=True)
print(e, '\n')
assert has_none # to be disconnected it has to have none nodes
skip = True
break
if skip:
continue
assert layers == len(graph.node_info), (layers, len(graph.node_info))
cell_ind, n_nodes, nodes_array = 0, 0, []
for j in range(layers):
n_nodes += len(graph.node_info[j])
for node in graph.node_info[j]:
param_name, name, sz = node[1:4]
cell_ind_ = get_cell_ind(param_name, layers)
if cell_ind_ is not None:
cell_ind = cell_ind_
assert cell_ind == j, (cell_ind, j, node)
if name == 'conv' and (len(sz) == 2 or sz[2] == sz[3] == 1):
name = 'conv_1x1'
if name not in primitives:
ind = len(primitives)
primitives[name] = ind
primitives_back[ind] = name
if param_name.startswith('cells.'):
# remove cells.x. prefix
pos1 = param_name.find('.')
assert param_name[pos1 + 1:].find('.') >= 0, node
pos2 = pos1 + param_name[pos1 + 1:].find('.') + 2
param_name = param_name[pos2:]
if param_name not in op_types:
ind = len(op_types)
op_types[param_name] = ind
op_types_back[ind] = param_name
nodes_array.append([primitives[name], cell_ind, op_types[param_name]])
nodes_array = np.array(nodes_array).astype(np.uint16)
A = graph._Adj.cpu().numpy().astype(np.uint8)
assert nodes_array.shape[0] == n_nodes == A.shape[0] == graph.n_nodes, (nodes_array.shape, n_nodes, A.shape, graph.n_nodes)
idx = len(meta_data[split]['nets'])
group.create_dataset(str(idx) + '/adj', data=A)
group.create_dataset(str(idx) + '/nodes', data=nodes_array)
net_args['num_nodes'] = int(A.shape[0])
net_args['num_params'] = num_params
net_args['genotype'] = to_dict(net_args['genotype'])
meta_data[split]['nets'].append(net_args)
meta_data[split]['meta']['primitives_ext'] = primitives_back
meta_data[split]['meta']['unique_op_names'] = op_types_back
if (idx + 1) % 100 == 0 or idx >= N - 1:
all_n_nodes = np.array([net['num_nodes'] for net in meta_data[split]['nets']])
all_n_params = np.array([net['num_params']['cifar10'] for net in meta_data[split]['nets']]) / 10 ** 6
print('N={} nets created: \t {}-{} nodes (mean\u00B1std: {:.1f}\u00B1{:.1f}) '
'\t {:.2f}-{:.2f} params (M) (mean\u00B1std: {:.2f}\u00B1{:.2f}) '
'\t {} unique primitives, {} unique param names '
'\t total time={:.2f} sec'.format(
idx + 1,
all_n_nodes.min(),
all_n_nodes.max(),
all_n_nodes.mean(),
all_n_nodes.std(),
all_n_params.min(),
all_n_params.max(),
all_n_params.mean(),
all_n_params.std(),
len(primitives_back),
len(op_types_back),
time.time() - start),
flush=True)
with open(meta_file, 'w') as f:
json.dump(meta_data, f)
print('saved to %s and %s' % (h5_file, meta_file))
print('\ndone')
if split == 'bnfree':
merge_eval(data_dir) # assume bnfree was generated the last