def get_slimmed_network()

in models/networks/resprune.py [0:0]


def get_slimmed_network(model, model_kwargs, cfg, cfg_mask):
    assert not isinstance(model, nn.DataParallel), f"Must unwrap DataParallel"

    is_cuda = next(model.parameters()).is_cuda
    print("Cfg:")
    print(cfg)

    newmodel = cpreresnet(cfg=cfg, **model_kwargs)

    if is_cuda:
        newmodel.cuda()

    old_modules = list(model.modules())
    new_modules = list(newmodel.modules())
    layer_id_in_cfg = 0
    start_mask = torch.ones(3)
    end_mask = cfg_mask[layer_id_in_cfg]
    conv_count = 0

    for layer_id in range(len(old_modules)):
        m0 = old_modules[layer_id]
        m1 = new_modules[layer_id]

        if isinstance(m0, (nn.modules.batchnorm._NormBase, nn.GroupNorm)):
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
            if idx1.size == 1:
                idx1 = np.resize(idx1, (1,))

            if isinstance(
                old_modules[layer_id + 1], networks.channel_selection
            ):
                # If the next layer is the channel selection layer, then the
                # current batchnorm 2d layer won't be pruned.
                m1.weight.data = m0.weight.data.clone()
                m1.bias.data = m0.bias.data.clone()
                if m0.running_mean is None:
                    m1.running_mean = m0.running_mean
                    m1.running_var = m0.running_var
                else:
                    m1.running_mean.data = m0.running_mean.clone()
                    m1.running_var.data = m0.running_var.clone()

                # We need to set the channel selection layer.
                m2 = new_modules[layer_id + 1]
                m2.indexes.data.zero_()
                m2.indexes.data[idx1.tolist()] = 1.0

                layer_id_in_cfg += 1
                start_mask = end_mask.clone()
                if layer_id_in_cfg < len(cfg_mask):
                    end_mask = cfg_mask[layer_id_in_cfg]
            else:
                m1.weight.data = m0.weight.data[idx1.tolist()].clone()
                m1.bias.data = m0.bias.data[idx1.tolist()].clone()
                if m0.running_mean is None:
                    m1.running_mean = m0.running_mean
                    m1.running_var = m0.running_var
                else:
                    m1.running_mean = m0.running_mean[idx1.tolist()].clone()
                    m1.running_var = m0.running_var[idx1.tolist()].clone()
                layer_id_in_cfg += 1
                start_mask = end_mask.clone()
                if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FC
                    end_mask = cfg_mask[layer_id_in_cfg]
        elif isinstance(m0, nn.Conv2d) and m0 != model.fc:
            if conv_count == 0:
                m1.weight.data = m0.weight.data.clone()
                conv_count += 1
                continue
            if isinstance(
                old_modules[layer_id - 1], networks.channel_selection
            ) or isinstance(
                old_modules[layer_id - 1],
                (nn.modules.batchnorm._NormBase, nn.GroupNorm),
            ):
                # This convers the convolutions in the residual block.
                # The convolutions are either after the channel selection layer
                # or after the batch normalization layer.
                conv_count += 1
                idx0 = np.squeeze(
                    np.argwhere(np.asarray(start_mask.cpu().numpy()))
                )
                idx1 = np.squeeze(
                    np.argwhere(np.asarray(end_mask.cpu().numpy()))
                )
                if idx0.size == 1:
                    idx0 = np.resize(idx0, (1,))
                if idx1.size == 1:
                    idx1 = np.resize(idx1, (1,))
                w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()

                # If the current convolution is not the last convolution in the
                # residual block, then we can change the number of output
                # channels. Currently we use `conv_count` to detect whether it
                # is such convolution.
                if conv_count % 3 != 1:
                    w1 = w1[idx1.tolist(), :, :, :].clone()
                m1.weight.data = w1.clone()
                continue

            # We need to consider the case where there are downsampling
            # convolutions. For these convolutions, we just copy the weights.
            m1.weight.data = m0.weight.data.clone()
        elif m0 == model.fc:
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
            if idx0.size == 1:
                idx0 = np.resize(idx0, (1,))

            m1.weight.data = m0.weight.data[:, idx0].clone()

            assert m1.bias is None == m0.bias is None
            if m1.bias is not None:
                m1.bias.data = m0.bias.data.clone()

    num_parameters = sum([param.nelement() for param in newmodel.parameters()])

    return num_parameters, newmodel