def register_mask()

in tinynn/graph/modifier.py [0:0]


    def register_mask(self, modifiers, importance, sparsity):
        gs = rnn_gate_size(self.module())
        num_directions = 2 if self.module().bidirectional else 1
        has_proj = hasattr(self.module(), 'proj_size') and self.module().proj_size > 0

        if self.dim_changes_info.pruned_idx_i:
            remove_idx = self.dim_changes_info.pruned_idx_i
            self.weight_mask['weight_ih_l0'][:, remove_idx] = 0
            self.masker().set_in_remove_idx(remove_idx)

        if self.dim_changes_info.pruned_idx_o:
            if has_proj:
                u_name = self.unique_name()
                hu_name = f'{u_name}:h'

                remove_idx = []
                idx_num = len(importance[hu_name])
                remove_num = int(sparsity[u_name] * len(importance[hu_name]))
                if self.bidirectional:
                    idx_num //= 2
                    remove_num //= 2
                    remove_idx += get_smallest_k(importance[hu_name][:idx_num], remove_num)
                    remove_idx += get_smallest_k(importance[hu_name][idx_num:], remove_num, offset=idx_num)
                else:
                    remove_idx += get_smallest_k(importance[hu_name], remove_num)
                remove_idx_proj = self.dim_changes_info.pruned_idx_o
            else:
                remove_idx = self.dim_changes_info.pruned_idx_o
                remove_idx_proj = None

            remove_idx_bwd = None
            remove_idx_fwd = None
            remove_idx_proj_bwd = None
            remove_idx_proj_fwd = None
            if num_directions > 1:
                offset = self.module().hidden_size
                remove_idx_fwd, remove_idx_bwd = self.split_indices_with_directions(remove_idx, offset, num_directions)
                if remove_idx_proj is not None:
                    offset = self.module().proj_size
                    remove_idx_proj_fwd, remove_idx_proj_bwd = self.split_indices_with_directions(
                        remove_idx_proj, offset, num_directions
                    )
                    assert len(remove_idx_proj_fwd) == len(remove_idx_proj_bwd)

            if gs > 1:
                offset = self.module().hidden_size
                if num_directions > 1:
                    remove_idx_bwd_gs = self.tile_indices_with_gate_size(remove_idx_bwd, gs, offset)
                    remove_idx_fwd_gs = self.tile_indices_with_gate_size(remove_idx_fwd, gs, offset)
                else:
                    remove_idx_gs = self.tile_indices_with_gate_size(remove_idx, gs, offset)

            for n in self.weight_mask:
                remove_idx_r = remove_idx
                remove_idx_c = remove_idx
                remove_idx_pc = None
                if num_directions > 1:
                    if n.endswith('_reverse'):
                        if gs > 1:
                            remove_idx_r = remove_idx_bwd_gs
                        else:
                            remove_idx_r = remove_idx_bwd
                        remove_idx_c = remove_idx_bwd
                        if has_proj:
                            remove_idx_pc = remove_idx_proj_bwd
                    else:
                        if gs > 1:
                            remove_idx_r = remove_idx_fwd_gs
                        else:
                            remove_idx_r = remove_idx_fwd
                        remove_idx_c = remove_idx_fwd
                        if has_proj:
                            remove_idx_pc = remove_idx_proj_fwd
                elif gs > 1:
                    remove_idx_r = remove_idx_gs
                    remove_idx_pc = remove_idx_proj

                if n.startswith('weight_ih_l0'):
                    self.weight_mask[n][remove_idx_r, :] = 0
                elif n.startswith('weight_ih'):
                    self.weight_mask[n][remove_idx_r, :] = 0
                    if remove_idx_proj is None:
                        self.weight_mask[n][:, remove_idx] = 0
                    else:
                        self.weight_mask[n][:, remove_idx_proj] = 0
                    self.masker().register_mask(n, self.weight_mask[n])
                elif n.startswith('weight_hh'):
                    self.weight_mask[n][remove_idx_r, :] = 0
                    if remove_idx_pc is None:
                        self.weight_mask[n][:, remove_idx_c] = 0
                    else:
                        self.weight_mask[n][:, remove_idx_pc] = 0
                    self.masker().register_mask(n, self.weight_mask[n])
                elif n.startswith('weight_hr'):
                    if remove_idx_pc is not None:
                        self.weight_mask[n][remove_idx_pc, :] = 0
                    self.weight_mask[n][:, remove_idx_c] = 0
                    self.masker().register_mask(n, self.weight_mask[n])

            for n in self.bias_mask:
                if self.bias_mask[n] is None:
                    continue
                remove_idx_ = remove_idx
                if num_directions > 1:
                    if n.endswith('_reverse'):
                        if gs > 1:
                            remove_idx_ = remove_idx_bwd_gs
                        else:
                            remove_idx_ = remove_idx_bwd
                    else:
                        if gs > 1:
                            remove_idx_ = remove_idx_fwd_gs
                        else:
                            remove_idx_ = remove_idx_fwd
                elif gs > 1:
                    remove_idx_ = remove_idx_gs
                self.bias_mask[n][remove_idx_] = 0
                self.masker().register_mask(n, self.bias_mask[n])
            self.masker().set_ot_remove_idx(remove_idx)

            if remove_idx_proj is not None:
                self.masker().set_custom_remove_idx(remove_idx_proj)

        self.masker().register_mask('weight_ih_l0', self.weight_mask['weight_ih_l0'])