def fix_mask()

in nni/compression/pytorch/utils/mask_conflict.py [0:0]


    def fix_mask(self):
        """
        Fix the mask conflict before the mask inference for the layers that
        has shape dependencies. This function should be called before the
        mask inference of the 'speedup' module. Only structured pruning masks
        are supported.
        """
        if self.conv_prune_dim == 0:
            channel_depen = ChannelDependency(
                self.model, self.dummy_input, self.traced, self.channel_prune_type)

        else:
            channel_depen = InputChannelDependency(
                self.model, self.dummy_input, self.traced)
        depen_sets = channel_depen.dependency_sets
        sum_idx = (1, 2, 3) if self.conv_prune_dim == 0 else (0, 2, 3)

        (_tmp_name, _tmp_tensor) = list(self.masks.items())[0]
        device = _tmp_tensor['weight'].device

        for dset in depen_sets:
            if len(dset) <= 1:
                continue
            # channel_masks is a list, each element is None or a vector, for example:
            # [[0, 1, 1, 0, 0], [0, 0, 1, 1, 0], None], None means no channel
            # is pruned.
            channel_masks = []
            fine_grained = False
            for name in dset:
                if name in self.masks:
                    _, m = get_module_by_name(self.model, name)
                    assert m is not None
                    mask = self.masks[name]['weight']
                    if type(m).__name__ == 'Conv2d':
                        channel_mask = (mask.abs().sum(sum_idx) != 0).int()
                        channel_masks.append(channel_mask)
                        if (channel_mask.sum() * (mask.numel() / mask.shape[self.conv_prune_dim])).item() != (mask > 0).sum().item():
                            fine_grained = True
                    elif type(m).__name__ == 'Linear':
                        if self.conv_prune_dim == 1:
                            channel_masks.append(
                                (mask.abs().sum(0) != 0).int())
                        else:
                            channel_masks.append(
                                (mask.abs().sum(1) != 0).int())
                    elif type(m).__name__ == 'BatchNorm2d':
                        channel_masks.append(mask.int())
                    elif type(m).__name__ == 'ConvTranspose2d':
                        # convtranspose have difference memory layout, so that we need create
                        # a tmp_sum_idx for conv_transpose
                        tmp_sum_idx = (
                            0, 2, 3) if self.conv_prune_dim == 0 else (1, 2, 3)
                        channel_mask = (mask.abs().sum(tmp_sum_idx) != 0).int()
                        channel_masks.append(channel_mask)
                        if (channel_mask.sum() * (mask.numel() / mask.shape[1 - self.conv_prune_dim])).item() != (mask > 0).sum().item():
                            fine_grained = True
                    else:
                        raise RuntimeError(
                            f'unsupported module type: {type(m).__name__}')
                else:
                    # no mask means not pruned, equivlent to full masks
                    channel_masks.append(None)
            if fine_grained:
                _logger.info("Fine-grianed mask detected")
            if all(x is None for x in channel_masks):
                continue
            num_channels_list = [len(x)
                                 for x in channel_masks if x is not None]
            # number of channels in same set should be identical
            assert len(set(num_channels_list)) == 1
            num_channels = num_channels_list[0]

            for i, dim_mask in enumerate(channel_masks):
                if dim_mask is None:
                    channel_masks[i] = torch.ones(
                        num_channels).int().to(device)

            # merge masks with 'or'
            merged_channel_mask = channel_masks[0].clone()
            for i in range(1, len(channel_masks)):
                merged_channel_mask = (
                    (merged_channel_mask + channel_masks[i]) != 0).int()

            merged_index = torch.nonzero(merged_channel_mask, as_tuple=True)[0]

            for name in dset:
                if name not in self.masks:
                    assert all(merged_channel_mask)
                    continue
                orig_mask = self.masks[name]['weight']
                _, m = get_module_by_name(self.model, name)
                new_mask = torch.zeros_like(orig_mask)
                if type(m).__name__ == 'Conv2d':
                    if self.conv_prune_dim == 0:
                        new_mask[merged_index, :, :, :] = 1.
                    else:
                        new_mask[:, merged_index, :, :] = 1.
                elif type(m).__name__ == 'Linear':
                    if self.conv_prune_dim == 0:
                        new_mask[merged_index, :] = 1
                    elif self.conv_prune_dim == 1:
                        new_mask[:, merged_index] = 1.
                elif type(m).__name__ == 'BatchNorm2d':
                    new_mask = merged_channel_mask.type_as(orig_mask)
                else:
                    raise RuntimeError(
                        f'unsupported module type: {type(m).__name__}')
                self.masks[name]['weight'] = new_mask
                if 'bias' in self.masks[name] and self.masks[name]['bias'] is not None:
                    if type(m).__name__ == 'Conv2d':
                        assert self.conv_prune_dim == 0
                    if self.conv_prune_dim == 0:
                        self.masks[name]['bias'] = merged_channel_mask.type_as(
                            self.masks[name]['bias'])

        return self.masks