def forward()

in ppuda/ghn/layers.py [0:0]


    def forward(self, x, params_map, predict_class_layers=True):
        shape_ind = self.dummy_ind.repeat(len(x), 1)

        self.printed_warning = False
        for node_ind in params_map:
            sz = params_map[node_ind][0]['sz']
            if sz is None:
                continue

            sz_org = sz
            if len(sz) == 1:
                sz = (sz[0], 1)
            if len(sz) == 2:
                sz = (sz[0], sz[1], 1, 1)
            assert len(sz) == 4, sz

            if not predict_class_layers and params_map[node_ind][1] in ['cls_w', 'cls_b']:
                # keep the classification shape as though the GHN is used on the dataset it was trained on
                sz = (self.num_classes, *sz[1:])

            recognized_sz = 0
            for i in range(4):
                # if not in the dictionary, then use the maximum shape
                if i < 2:  # for out/in channel dimensions
                    shape_ind[node_ind, i] = self.channels_lookup[sz[i] if sz[i] in self.channels_lookup else self.channels[-1]]
                    if self.debug_level and not self.printed_warning:
                        recognized_sz += int(sz[i] in self.channels_lookup_training)
                else:  # for kernel height/width
                    shape_ind[node_ind, i] = self.spatial_lookup[sz[i] if sz[i] in self.spatial_lookup else self.spatial[-1]]
                    if self.debug_level and not self.printed_warning:
                        recognized_sz += int(sz[i] in self.spatial_lookup_training)

            if self.debug_level and not self.printed_warning:  # print a warning once per architecture
                if recognized_sz != 4:
                    print( 'WARNING: unrecognized shape %s, so the closest shape at index %s will be used instead.' % (
                        sz_org, ([self.channels[c.item()] if i < 2 else self.spatial[c.item()] for i, c in
                                  enumerate(shape_ind[node_ind])])))
                    self.printed_warning = True

        shape_embed = torch.cat(
            (self.embed_channel(shape_ind[:, 0]),
             self.embed_channel(shape_ind[:, 1]),
             self.embed_spatial(shape_ind[:, 2]),
             self.embed_spatial(shape_ind[:, 3])), dim=1)

        return x + shape_embed