def get_slimmed_network()

in models/networks/vggprune.py [0:0]


def get_slimmed_network(model, model_kwargs, cfg, cfg_mask):
    assert not isinstance(model, nn.DataParallel), f"Must unwrap DataParallel"

    is_cuda = next(model.parameters()).is_cuda
    print("Cfg:")
    print(cfg)

    newmodel = vgg(cfg=cfg, **model_kwargs)

    if is_cuda:
        newmodel.cuda()

    old_modules = list(model.modules())
    new_modules = list(newmodel.modules())
    layer_id_in_cfg = 0
    start_mask = torch.ones(3)
    end_mask = cfg_mask[layer_id_in_cfg]

    for layer_id in range(len(old_modules)):
        m0 = old_modules[layer_id]
        m1 = new_modules[layer_id]

        if isinstance(m0, (nn.modules.batchnorm._NormBase, nn.GroupNorm)):
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
            if idx1.size == 1:
                idx1 = np.resize(idx1, (1,))
            m1.weight.data = m0.weight.data[idx1.tolist()].clone()
            m1.bias.data = m0.bias.data[idx1.tolist()].clone()
            m1.running_mean = m0.running_mean[idx1.tolist()].clone()
            m1.running_var = m0.running_var[idx1.tolist()].clone()
            layer_id_in_cfg += 1
            start_mask = end_mask.clone()
            if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FC
                end_mask = cfg_mask[layer_id_in_cfg]
        elif isinstance(m0, nn.Conv2d) and m0 != model.classifier:
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
            print(
                "In shape: {:d}, Out shape {:d}.".format(idx0.size, idx1.size)
            )
            if idx0.size == 1:
                idx0 = np.resize(idx0, (1,))
            if idx1.size == 1:
                idx1 = np.resize(idx1, (1,))
            w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
            w1 = w1[idx1.tolist(), :, :, :].clone()
            m1.weight.data = w1.clone()
        elif m0 == model.classifier:
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
            if idx0.size == 1:
                idx0 = np.resize(idx0, (1,))

            m1.weight.data = m0.weight.data[:, idx0].clone()

            assert m1.bias is None == m0.bias is None
            if m1.bias is not None:
                m1.bias.data = m0.bias.data.clone()

    num_parameters = sum([param.nelement() for param in newmodel.parameters()])

    return num_parameters, newmodel