def _test_time_lec_update()

in train_curve.py [0:0]


def _test_time_lec_update(model, **regime_params):
    # This requires that the topk values are already set on the model.

    # We create a whole new copy of the model which is pruned.
    model_kwargs = regime_params["model_kwargs"]
    fresh_copy = utils.make_fresh_copy_of_pruned_network(model, model_kwargs)
    cfg, cfg_mask = network_utils.get_slim_configs(fresh_copy)

    builder = Builder(conv_type="StandardConv", bn_type="StandardIN")

    try:
        if isinstance(model, models.cpreresnet):
            model_class = resprune
        elif isinstance(model, models.vgg.vgg):
            model_class = vggprune
        else:
            raise ValueError(
                "Model {} is not surpported for LEC.".format(model)
            )

        _, slimmed_network = model_class.get_slimmed_network(
            fresh_copy.module,
            {"builder": builder, "block_builder": builder, **model_kwargs},
            cfg,
            cfg_mask,
        )
    except:
        print(
            f"Something went wrong during LEC. Most likely, an entire "
            f"layer was deleted. Using @fresh_copy."
        )
        slimmed_network = fresh_copy
    num_parameters = sum(
        [param.nelement() for param in slimmed_network.parameters()]
    )

    # NOTE: DO NOT use @model here, since it has too many extra buffers in the
    # case of training a line.
    total_params = sum([param.nelement() for param in fresh_copy.parameters()])
    regime_params["sparsity"] = (total_params - num_parameters) / total_params

    print(f"Got sparsity level of {regime_params['sparsity']}")

    return slimmed_network, regime_params