in tinynn/graph/modifier.py [0:0]
def modify_output(self, remove_idx):
rnn = self.node.module
log.debug(f'[RNN] {self.unique_name()}: remove_idx = {remove_idx}')
num_directions = 2 if rnn.bidirectional else 1
has_proj = hasattr(self.module(), 'proj_size') and self.module().proj_size > 0
gs = rnn_gate_size(rnn)
if num_directions > 1:
offset = rnn.hidden_size
remove_idx_fwd, remove_idx_bwd = self.split_indices_with_directions(remove_idx, offset, num_directions)
if gs > 1:
offset = rnn.hidden_size
if num_directions > 1:
remove_idx_bwd_gs = self.tile_indices_with_gate_size(remove_idx_bwd, gs, offset)
remove_idx_fwd_gs = self.tile_indices_with_gate_size(remove_idx_fwd, gs, offset)
else:
remove_idx_gs = self.tile_indices_with_gate_size(remove_idx, gs, offset)
remove_idx_proj = None
if has_proj:
remove_idx_proj = self.masker().custom_remove_idx
if remove_idx_proj is not None:
offset = rnn.proj_size
remove_idx_proj_fwd, remove_idx_proj_bwd = self.split_indices_with_directions(
remove_idx_proj, offset, num_directions
)
for i in range(rnn.num_layers):
for j in range(num_directions):
suffix = '_reverse' if j > 0 else ''
desc = f'layer{suffix} hidden #{i}'
weight_ih = getattr(rnn, f'weight_ih_l{i}{suffix}')
weight_hh = getattr(rnn, f'weight_hh_l{i}{suffix}')
weight_hr = getattr(rnn, f'weight_hr_l{i}{suffix}', None)
bias_ih = getattr(rnn, f'bias_ih_l{i}{suffix}', None)
bias_hh = getattr(rnn, f'bias_hh_l{i}{suffix}', None)
remove_idx_r = remove_idx
remove_idx_c = remove_idx
remove_idx_pc = None
if num_directions > 1:
if j > 0:
if gs > 1:
remove_idx_r = remove_idx_bwd_gs
else:
remove_idx_r = remove_idx_bwd
remove_idx_c = remove_idx_bwd
if has_proj:
remove_idx_pc = remove_idx_proj_bwd
else:
if gs > 1:
remove_idx_r = remove_idx_fwd_gs
else:
remove_idx_r = remove_idx_fwd
remove_idx_c = remove_idx_fwd
if has_proj:
remove_idx_pc = remove_idx_proj_fwd
elif gs > 1:
remove_idx_r = remove_idx_gs
remove_idx_pc = remove_idx_proj
preserve_idx_ih_r = complementary_list(
[j for j in range(self.weight_mask[f'weight_ih_l{i}{suffix}'].shape[0])], remove_idx_r
)
preserve_idx_hh_r = complementary_list(
[j for j in range(self.weight_mask[f'weight_hh_l{i}{suffix}'].shape[0])], remove_idx_r
)
if weight_hr is None:
preserve_idx_hh_c = complementary_list(
[j for j in range(self.weight_mask[f'weight_hh_l{i}{suffix}'].shape[1])], remove_idx_c
)
else:
preserve_idx_hh_c = complementary_list(
[j for j in range(self.weight_mask[f'weight_hh_l{i}{suffix}'].shape[1])], remove_idx_pc
)
preserve_idx_hr_c = complementary_list(
[j for j in range(self.weight_mask[f'weight_hr_l{i}{suffix}'].shape[1])], remove_idx_c
)
preserve_idx_ih_c = None
if i != 0 and preserve_idx_ih_c is None:
if weight_hr is not None:
preserve_idx_ih_c = complementary_list(
[j for j in range(self.weight_mask[f'weight_ih_l{i}{suffix}'].shape[1])], remove_idx_proj
)
else:
preserve_idx_ih_c = preserve_idx_ih_r
if num_directions > 1 or gs > 1:
preserve_idx_ih_c = complementary_list(
[j for j in range(self.weight_mask[f'weight_ih_l{i}{suffix}'].shape[1])], remove_idx
)
if weight_ih.shape[0] != len(preserve_idx_ih_r):
if i != 0 and weight_ih.shape[1] != len(preserve_idx_ih_c):
desc_i = f'layer{suffix} input #{i}'
log.info(
f'[RNN] {self.unique_name()}: {desc_i} {weight_ih.shape[1]} -> {len(preserve_idx_ih_c)}'
)
log.info(f'[RNN] {self.unique_name()}: {desc} {rnn.hidden_size * gs} -> {len(preserve_idx_ih_r)}')
if i != 0:
new_w = weight_ih[preserve_idx_ih_r, :][:, preserve_idx_ih_c]
setattr(rnn, f'weight_ih_l{i}{suffix}', torch.nn.Parameter(new_w))
else:
setattr(rnn, f'weight_ih_l{i}{suffix}', torch.nn.Parameter(weight_ih[preserve_idx_ih_r, :]))
if bias_ih is not None:
setattr(rnn, f'bias_ih_l{i}{suffix}', torch.nn.Parameter(bias_ih[preserve_idx_ih_r]))
desc = f'layer{suffix} output #{i}'
if weight_hh.shape[0] != len(preserve_idx_hh_r) or weight_hh.shape[1] != len(preserve_idx_hh_c):
log.info(f'[RNN] {self.unique_name()}: {desc} {rnn.hidden_size * gs} -> {len(preserve_idx_hh_r)}')
if weight_hr is None:
setattr(
rnn,
f'weight_hh_l{i}{suffix}',
torch.nn.Parameter(weight_hh[preserve_idx_hh_r, :][:, preserve_idx_hh_c]),
)
else:
setattr(
rnn,
f'weight_hh_l{i}{suffix}',
torch.nn.Parameter(weight_hh[preserve_idx_hh_r, :][:, preserve_idx_hh_c]),
)
setattr(
rnn,
f'weight_hr_l{i}{suffix}',
torch.nn.Parameter(weight_hr[preserve_idx_hh_c, :][:, preserve_idx_hr_c]),
)
if bias_hh is not None:
setattr(rnn, f'bias_hh_l{i}{suffix}', torch.nn.Parameter(bias_hh[preserve_idx_hh_r]))
if weight_hr is None:
rnn.hidden_size = len(preserve_idx_hh_c)
else:
rnn.proj_size = len(preserve_idx_hh_c)
rnn.hidden_size = len(preserve_idx_hr_c)