in train_indep.py [0:0]
def lec_update(model, training=False, **regime_params):
# The original LEC paper does the update using a global threshold, so we
# adopt that strategy here.
model, regime_params = sparse_module_updates(
model, training=training, **regime_params
)
if training:
return model, regime_params
else:
# We create a pruned copy of the model.
model_kwargs = regime_params["model_kwargs"]
fresh_copy = utils.make_fresh_copy_of_pruned_network(
model, model_kwargs
)
# The @fresh_copy needs to have its smallest InstanceNorm parameters
# deleted.
topk = regime_params["topk"]
all_weights = []
for m in fresh_copy.modules():
if isinstance(m, (nn.modules.batchnorm._NormBase, nn.GroupNorm)):
all_weights.append(m.weight.abs())
all_weights = torch.cat(all_weights, dim=0)
y, i = torch.sort(all_weights)
threshold = y[int(all_weights.shape[0] * (1.0 - topk))]
for m in fresh_copy.modules():
if isinstance(m, (nn.modules.batchnorm._NormBase, nn.GroupNorm)):
mask = m.weight.data.clone().abs().gt(threshold).float()
m.weight.data.mul_(mask)
m.bias.data.mul_(mask)
# Now that we have the sparse copy, we slim it down.
cfg, cfg_mask = network_utils.get_slim_configs(fresh_copy)
builder = Builder(
conv_type="StandardConv", bn_type=regime_params["bn_type"]
)
try:
if isinstance(model, nn.DataParallel):
model = model.module
if isinstance(model, networks.cpreresnet):
model_class = resprune
elif isinstance(model, networks.vgg.vgg):
model_class = vggprune
else:
raise ValueError(
"Model {} is not supported for LEC.".format(model)
)
_, slimmed_network = model_class.get_slimmed_network(
fresh_copy.module,
{"builder": builder, "block_builder": builder, **model_kwargs},
cfg,
cfg_mask,
)
except IndexError:
# This is the error if we eliminate a whole layer.
print(
f"Something went wrong during LEC - most likely, an entire "
f"layer was deleted. Using @fresh_copy."
)
slimmed_network = fresh_copy
num_parameters = sum(
[param.nelement() for param in slimmed_network.parameters()]
)
# NOTE: DO NOT use @model here, since it has too many extra buffers.
total_params = sum(
[param.nelement() for param in fresh_copy.parameters()]
)
regime_params["sparsity"] = (
total_params - num_parameters
) / total_params
print(f"Got sparsity level of {regime_params['sparsity']}")
return slimmed_network, regime_params