in nni/compression/pytorch/speedup/compress_modules.py [0:0]
def replace_conv2d(conv, masks):
"""
Replace the original conv with a new one according to the infered
masks, the function support the fine-grained sparsity and coarse-grained
sparsity. In the fine-grained scenario, this replace function will replace
the filters that happen to be totally coverd by the fine-grained sparsity.
Parameters
----------
conv : torch.nn.Conv2d
The conv2d module to be replaced
masks : Tuple of the input masks, output masks and weight masks
Tuple of the masks, for example
([input_m1, input_m2], [output_m], {'weight':weight_m})
Returns
-------
torch.nn.Conv2d
The new conv2d module
"""
in_masks, output_mask, weight_masks = masks
assert isinstance(conv, nn.Conv2d)
# the conv layer should only have one input tensor
if len(in_masks) != 1:
raise InputsNumberError()
in_mask = in_masks[0]
weight_mask = weight_masks['weight']
pruned_in, remained_in = convert_to_coarse_mask(in_mask, 1)
pruned_out, remained_out = convert_to_coarse_mask(output_mask, 1)
n_remained_in = weight_mask.size(1) * conv.groups - pruned_in.size(0)
n_remained_out = weight_mask.size(0) - pruned_out.size(0)
if n_remained_in != remained_in.size(0) or n_remained_out != remained_out.size(0):
raise ShapeMisMatchError()
k_size1, k_size2 = conv.kernel_size
# Note: We should resolve the group dependency of the conv layers before
# run into here.
# check if the mask tensor meets the group dependency and calculate the
# new number of the groups after pruning
# the original step size of the input channel for each group
ori_inchannel_step = int(conv.in_channels/conv.groups)
# the original step size of the output channel for each group
ori_outchannel_step = int(conv.out_channels/conv.groups)
# calculate the new_in_channel_step and new_outchannel_step first
new_inchannel_step = new_outchannel_step = None
for groupid in range(conv.groups):
in_start = groupid * ori_inchannel_step
in_end = in_start + ori_inchannel_step
out_start = groupid * ori_outchannel_step
out_end = out_start + ori_outchannel_step
current_input_index = list(
filter(lambda x: in_start <= x and x < in_end, remained_in.tolist()))
current_output_index = list(
filter(lambda x: out_start <= x and x < out_end, remained_out.tolist()))
# remap the global index to the group index
if len(current_input_index) == 0:
# if the whole group are pruned
continue
else:
new_inchannel_step = len(current_input_index)
new_outchannel_step = len(current_output_index)
break
tmp_weight = torch.ones(
n_remained_out, new_inchannel_step, k_size1, k_size2)
tmp_weight = tmp_weight.to(conv.weight.device)
if new_inchannel_step == 0 or new_outchannel_step == 0:
raise EmptyLayerError()
if n_remained_in % new_inchannel_step != 0 or n_remained_out % new_outchannel_step != 0:
raise UnBalancedGroupError()
new_groups = 0
for groupid in range(conv.groups):
in_start = groupid * ori_inchannel_step
in_end = in_start + ori_inchannel_step
out_start = groupid * ori_outchannel_step
out_end = out_start + ori_outchannel_step
current_input_index = list(
filter(lambda x: in_start <= x and x < in_end, remained_in.tolist()))
current_output_index = list(
filter(lambda x: out_start <= x and x < out_end, remained_out.tolist()))
# remap the global index to the group index
current_input_index = [x-in_start for x in current_input_index]
if len(current_input_index) == 0:
# if the whole group are pruned
assert len(current_output_index) == 0
continue
# check if the number of remained channel of each group are the same
if len(current_input_index) != new_inchannel_step or len(current_output_index) != new_outchannel_step:
raise UnBalancedGroupError()
# copy the weight into tmp_weight
new_out_start = new_outchannel_step * new_groups
new_out_end = new_out_start + new_outchannel_step
tmp_weight[new_out_start:new_out_end] = torch.index_select(
conv.weight[current_output_index], 1, torch.as_tensor(current_input_index, dtype=torch.long).to(conv.weight.device))
new_groups += 1
_logger.debug("replace conv2d with in_channels: %d, out_channels: %d",
n_remained_in, n_remained_out)
# need_bias is a flag that indicates that if a conv layer need
# bias, if the original conv doesn't have a bias and there is
# no constant need to be folded into the bias, the need_bias is False.
need_bias = conv.bias is not None
new_conv = torch.nn.Conv2d(in_channels=n_remained_in,
out_channels=n_remained_out,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
dilation=conv.dilation,
groups=new_groups,
bias=need_bias,
padding_mode=conv.padding_mode)
new_conv.to(conv.weight.device)
new_conv.weight.copy_(tmp_weight)
# copy the bias data
if conv.bias is not None:
new_conv.bias.data.copy_(torch.index_select(
conv.bias.data, 0, remained_out))
return new_conv