in archived/sagemaker-debugger/pytorch_iterative_model_pruning/model_resnet.py [0:0]
def prune(model, filters_list, trial, step):
# dict that has a list of filters to be pruned per layer
filters_dict = {}
for layer_name, channel, _ in filters_list:
if layer_name not in filters_dict:
filters_dict[layer_name] = []
filters_dict[layer_name].append(channel)
counter = 0
in_channels_dense = 0
exclude_filters = None
in_channels = 3
exclude = False
# iterate over layers in the ResNet model
for named_module in model.named_modules():
layer_name = named_module[0]
layer = named_module[1]
# check if current layer is a convolutional layer
if isinstance(layer, torch.nn.modules.conv.Conv2d):
# remember the output channels of non-pruned convolution (needed for pruning first fc layer)
in_channels_dense = layer.out_channels
# create key to find right weights/bias/filters for the corresponding layer
weight_name = "ResNet_" + layer_name + ".weight"
# get weight values from last available training step
weight = trial.tensor(weight_name).value(step, mode=modes.TRAIN)
# we need to adjust the number of input channels,
# if previous covolution has been pruned
# print( "current:", layer.in_channels, "previous", in_channels, layer_name, weight_name)
if "conv1" in layer_name or "conv2" in layer_name:
if layer.in_channels != in_channels:
layer.in_channels = in_channels
weight = np.delete(weight, exclude_filters, axis=1)
exclude_filters = None
# if current layer is in the list of filters to be pruned
if "conv1" in layer_name:
layer_id = layer_name.strip("conv1")
for key in filters_dict:
if len(layer_id) > 0 and layer_id in key:
print(
"Reduce output channels for conv layer",
layer_id,
"from",
layer.out_channels,
"to",
layer.out_channels - len(filters_dict[key]),
)
# set new output channels
layer.out_channels = layer.out_channels - len(filters_dict[key])
# remove corresponding filters from weights and bias
# convolution weights have dimension: filter x channel x kernel x kernel
exclude_filters = filters_dict[key]
weight = np.delete(weight, exclude_filters, axis=0)
break
# remember new size of output channels, because we need to prune subsequent convolution
in_channels = layer.out_channels
# set pruned weight and bias
layer.weight.data = torch.from_numpy(weight)
if isinstance(layer, torch.nn.modules.batchnorm.BatchNorm2d):
# get weight values from last available training step
weight_name = "ResNet_" + layer_name + ".weight"
weight = trial.tensor(weight_name).value(step, mode=modes.TRAIN)
# get bias values from last available training step
bias_name = "ResNet_" + layer_name + ".bias"
bias = trial.tensor(bias_name).value(step, mode=modes.TRAIN)
# get running_mean values from last available training step
mean_name = layer_name + ".running_mean_output_0"
mean = trial.tensor(mean_name).value(step, mode=modes.TRAIN)
# get running_var values from last available training step
var_name = layer_name + ".running_var_output_0"
var = trial.tensor(var_name).value(step, mode=modes.TRAIN)
# if current layer is in the list of filters to be pruned
if "bn1" in layer_name:
layer_id = layer_name.strip("bn1")
for key in filters_dict:
if len(layer_id) > 0 and layer_id in key:
print(
"Reduce bn layer",
layer_id,
"from",
weight.shape[0],
"to",
weight.shape[0] - len(filters_dict[key]),
)
# remove corresponding filters from weights and bias
# convolution weights have dimension: filter x channel x kernel x kernel
exclude_filters = filters_dict[key]
weight = np.delete(weight, exclude_filters, axis=0)
bias = np.delete(bias, exclude_filters, axis=0)
mean = np.delete(mean, exclude_filters, axis=0)
var = np.delete(var, exclude_filters, axis=0)
break
# set pruned weight and bias
layer.weight.data = torch.from_numpy(weight)
layer.bias.data = torch.from_numpy(bias)
layer.running_mean.data = torch.from_numpy(mean)
layer.running_var.data = torch.from_numpy(var)
layer.num_features = weight.shape[0]
in_channels = weight.shape[0]
if isinstance(layer, torch.nn.modules.linear.Linear):
# get weight values from last available training step
weight_name = "ResNet_" + layer_name + ".weight"
weight = trial.tensor(weight_name).value(step, mode=modes.TRAIN)
# get bias values from last available training step
bias_name = "ResNet_" + layer_name + ".bias"
bias = trial.tensor(bias_name).value(step, mode=modes.TRAIN)
# prune first fc layer
if exclude_filters is not None:
# in_channels_dense is the number of output channels of last non-pruned convolution layer
params = int(layer.in_features / in_channels_dense)
# prune weights of first fc layer
indexes = []
for i in exclude_filters:
indexes.extend(np.arange(i * params, (i + 1) * params))
if indexes[-1] > weight.shape[1]:
indexes.extend(np.arange(weight.shape[1] - params, weight.shape[1]))
weight = np.delete(weight, indexes, axis=1)
print(
"Reduce weights for first linear layer from",
layer.in_features,
"to",
weight.shape[1],
)
# set new in_features
layer.in_features = weight.shape[1]
exclude_filters = None
# set weights
layer.weight.data = torch.from_numpy(weight)
# set bias
layer.bias.data = torch.from_numpy(bias)
return model