in tinynn/graph/modifier.py [0:0]
def register_mask(self, modifiers, importance, sparsity):
gs = rnn_gate_size(self.module())
num_directions = 2 if self.module().bidirectional else 1
has_proj = hasattr(self.module(), 'proj_size') and self.module().proj_size > 0
if self.dim_changes_info.pruned_idx_i:
remove_idx = self.dim_changes_info.pruned_idx_i
self.weight_mask['weight_ih_l0'][:, remove_idx] = 0
self.masker().set_in_remove_idx(remove_idx)
if self.dim_changes_info.pruned_idx_o:
if has_proj:
u_name = self.unique_name()
hu_name = f'{u_name}:h'
remove_idx = []
idx_num = len(importance[hu_name])
remove_num = int(sparsity[u_name] * len(importance[hu_name]))
if self.bidirectional:
idx_num //= 2
remove_num //= 2
remove_idx += get_smallest_k(importance[hu_name][:idx_num], remove_num)
remove_idx += get_smallest_k(importance[hu_name][idx_num:], remove_num, offset=idx_num)
else:
remove_idx += get_smallest_k(importance[hu_name], remove_num)
remove_idx_proj = self.dim_changes_info.pruned_idx_o
else:
remove_idx = self.dim_changes_info.pruned_idx_o
remove_idx_proj = None
remove_idx_bwd = None
remove_idx_fwd = None
remove_idx_proj_bwd = None
remove_idx_proj_fwd = None
if num_directions > 1:
offset = self.module().hidden_size
remove_idx_fwd, remove_idx_bwd = self.split_indices_with_directions(remove_idx, offset, num_directions)
if remove_idx_proj is not None:
offset = self.module().proj_size
remove_idx_proj_fwd, remove_idx_proj_bwd = self.split_indices_with_directions(
remove_idx_proj, offset, num_directions
)
assert len(remove_idx_proj_fwd) == len(remove_idx_proj_bwd)
if gs > 1:
offset = self.module().hidden_size
if num_directions > 1:
remove_idx_bwd_gs = self.tile_indices_with_gate_size(remove_idx_bwd, gs, offset)
remove_idx_fwd_gs = self.tile_indices_with_gate_size(remove_idx_fwd, gs, offset)
else:
remove_idx_gs = self.tile_indices_with_gate_size(remove_idx, gs, offset)
for n in self.weight_mask:
remove_idx_r = remove_idx
remove_idx_c = remove_idx
remove_idx_pc = None
if num_directions > 1:
if n.endswith('_reverse'):
if gs > 1:
remove_idx_r = remove_idx_bwd_gs
else:
remove_idx_r = remove_idx_bwd
remove_idx_c = remove_idx_bwd
if has_proj:
remove_idx_pc = remove_idx_proj_bwd
else:
if gs > 1:
remove_idx_r = remove_idx_fwd_gs
else:
remove_idx_r = remove_idx_fwd
remove_idx_c = remove_idx_fwd
if has_proj:
remove_idx_pc = remove_idx_proj_fwd
elif gs > 1:
remove_idx_r = remove_idx_gs
remove_idx_pc = remove_idx_proj
if n.startswith('weight_ih_l0'):
self.weight_mask[n][remove_idx_r, :] = 0
elif n.startswith('weight_ih'):
self.weight_mask[n][remove_idx_r, :] = 0
if remove_idx_proj is None:
self.weight_mask[n][:, remove_idx] = 0
else:
self.weight_mask[n][:, remove_idx_proj] = 0
self.masker().register_mask(n, self.weight_mask[n])
elif n.startswith('weight_hh'):
self.weight_mask[n][remove_idx_r, :] = 0
if remove_idx_pc is None:
self.weight_mask[n][:, remove_idx_c] = 0
else:
self.weight_mask[n][:, remove_idx_pc] = 0
self.masker().register_mask(n, self.weight_mask[n])
elif n.startswith('weight_hr'):
if remove_idx_pc is not None:
self.weight_mask[n][remove_idx_pc, :] = 0
self.weight_mask[n][:, remove_idx_c] = 0
self.masker().register_mask(n, self.weight_mask[n])
for n in self.bias_mask:
if self.bias_mask[n] is None:
continue
remove_idx_ = remove_idx
if num_directions > 1:
if n.endswith('_reverse'):
if gs > 1:
remove_idx_ = remove_idx_bwd_gs
else:
remove_idx_ = remove_idx_bwd
else:
if gs > 1:
remove_idx_ = remove_idx_fwd_gs
else:
remove_idx_ = remove_idx_fwd
elif gs > 1:
remove_idx_ = remove_idx_gs
self.bias_mask[n][remove_idx_] = 0
self.masker().register_mask(n, self.bias_mask[n])
self.masker().set_ot_remove_idx(remove_idx)
if remove_idx_proj is not None:
self.masker().set_custom_remove_idx(remove_idx_proj)
self.masker().register_mask('weight_ih_l0', self.weight_mask['weight_ih_l0'])