in crlapi/sl/architectures/firefly_vgg/sp/conv.py [0:0]
def spffn_passive_grow(self, split_idx, new_idx):
n_split = split_idx.shape[0] if split_idx is not None else 0
n_new = new_idx.shape[0] if new_idx is not None else 0
C_out, C_in, _, _ = self.module.weight.shape
if self.groups != 1:
C_in = C_out
device = self.get_device()
c1 = C_in-self.ein
if n_split == 0 and n_new == self.ein:
return
if self.groups != 1:
self.groups = c1 + n_split + n_new
C_out = self.groups
new_layer = nn.Conv2d(in_channels=c1+n_split+n_new,
out_channels=C_out,
kernel_size=(self.kh, self.kw),
stride=(self.dh, self.dw),
padding=(self.ph, self.pw),
bias=self.has_bias, groups = self.groups).to(device)
c2 = c1 + n_split
if self.has_bias:
new_layer.bias.data = self.module.bias.data.clone()
if self.groups != 1:
new_layer.weight.data[:c1,:,...] = self.module.weight.data[:c1,:,...].clone()
else:
new_layer.weight.data[:,:c1,...] = self.module.weight.data[:,:c1,...].clone()
if n_split > 0:
if self.groups == 1:
new_layer.weight.data[:,c1:c2,:,:] = self.module.weight.data[:,split_idx,:,:] / 2.
new_layer.weight.data[:,split_idx,...] /= 2.
else:
new_layer.weight.data[c1:c2, :,...] = self.module.weight.data[split_idx, :,...]
if self.groups != 1:
new_bn = nn.BatchNorm2d(C_out).to(device)
out = C_out - n_new - n_split
out1 = out + n_split
out2 = out1 + n_new
new_bn.weight.data[:out] = self.bn.weight.data.clone()[:out]
new_bn.bias.data[:out] = self.bn.bias.data.clone()[:out]
new_bn.running_mean.data[:out] = self.bn.running_mean.data.clone()[:out]
new_bn.running_var.data[:out] = self.bn.running_var.data.clone()[:out]
if n_split > 0:
out1 = out + n_split
new_bn.weight.data[out:out1] = self.bn.weight.data[split_idx]
new_bn.bias.data[out:out1] = self.bn.bias.data[split_idx]
new_bn.running_mean.data[out:out1] = self.bn.running_mean.data[split_idx]
new_bn.running_var.data[out:out1] = self.bn.running_var.data[split_idx]
if n_new > 0:
new_bn.weight.data[out1:out2] = self.bn.weight.data[new_idx]
new_bn.bias.data[out1:out2] = self.bn.bias.data[new_idx]
new_bn.running_mean.data[out1:out2] = self.bn.running_mean.data[new_idx]
new_bn.running_var.data[out1:out2] = self.bn.running_var.data[new_idx]
self.bn = new_bn
if n_new > 0:
if self.groups != 1:
new_layer.weight.data[c2:,:,...] = self.module.weight.data[new_idx, :,...]
else:
new_layer.weight.data[:,c2:,...] = self.module.weight.data[:,new_idx,...]
self.module = new_layer