in nni/compression/pytorch/utils/mask_conflict.py [0:0]
def fix_mask(self):
"""
Fix the mask conflict before the mask inference for the layers that
has shape dependencies. This function should be called before the
mask inference of the 'speedup' module. Only structured pruning masks
are supported.
"""
if self.conv_prune_dim == 0:
channel_depen = ChannelDependency(
self.model, self.dummy_input, self.traced, self.channel_prune_type)
else:
channel_depen = InputChannelDependency(
self.model, self.dummy_input, self.traced)
depen_sets = channel_depen.dependency_sets
sum_idx = (1, 2, 3) if self.conv_prune_dim == 0 else (0, 2, 3)
(_tmp_name, _tmp_tensor) = list(self.masks.items())[0]
device = _tmp_tensor['weight'].device
for dset in depen_sets:
if len(dset) <= 1:
continue
# channel_masks is a list, each element is None or a vector, for example:
# [[0, 1, 1, 0, 0], [0, 0, 1, 1, 0], None], None means no channel
# is pruned.
channel_masks = []
fine_grained = False
for name in dset:
if name in self.masks:
_, m = get_module_by_name(self.model, name)
assert m is not None
mask = self.masks[name]['weight']
if type(m).__name__ == 'Conv2d':
channel_mask = (mask.abs().sum(sum_idx) != 0).int()
channel_masks.append(channel_mask)
if (channel_mask.sum() * (mask.numel() / mask.shape[self.conv_prune_dim])).item() != (mask > 0).sum().item():
fine_grained = True
elif type(m).__name__ == 'Linear':
if self.conv_prune_dim == 1:
channel_masks.append(
(mask.abs().sum(0) != 0).int())
else:
channel_masks.append(
(mask.abs().sum(1) != 0).int())
elif type(m).__name__ == 'BatchNorm2d':
channel_masks.append(mask.int())
elif type(m).__name__ == 'ConvTranspose2d':
# convtranspose have difference memory layout, so that we need create
# a tmp_sum_idx for conv_transpose
tmp_sum_idx = (
0, 2, 3) if self.conv_prune_dim == 0 else (1, 2, 3)
channel_mask = (mask.abs().sum(tmp_sum_idx) != 0).int()
channel_masks.append(channel_mask)
if (channel_mask.sum() * (mask.numel() / mask.shape[1 - self.conv_prune_dim])).item() != (mask > 0).sum().item():
fine_grained = True
else:
raise RuntimeError(
f'unsupported module type: {type(m).__name__}')
else:
# no mask means not pruned, equivlent to full masks
channel_masks.append(None)
if fine_grained:
_logger.info("Fine-grianed mask detected")
if all(x is None for x in channel_masks):
continue
num_channels_list = [len(x)
for x in channel_masks if x is not None]
# number of channels in same set should be identical
assert len(set(num_channels_list)) == 1
num_channels = num_channels_list[0]
for i, dim_mask in enumerate(channel_masks):
if dim_mask is None:
channel_masks[i] = torch.ones(
num_channels).int().to(device)
# merge masks with 'or'
merged_channel_mask = channel_masks[0].clone()
for i in range(1, len(channel_masks)):
merged_channel_mask = (
(merged_channel_mask + channel_masks[i]) != 0).int()
merged_index = torch.nonzero(merged_channel_mask, as_tuple=True)[0]
for name in dset:
if name not in self.masks:
assert all(merged_channel_mask)
continue
orig_mask = self.masks[name]['weight']
_, m = get_module_by_name(self.model, name)
new_mask = torch.zeros_like(orig_mask)
if type(m).__name__ == 'Conv2d':
if self.conv_prune_dim == 0:
new_mask[merged_index, :, :, :] = 1.
else:
new_mask[:, merged_index, :, :] = 1.
elif type(m).__name__ == 'Linear':
if self.conv_prune_dim == 0:
new_mask[merged_index, :] = 1
elif self.conv_prune_dim == 1:
new_mask[:, merged_index] = 1.
elif type(m).__name__ == 'BatchNorm2d':
new_mask = merged_channel_mask.type_as(orig_mask)
else:
raise RuntimeError(
f'unsupported module type: {type(m).__name__}')
self.masks[name]['weight'] = new_mask
if 'bias' in self.masks[name] and self.masks[name]['bias'] is not None:
if type(m).__name__ == 'Conv2d':
assert self.conv_prune_dim == 0
if self.conv_prune_dim == 0:
self.masks[name]['bias'] = merged_channel_mask.type_as(
self.masks[name]['bias'])
return self.masks