def make_fresh_copy_of_pruned_network()

in utils.py [0:0]


def make_fresh_copy_of_pruned_network(model: nn.Module, model_kwargs: Dict):
    norm_type_string = get_norm_type_string(model)
    builder = Builder(conv_type="StandardConv", bn_type=norm_type_string)

    copy = type(model.module)(
        builder=builder, block_builder=builder, **model_kwargs
    )  # type: nn.Module
    # Need to move @copy to GPU before moving to DataParallel.
    if next(model.parameters()).is_cuda:
        copy = copy.cuda()
    copy = nn.DataParallel(copy)

    state_dict = model.state_dict()
    del_me = []
    for k, v in state_dict.items():
        if k.endswith(f"1"):
            del_me.append(k)

    for elem in del_me:
        del state_dict[elem]

    copy.load_state_dict(state_dict)

    # The only part we should need to fix are modules with a get_weight()
    # function.
    name_to_copy = {name: module for name, module in copy.named_modules()}

    for name, module in model.named_modules():
        if hasattr(module, "get_weight"):
            print(f"Adjusting weight at module {name}")

            pieces = module.get_weight()

            if len(pieces) == 1:
                name_to_copy[name].weight.data = pieces
            else:
                assert len(pieces) == 2, f"Invalid len(pieces)={len(pieces)}"
                name_to_copy[name].weight.data = pieces[0]
                name_to_copy[name].bias.data = pieces[1]

    return copy