def replace_transformer_module()

in src/transformers/models/oneformer/convert_to_hf_oneformer.py [0:0]


    def replace_transformer_module(self, dst_state_dict: StateDict, src_state_dict: StateDict):
        dst_prefix: str = "transformer_module"
        src_prefix: str = "sem_seg_head.predictor"

        def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str):
            return [
                (f"{src_prefix}.weight", f"{dst_prefix}.weight"),
                (f"{src_prefix}.bias", f"{dst_prefix}.bias"),
            ]

        def rename_keys_for_attn(src_prefix: str, dst_prefix: str):
            attn_keys = [
                (f"{src_prefix}.in_proj_bias", f"{dst_prefix}.in_proj_bias"),
                (f"{src_prefix}.in_proj_weight", f"{dst_prefix}.in_proj_weight"),
            ]
            attn_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.out_proj", f"{dst_prefix}.out_proj"))

            return attn_keys

        def rename_keys_for_self_attn(src_prefix: str, dst_prefix: str):
            attn_keys = []
            attn_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.out_proj", f"{dst_prefix}.out_proj"))

            return attn_keys

        def rename_keys_for_query_transformer_layer(src_prefix: str, dst_prefix: str):
            query_transformer_layer_keys = []

            query_transformer_layer_keys.extend(
                rename_keys_for_weight_bias(f"{src_prefix}.linear1", f"{dst_prefix}.linear1")
            )
            query_transformer_layer_keys.extend(
                rename_keys_for_weight_bias(f"{src_prefix}.linear2", f"{dst_prefix}.linear2")
            )
            query_transformer_layer_keys.extend(
                rename_keys_for_weight_bias(f"{src_prefix}.norm1", f"{dst_prefix}.norm1")
            )
            query_transformer_layer_keys.extend(
                rename_keys_for_weight_bias(f"{src_prefix}.norm2", f"{dst_prefix}.norm2")
            )
            query_transformer_layer_keys.extend(
                rename_keys_for_weight_bias(f"{src_prefix}.norm3", f"{dst_prefix}.norm3")
            )

            query_transformer_layer_keys.extend(
                rename_keys_for_attn(f"{src_prefix}.self_attn", f"{dst_prefix}.self_attn")
            )

            query_transformer_layer_keys.extend(
                rename_keys_for_attn(f"{src_prefix}.multihead_attn", f"{dst_prefix}.multihead_attn")
            )

            return query_transformer_layer_keys

        def rename_keys_for_cross_attn_layer(src_prefix: str, dst_prefix: str):
            cross_attn_layer_keys = []

            cross_attn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.norm", f"{dst_prefix}.norm"))
            cross_attn_layer_keys.extend(
                rename_keys_for_attn(f"{src_prefix}.multihead_attn", f"{dst_prefix}.multihead_attn")
            )

            return cross_attn_layer_keys

        def rename_keys_for_self_attn_layer(src_prefix: str, dst_prefix: str):
            self_attn_layer_keys = []

            self_attn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.norm", f"{dst_prefix}.norm"))
            self_attn_layer_keys.extend(
                rename_keys_for_self_attn(f"{src_prefix}.self_attn", f"{dst_prefix}.self_attn")
            )

            return self_attn_layer_keys

        def rename_keys_for_ffn_layer(src_prefix: str, dst_prefix: str):
            ffn_layer_keys = []

            ffn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear1", f"{dst_prefix}.linear1"))
            ffn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear2", f"{dst_prefix}.linear2"))
            ffn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.norm", f"{dst_prefix}.norm"))

            return ffn_layer_keys

        def rename_keys_for_transformer_decoder_layer(src_prefix: str, dst_prefix: str, idx: int):
            transformer_decoder_layer_keys = []

            transformer_decoder_layer_keys.extend(
                rename_keys_for_cross_attn_layer(
                    f"{src_prefix}.transformer_cross_attention_layers.{idx}", f"{dst_prefix}.{idx}.cross_attn"
                )
            )

            transformer_decoder_layer_keys.extend(
                rename_keys_for_self_attn_layer(
                    f"{src_prefix}.transformer_self_attention_layers.{idx}", f"{dst_prefix}.{idx}.self_attn"
                )
            )

            transformer_decoder_layer_keys.extend(
                rename_keys_for_ffn_layer(f"{src_prefix}.transformer_ffn_layers.{idx}", f"{dst_prefix}.{idx}.ffn")
            )

            return transformer_decoder_layer_keys

        # positional embedding for object queries
        renamed_keys = [
            (f"{src_prefix}.query_embed.weight", f"{dst_prefix}.queries_embedder.weight"),
            (f"{src_prefix}.level_embed.weight", f"{dst_prefix}.level_embed.weight"),
        ]

        # norm
        renamed_keys.extend(
            rename_keys_for_weight_bias(f"{src_prefix}.decoder_norm", f"{dst_prefix}.decoder.decoder_norm")
        )

        # proj
        renamed_keys.extend(
            rename_keys_for_weight_bias(
                f"{src_prefix}.class_input_proj", f"{dst_prefix}.decoder.query_input_projection"
            )
        )

        renamed_keys.extend(
            rename_keys_for_weight_bias(f"{src_prefix}.class_embed", f"{dst_prefix}.decoder.class_embed")
        )

        for i in range(3):
            renamed_keys.extend(
                rename_keys_for_weight_bias(
                    f"{src_prefix}.mask_embed.layers.{i}", f"{dst_prefix}.decoder.mask_embed.layers.{i}.0"
                )
            )

        # norm
        renamed_keys.extend(
            rename_keys_for_weight_bias(
                f"{src_prefix}.class_transformer.decoder.norm", f"{dst_prefix}.decoder.query_transformer.decoder.norm"
            )
        )

        # transformer to update queries with task tokens
        for i in range(self.config.query_dec_layers):
            renamed_keys.extend(
                rename_keys_for_query_transformer_layer(
                    f"{src_prefix}.class_transformer.decoder.layers.{i}",
                    f"{dst_prefix}.decoder.query_transformer.decoder.layers.{i}",
                )
            )

        # decoder layers
        for i in range(self.config.decoder_layers - 1):
            renamed_keys.extend(
                rename_keys_for_transformer_decoder_layer(
                    f"{src_prefix}",
                    f"{dst_prefix}.decoder.layers",
                    i,
                )
            )

        self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
        self.replace_keys_qkv_transformer_decoder(dst_state_dict, src_state_dict)