def rename_keys()

in src/transformers/models/glpn/convert_glpn_to_pytorch.py [0:0]


def rename_keys(state_dict):
    new_state_dict = OrderedDict()
    for key, value in state_dict.items():
        if key.startswith("module.encoder"):
            key = key.replace("module.encoder", "glpn.encoder")
        if key.startswith("module.decoder"):
            key = key.replace("module.decoder", "decoder.stages")
        if "patch_embed" in key:
            # replace for example patch_embed1 by patch_embeddings.0
            idx = key[key.find("patch_embed") + len("patch_embed")]
            key = key.replace(f"patch_embed{idx}", f"patch_embeddings.{int(idx) - 1}")
        if "norm" in key:
            key = key.replace("norm", "layer_norm")
        if "glpn.encoder.layer_norm" in key:
            # replace for example layer_norm1 by layer_norm.0
            idx = key[key.find("glpn.encoder.layer_norm") + len("glpn.encoder.layer_norm")]
            key = key.replace(f"layer_norm{idx}", f"layer_norm.{int(idx) - 1}")
        if "layer_norm1" in key:
            key = key.replace("layer_norm1", "layer_norm_1")
        if "layer_norm2" in key:
            key = key.replace("layer_norm2", "layer_norm_2")
        if "block" in key:
            # replace for example block1 by block.0
            idx = key[key.find("block") + len("block")]
            key = key.replace(f"block{idx}", f"block.{int(idx) - 1}")
        if "attn.q" in key:
            key = key.replace("attn.q", "attention.self.query")
        if "attn.proj" in key:
            key = key.replace("attn.proj", "attention.output.dense")
        if "attn" in key:
            key = key.replace("attn", "attention.self")
        if "fc1" in key:
            key = key.replace("fc1", "dense1")
        if "fc2" in key:
            key = key.replace("fc2", "dense2")
        if "linear_pred" in key:
            key = key.replace("linear_pred", "classifier")
        if "linear_fuse" in key:
            key = key.replace("linear_fuse.conv", "linear_fuse")
            key = key.replace("linear_fuse.bn", "batch_norm")
        if "linear_c" in key:
            # replace for example linear_c4 by linear_c.3
            idx = key[key.find("linear_c") + len("linear_c")]
            key = key.replace(f"linear_c{idx}", f"linear_c.{int(idx) - 1}")
        if "bot_conv" in key:
            key = key.replace("bot_conv", "0.convolution")
        if "skip_conv1" in key:
            key = key.replace("skip_conv1", "1.convolution")
        if "skip_conv2" in key:
            key = key.replace("skip_conv2", "2.convolution")
        if "fusion1" in key:
            key = key.replace("fusion1", "1.fusion")
        if "fusion2" in key:
            key = key.replace("fusion2", "2.fusion")
        if "fusion3" in key:
            key = key.replace("fusion3", "3.fusion")
        if "fusion" in key and "conv" in key:
            key = key.replace("conv", "convolutional_layer")
        if key.startswith("module.last_layer_depth"):
            key = key.replace("module.last_layer_depth", "head.head")
        new_state_dict[key] = value

    return new_state_dict