def rename_flax_dict()

in scripts/convert_maskgit_vqgan.py [0:0]


def rename_flax_dict(params):
    keys = list(params.keys())

    for key in keys:
        new_key = ".".join(key)
        params[new_key] = params.pop(key)
    keys = list(params.keys())

    block_map = {
        0: (0, 0),
        1: (0, 1),
        2: (1, 0),
        3: (1, 1),
        4: (2, 0),
        5: (2, 1),
        6: (3, 0),
        7: (3, 1),
        8: (4, 0),
        9: (4, 1),
    }

    encoder_keys = [key for key in keys if "encoder.ResBlock" in key]
    for key in encoder_keys:
        if "ResBlock_10" in key:
            new_key = key.replace("ResBlock_10", "mid.0")
            new_key = new_key.replace("Conv_0", "conv1")
            new_key = new_key.replace("Conv_1", "conv2")
            new_key = new_key.replace("GroupNorm_0", "norm1")
            new_key = new_key.replace("GroupNorm_1", "norm2")
            params[new_key] = params.pop(key)
        elif "ResBlock_11" in key:
            new_key = key.replace("ResBlock_11", "mid.1")
            new_key = new_key.replace("Conv_0", "conv1")
            new_key = new_key.replace("Conv_1", "conv2")
            new_key = new_key.replace("GroupNorm_0", "norm1")
            new_key = new_key.replace("GroupNorm_1", "norm2")
            params[new_key] = params.pop(key)
    keys = list(params.keys())

    encoder_keys = [key for key in keys if "encoder.ResBlock" in key]
    for key in encoder_keys:
        name = key.split(".")[1]
        res_name, idx = name.split("_")
        idx1, idx2 = block_map[int(idx)]
        new_key = key.replace(name, f"down.{idx1}.block.{idx2}")
        new_key = new_key.replace("Conv_0", "conv1")
        new_key = new_key.replace("Conv_1", "conv2")
        new_key = new_key.replace("Conv_2", "nin_shortcut")
        new_key = new_key.replace("GroupNorm_0", "norm1")
        new_key = new_key.replace("GroupNorm_1", "norm2")
        params[new_key] = params.pop(key)
    keys = list(params.keys())

    decoder_keys = [key for key in keys if "decoder.ResBlock" in key]
    for key in decoder_keys:
        if "ResBlock_0" in key:
            new_key = key.replace("ResBlock_0", "mid.0")
            new_key = new_key.replace("Conv_0", "conv1")
            new_key = new_key.replace("Conv_1", "conv2")
            new_key = new_key.replace("GroupNorm_0", "norm1")
            new_key = new_key.replace("GroupNorm_1", "norm2")
            params[new_key] = params.pop(key)
        elif "ResBlock_1." in key:
            new_key = key.replace("ResBlock_1", "mid.1")
            new_key = new_key.replace("Conv_0", "conv1")
            new_key = new_key.replace("Conv_1", "conv2")
            new_key = new_key.replace("GroupNorm_0", "norm1")
            new_key = new_key.replace("GroupNorm_1", "norm2")
            params[new_key] = params.pop(key)
    keys = list(params.keys())

    decoder_keys = [key for key in keys if "decoder.ResBlock" in key]
    for key in decoder_keys:
        name = key.split(".")[1]
        res_name, idx = name.split("_")
        idx = int(idx) - 2
        idx1, idx2 = block_map[int(idx)]
        idx1 = 4 - idx1
        new_key = key.replace(name, f"up.{idx1}.block.{idx2}")
        new_key = new_key.replace("Conv_0", "conv1")
        new_key = new_key.replace("Conv_1", "conv2")
        new_key = new_key.replace("Conv_2", "nin_shortcut")
        new_key = new_key.replace("GroupNorm_0", "norm1")
        new_key = new_key.replace("GroupNorm_1", "norm2")
        params[new_key] = params.pop(key)
    keys = list(params.keys())

    for i in range(1, 5):
        w = f"decoder.Conv_{i}.kernel"
        b = f"decoder.Conv_{i}.bias"
        new_w = f"decoder.up.{5 - i}.upsample_conv.kernel"
        new_b = f"decoder.up.{5 - i}.upsample_conv.bias"
        params[new_w] = params.pop(w)
        params[new_b] = params.pop(b)
    keys = list(params.keys())

    for key in keys:
        if "Conv_" in key:
            new_key = key.replace("Conv_0", "conv_in")
            new_key = new_key.replace("Conv_5", "conv_out")
            new_key = new_key.replace("Conv_1", "conv_out")
            params[new_key] = params.pop(key)
        elif "GroupNorm" in key:
            new_key = key.replace("GroupNorm_0", "norm_out")
            params[new_key] = params.pop(key)
    params["quantize.embedding.embedding"] = params.pop("quantizer.codebook")

    return params