in ppuda/deepnets1m/loader.py [0:0]
def _init_graph(self, A, nodes, layers):
N = A.shape[0]
assert N == len(nodes), (N, len(nodes))
node_feat = torch.zeros(N, 1, dtype=torch.long)
node_info = [[] for _ in range(layers)]
param_shapes = []
for node_ind, node in enumerate(nodes):
name = self.primitives_ext[node[0]]
cell_ind = node[1]
name_op_net = self.op_names_net[node[2]]
sz = None
if not name_op_net.startswith('classifier'):
# fix some inconsistency between names in different versions of our code
if len(name_op_net) == 0:
name_op_net = 'input'
elif name_op_net.endswith('to_out.0.'):
name_op_net += 'weight'
else:
parts = name_op_net.split('.')
for i, s in enumerate(parts):
if s == '_ops' and parts[i + 2] != 'op':
try:
_ = int(parts[i + 2])
parts.insert(i + 2, 'op')
name_op_net = '.'.join(parts)
break
except:
continue
name_op_net = 'cells.%d.%s' % (cell_ind, name_op_net)
stem_p = name_op_net.find('stem')
pos_enc_p = name_op_net.find('pos_enc')
if stem_p >= 0:
name_op_net = name_op_net[stem_p:]
elif pos_enc_p >= 0:
name_op_net = name_op_net[pos_enc_p:]
elif name.find('pool') >= 0:
sz = (1, 1, 3, 3) # assume all pooling layers are 3x3 in our DeepNets-1M
if name.startswith('conv_'):
if name == 'conv_1x1':
sz = (3, 16, 1, 1) # just some random shape for visualization purposes
name = 'conv' # remove kernel size info from the name
elif name.find('conv_') > 0 or name.find('pool_') > 0:
name = name[:len(name) - 4] # remove kernel size info from the name
elif name == 'fc-b':
name = 'bias'
param_shapes.append(sz)
node_feat[node_ind] = self.primitives_dict[name]
if name.find('conv') >= 0 or name.find('pool') >= 0 or name in ['bias', 'bn', 'ln', 'pos_enc']:
node_info[cell_ind].append((node_ind, name_op_net, name, sz, node_ind == len(nodes) - 2, node_ind == len(nodes) - 1))
A = torch.from_numpy(A).long()
A[A > self.virtual_edges] = 0
assert A[np.diag_indices_from(A)].sum() == 0, (
'no loops should be in the graph', A[np.diag_indices_from(A)].sum())
graph = Graph(node_feat=node_feat, node_info=node_info, A=A)
graph._param_shapes = param_shapes
return graph