def rename_keys()

in src/transformers/models/perceiver/convert_perceiver_haiku_to_pytorch.py [0:0]


def rename_keys(state_dict, architecture):
    for name in list(state_dict):
        param = state_dict.pop(name)

        # PREPROCESSORS
        # rename text preprocessor embeddings (for MLM model)
        name = name.replace("embed/embeddings", "input_preprocessor.embeddings.weight")
        if name.startswith("trainable_position_encoding/pos_embs"):
            name = name.replace(
                "trainable_position_encoding/pos_embs", "input_preprocessor.position_embeddings.weight"
            )

        # rename image preprocessor embeddings (for image classification model with learned position embeddings)
        name = name.replace("image_preprocessor/~/conv2_d/w", "input_preprocessor.convnet_1x1.weight")
        name = name.replace("image_preprocessor/~/conv2_d/b", "input_preprocessor.convnet_1x1.bias")
        name = name.replace(
            "image_preprocessor/~_build_network_inputs/trainable_position_encoding/pos_embs",
            "input_preprocessor.position_embeddings.position_embeddings",
        )
        name = name.replace(
            "image_preprocessor/~_build_network_inputs/position_encoding_projector/linear/w",
            "input_preprocessor.positions_projection.weight",
        )
        name = name.replace(
            "image_preprocessor/~_build_network_inputs/position_encoding_projector/linear/b",
            "input_preprocessor.positions_projection.bias",
        )

        # rename image preprocessor embeddings (for image classification model with conv processing)
        if "counter" in name or "hidden" in name:
            continue
        name = name.replace(
            "image_preprocessor/~/conv2_d_downsample/~/conv/w", "input_preprocessor.convnet.conv.weight"
        )
        name = name.replace(
            "image_preprocessor/~/conv2_d_downsample/~/batchnorm/offset", "input_preprocessor.convnet.batchnorm.bias"
        )
        name = name.replace(
            "image_preprocessor/~/conv2_d_downsample/~/batchnorm/scale", "input_preprocessor.convnet.batchnorm.weight"
        )
        name = name.replace(
            "image_preprocessor/~/conv2_d_downsample/~/batchnorm/~/mean_ema/average",
            "input_preprocessor.convnet.batchnorm.running_mean",
        )
        name = name.replace(
            "image_preprocessor/~/conv2_d_downsample/~/batchnorm/~/var_ema/average",
            "input_preprocessor.convnet.batchnorm.running_var",
        )

        # rename image preprocessor embeddings (for optical flow model)
        name = name.replace("image_preprocessor/patches_linear/b", "input_preprocessor.conv_after_patches.bias")
        name = name.replace("image_preprocessor/patches_linear/w", "input_preprocessor.conv_after_patches.weight")

        # rename multimodal preprocessor embeddings
        name = name.replace("multimodal_preprocessor/audio_mask_token/pos_embs", "input_preprocessor.mask.audio")
        name = name.replace("multimodal_preprocessor/audio_padding/pos_embs", "input_preprocessor.padding.audio")
        name = name.replace("multimodal_preprocessor/image_mask_token/pos_embs", "input_preprocessor.mask.image")
        name = name.replace("multimodal_preprocessor/image_padding/pos_embs", "input_preprocessor.padding.image")
        name = name.replace("multimodal_preprocessor/label_mask_token/pos_embs", "input_preprocessor.mask.label")
        name = name.replace("multimodal_preprocessor/label_padding/pos_embs", "input_preprocessor.padding.label")

        # DECODERS
        # rename prefix of decoders
        # multimodal autoencoding model
        name = name.replace(
            "multimodal_decoder/~/basic_decoder/cross_attention/", "decoder.decoder.decoding_cross_attention."
        )
        name = name.replace("multimodal_decoder/~decoder_query/audio_padding/pos_embs", "decoder.padding.audio")
        name = name.replace("multimodal_decoder/~decoder_query/image_padding/pos_embs", "decoder.padding.image")
        name = name.replace("multimodal_decoder/~decoder_query/label_padding/pos_embs", "decoder.padding.label")
        name = name.replace("multimodal_decoder/~/basic_decoder/output/b", "decoder.decoder.final_layer.bias")
        name = name.replace("multimodal_decoder/~/basic_decoder/output/w", "decoder.decoder.final_layer.weight")
        if architecture == "multimodal_autoencoding":
            name = name.replace(
                "classification_decoder/~/basic_decoder/~/trainable_position_encoding/pos_embs",
                "decoder.modalities.label.decoder.output_position_encodings.position_embeddings",
            )
        # flow model
        name = name.replace(
            "flow_decoder/~/basic_decoder/cross_attention/", "decoder.decoder.decoding_cross_attention."
        )
        name = name.replace("flow_decoder/~/basic_decoder/output/w", "decoder.decoder.final_layer.weight")
        name = name.replace("flow_decoder/~/basic_decoder/output/b", "decoder.decoder.final_layer.bias")
        # image models
        name = name.replace(
            "classification_decoder/~/basic_decoder/~/trainable_position_encoding/pos_embs",
            "decoder.decoder.output_position_encodings.position_embeddings",
        )
        name = name.replace(
            "basic_decoder/~/trainable_position_encoding/pos_embs",
            "decoder.output_position_encodings.position_embeddings",
        )
        name = name.replace(
            "classification_decoder/~/basic_decoder/cross_attention/", "decoder.decoder.decoding_cross_attention."
        )
        name = name.replace("classification_decoder/~/basic_decoder/output/b", "decoder.decoder.final_layer.bias")
        name = name.replace("classification_decoder/~/basic_decoder/output/w", "decoder.decoder.final_layer.weight")
        name = name = name.replace("classification_decoder/~/basic_decoder/~/", "decoder.decoder.")
        name = name.replace("basic_decoder/cross_attention/", "decoder.decoding_cross_attention.")
        name = name.replace("basic_decoder/~/", "decoder.")

        # POSTPROCESSORS
        name = name.replace(
            "projection_postprocessor/linear/b", "output_postprocessor.modalities.image.classifier.bias"
        )
        name = name.replace(
            "projection_postprocessor/linear/w", "output_postprocessor.modalities.image.classifier.weight"
        )
        name = name.replace(
            "classification_postprocessor/linear/b", "output_postprocessor.modalities.label.classifier.bias"
        )
        name = name.replace(
            "classification_postprocessor/linear/w", "output_postprocessor.modalities.label.classifier.weight"
        )
        name = name.replace("audio_postprocessor/linear/b", "output_postprocessor.modalities.audio.classifier.bias")
        name = name.replace("audio_postprocessor/linear/w", "output_postprocessor.modalities.audio.classifier.weight")

        # PERCEIVER MODEL

        # rename latent embeddings
        name = name.replace("perceiver_encoder/~/trainable_position_encoding/pos_embs", "embeddings.latents")
        # rename latent embeddings (for multimodal model)
        name = name.replace("encoder/~/trainable_position_encoding/pos_embs", "embeddings.latents")

        # rename prefixes
        if name.startswith("perceiver_encoder/~/"):
            if "self_attention" in name:
                suffix = "self_attends."
            else:
                suffix = ""
            name = name.replace("perceiver_encoder/~/", "encoder." + suffix)
        if name.startswith("encoder/~/"):
            if "self_attention" in name:
                suffix = "self_attends."
            else:
                suffix = ""
            name = name.replace("encoder/~/", "encoder." + suffix)
        # rename layernorm parameters
        if "offset" in name:
            name = name.replace("offset", "bias")
        if "scale" in name:
            name = name.replace("scale", "weight")
        # in HuggingFace, the layernorm in between attention + MLP is just called "layernorm"
        # rename layernorm in between attention + MLP of cross-attention
        if "cross_attention" in name and "layer_norm_2" in name:
            name = name.replace("layer_norm_2", "layernorm")
        # rename layernorm in between attention + MLP of self-attention
        if "self_attention" in name and "layer_norm_1" in name:
            name = name.replace("layer_norm_1", "layernorm")

        # in HuggingFace, the layernorms for queries + keys are called "layernorm1" and "layernorm2"
        if "cross_attention" in name and "layer_norm_1" in name:
            name = name.replace("layer_norm_1", "attention.self.layernorm2")
        if "cross_attention" in name and "layer_norm" in name:
            name = name.replace("layer_norm", "attention.self.layernorm1")
        if "self_attention" in name and "layer_norm" in name:
            name = name.replace("layer_norm", "attention.self.layernorm1")

        # rename special characters by dots
        name = name.replace("-", ".")
        name = name.replace("/", ".")
        # rename keys, queries, values and output of attention layers
        if ("cross_attention" in name or "self_attention" in name) and "mlp" not in name:
            if "linear.b" in name:
                name = name.replace("linear.b", "self.query.bias")
            if "linear.w" in name:
                name = name.replace("linear.w", "self.query.weight")
            if "linear_1.b" in name:
                name = name.replace("linear_1.b", "self.key.bias")
            if "linear_1.w" in name:
                name = name.replace("linear_1.w", "self.key.weight")
            if "linear_2.b" in name:
                name = name.replace("linear_2.b", "self.value.bias")
            if "linear_2.w" in name:
                name = name.replace("linear_2.w", "self.value.weight")
            if "linear_3.b" in name:
                name = name.replace("linear_3.b", "output.dense.bias")
            if "linear_3.w" in name:
                name = name.replace("linear_3.w", "output.dense.weight")
        if "self_attention_" in name:
            name = name.replace("self_attention_", "")
        if "self_attention" in name:
            name = name.replace("self_attention", "0")
        # rename dense layers of 2-layer MLP
        if "mlp" in name:
            if "linear.b" in name:
                name = name.replace("linear.b", "dense1.bias")
            if "linear.w" in name:
                name = name.replace("linear.w", "dense1.weight")
            if "linear_1.b" in name:
                name = name.replace("linear_1.b", "dense2.bias")
            if "linear_1.w" in name:
                name = name.replace("linear_1.w", "dense2.weight")

        # finally, TRANSPOSE if kernel and not embedding layer, and set value
        if name[-6:] == "weight" and "embeddings" not in name:
            param = np.transpose(param)

        # if batchnorm, we need to squeeze it
        if "batchnorm" in name:
            param = np.squeeze(param)

        if "embedding_decoder" not in name:
            state_dict["perceiver." + name] = torch.from_numpy(param)
        else:
            state_dict[name] = torch.from_numpy(param)