def _get_state_dict()

in torchaudio/pipelines/_wav2vec2/impl.py [0:0]


    def _get_state_dict(self, dl_kwargs):
        state_dict = super()._get_state_dict(dl_kwargs)
        if self._remove_aux_axis:
            # Remove the seemingly unnecessary axis
            # For ASR task, the pretrained weights originated from fairseq has unrelated dimensions at index 1, 2, 3
            # It's originated from the Dictionary implementation of fairseq, which was intended for NLP tasks,
            # but not used during the ASR training.
            # https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/data/dictionary.py#L21-L37
            # https://github.com/pytorch/fairseq/blob/c5ff181125c7e6126b49a85e5ebdd5f5b6a07914/fairseq/criterions/ctc.py#L126-L129
            #
            # Also, some pretrained weights originated from voxpopuli has an extra dimensions that almost never used and
            # that resembles mistake.
            # The label `1` shows up in the training dataset of German (1 out of 16M),
            # English (1 / 28M), Spanish (1 / 9.4M), Romanian (1 / 4.7M) and Polish (6 / 5.8M)
            for key in ["aux.weight", "aux.bias"]:
                t = state_dict[key]
                state_dict[key] = torch.stack([t[i] for i in range(t.size(0)) if i not in self._remove_aux_axis])
        return state_dict