in crlapi/sl/architectures/firefly_vgg/sp/conv.py [0:0]
def random_split(self, C_new):
if C_new == 0:
return 0, None
C_out, C_in, kh, kw = self.module.weight.shape
idx = np.random.choice(C_out, C_new)
device = self.get_device()
delta1 = F.normalize(torch.randn(C_new, C_in, kh, kw).to(device), p=2, dim=-1)
delta2 = F.normalize(torch.randn(C_new, C_in, kh, kw).to(device), p=2, dim=-1)
delta1 = delta1 * 1e-2
delta2 = delta2 * 1e-2
idx = torch.LongTensor(idx).to(device)
new_layer = nn.Conv2d(in_channels=C_in,
out_channels=C_out+C_new,
kernel_size=(self.kh, self.kw),
stride=(self.dh, self.dw),
padding=(self.ph, self.pw),
bias=self.has_bias).to(device)
# for current layer
new_layer.weight.data[:C_out, ...] = self.module.weight.data.clone()
new_layer.weight.data[C_out:, ...] = self.module.weight.data[idx, ...]
new_layer.weight.data[idx, ...] += delta1
new_layer.weight.data[C_out:, ...] -= delta2
if self.has_bias:
new_layer.bias.data[:C_out, ...] = self.module.bias.data.clone()
new_layer.bias.data[C_out:, ...] = self.module.bias.data[idx]
self.module = new_layer
# for batchnorm layer
if self.has_bn:
new_bn = nn.BatchNorm2d(C_out+C_new).to(device)
new_bn.weight.data[:C_out] = self.bn.weight.data.clone()
new_bn.weight.data[C_out:] = self.bn.weight.data[idx]
new_bn.bias.data[:C_out] = self.bn.bias.data.clone()
new_bn.bias.data[C_out:] = self.bn.bias.data[idx]
new_bn.running_mean.data[:C_out] = self.bn.running_mean.data.clone()
new_bn.running_mean.data[C_out:] = self.bn.running_mean.data[idx]
new_bn.running_var.data[:C_out] = self.bn.running_var.data.clone()
new_bn.running_var.data[C_out:] = self.bn.running_var.data[idx]
self.bn = new_bn
return C_new, idx