def _convert_non_diffusers_wan_lora_to_diffusers()

in src/diffusers/loaders/lora_conversion_utils.py [0:0]


def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
    converted_state_dict = {}
    original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}

    block_numbers = {int(k.split(".")[1]) for k in original_state_dict if k.startswith("blocks.")}
    min_block = min(block_numbers)
    max_block = max(block_numbers)

    is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
    lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down"
    lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up"

    diff_keys = [k for k in original_state_dict if k.endswith((".diff_b", ".diff"))]
    if diff_keys:
        for diff_k in diff_keys:
            param = original_state_dict[diff_k]
            # The magnitudes of the .diff-ending weights are very low (most are below 1e-4, some are upto 1e-3,
            # and 2 of them are about 1.6e-2 [the case with AccVideo lora]). The low magnitudes mostly correspond
            # to norm layers. Ignoring them is the best option at the moment until a better solution is found. It
            # is okay to ignore because they do not affect the model output in a significant manner.
            threshold = 1.6e-2
            absdiff = param.abs().max() - param.abs().min()
            all_zero = torch.all(param == 0).item()
            all_absdiff_lower_than_threshold = absdiff < threshold
            if all_zero or all_absdiff_lower_than_threshold:
                logger.debug(
                    f"Removed {diff_k} key from the state dict as it's all zeros, or values lower than hardcoded threshold."
                )
                original_state_dict.pop(diff_k)

    # For the `diff_b` keys, we treat them as lora_bias.
    # https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias

    for i in range(min_block, max_block + 1):
        # Self-attention
        for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
            original_key = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
            converted_key = f"blocks.{i}.attn1.{c}.lora_A.weight"
            if original_key in original_state_dict:
                converted_state_dict[converted_key] = original_state_dict.pop(original_key)

            original_key = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
            converted_key = f"blocks.{i}.attn1.{c}.lora_B.weight"
            if original_key in original_state_dict:
                converted_state_dict[converted_key] = original_state_dict.pop(original_key)

            original_key = f"blocks.{i}.self_attn.{o}.diff_b"
            converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias"
            if original_key in original_state_dict:
                converted_state_dict[converted_key] = original_state_dict.pop(original_key)

        # Cross-attention
        for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
            original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
            converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
            if original_key in original_state_dict:
                converted_state_dict[converted_key] = original_state_dict.pop(original_key)

            original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
            converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
            if original_key in original_state_dict:
                converted_state_dict[converted_key] = original_state_dict.pop(original_key)

            original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
            converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
            if original_key in original_state_dict:
                converted_state_dict[converted_key] = original_state_dict.pop(original_key)

        if is_i2v_lora:
            for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
                original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
                converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
                if original_key in original_state_dict:
                    converted_state_dict[converted_key] = original_state_dict.pop(original_key)

                original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
                converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
                if original_key in original_state_dict:
                    converted_state_dict[converted_key] = original_state_dict.pop(original_key)

                original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
                converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
                if original_key in original_state_dict:
                    converted_state_dict[converted_key] = original_state_dict.pop(original_key)

        # FFN
        for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
            original_key = f"blocks.{i}.{o}.{lora_down_key}.weight"
            converted_key = f"blocks.{i}.ffn.{c}.lora_A.weight"
            if original_key in original_state_dict:
                converted_state_dict[converted_key] = original_state_dict.pop(original_key)

            original_key = f"blocks.{i}.{o}.{lora_up_key}.weight"
            converted_key = f"blocks.{i}.ffn.{c}.lora_B.weight"
            if original_key in original_state_dict:
                converted_state_dict[converted_key] = original_state_dict.pop(original_key)

            original_key = f"blocks.{i}.{o}.diff_b"
            converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias"
            if original_key in original_state_dict:
                converted_state_dict[converted_key] = original_state_dict.pop(original_key)

    # Remaining.
    if original_state_dict:
        if any("time_projection" in k for k in original_state_dict):
            original_key = f"time_projection.1.{lora_down_key}.weight"
            converted_key = "condition_embedder.time_proj.lora_A.weight"
            if original_key in original_state_dict:
                converted_state_dict[converted_key] = original_state_dict.pop(original_key)

            original_key = f"time_projection.1.{lora_up_key}.weight"
            converted_key = "condition_embedder.time_proj.lora_B.weight"
            if original_key in original_state_dict:
                converted_state_dict[converted_key] = original_state_dict.pop(original_key)

            if "time_projection.1.diff_b" in original_state_dict:
                converted_state_dict["condition_embedder.time_proj.lora_B.bias"] = original_state_dict.pop(
                    "time_projection.1.diff_b"
                )

        if any("head.head" in k for k in state_dict):
            converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
                f"head.head.{lora_down_key}.weight"
            )
            converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(f"head.head.{lora_up_key}.weight")
            if "head.head.diff_b" in original_state_dict:
                converted_state_dict["proj_out.lora_B.bias"] = original_state_dict.pop("head.head.diff_b")

        for text_time in ["text_embedding", "time_embedding"]:
            if any(text_time in k for k in original_state_dict):
                for b_n in [0, 2]:
                    diffusers_b_n = 1 if b_n == 0 else 2
                    diffusers_name = (
                        "condition_embedder.text_embedder"
                        if text_time == "text_embedding"
                        else "condition_embedder.time_embedder"
                    )
                    if any(f"{text_time}.{b_n}" in k for k in original_state_dict):
                        converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_A.weight"] = (
                            original_state_dict.pop(f"{text_time}.{b_n}.{lora_down_key}.weight")
                        )
                        converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_B.weight"] = (
                            original_state_dict.pop(f"{text_time}.{b_n}.{lora_up_key}.weight")
                        )
                    if f"{text_time}.{b_n}.diff_b" in original_state_dict:
                        converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_B.bias"] = (
                            original_state_dict.pop(f"{text_time}.{b_n}.diff_b")
                        )

        for img_ours, img_theirs in [
            ("ff.net.0.proj", "img_emb.proj.1"),
            ("ff.net.2", "img_emb.proj.3"),
        ]:
            original_key = f"{img_theirs}.{lora_down_key}.weight"
            converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_A.weight"
            if original_key in original_state_dict:
                converted_state_dict[converted_key] = original_state_dict.pop(original_key)

            original_key = f"{img_theirs}.{lora_up_key}.weight"
            converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_B.weight"
            if original_key in original_state_dict:
                converted_state_dict[converted_key] = original_state_dict.pop(original_key)

    if len(original_state_dict) > 0:
        diff = all(".diff" in k for k in original_state_dict)
        if diff:
            diff_keys = {k for k in original_state_dict if k.endswith(".diff")}
            if not all("lora" not in k for k in diff_keys):
                raise ValueError
            logger.info(
                "The remaining `state_dict` contains `diff` keys which we do not handle yet. If you see performance issues, please file an issue: "
                "https://github.com/huggingface/diffusers//issues/new"
            )
        else:
            raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")

    for key in list(converted_state_dict.keys()):
        converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)

    return converted_state_dict