def update_indirect_sparsity()

in nni/compression/pytorch/speedup/infer_mask.py [0:0]


    def update_indirect_sparsity(self):
        """
        This function will update the indirect sparsity. To explain what's
        indirect sparsity, for example, there is two tensors TA and TB, and
        we perform the calculation: TC = TA x TB in which TC is also a tensor.
        Once some values in TA are masked to zeros, then the corresponding
        positions in TB are also potential sparsities, because these have no
        effect of the final output(the gradient of these positions in TB equal
        to 0 all the time). This function it to fine the potential sparsity caused
        by other sparsity(we call it indirect sparsity here). Basically we can find
        these potential sparsity through gradient.
        """
        # Each node only update the output mask when we backwards
        # update the output mask, this is because that some op may
        # have the broadcast operation, for example, OP A's output
        # tensor may be taken by two OPs(B, C) as inputs. So we cannot
        # directly update the input mask at the OP B or C. We can only
        # update the mask of C's output tensor only when B and C are
        # already updated(gradient are already calculated and added to
        # C's output tensor).
        # Besides, updating the mask of C's output tensor equals to updating
        # the input mask of OP B and C.
        if isinstance(self.output, torch.Tensor) and self.output.grad is not None:
            # if output have gradient which means this node has successor
            # nodes and the successor nodes have already update their indirect
            # sparsity
            # we can mask the values whose gradient is always zeros
            gradient_sum = torch.sum(torch.abs(self.output.grad.data), dim=0)
            _grad_zero = gradient_sum == 0
            for batchid in range(self.output.size(0)):
                # set the same mask value for the whole batche
                self.output_mask[batchid][_grad_zero] = 0
        elif isinstance(self.output, tuple) or isinstance(self.output, list):
            assert isinstance(self.output_mask, (tuple, list))
            for oid, tout in enumerate(self.output):
                errmsg = 'The output only support tensor/list of tensors'
                assert isinstance(tout, torch.Tensor), errmsg
                gradient_sum = torch.sum(
                    torch.abs(self.output.grad.data), dim=0)
                _grad_zero = gradient_sum == 0
                for batchid in range(self.output.size(0)):
                    # set the same mask value for the whole batch
                    self.output_mask[oid][batchid][_grad_zero] = 0

        self.requires_grad_(True)
        # Forward inference with auto gradient enabled
        # Note: tensors that need gradient cannot be used in the in-place operator
        self.random_init()
        self.apply_mask()
        # Some operator may have the in_place operations, so we need to clone the input
        # before passing to the self.module
        tmp_dummy_input = [x.clone() if isinstance(
            x, torch.Tensor) else x for x in self.dummy_input]
        output = self.module(*tmp_dummy_input)

        if output.grad_fn is None:
            # the output does not have the gradient function
            return
        # Note: output maybe tensor or list/tuple of tensors
        if isinstance(output, torch.Tensor):
            output.backward(self.output_mask)
        elif isinstance(output, list) or isinstance(output, tuple):
            for tid, t_out in enumerate(output):
                t_out.backward(self.output_mask[tid])

        # update the sparsity of the paramters
        for para_name in self.weights:
            grad_zero = self.weights[para_name].grad.data == 0
            self.weight_mask[para_name][grad_zero] = 0