in crlapi/sl/architectures/sp_vgg.py [0:0]
def grow(self, valid_loader, **args):
x = torch.FloatTensor(64, 3, 32, 32).normal_()
new_layers = []
for i, layer in enumerate(self.net):
if isinstance(layer, SubnetConv):
# input size
in_c = 3 if i == 0 else last_output_channels
# output size
out_c = layer.out_channels + (self.grow_n_units if i < len(self.net) - 2 else 0)
# what is the minimal score to be selected ?
max_val = percentile(layer.scores.abs(), layer.prune_rate)
min_val = layer.scores.abs().min().item()
# init new layer
new_layer = SubnetConv(in_c, out_c, kernel_size=layer.kernel_size, padding=layer.padding)
new_layer.scores.data.uniform_(min_val, max_val)
# adjust the prune rate so that the same amount of points get selected
new_layer.prune_rate = 1 - (1 - layer.prune_rate) * layer.weight.numel() / new_layer.weight.numel()
# copy the old params
a, b, c, d = layer.scores.size()
new_layer.weight[:a, :b, :c, :d].data.copy_(layer.weight.data)
new_layer.scores[:a, :b, :c, :d].data.copy_(layer.scores.data)
new_layer.bias.data.fill_(0)
new_layer.bias[:a].data.copy_(layer.bias)
last_output_channels = out_c
new_layers += [new_layer]
new_sub = torch.where(new_layer.clamped_scores < percentile(new_layer.clamped_scores, new_layer.prune_rate), new_layer.zeros, new_layer.ones)
import pdb
# assert torch.allclose(layer(x[:, :b]), new_layer(x)[:, :a]), pdb.set_trace()
elif isinstance(layer, nn.BatchNorm2d):
new_bn = nn.BatchNorm2d(last_output_channels, affine=False)
c = layer.running_mean.size(0)
new_bn.running_mean[:c].data.copy_(layer.running_mean.data)
new_bn.running_var[:c].data.copy_(layer.running_var.data)
new_layers += [new_bn]
new_bn.training = layer.training
# assert torch.allclose(layer(x[:, :c]), new_bn(x)[:, :c], atol=1e-7)
else:
new_layers += [copy.deepcopy(layer)]
x = new_layers[-1](x)
net = nn.Sequential(*new_layers)
copy_self = copy.deepcopy(self)
copy_self.net = net
print(net)
return copy_self