in ppuda/ghn/layers.py [0:0]
def forward(self, x, params_map, predict_class_layers=True):
shape_ind = self.dummy_ind.repeat(len(x), 1)
self.printed_warning = False
for node_ind in params_map:
sz = params_map[node_ind][0]['sz']
if sz is None:
continue
sz_org = sz
if len(sz) == 1:
sz = (sz[0], 1)
if len(sz) == 2:
sz = (sz[0], sz[1], 1, 1)
assert len(sz) == 4, sz
if not predict_class_layers and params_map[node_ind][1] in ['cls_w', 'cls_b']:
# keep the classification shape as though the GHN is used on the dataset it was trained on
sz = (self.num_classes, *sz[1:])
recognized_sz = 0
for i in range(4):
# if not in the dictionary, then use the maximum shape
if i < 2: # for out/in channel dimensions
shape_ind[node_ind, i] = self.channels_lookup[sz[i] if sz[i] in self.channels_lookup else self.channels[-1]]
if self.debug_level and not self.printed_warning:
recognized_sz += int(sz[i] in self.channels_lookup_training)
else: # for kernel height/width
shape_ind[node_ind, i] = self.spatial_lookup[sz[i] if sz[i] in self.spatial_lookup else self.spatial[-1]]
if self.debug_level and not self.printed_warning:
recognized_sz += int(sz[i] in self.spatial_lookup_training)
if self.debug_level and not self.printed_warning: # print a warning once per architecture
if recognized_sz != 4:
print( 'WARNING: unrecognized shape %s, so the closest shape at index %s will be used instead.' % (
sz_org, ([self.channels[c.item()] if i < 2 else self.spatial[c.item()] for i, c in
enumerate(shape_ind[node_ind])])))
self.printed_warning = True
shape_embed = torch.cat(
(self.embed_channel(shape_ind[:, 0]),
self.embed_channel(shape_ind[:, 1]),
self.embed_spatial(shape_ind[:, 2]),
self.embed_spatial(shape_ind[:, 3])), dim=1)
return x + shape_embed