def modify_output()

in tinynn/graph/modifier.py [0:0]


    def modify_output(self, remove_idx):
        rnn = self.node.module

        log.debug(f'[RNN] {self.unique_name()}: remove_idx = {remove_idx}')

        num_directions = 2 if rnn.bidirectional else 1
        has_proj = hasattr(self.module(), 'proj_size') and self.module().proj_size > 0
        gs = rnn_gate_size(rnn)
        if num_directions > 1:
            offset = rnn.hidden_size
            remove_idx_fwd, remove_idx_bwd = self.split_indices_with_directions(remove_idx, offset, num_directions)

        if gs > 1:
            offset = rnn.hidden_size
            if num_directions > 1:
                remove_idx_bwd_gs = self.tile_indices_with_gate_size(remove_idx_bwd, gs, offset)
                remove_idx_fwd_gs = self.tile_indices_with_gate_size(remove_idx_fwd, gs, offset)
            else:
                remove_idx_gs = self.tile_indices_with_gate_size(remove_idx, gs, offset)

        remove_idx_proj = None
        if has_proj:
            remove_idx_proj = self.masker().custom_remove_idx
            if remove_idx_proj is not None:
                offset = rnn.proj_size
                remove_idx_proj_fwd, remove_idx_proj_bwd = self.split_indices_with_directions(
                    remove_idx_proj, offset, num_directions
                )

        for i in range(rnn.num_layers):
            for j in range(num_directions):
                suffix = '_reverse' if j > 0 else ''
                desc = f'layer{suffix} hidden #{i}'

                weight_ih = getattr(rnn, f'weight_ih_l{i}{suffix}')
                weight_hh = getattr(rnn, f'weight_hh_l{i}{suffix}')
                weight_hr = getattr(rnn, f'weight_hr_l{i}{suffix}', None)

                bias_ih = getattr(rnn, f'bias_ih_l{i}{suffix}', None)
                bias_hh = getattr(rnn, f'bias_hh_l{i}{suffix}', None)

                remove_idx_r = remove_idx
                remove_idx_c = remove_idx
                remove_idx_pc = None
                if num_directions > 1:
                    if j > 0:
                        if gs > 1:
                            remove_idx_r = remove_idx_bwd_gs
                        else:
                            remove_idx_r = remove_idx_bwd
                        remove_idx_c = remove_idx_bwd
                        if has_proj:
                            remove_idx_pc = remove_idx_proj_bwd
                    else:
                        if gs > 1:
                            remove_idx_r = remove_idx_fwd_gs
                        else:
                            remove_idx_r = remove_idx_fwd
                        remove_idx_c = remove_idx_fwd
                        if has_proj:
                            remove_idx_pc = remove_idx_proj_fwd
                elif gs > 1:
                    remove_idx_r = remove_idx_gs
                    remove_idx_pc = remove_idx_proj

                preserve_idx_ih_r = complementary_list(
                    [j for j in range(self.weight_mask[f'weight_ih_l{i}{suffix}'].shape[0])], remove_idx_r
                )
                preserve_idx_hh_r = complementary_list(
                    [j for j in range(self.weight_mask[f'weight_hh_l{i}{suffix}'].shape[0])], remove_idx_r
                )

                if weight_hr is None:
                    preserve_idx_hh_c = complementary_list(
                        [j for j in range(self.weight_mask[f'weight_hh_l{i}{suffix}'].shape[1])], remove_idx_c
                    )
                else:
                    preserve_idx_hh_c = complementary_list(
                        [j for j in range(self.weight_mask[f'weight_hh_l{i}{suffix}'].shape[1])], remove_idx_pc
                    )
                    preserve_idx_hr_c = complementary_list(
                        [j for j in range(self.weight_mask[f'weight_hr_l{i}{suffix}'].shape[1])], remove_idx_c
                    )

                preserve_idx_ih_c = None
                if i != 0 and preserve_idx_ih_c is None:
                    if weight_hr is not None:
                        preserve_idx_ih_c = complementary_list(
                            [j for j in range(self.weight_mask[f'weight_ih_l{i}{suffix}'].shape[1])], remove_idx_proj
                        )
                    else:
                        preserve_idx_ih_c = preserve_idx_ih_r
                        if num_directions > 1 or gs > 1:
                            preserve_idx_ih_c = complementary_list(
                                [j for j in range(self.weight_mask[f'weight_ih_l{i}{suffix}'].shape[1])], remove_idx
                            )

                if weight_ih.shape[0] != len(preserve_idx_ih_r):
                    if i != 0 and weight_ih.shape[1] != len(preserve_idx_ih_c):
                        desc_i = f'layer{suffix} input #{i}'
                        log.info(
                            f'[RNN] {self.unique_name()}: {desc_i} {weight_ih.shape[1]} -> {len(preserve_idx_ih_c)}'
                        )

                    log.info(f'[RNN] {self.unique_name()}: {desc} {rnn.hidden_size * gs} -> {len(preserve_idx_ih_r)}')

                    if i != 0:
                        new_w = weight_ih[preserve_idx_ih_r, :][:, preserve_idx_ih_c]
                        setattr(rnn, f'weight_ih_l{i}{suffix}', torch.nn.Parameter(new_w))
                    else:
                        setattr(rnn, f'weight_ih_l{i}{suffix}', torch.nn.Parameter(weight_ih[preserve_idx_ih_r, :]))

                    if bias_ih is not None:
                        setattr(rnn, f'bias_ih_l{i}{suffix}', torch.nn.Parameter(bias_ih[preserve_idx_ih_r]))

                desc = f'layer{suffix} output #{i}'
                if weight_hh.shape[0] != len(preserve_idx_hh_r) or weight_hh.shape[1] != len(preserve_idx_hh_c):
                    log.info(f'[RNN] {self.unique_name()}: {desc} {rnn.hidden_size * gs} -> {len(preserve_idx_hh_r)}')

                    if weight_hr is None:
                        setattr(
                            rnn,
                            f'weight_hh_l{i}{suffix}',
                            torch.nn.Parameter(weight_hh[preserve_idx_hh_r, :][:, preserve_idx_hh_c]),
                        )
                    else:
                        setattr(
                            rnn,
                            f'weight_hh_l{i}{suffix}',
                            torch.nn.Parameter(weight_hh[preserve_idx_hh_r, :][:, preserve_idx_hh_c]),
                        )
                        setattr(
                            rnn,
                            f'weight_hr_l{i}{suffix}',
                            torch.nn.Parameter(weight_hr[preserve_idx_hh_c, :][:, preserve_idx_hr_c]),
                        )

                    if bias_hh is not None:
                        setattr(rnn, f'bias_hh_l{i}{suffix}', torch.nn.Parameter(bias_hh[preserve_idx_hh_r]))

        if weight_hr is None:
            rnn.hidden_size = len(preserve_idx_hh_c)
        else:
            rnn.proj_size = len(preserve_idx_hh_c)
            rnn.hidden_size = len(preserve_idx_hr_c)