def superres_convert_ldm_unet_checkpoint()

in scripts/convert_if.py [0:0]


def superres_convert_ldm_unet_checkpoint(unet_state_dict, config, path=None, extract_ema=False):
    """
    Takes a state dict and a config, and returns a converted checkpoint.
    """
    new_checkpoint = {}

    new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
    new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
    new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
    new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]

    if config["class_embed_type"] is None:
        # No parameters to port
        ...
    elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
        new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["aug_proj.0.weight"]
        new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["aug_proj.0.bias"]
        new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["aug_proj.2.weight"]
        new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["aug_proj.2.bias"]
    else:
        raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")

    if "encoder_proj.weight" in unet_state_dict:
        new_checkpoint["encoder_hid_proj.weight"] = unet_state_dict["encoder_proj.weight"]
        new_checkpoint["encoder_hid_proj.bias"] = unet_state_dict["encoder_proj.bias"]

    if "encoder_pooling.0.weight" in unet_state_dict:
        mapping = {
            "encoder_pooling.0": "add_embedding.norm1",
            "encoder_pooling.1": "add_embedding.pool",
            "encoder_pooling.2": "add_embedding.proj",
            "encoder_pooling.3": "add_embedding.norm2",
        }
        for key in unet_state_dict.keys():
            if key.startswith("encoder_pooling"):
                prefix = key[: len("encoder_pooling.0")]
                new_key = key.replace(prefix, mapping[prefix])
                new_checkpoint[new_key] = unet_state_dict[key]

    new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
    new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]

    new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
    new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
    new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
    new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]

    # Retrieves the keys for the input blocks only
    num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
    input_blocks = {
        layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
        for layer_id in range(num_input_blocks)
    }

    # Retrieves the keys for the middle blocks only
    num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
    middle_blocks = {
        layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
        for layer_id in range(num_middle_blocks)
    }

    # Retrieves the keys for the output blocks only
    num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
    output_blocks = {
        layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
        for layer_id in range(num_output_blocks)
    }
    if not isinstance(config["layers_per_block"], int):
        layers_per_block_list = [e + 1 for e in config["layers_per_block"]]
        layers_per_block_cumsum = list(np.cumsum(layers_per_block_list))
        downsampler_ids = layers_per_block_cumsum
    else:
        # TODO need better check than i in [4, 8, 12, 16]
        downsampler_ids = [4, 8, 12, 16]

    for i in range(1, num_input_blocks):
        if isinstance(config["layers_per_block"], int):
            layers_per_block = config["layers_per_block"]
            block_id = (i - 1) // (layers_per_block + 1)
            layer_in_block_id = (i - 1) % (layers_per_block + 1)
        else:
            block_id = next(k for k, n in enumerate(layers_per_block_cumsum) if (i - 1) < n)
            passed_blocks = layers_per_block_cumsum[block_id - 1] if block_id > 0 else 0
            layer_in_block_id = (i - 1) - passed_blocks

        resnets = [
            key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
        ]
        attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]

        if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
            new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
                f"input_blocks.{i}.0.op.weight"
            )
            new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
                f"input_blocks.{i}.0.op.bias"
            )

        paths = renew_resnet_paths(resnets)

        block_type = config["down_block_types"][block_id]
        if (
            block_type == "ResnetDownsampleBlock2D" or block_type == "SimpleCrossAttnDownBlock2D"
        ) and i in downsampler_ids:
            meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.downsamplers.0"}
        else:
            meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}

        assign_to_checkpoint(
            paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
        )

        if len(attentions):
            old_path = f"input_blocks.{i}.1"
            new_path = f"down_blocks.{block_id}.attentions.{layer_in_block_id}"

            assign_attention_to_checkpoint(
                new_checkpoint=new_checkpoint,
                unet_state_dict=unet_state_dict,
                old_path=old_path,
                new_path=new_path,
                config=config,
            )

            paths = renew_attention_paths(attentions)
            meta_path = {"old": old_path, "new": new_path}
            assign_to_checkpoint(
                paths,
                new_checkpoint,
                unet_state_dict,
                additional_replacements=[meta_path],
                config=config,
            )

    resnet_0 = middle_blocks[0]
    attentions = middle_blocks[1]
    resnet_1 = middle_blocks[2]

    resnet_0_paths = renew_resnet_paths(resnet_0)
    assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)

    resnet_1_paths = renew_resnet_paths(resnet_1)
    assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)

    old_path = "middle_block.1"
    new_path = "mid_block.attentions.0"

    assign_attention_to_checkpoint(
        new_checkpoint=new_checkpoint,
        unet_state_dict=unet_state_dict,
        old_path=old_path,
        new_path=new_path,
        config=config,
    )

    attentions_paths = renew_attention_paths(attentions)
    meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
    assign_to_checkpoint(
        attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
    )
    if not isinstance(config["layers_per_block"], int):
        layers_per_block_list = list(reversed([e + 1 for e in config["layers_per_block"]]))
        layers_per_block_cumsum = list(np.cumsum(layers_per_block_list))

    for i in range(num_output_blocks):
        if isinstance(config["layers_per_block"], int):
            layers_per_block = config["layers_per_block"]
            block_id = i // (layers_per_block + 1)
            layer_in_block_id = i % (layers_per_block + 1)
        else:
            block_id = next(k for k, n in enumerate(layers_per_block_cumsum) if i < n)
            passed_blocks = layers_per_block_cumsum[block_id - 1] if block_id > 0 else 0
            layer_in_block_id = i - passed_blocks

        output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
        output_block_list = {}

        for layer in output_block_layers:
            layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
            if layer_id in output_block_list:
                output_block_list[layer_id].append(layer_name)
            else:
                output_block_list[layer_id] = [layer_name]

        # len(output_block_list) == 1 -> resnet
        # len(output_block_list) == 2 -> resnet, attention or resnet, upscale resnet
        # len(output_block_list) == 3 -> resnet, attention, upscale resnet

        if len(output_block_list) > 1:
            resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]

            has_attention = True
            if len(output_block_list) == 2 and any("in_layers" in k for k in output_block_list["1"]):
                has_attention = False

            maybe_attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]

            paths = renew_resnet_paths(resnets)

            meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}

            assign_to_checkpoint(
                paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
            )

            output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
            if ["conv.bias", "conv.weight"] in output_block_list.values():
                index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
                new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
                    f"output_blocks.{i}.{index}.conv.weight"
                ]
                new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
                    f"output_blocks.{i}.{index}.conv.bias"
                ]

                # this layer was no attention
                has_attention = False
                maybe_attentions = []

            if has_attention:
                old_path = f"output_blocks.{i}.1"
                new_path = f"up_blocks.{block_id}.attentions.{layer_in_block_id}"

                assign_attention_to_checkpoint(
                    new_checkpoint=new_checkpoint,
                    unet_state_dict=unet_state_dict,
                    old_path=old_path,
                    new_path=new_path,
                    config=config,
                )

                paths = renew_attention_paths(maybe_attentions)
                meta_path = {
                    "old": old_path,
                    "new": new_path,
                }
                assign_to_checkpoint(
                    paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
                )

            if len(output_block_list) == 3 or (not has_attention and len(maybe_attentions) > 0):
                layer_id = len(output_block_list) - 1
                resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.{layer_id}" in key]
                paths = renew_resnet_paths(resnets)
                meta_path = {"old": f"output_blocks.{i}.{layer_id}", "new": f"up_blocks.{block_id}.upsamplers.0"}
                assign_to_checkpoint(
                    paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
                )
        else:
            resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
            for path in resnet_0_paths:
                old_path = ".".join(["output_blocks", str(i), path["old"]])
                new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])

                new_checkpoint[new_path] = unet_state_dict[old_path]

    return new_checkpoint