def convert_model_weight()

in parlai/agents/bart/convert_fairseq_to_parlai.py [0:0]


    def convert_model_weight(self, opt: Opt) -> Dict[str, Any]:
        """
        Convert state_dict between fairseq and ParlAI.

        :param opt:
            ParlAI opt

        :return state_dict:
            return a state dict to load into ParlAI model.
        """
        # deal with embeddings
        state = self.state
        agent = self.agent
        state_dict = state['model']
        return_dict = OrderedDict()
        for each_key in state_dict.keys():
            mapped_key = each_key
            if mapped_key == 'encoder.version' or mapped_key == 'decoder.version':
                continue

            # 1. replace if embedding
            for emb in EMBEDDING_DICT_MAPPING:
                mapped_key = mapped_key.replace(emb, EMBEDDING_DICT_MAPPING[emb])

            # 2. Replace attention
            if 'encoder' in each_key and 'self_attn' in each_key:
                mapped_key = mapped_key.replace('self_attn', 'attention')
            elif 'decoder' in each_key and 'self_attn' in each_key:
                mapped_key = mapped_key.replace('self_attn', 'self_attention')
            elif 'decoder' in each_key and 'encoder_attn' in each_key:
                mapped_key = mapped_key.replace('encoder_attn', 'encoder_attention')

            # 3. Replace multihead linear layers
            #    fairseq sometimes chunks all three layers into one model weight
            if 'in_proj_weight' in mapped_key or 'in_proj_bias' in mapped_key:
                for weightorbias in {'weight', 'bias'}:
                    attention_project_name = 'in_proj_{}'.format(weightorbias)
                    if attention_project_name in mapped_key:
                        weight = state_dict[each_key]
                        size = int(weight.size(0) / 3)
                        weights = weight.split(size, 0)
                        # For Q, K, V in order
                        return_dict[
                            mapped_key.replace(
                                attention_project_name, 'q_lin.{}'.format(weightorbias)
                            )
                        ] = weights[0]
                        return_dict[
                            mapped_key.replace(
                                attention_project_name, 'k_lin.{}'.format(weightorbias)
                            )
                        ] = weights[1]
                        return_dict[
                            mapped_key.replace(
                                attention_project_name, 'v_lin.{}'.format(weightorbias)
                            )
                        ] = weights[2]
                continue
            elif (
                'v_proj' in mapped_key
                or 'k_proj' in mapped_key
                or 'q_proj' in mapped_key
            ):
                mapped_key = mapped_key.replace('v_proj', 'v_lin')
                mapped_key = mapped_key.replace('q_proj', 'q_lin')
                mapped_key = mapped_key.replace('k_proj', 'k_lin')

            # 4. Replace FFN layers
            for old, new in FFN_MAPPING.items():
                mapped_key = mapped_key.replace(old, new)

            # 5. Fix layer norms
            if 'encoder.' in mapped_key:
                mapped_key = mapped_key.replace('attention_layer_norm', 'norm1')
                mapped_key = mapped_key.replace('final_layer_norm', 'norm2')
            else:
                mapped_key = mapped_key.replace('self_attention_layer_norm', 'norm1')
                mapped_key = mapped_key.replace('encoder_attention_layer_norm', 'norm2')
                mapped_key = mapped_key.replace('final_layer_norm', 'norm3')

            for _key in ['encoder', 'decoder']:
                mapped_key = mapped_key.replace(
                    f'{_key}.layer_norm', f'{_key}.norm_embeddings'
                )
                mapped_key = mapped_key.replace(
                    f'{_key}.layernorm_embedding', f'{_key}.norm_embeddings'
                )

            weight = state_dict[each_key]
            return_dict[mapped_key] = weight

        # 6. Shuffle embedding matrix given dictionary.
        enc_emb_key = 'encoder.embeddings.weight'
        bart_dict = os.path.join(opt['datapath'], 'models/bart/bart.large/dict.txt')
        with PathManager.open(bart_dict) as f:
            offset_dict = {i: l.split()[0] for i, l in enumerate(f.readlines())}
        new_embs = return_dict[enc_emb_key].clone()
        for idx, new_idx in offset_dict.items():
            try:
                new_embs[int(new_idx) + 4] = return_dict[enc_emb_key][idx + 4]
            except ValueError:
                # if idx is not an int
                if 'madeupword' in new_idx:
                    pad_idx = int(new_idx.split('madeupword')[1])
                    new_embs[-(4 - pad_idx)] = return_dict['encoder.embeddings.weight'][
                        idx + 4
                    ]
        return_dict['encoder.embeddings.weight'] = new_embs

        # 7. Swap special tokens
        #    Fairseq swaps the bos and eos token order for seq2seq models.
        #
        #   ParlAI s2s models expect:
        #       Encoder: TOKENS </s>
        #       Decoder: <s> TOKENS <s>
        #   Fairseq models get:
        #       Encoder: TOKENS </s>
        #       Decoder: </s> TOKENS <s>
        #
        #   So we swap to get:
        #       Encoder: TOKENS </s>
        #       Decoder: </s> TOKENS <s>
        #
        size_dict = return_dict[enc_emb_key].size(0)
        if size_dict == len(agent.dict) + 1 and '<mask>' not in agent.dict:
            return_dict[enc_emb_key] = return_dict[enc_emb_key][: size_dict - 1, :]
            size_dict -= 1
        specials, words = return_dict[enc_emb_key].split([4, size_dict - 4], 0)
        bos, pad, eos, unk = specials
        if not self.opt['retain_bos_emb']:
            bos = eos
        specials = torch.stack([pad, bos, eos, unk])
        fp16_pad = (8 - (len(specials) + len(words)) % 8) % 8
        fp16_pad_ez = torch.zeros(fp16_pad, specials.size(1)).type_as(specials)
        return_dict[enc_emb_key] = torch.cat(
            [
                specials,  # special tokens
                words,  # word embeddings
                fp16_pad_ez,  # fp16 requires embeddings size to be a multiple of 8
            ],
            0,
        )

        return_dict['decoder.embeddings.weight'] = return_dict[enc_emb_key]
        return_dict['embeddings.weight'] = return_dict[enc_emb_key]

        # 8. Positional Embeddings
        if 'encoder.position_embeddings.weight' in return_dict:
            return_dict['encoder.position_embeddings.weight'] = return_dict[
                'encoder.position_embeddings.weight'
            ][2:, :]
            return_dict['decoder.position_embeddings.weight'] = return_dict[
                'decoder.position_embeddings.weight'
            ][2:, :]
        else:
            # sinusoidal embeddings
            from fairseq.modules.sinusoidal_positional_embedding import (
                SinusoidalPositionalEmbedding,
            )

            emb = SinusoidalPositionalEmbedding.get_embedding(
                128 + 2, opt['embedding_size'], 1
            )
            del return_dict['encoder.position_embeddings._float_tensor']
            del return_dict['decoder.position_embeddings._float_tensor']

            return_dict['encoder.position_embeddings.weight'] = emb[2:]
            return_dict['decoder.position_embeddings.weight'] = emb[2:]

        return_dict['START'] = torch.LongTensor([1])  # type: ignore
        return return_dict