in nni/compression/pytorch/speedup/infer_mask.py [0:0]
def __init__(self, module, dummy_input, in_masks=None, weight_mask=None, \
output_mask=None, name=None, in_constants=None, state_dict=None, batch_dim=0):
"""
This class will infer the mask of the target module automatically.
This update_direct_sparsity will infer the output mask according
to the input masks, in constrast, update_indirect_sparsity will
infer the input masks according to given output masks. The newly
found sparsity will be incrementally updated to the original in_masks
and output_mask.
Parameters
----------
module: torch.nn.Module/function
The target module to infer the mask. Need to be callable.
dummy_input: torch.Tensor/list of Tensor
The dummy_input of the target module.
in_masks: list of torch.Tensor
The input masks of the target module, if in_masks is not None, then
update_direct_sparsity and update_indirect_sparsity will incrementally
update the given in_masks, else, AutoMaskInference will create a new
in_masks for the target module.
output_mask: torch.Tensor
The output mask of the target module. Similar to in_masks, if output_mask
is not None, then update_direct_sparsity and update_indirect_sparsity will
incrementally update the given output_mask, else AutoMaskInference will create
one output_mask for the target module.
weight_mask: dict of the weight masks
The weight masks of the target module, the key is the corresponding name of
the mask. For example: {'weight':torch.ones(1000, 1000), bias:torch.ones(1000)}
name: str
Name of the target module.
in_constants: list of torch.Tensor
The correponding constant values of the in_masks.
state_dict: dict of torch.Tensor
The original values of the weights.
batch_dim: int
The index of the batch dimension of the input tensors.
"""
errmsg = '%s is not callable, should pass the nn.Module/function' % str(
module)
assert callable(module), errmsg
self.module = module
# Initialize the dummy_input
if isinstance(dummy_input, list):
# if there are multiple input variables
self.dummy_input = dummy_input
else:
# if there is only one input variable
self.dummy_input = [dummy_input]
# Initialize the masks for input tensors
self.in_masks = in_masks if in_masks is not None else [
None] * len(self.dummy_input)
self.in_constants = in_constants if in_constants is not None else [
torch.zeros_like(x) for x in dummy_input]
for in_id, _ in enumerate(self.in_masks):
if self.in_masks[in_id] is None and \
isinstance(self.dummy_input[in_id], torch.Tensor):
# if the input mask is None then create a all-ones mask for corresponding input tensor
self.in_masks[in_id] = torch.ones_like(self.dummy_input[in_id])
# ones_like will put the created mask on the same device with the dummy_input
# Initialize the mask for output tensors
self.output = self.module(*dummy_input)
# self.output.requires_grad_()
if output_mask is not None:
# assume the given output mask is right
self.output_mask = output_mask
else:
if isinstance(self.output, torch.Tensor):
self.output_mask = torch.ones_like(self.output)
elif isinstance(self.output, list) or isinstance(self.output, tuple):
self.output_mask = []
for o_tensor in self.output:
if isinstance(o_tensor, torch.Tensor):
self.output_mask.append(torch.ones_like(o_tensor))
else:
# if one of the outputs is not tensor, set the corresponding
# mask to None
self.output_mask.append(None)
else:
self.output_mask = None
# Initialize the mask for the parameters
self.weights = {}
self.weight_mask = {}
if weight_mask:
self.weight_mask.update(weight_mask)
self.name = name
if isinstance(self.module, nn.Module):
# the function should not has parameters
# get all the parameter tensors of the target module
for name, para in module.named_parameters():
self.weights[name] = para
if name not in self.weight_mask:
self.weight_mask[name] = torch.ones_like(para.data)
self.state_dict = state_dict
# TODO support the other batch dimension in the future
self.batch_dim = batch_dim