in nni/compression/pytorch/speedup/compress_modules.py [0:0]
def replace_convtranspose2d(convtrans, masks):
"""
We need anothor replace function for
convtranspose2d, because the layout of
the weight is different from traditional
conv layers. The layout of the weight is [N_in, N_out, ksize_1, ksize_2]
Parameters
----------
convtrans : torch.nn.ConvTranspose2d
The conv2d module to be replaced
masks : Tuple of the input masks, output masks and weight masks
Tuple of the masks, for example
([input_m1, input_m2], [output_m], {'weight':weight_m})
Returns
-------
torch.nn.ConvTranspose2d
The new conv2d module
"""
in_masks, output_mask, weight_masks = masks
assert isinstance(convtrans, torch.nn.ConvTranspose2d)
if len(in_masks) != 1:
raise InputsNumberError()
in_mask = in_masks[0]
weight_mask = weight_masks['weight']
pruned_in, remained_in = convert_to_coarse_mask(in_mask, 1)
pruned_out, remained_out = convert_to_coarse_mask(output_mask, 1)
# ConvTranspose2d has the weight shape of [N_in, N_out/groups, k1, k2]
n_remained_in = weight_mask.size(0) - pruned_in.size(0)
n_remained_out = weight_mask.size(
1) * convtrans.groups - pruned_out.size(0)
if n_remained_in != remained_in.size(0) or n_remained_out != remained_out.size(0):
raise ShapeMisMatchError()
k_size1, k_size2 = convtrans.kernel_size
# Note: we should resolve the group dependency of the convtrans layers before
# run into this function
ori_inchannel_step = int(convtrans.in_channels/convtrans.groups)
ori_outchannel_step = int(convtrans.out_channels/convtrans.groups)
new_inchannel_step = new_outchannel_step = None
for groupid in range(convtrans.groups):
in_start = groupid * ori_inchannel_step
in_end = in_start + ori_inchannel_step
out_start = groupid * ori_outchannel_step
out_end = out_start + ori_outchannel_step
current_input_index = list(
filter(lambda x: in_start <= x and x < in_end, remained_in.tolist()))
current_output_index = list(
filter(lambda x: out_start <= x and x < out_end, remained_out.tolist()))
if len(current_input_index) == 0:
# if the whole group are pruned
continue
else:
new_inchannel_step = len(current_input_index)
new_outchannel_step = len(current_output_index)
break
tmp_weight = torch.ones(
n_remained_in, new_outchannel_step, k_size1, k_size2)
tmp_weight = tmp_weight.to(convtrans.weight.device)
if new_inchannel_step == 0 or new_outchannel_step == 0:
raise EmptyLayerError()
if n_remained_in % new_inchannel_step != 0 or n_remained_out % new_outchannel_step != 0:
raise UnBalancedGroupError()
new_groups = 0
for groupid in range(convtrans.groups):
# copy the weights of this group
in_start = groupid * ori_inchannel_step
in_end = in_start + ori_inchannel_step
out_start = groupid * ori_outchannel_step
out_end = out_start + ori_outchannel_step
current_input_index = list(
filter(lambda x: in_start <= x and x < in_end, remained_in.tolist()))
current_output_index = list(
filter(lambda x: out_start <= x and x < out_end, remained_out.tolist()))
# remap the global index to the group index
# in the convtranspose layer, the groups are on
# the output channel dimension
current_output_index = [x-out_start for x in current_output_index]
if len(current_input_index) == 0:
# if the whole group are pruned
assert len(current_output_index) == 0
continue
# check if the number of remained channel of each group are the same
if len(current_input_index) != new_inchannel_step or len(current_output_index) != new_outchannel_step:
raise UnBalancedGroupError()
# copy the weight into tmp_weight
new_in_start = new_inchannel_step * new_groups
new_in_end = new_in_start + new_inchannel_step
tmp_weight[new_in_start:new_in_end] = torch.index_select(
convtrans.weight[current_input_index], 1, torch.as_tensor(current_output_index, dtype=torch.long).to(convtrans.weight.device))
new_groups += 1
_logger.debug('Replace convtranspose2d with in_channels:%d out_channels:%d',
n_remained_in, n_remained_out)
new_convtrans = torch.nn.ConvTranspose2d(in_channels=n_remained_in,
out_channels=n_remained_out,
kernel_size=convtrans.kernel_size,
stride=convtrans.stride,
padding=convtrans.padding,
dilation=convtrans.dilation,
groups=new_groups,
bias=convtrans.bias is not None,
padding_mode=convtrans.padding_mode)
new_convtrans.to(convtrans.weight.device)
new_convtrans.weight.copy_(tmp_weight)
if convtrans.bias is not None:
if output_mask is not None:
new_convtrans.bias.data[:] = torch.index_select(
convtrans.bias.data, 0, remained_out)
else:
new_convtrans.bias.data.copy_(convtrans.bias.data)
return new_convtrans