def __init__()

in nni/compression/pytorch/speedup/infer_mask.py [0:0]


    def __init__(self, module, dummy_input, in_masks=None, weight_mask=None, \
                output_mask=None, name=None, in_constants=None, state_dict=None, batch_dim=0):
        """
        This class will infer the mask of the target module automatically.
        This update_direct_sparsity will infer the output mask according
        to the input masks, in constrast, update_indirect_sparsity will
        infer the input masks according to given output masks. The newly
        found sparsity will be incrementally updated to the original in_masks
        and output_mask.

        Parameters
        ----------
        module: torch.nn.Module/function
            The target module to infer the mask. Need to be callable.
        dummy_input: torch.Tensor/list of Tensor
            The dummy_input of the target module.
        in_masks:  list of torch.Tensor
            The input masks of the target module, if in_masks is not None, then
            update_direct_sparsity and update_indirect_sparsity will incrementally
            update the given in_masks, else, AutoMaskInference will create a new
            in_masks for the target module.
        output_mask: torch.Tensor
            The output mask of the target module. Similar to in_masks, if output_mask
            is not None, then update_direct_sparsity and update_indirect_sparsity will
            incrementally update the given output_mask, else AutoMaskInference will create
            one output_mask for the target module.
        weight_mask: dict of the weight masks
            The weight masks of the target module, the key is the corresponding name of
            the mask. For example: {'weight':torch.ones(1000, 1000), bias:torch.ones(1000)}
        name: str
            Name of the target module.
        in_constants: list of torch.Tensor
            The correponding constant values of the in_masks.
        state_dict: dict of torch.Tensor
            The original values of the weights.
        batch_dim: int
            The index of the batch dimension of the input tensors.

        """
        errmsg = '%s is not callable, should pass the nn.Module/function' % str(
            module)
        assert callable(module), errmsg
        self.module = module

        # Initialize the dummy_input
        if isinstance(dummy_input, list):
            # if there are multiple input variables
            self.dummy_input = dummy_input
        else:
            # if there is only one input variable
            self.dummy_input = [dummy_input]

        # Initialize the masks for input tensors
        self.in_masks = in_masks if in_masks is not None else [
            None] * len(self.dummy_input)
        self.in_constants = in_constants if in_constants is not None else [
            torch.zeros_like(x) for x in dummy_input]
        for in_id, _ in enumerate(self.in_masks):
            if self.in_masks[in_id] is None and \
                    isinstance(self.dummy_input[in_id], torch.Tensor):
                # if the input mask is None then create a all-ones mask for corresponding input tensor
                self.in_masks[in_id] = torch.ones_like(self.dummy_input[in_id])
                # ones_like will put the created mask on the same device with the dummy_input

        # Initialize the mask for output tensors
        self.output = self.module(*dummy_input)
        # self.output.requires_grad_()
        if output_mask is not None:
            # assume the given output mask is right
            self.output_mask = output_mask
        else:
            if isinstance(self.output, torch.Tensor):
                self.output_mask = torch.ones_like(self.output)
            elif isinstance(self.output, list) or isinstance(self.output, tuple):
                self.output_mask = []
                for o_tensor in self.output:
                    if isinstance(o_tensor, torch.Tensor):
                        self.output_mask.append(torch.ones_like(o_tensor))
                    else:
                        # if one of the outputs is not tensor, set the corresponding
                        # mask to None
                        self.output_mask.append(None)
            else:
                self.output_mask = None

        # Initialize the mask for the parameters
        self.weights = {}
        self.weight_mask = {}
        if weight_mask:
            self.weight_mask.update(weight_mask)
        self.name = name
        if isinstance(self.module, nn.Module):
            # the function should not has parameters
            # get all the parameter tensors of the target module
            for name, para in module.named_parameters():
                self.weights[name] = para
                if name not in self.weight_mask:
                    self.weight_mask[name] = torch.ones_like(para.data)
        self.state_dict = state_dict
        # TODO support the other batch dimension in the future
        self.batch_dim = batch_dim