in nni/compression/pytorch/speedup/infer_mask.py [0:0]
def update_indirect_sparsity(self):
"""
This function will update the indirect sparsity. To explain what's
indirect sparsity, for example, there is two tensors TA and TB, and
we perform the calculation: TC = TA x TB in which TC is also a tensor.
Once some values in TA are masked to zeros, then the corresponding
positions in TB are also potential sparsities, because these have no
effect of the final output(the gradient of these positions in TB equal
to 0 all the time). This function it to fine the potential sparsity caused
by other sparsity(we call it indirect sparsity here). Basically we can find
these potential sparsity through gradient.
"""
# Each node only update the output mask when we backwards
# update the output mask, this is because that some op may
# have the broadcast operation, for example, OP A's output
# tensor may be taken by two OPs(B, C) as inputs. So we cannot
# directly update the input mask at the OP B or C. We can only
# update the mask of C's output tensor only when B and C are
# already updated(gradient are already calculated and added to
# C's output tensor).
# Besides, updating the mask of C's output tensor equals to updating
# the input mask of OP B and C.
if isinstance(self.output, torch.Tensor) and self.output.grad is not None:
# if output have gradient which means this node has successor
# nodes and the successor nodes have already update their indirect
# sparsity
# we can mask the values whose gradient is always zeros
gradient_sum = torch.sum(torch.abs(self.output.grad.data), dim=0)
_grad_zero = gradient_sum == 0
for batchid in range(self.output.size(0)):
# set the same mask value for the whole batche
self.output_mask[batchid][_grad_zero] = 0
elif isinstance(self.output, tuple) or isinstance(self.output, list):
assert isinstance(self.output_mask, (tuple, list))
for oid, tout in enumerate(self.output):
errmsg = 'The output only support tensor/list of tensors'
assert isinstance(tout, torch.Tensor), errmsg
gradient_sum = torch.sum(
torch.abs(self.output.grad.data), dim=0)
_grad_zero = gradient_sum == 0
for batchid in range(self.output.size(0)):
# set the same mask value for the whole batch
self.output_mask[oid][batchid][_grad_zero] = 0
self.requires_grad_(True)
# Forward inference with auto gradient enabled
# Note: tensors that need gradient cannot be used in the in-place operator
self.random_init()
self.apply_mask()
# Some operator may have the in_place operations, so we need to clone the input
# before passing to the self.module
tmp_dummy_input = [x.clone() if isinstance(
x, torch.Tensor) else x for x in self.dummy_input]
output = self.module(*tmp_dummy_input)
if output.grad_fn is None:
# the output does not have the gradient function
return
# Note: output maybe tensor or list/tuple of tensors
if isinstance(output, torch.Tensor):
output.backward(self.output_mask)
elif isinstance(output, list) or isinstance(output, tuple):
for tid, t_out in enumerate(output):
t_out.backward(self.output_mask[tid])
# update the sparsity of the paramters
for para_name in self.weights:
grad_zero = self.weights[para_name].grad.data == 0
self.weight_mask[para_name][grad_zero] = 0