in crlapi/sl/architectures/firefly_vgg/sp/conv.py [0:0]
def spffn_active_grow(self, threshold):
idx = torch.nonzero((self.w <= threshold).float()).view(-1)
C_out, C_in, kh, kw = self.module.weight.shape
c1 = C_out - self.eout
c3 = C_in - self.ein
split_idx, new_idx = idx[idx < c1], idx[idx >= c1]
n_split = split_idx.shape[0]
n_new = new_idx.shape[0]
c2 = c1 + n_split
device = self.get_device()
delta = self.v[split_idx, ...]
new_layer = nn.Conv2d(in_channels=C_in,
out_channels=c1+n_split+n_new,
kernel_size=(self.kh, self.kw),
stride=(self.dh, self.dw),
padding=(self.ph, self.pw),
bias=self.has_bias).to(device)
# for current layer [--original--c1--split_new--c2--add new--]
old_W = self.module.weight.data.clone()
try:
old_W[:, C_in - self.ein:, :, :] = self.vni.clone()
except:
pass
try:
old_W[C_out-self.eout:, :C_in-self.ein, :, :] = self.vno.clone()
except:
pass
new_layer.weight.data[:c1, ...] = old_W[:c1,...]
if n_split > 0:
new_layer.weight.data[c1:c2, ...] = old_W[split_idx, ...]
new_layer.weight.data[split_idx,:c3,...] += delta
new_layer.weight.data[c1:c2:,:c3,...] -= delta
if n_new > 0:
new_layer.weight.data[c2:, ...] = old_W[new_idx, ...]
if self.has_bias:
old_b = self.module.bias.data.clone()
new_layer.bias.data[:c1, ...] = old_b[:c1,...].clone()
if n_split > 0:
new_layer.bias.data[c1:c2, ...] = old_b[split_idx]
if n_new > 0:
new_layer.bias.data[c2:,...] = 0.
self.module = new_layer
# for batchnorm layer
if self.has_bn:
new_bn = nn.BatchNorm2d(c1+n_split+n_new).to(device)
new_bn.weight.data[:c1] = self.bn.weight.data[:c1].clone()
new_bn.bias.data[:c1] = self.bn.bias.data[:c1].clone()
new_bn.running_mean.data[:c1] = self.bn.running_mean.data[:c1].clone()
new_bn.running_var.data[:c1] = self.bn.running_var.data[:c1].clone()
if n_split > 0:
new_bn.weight.data[c1:c2] = self.bn.weight.data[split_idx]
new_bn.bias.data[c1:c2] = self.bn.bias.data[split_idx]
new_bn.running_mean.data[c1:c2] = self.bn.running_mean.data[split_idx]
new_bn.running_var.data[c1:c2] = self.bn.running_var.data[split_idx]
if n_new > 0:
new_bn.weight.data[c2:] = self.bn.weight.data[new_idx]
new_bn.bias.data[c2:] = self.bn.bias.data[new_idx]
new_bn.running_mean.data[c2:] = self.bn.running_mean.data[new_idx]
new_bn.running_var.data[c2:] = self.bn.running_var.data[new_idx]
self.bn = new_bn
return n_split+n_new, split_idx, new_idx