def grow()

in crlapi/sl/architectures/sp_vgg.py [0:0]


    def grow(self, valid_loader, **args):

        x = torch.FloatTensor(64, 3, 32, 32).normal_()
        new_layers = []
        for i, layer in enumerate(self.net):
            if isinstance(layer, SubnetConv):
                # input size
                in_c = 3 if i == 0 else last_output_channels

                # output size
                out_c = layer.out_channels + (self.grow_n_units if i < len(self.net) - 2 else 0)

                # what is the minimal score to be selected ?
                max_val = percentile(layer.scores.abs(), layer.prune_rate)
                min_val = layer.scores.abs().min().item()
                # init new layer
                new_layer = SubnetConv(in_c, out_c, kernel_size=layer.kernel_size, padding=layer.padding)
                new_layer.scores.data.uniform_(min_val, max_val)

                # adjust the prune rate so that the same amount of points get selected
                new_layer.prune_rate = 1 - (1 - layer.prune_rate) * layer.weight.numel() / new_layer.weight.numel()

                # copy the old params
                a, b, c, d = layer.scores.size()
                new_layer.weight[:a, :b, :c, :d].data.copy_(layer.weight.data)
                new_layer.scores[:a, :b, :c, :d].data.copy_(layer.scores.data)
                new_layer.bias.data.fill_(0)
                new_layer.bias[:a].data.copy_(layer.bias)
                last_output_channels = out_c
                new_layers += [new_layer]

                new_sub = torch.where(new_layer.clamped_scores < percentile(new_layer.clamped_scores, new_layer.prune_rate), new_layer.zeros, new_layer.ones)
                import pdb
                # assert torch.allclose(layer(x[:, :b]), new_layer(x)[:, :a]), pdb.set_trace()

            elif isinstance(layer, nn.BatchNorm2d):
                new_bn = nn.BatchNorm2d(last_output_channels, affine=False)
                c = layer.running_mean.size(0)
                new_bn.running_mean[:c].data.copy_(layer.running_mean.data)
                new_bn.running_var[:c].data.copy_(layer.running_var.data)
                new_layers += [new_bn]

                new_bn.training = layer.training

                # assert torch.allclose(layer(x[:, :c]), new_bn(x)[:, :c], atol=1e-7)
            else:
                new_layers += [copy.deepcopy(layer)]

            x = new_layers[-1](x)

        net  =  nn.Sequential(*new_layers)

        copy_self = copy.deepcopy(self)
        copy_self.net = net
        print(net)

        return copy_self