in ppuda/deepnets1m/graph.py [0:0]
def _construct_features(self):
r"""
Construct pytorch tensor features for nodes and edges.
:return:
"""
self.n_nodes = len(self._nodes)
self.node_feat = torch.zeros(self.n_nodes, 1, dtype=torch.long)
self.node_info = [[] for _ in range(self.n_cells)]
self._param_shapes = []
primitives_dict = {op: i for i, op in enumerate(PRIMITIVES_DEEPNETS1M)}
n_glob_avg = 0
cell_ind = 0
for node_ind, node in enumerate(self._nodes):
param_name = node['param_name']
cell_ind_ = get_cell_ind(param_name, self.n_cells)
if cell_ind_ is not None:
cell_ind = cell_ind_
pos_stem = param_name.find('stem')
pos_pos = param_name.find('pos_enc')
if pos_stem >= 0:
param_name = param_name[pos_stem:]
elif pos_pos >= 0:
param_name = param_name[pos_pos:]
if node['module'] is not None:
# Preprocess param_name to be consistent with the DeepNets dataset
parts = param_name.split('.')
for i, s in enumerate(parts):
if s == '_ops' and parts[i + 2] != 'op':
try:
_ = int(parts[i + 2])
parts.insert(i + 2, 'op')
param_name = '.'.join(parts)
break
except:
continue
name = MODULES[type(node['module'])](node['module'], param_name)
else:
ind = param_name.find('Backward')
name = MODULES[param_name[:len(param_name) if ind == -1 else ind]]
n_glob_avg += int(name == 'glob_avg')
if self.n_cells > 1:
# Add cell id to the names of pooling layers, so that they will be matched with proper modules in Network
if param_name.startswith('MaxPool') or param_name.startswith('AvgPool'):
param_name = 'cells.{}.'.format(cell_ind) + name
sz = None
attrs = node['attrs']
if isinstance(attrs, dict):
if 'size' in attrs:
sz = attrs['size']
elif name.find('pool') >= 0:
if 'kernel_size' in attrs:
sz = (1, 1, *[int(a.strip('(').strip(')').strip(' ')) for a in attrs['kernel_size'].split(',')])
else:
# Pytorch 1.9+ is required to correctly extract pooling attributes, otherwise the default pooling size of 3 is used
sz = (1, 1, 3, 3)
elif node['module'] is not None:
sz = (node['module'].weight if param_name.find('weight') >= 0 else node['module'].bias).shape
self._param_shapes.append(sz)
self.node_feat[node_ind] = primitives_dict[name]
if node['module'] is not None or name.find('pool') >= 0 or self._list_all_nodes:
self.node_info[cell_ind].append(
(node_ind,
param_name if node['module'] is not None else name,
name,
sz,
node_ind == len(self._nodes) - 2,
node_ind == len(self._nodes) - 1))
if n_glob_avg > 1:
print(
'\nWARNING: n_glob_avg should be 0 or 1 in most architectures, but is %d in this architecture\n' % n_glob_avg)
self._Adj = torch.tensor(self._Adj, dtype=torch.long)
ind = torch.nonzero(self._Adj) # rows, cols
self.edges = torch.cat((ind, self._Adj[ind[:, 0], ind[:, 1]].view(-1, 1)), dim=1)
return