def random_split()

in crlapi/sl/architectures/firefly_vgg/sp/conv.py [0:0]


    def random_split(self, C_new):
        if C_new == 0:
            return 0, None
        C_out, C_in, kh, kw = self.module.weight.shape
        idx = np.random.choice(C_out, C_new)

        device = self.get_device()
        delta1 = F.normalize(torch.randn(C_new, C_in, kh, kw).to(device), p=2, dim=-1)
        delta2 = F.normalize(torch.randn(C_new, C_in, kh, kw).to(device), p=2, dim=-1)

        delta1 = delta1 * 1e-2
        delta2 = delta2 * 1e-2

        idx = torch.LongTensor(idx).to(device)

        new_layer = nn.Conv2d(in_channels=C_in,
                              out_channels=C_out+C_new,
                              kernel_size=(self.kh, self.kw),
                              stride=(self.dh, self.dw),
                              padding=(self.ph, self.pw),
                              bias=self.has_bias).to(device)

        # for current layer
        new_layer.weight.data[:C_out, ...] = self.module.weight.data.clone()
        new_layer.weight.data[C_out:, ...] = self.module.weight.data[idx, ...]
        new_layer.weight.data[idx, ...] += delta1
        new_layer.weight.data[C_out:, ...] -= delta2

        if self.has_bias:
            new_layer.bias.data[:C_out, ...] = self.module.bias.data.clone()
            new_layer.bias.data[C_out:, ...] = self.module.bias.data[idx]

        self.module = new_layer

        # for batchnorm layer
        if self.has_bn:
            new_bn = nn.BatchNorm2d(C_out+C_new).to(device)
            new_bn.weight.data[:C_out] = self.bn.weight.data.clone()
            new_bn.weight.data[C_out:] = self.bn.weight.data[idx]
            new_bn.bias.data[:C_out] = self.bn.bias.data.clone()
            new_bn.bias.data[C_out:] = self.bn.bias.data[idx]
            new_bn.running_mean.data[:C_out] = self.bn.running_mean.data.clone()
            new_bn.running_mean.data[C_out:] = self.bn.running_mean.data[idx]
            new_bn.running_var.data[:C_out] = self.bn.running_var.data.clone()
            new_bn.running_var.data[C_out:] = self.bn.running_var.data[idx]
            self.bn = new_bn
        return C_new, idx