def spffn_active_grow()

in crlapi/sl/architectures/firefly_vgg/sp/conv.py [0:0]


    def spffn_active_grow(self, threshold):
        idx = torch.nonzero((self.w <= threshold).float()).view(-1)

        C_out, C_in, kh, kw = self.module.weight.shape
        c1 = C_out - self.eout
        c3 = C_in - self.ein

        split_idx, new_idx = idx[idx < c1], idx[idx >= c1]
        n_split = split_idx.shape[0]
        n_new = new_idx.shape[0]

        c2 = c1 + n_split

        device = self.get_device()
        delta = self.v[split_idx, ...]

        new_layer = nn.Conv2d(in_channels=C_in,
                              out_channels=c1+n_split+n_new,
                              kernel_size=(self.kh, self.kw),
                              stride=(self.dh, self.dw),
                              padding=(self.ph, self.pw),
                              bias=self.has_bias).to(device)

        # for current layer [--original--c1--split_new--c2--add new--]
        old_W = self.module.weight.data.clone()

        try:
            old_W[:, C_in - self.ein:, :, :] = self.vni.clone()
        except:
            pass
        try:
            old_W[C_out-self.eout:, :C_in-self.ein, :, :] = self.vno.clone()
        except:
            pass

        new_layer.weight.data[:c1, ...] = old_W[:c1,...]

        if n_split > 0:
            new_layer.weight.data[c1:c2, ...] = old_W[split_idx, ...]
            new_layer.weight.data[split_idx,:c3,...] += delta
            new_layer.weight.data[c1:c2:,:c3,...] -= delta

        if n_new > 0:
            new_layer.weight.data[c2:, ...] = old_W[new_idx, ...]

        if self.has_bias:
            old_b = self.module.bias.data.clone()
            new_layer.bias.data[:c1, ...] = old_b[:c1,...].clone()
            if n_split > 0:
                new_layer.bias.data[c1:c2, ...] = old_b[split_idx]
            if n_new > 0:
                new_layer.bias.data[c2:,...] = 0.

        self.module = new_layer

        # for batchnorm layer
        if self.has_bn:
            new_bn = nn.BatchNorm2d(c1+n_split+n_new).to(device)
            new_bn.weight.data[:c1] = self.bn.weight.data[:c1].clone()
            new_bn.bias.data[:c1] = self.bn.bias.data[:c1].clone()
            new_bn.running_mean.data[:c1] = self.bn.running_mean.data[:c1].clone()
            new_bn.running_var.data[:c1] = self.bn.running_var.data[:c1].clone()

            if n_split > 0:
                new_bn.weight.data[c1:c2] = self.bn.weight.data[split_idx]
                new_bn.bias.data[c1:c2] = self.bn.bias.data[split_idx]
                new_bn.running_mean.data[c1:c2] = self.bn.running_mean.data[split_idx]
                new_bn.running_var.data[c1:c2] = self.bn.running_var.data[split_idx]

            if n_new > 0:
                new_bn.weight.data[c2:] = self.bn.weight.data[new_idx]
                new_bn.bias.data[c2:] = self.bn.bias.data[new_idx]
                new_bn.running_mean.data[c2:] = self.bn.running_mean.data[new_idx]
                new_bn.running_var.data[c2:] = self.bn.running_var.data[new_idx]
            self.bn = new_bn

        return n_split+n_new, split_idx, new_idx