in train_curve.py [0:0]
def _test_time_lec_update(model, **regime_params):
# This requires that the topk values are already set on the model.
# We create a whole new copy of the model which is pruned.
model_kwargs = regime_params["model_kwargs"]
fresh_copy = utils.make_fresh_copy_of_pruned_network(model, model_kwargs)
cfg, cfg_mask = network_utils.get_slim_configs(fresh_copy)
builder = Builder(conv_type="StandardConv", bn_type="StandardIN")
try:
if isinstance(model, models.cpreresnet):
model_class = resprune
elif isinstance(model, models.vgg.vgg):
model_class = vggprune
else:
raise ValueError(
"Model {} is not surpported for LEC.".format(model)
)
_, slimmed_network = model_class.get_slimmed_network(
fresh_copy.module,
{"builder": builder, "block_builder": builder, **model_kwargs},
cfg,
cfg_mask,
)
except:
print(
f"Something went wrong during LEC. Most likely, an entire "
f"layer was deleted. Using @fresh_copy."
)
slimmed_network = fresh_copy
num_parameters = sum(
[param.nelement() for param in slimmed_network.parameters()]
)
# NOTE: DO NOT use @model here, since it has too many extra buffers in the
# case of training a line.
total_params = sum([param.nelement() for param in fresh_copy.parameters()])
regime_params["sparsity"] = (total_params - num_parameters) / total_params
print(f"Got sparsity level of {regime_params['sparsity']}")
return slimmed_network, regime_params