def lec_update()

in train_indep.py [0:0]


def lec_update(model, training=False, **regime_params):
    # The original LEC paper does the update using a global threshold, so we
    # adopt that strategy here.
    model, regime_params = sparse_module_updates(
        model, training=training, **regime_params
    )

    if training:
        return model, regime_params
    else:
        # We create a pruned copy of the model.
        model_kwargs = regime_params["model_kwargs"]
        fresh_copy = utils.make_fresh_copy_of_pruned_network(
            model, model_kwargs
        )

        # The @fresh_copy needs to have its smallest InstanceNorm parameters
        # deleted.
        topk = regime_params["topk"]
        all_weights = []
        for m in fresh_copy.modules():
            if isinstance(m, (nn.modules.batchnorm._NormBase, nn.GroupNorm)):
                all_weights.append(m.weight.abs())

        all_weights = torch.cat(all_weights, dim=0)
        y, i = torch.sort(all_weights)
        threshold = y[int(all_weights.shape[0] * (1.0 - topk))]

        for m in fresh_copy.modules():
            if isinstance(m, (nn.modules.batchnorm._NormBase, nn.GroupNorm)):
                mask = m.weight.data.clone().abs().gt(threshold).float()
                m.weight.data.mul_(mask)
                m.bias.data.mul_(mask)

        # Now that we have the sparse copy, we slim it down.
        cfg, cfg_mask = network_utils.get_slim_configs(fresh_copy)

        builder = Builder(
            conv_type="StandardConv", bn_type=regime_params["bn_type"]
        )

        try:
            if isinstance(model, nn.DataParallel):
                model = model.module

            if isinstance(model, networks.cpreresnet):
                model_class = resprune
            elif isinstance(model, networks.vgg.vgg):
                model_class = vggprune
            else:
                raise ValueError(
                    "Model {} is not supported for LEC.".format(model)
                )

            _, slimmed_network = model_class.get_slimmed_network(
                fresh_copy.module,
                {"builder": builder, "block_builder": builder, **model_kwargs},
                cfg,
                cfg_mask,
            )
        except IndexError:
            # This is the error if we eliminate a whole layer.
            print(
                f"Something went wrong during LEC - most likely, an entire "
                f"layer was deleted. Using @fresh_copy."
            )
            slimmed_network = fresh_copy
        num_parameters = sum(
            [param.nelement() for param in slimmed_network.parameters()]
        )

        # NOTE: DO NOT use @model here, since it has too many extra buffers.
        total_params = sum(
            [param.nelement() for param in fresh_copy.parameters()]
        )
        regime_params["sparsity"] = (
            total_params - num_parameters
        ) / total_params

        print(f"Got sparsity level of {regime_params['sparsity']}")

        return slimmed_network, regime_params