def spffn_passive_grow()

in crlapi/sl/architectures/firefly_vgg/sp/conv.py [0:0]


    def spffn_passive_grow(self, split_idx, new_idx):
        n_split = split_idx.shape[0] if split_idx is not None else 0
        n_new = new_idx.shape[0] if new_idx is not None else 0

        C_out, C_in, _, _ = self.module.weight.shape
        if self.groups != 1:
            C_in = C_out
        device = self.get_device()
        c1 = C_in-self.ein
        if n_split == 0 and n_new == self.ein:
            return

        if self.groups != 1:
            self.groups = c1 + n_split + n_new
            C_out = self.groups
        new_layer = nn.Conv2d(in_channels=c1+n_split+n_new,
                              out_channels=C_out,
                              kernel_size=(self.kh, self.kw),
                              stride=(self.dh, self.dw),
                              padding=(self.ph, self.pw),
                              bias=self.has_bias, groups = self.groups).to(device)

        c2 = c1 + n_split

        if self.has_bias:
            new_layer.bias.data = self.module.bias.data.clone()

        if self.groups != 1:
            new_layer.weight.data[:c1,:,...] = self.module.weight.data[:c1,:,...].clone()
        else:
            new_layer.weight.data[:,:c1,...] = self.module.weight.data[:,:c1,...].clone()

        if n_split > 0:
            if self.groups == 1:
                new_layer.weight.data[:,c1:c2,:,:] = self.module.weight.data[:,split_idx,:,:] / 2.
                new_layer.weight.data[:,split_idx,...] /= 2.
            else:
                new_layer.weight.data[c1:c2, :,...] = self.module.weight.data[split_idx, :,...]
        if self.groups != 1:
            new_bn = nn.BatchNorm2d(C_out).to(device)
            out = C_out - n_new - n_split
            out1 = out + n_split
            out2 = out1 + n_new
            new_bn.weight.data[:out] = self.bn.weight.data.clone()[:out]
            new_bn.bias.data[:out] = self.bn.bias.data.clone()[:out]
            new_bn.running_mean.data[:out] = self.bn.running_mean.data.clone()[:out]
            new_bn.running_var.data[:out] = self.bn.running_var.data.clone()[:out]
            if n_split > 0:
                out1 = out + n_split
                new_bn.weight.data[out:out1] = self.bn.weight.data[split_idx]
                new_bn.bias.data[out:out1] = self.bn.bias.data[split_idx]
                new_bn.running_mean.data[out:out1] = self.bn.running_mean.data[split_idx]
                new_bn.running_var.data[out:out1] = self.bn.running_var.data[split_idx]

            if n_new > 0:
                new_bn.weight.data[out1:out2] = self.bn.weight.data[new_idx]
                new_bn.bias.data[out1:out2] = self.bn.bias.data[new_idx]
                new_bn.running_mean.data[out1:out2] = self.bn.running_mean.data[new_idx]
                new_bn.running_var.data[out1:out2] = self.bn.running_var.data[new_idx]
            self.bn = new_bn
        if n_new > 0:
            if self.groups != 1:
                new_layer.weight.data[c2:,:,...] = self.module.weight.data[new_idx, :,...]
            else:
                new_layer.weight.data[:,c2:,...] = self.module.weight.data[:,new_idx,...]

        self.module = new_layer