def _convert_kohya_flux_lora_to_diffusers()

in src/diffusers/loaders/lora_conversion_utils.py [0:0]


def _convert_kohya_flux_lora_to_diffusers(state_dict):
    def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
        if sds_key + ".lora_down.weight" not in sds_sd:
            return
        down_weight = sds_sd.pop(sds_key + ".lora_down.weight")

        # scale weight by alpha and dim
        rank = down_weight.shape[0]
        default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False)
        alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item()  # alpha is scalar
        scale = alpha / rank  # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here

        # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
        scale_down = scale
        scale_up = 1.0
        while scale_down * 2 < scale_up:
            scale_down *= 2
            scale_up /= 2

        ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
        ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up

    def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
        if sds_key + ".lora_down.weight" not in sds_sd:
            return
        down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
        up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
        sd_lora_rank = down_weight.shape[0]

        # scale weight by alpha and dim
        default_alpha = torch.tensor(
            sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False
        )
        alpha = sds_sd.pop(sds_key + ".alpha", default_alpha)
        scale = alpha / sd_lora_rank

        # calculate scale_down and scale_up
        scale_down = scale
        scale_up = 1.0
        while scale_down * 2 < scale_up:
            scale_down *= 2
            scale_up /= 2

        down_weight = down_weight * scale_down
        up_weight = up_weight * scale_up

        # calculate dims if not provided
        num_splits = len(ait_keys)
        if dims is None:
            dims = [up_weight.shape[0] // num_splits] * num_splits
        else:
            assert sum(dims) == up_weight.shape[0]

        # check upweight is sparse or not
        is_sparse = False
        if sd_lora_rank % num_splits == 0:
            ait_rank = sd_lora_rank // num_splits
            is_sparse = True
            i = 0
            for j in range(len(dims)):
                for k in range(len(dims)):
                    if j == k:
                        continue
                    is_sparse = is_sparse and torch.all(
                        up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
                    )
                i += dims[j]
            if is_sparse:
                logger.info(f"weight is sparse: {sds_key}")

        # make ai-toolkit weight
        ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
        ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
        if not is_sparse:
            # down_weight is copied to each split
            ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))

            # up_weight is split to each split
            ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))})  # noqa: C416
        else:
            # down_weight is chunked to each split
            ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))})  # noqa: C416

            # up_weight is sparse: only non-zero values are copied to each split
            i = 0
            for j in range(len(dims)):
                ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
                i += dims[j]

    def _convert_sd_scripts_to_ai_toolkit(sds_sd):
        ait_sd = {}
        for i in range(19):
            _convert_to_ai_toolkit(
                sds_sd,
                ait_sd,
                f"lora_unet_double_blocks_{i}_img_attn_proj",
                f"transformer.transformer_blocks.{i}.attn.to_out.0",
            )
            _convert_to_ai_toolkit_cat(
                sds_sd,
                ait_sd,
                f"lora_unet_double_blocks_{i}_img_attn_qkv",
                [
                    f"transformer.transformer_blocks.{i}.attn.to_q",
                    f"transformer.transformer_blocks.{i}.attn.to_k",
                    f"transformer.transformer_blocks.{i}.attn.to_v",
                ],
            )
            _convert_to_ai_toolkit(
                sds_sd,
                ait_sd,
                f"lora_unet_double_blocks_{i}_img_mlp_0",
                f"transformer.transformer_blocks.{i}.ff.net.0.proj",
            )
            _convert_to_ai_toolkit(
                sds_sd,
                ait_sd,
                f"lora_unet_double_blocks_{i}_img_mlp_2",
                f"transformer.transformer_blocks.{i}.ff.net.2",
            )
            _convert_to_ai_toolkit(
                sds_sd,
                ait_sd,
                f"lora_unet_double_blocks_{i}_img_mod_lin",
                f"transformer.transformer_blocks.{i}.norm1.linear",
            )
            _convert_to_ai_toolkit(
                sds_sd,
                ait_sd,
                f"lora_unet_double_blocks_{i}_txt_attn_proj",
                f"transformer.transformer_blocks.{i}.attn.to_add_out",
            )
            _convert_to_ai_toolkit_cat(
                sds_sd,
                ait_sd,
                f"lora_unet_double_blocks_{i}_txt_attn_qkv",
                [
                    f"transformer.transformer_blocks.{i}.attn.add_q_proj",
                    f"transformer.transformer_blocks.{i}.attn.add_k_proj",
                    f"transformer.transformer_blocks.{i}.attn.add_v_proj",
                ],
            )
            _convert_to_ai_toolkit(
                sds_sd,
                ait_sd,
                f"lora_unet_double_blocks_{i}_txt_mlp_0",
                f"transformer.transformer_blocks.{i}.ff_context.net.0.proj",
            )
            _convert_to_ai_toolkit(
                sds_sd,
                ait_sd,
                f"lora_unet_double_blocks_{i}_txt_mlp_2",
                f"transformer.transformer_blocks.{i}.ff_context.net.2",
            )
            _convert_to_ai_toolkit(
                sds_sd,
                ait_sd,
                f"lora_unet_double_blocks_{i}_txt_mod_lin",
                f"transformer.transformer_blocks.{i}.norm1_context.linear",
            )

        for i in range(38):
            _convert_to_ai_toolkit_cat(
                sds_sd,
                ait_sd,
                f"lora_unet_single_blocks_{i}_linear1",
                [
                    f"transformer.single_transformer_blocks.{i}.attn.to_q",
                    f"transformer.single_transformer_blocks.{i}.attn.to_k",
                    f"transformer.single_transformer_blocks.{i}.attn.to_v",
                    f"transformer.single_transformer_blocks.{i}.proj_mlp",
                ],
                dims=[3072, 3072, 3072, 12288],
            )
            _convert_to_ai_toolkit(
                sds_sd,
                ait_sd,
                f"lora_unet_single_blocks_{i}_linear2",
                f"transformer.single_transformer_blocks.{i}.proj_out",
            )
            _convert_to_ai_toolkit(
                sds_sd,
                ait_sd,
                f"lora_unet_single_blocks_{i}_modulation_lin",
                f"transformer.single_transformer_blocks.{i}.norm.linear",
            )

        # TODO: alphas.
        def assign_remaining_weights(assignments, source):
            for lora_key in ["lora_A", "lora_B"]:
                orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
                for target_fmt, source_fmt, transform in assignments:
                    target_key = target_fmt.format(lora_key=lora_key)
                    source_key = source_fmt.format(orig_lora_key=orig_lora_key)
                    value = source.pop(source_key)
                    if transform:
                        value = transform(value)
                    ait_sd[target_key] = value

        if any("guidance_in" in k for k in sds_sd):
            assign_remaining_weights(
                [
                    (
                        "time_text_embed.guidance_embedder.linear_1.{lora_key}.weight",
                        "lora_unet_guidance_in_in_layer.{orig_lora_key}.weight",
                        None,
                    ),
                    (
                        "time_text_embed.guidance_embedder.linear_2.{lora_key}.weight",
                        "lora_unet_guidance_in_out_layer.{orig_lora_key}.weight",
                        None,
                    ),
                ],
                sds_sd,
            )

        if any("img_in" in k for k in sds_sd):
            assign_remaining_weights(
                [
                    ("x_embedder.{lora_key}.weight", "lora_unet_img_in.{orig_lora_key}.weight", None),
                ],
                sds_sd,
            )

        if any("txt_in" in k for k in sds_sd):
            assign_remaining_weights(
                [
                    ("context_embedder.{lora_key}.weight", "lora_unet_txt_in.{orig_lora_key}.weight", None),
                ],
                sds_sd,
            )

        if any("time_in" in k for k in sds_sd):
            assign_remaining_weights(
                [
                    (
                        "time_text_embed.timestep_embedder.linear_1.{lora_key}.weight",
                        "lora_unet_time_in_in_layer.{orig_lora_key}.weight",
                        None,
                    ),
                    (
                        "time_text_embed.timestep_embedder.linear_2.{lora_key}.weight",
                        "lora_unet_time_in_out_layer.{orig_lora_key}.weight",
                        None,
                    ),
                ],
                sds_sd,
            )

        if any("vector_in" in k for k in sds_sd):
            assign_remaining_weights(
                [
                    (
                        "time_text_embed.text_embedder.linear_1.{lora_key}.weight",
                        "lora_unet_vector_in_in_layer.{orig_lora_key}.weight",
                        None,
                    ),
                    (
                        "time_text_embed.text_embedder.linear_2.{lora_key}.weight",
                        "lora_unet_vector_in_out_layer.{orig_lora_key}.weight",
                        None,
                    ),
                ],
                sds_sd,
            )

        if any("final_layer" in k for k in sds_sd):
            # Notice the swap in processing for "final_layer".
            assign_remaining_weights(
                [
                    (
                        "norm_out.linear.{lora_key}.weight",
                        "lora_unet_final_layer_adaLN_modulation_1.{orig_lora_key}.weight",
                        swap_scale_shift,
                    ),
                    ("proj_out.{lora_key}.weight", "lora_unet_final_layer_linear.{orig_lora_key}.weight", None),
                ],
                sds_sd,
            )

        remaining_keys = list(sds_sd.keys())
        te_state_dict = {}
        if remaining_keys:
            if not all(k.startswith(("lora_te", "lora_te1")) for k in remaining_keys):
                raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
            for key in remaining_keys:
                if not key.endswith("lora_down.weight"):
                    continue

                lora_name = key.split(".")[0]
                lora_name_up = f"{lora_name}.lora_up.weight"
                lora_name_alpha = f"{lora_name}.alpha"
                diffusers_name = _convert_text_encoder_lora_key(key, lora_name)

                if lora_name.startswith(("lora_te_", "lora_te1_")):
                    down_weight = sds_sd.pop(key)
                    sd_lora_rank = down_weight.shape[0]
                    te_state_dict[diffusers_name] = down_weight
                    te_state_dict[diffusers_name.replace(".down.", ".up.")] = sds_sd.pop(lora_name_up)

                if lora_name_alpha in sds_sd:
                    alpha = sds_sd.pop(lora_name_alpha).item()
                    scale = alpha / sd_lora_rank

                    scale_down = scale
                    scale_up = 1.0
                    while scale_down * 2 < scale_up:
                        scale_down *= 2
                        scale_up /= 2

                    te_state_dict[diffusers_name] *= scale_down
                    te_state_dict[diffusers_name.replace(".down.", ".up.")] *= scale_up

        if len(sds_sd) > 0:
            logger.warning(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}")

        if te_state_dict:
            te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()}

        new_state_dict = {**ait_sd, **te_state_dict}
        return new_state_dict

    def _convert_mixture_state_dict_to_diffusers(state_dict):
        new_state_dict = {}

        def _convert(original_key, diffusers_key, state_dict, new_state_dict):
            down_key = f"{original_key}.lora_down.weight"
            down_weight = state_dict.pop(down_key)
            lora_rank = down_weight.shape[0]

            up_weight_key = f"{original_key}.lora_up.weight"
            up_weight = state_dict.pop(up_weight_key)

            alpha_key = f"{original_key}.alpha"
            alpha = state_dict.pop(alpha_key)

            # scale weight by alpha and dim
            scale = alpha / lora_rank
            # calculate scale_down and scale_up
            scale_down = scale
            scale_up = 1.0
            while scale_down * 2 < scale_up:
                scale_down *= 2
                scale_up /= 2
            down_weight = down_weight * scale_down
            up_weight = up_weight * scale_up

            diffusers_down_key = f"{diffusers_key}.lora_A.weight"
            new_state_dict[diffusers_down_key] = down_weight
            new_state_dict[diffusers_down_key.replace(".lora_A.", ".lora_B.")] = up_weight

        all_unique_keys = {
            k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "")
            for k in state_dict
            if not k.startswith(("lora_unet_"))
        }
        assert all(k.startswith(("lora_transformer_", "lora_te1_")) for k in all_unique_keys), f"{all_unique_keys=}"

        has_te_keys = False
        for k in all_unique_keys:
            if k.startswith("lora_transformer_single_transformer_blocks_"):
                i = int(k.split("lora_transformer_single_transformer_blocks_")[-1].split("_")[0])
                diffusers_key = f"single_transformer_blocks.{i}"
            elif k.startswith("lora_transformer_transformer_blocks_"):
                i = int(k.split("lora_transformer_transformer_blocks_")[-1].split("_")[0])
                diffusers_key = f"transformer_blocks.{i}"
            elif k.startswith("lora_te1_"):
                has_te_keys = True
                continue
            elif k.startswith("lora_transformer_context_embedder"):
                diffusers_key = "context_embedder"
            elif k.startswith("lora_transformer_norm_out_linear"):
                diffusers_key = "norm_out.linear"
            elif k.startswith("lora_transformer_proj_out"):
                diffusers_key = "proj_out"
            elif k.startswith("lora_transformer_x_embedder"):
                diffusers_key = "x_embedder"
            elif k.startswith("lora_transformer_time_text_embed_guidance_embedder_linear_"):
                i = int(k.split("lora_transformer_time_text_embed_guidance_embedder_linear_")[-1])
                diffusers_key = f"time_text_embed.guidance_embedder.linear_{i}"
            elif k.startswith("lora_transformer_time_text_embed_text_embedder_linear_"):
                i = int(k.split("lora_transformer_time_text_embed_text_embedder_linear_")[-1])
                diffusers_key = f"time_text_embed.text_embedder.linear_{i}"
            elif k.startswith("lora_transformer_time_text_embed_timestep_embedder_linear_"):
                i = int(k.split("lora_transformer_time_text_embed_timestep_embedder_linear_")[-1])
                diffusers_key = f"time_text_embed.timestep_embedder.linear_{i}"
            else:
                raise NotImplementedError(f"Handling for key ({k}) is not implemented.")

            if "attn_" in k:
                if "_to_out_0" in k:
                    diffusers_key += ".attn.to_out.0"
                elif "_to_add_out" in k:
                    diffusers_key += ".attn.to_add_out"
                elif any(qkv in k for qkv in ["to_q", "to_k", "to_v"]):
                    remaining = k.split("attn_")[-1]
                    diffusers_key += f".attn.{remaining}"
                elif any(add_qkv in k for add_qkv in ["add_q_proj", "add_k_proj", "add_v_proj"]):
                    remaining = k.split("attn_")[-1]
                    diffusers_key += f".attn.{remaining}"

            _convert(k, diffusers_key, state_dict, new_state_dict)

        if has_te_keys:
            layer_pattern = re.compile(r"lora_te1_text_model_encoder_layers_(\d+)")
            attn_mapping = {
                "q_proj": ".self_attn.q_proj",
                "k_proj": ".self_attn.k_proj",
                "v_proj": ".self_attn.v_proj",
                "out_proj": ".self_attn.out_proj",
            }
            mlp_mapping = {"fc1": ".mlp.fc1", "fc2": ".mlp.fc2"}
            for k in all_unique_keys:
                if not k.startswith("lora_te1_"):
                    continue

                match = layer_pattern.search(k)
                if not match:
                    continue
                i = int(match.group(1))
                diffusers_key = f"text_model.encoder.layers.{i}"

                if "attn" in k:
                    for key_fragment, suffix in attn_mapping.items():
                        if key_fragment in k:
                            diffusers_key += suffix
                            break
                elif "mlp" in k:
                    for key_fragment, suffix in mlp_mapping.items():
                        if key_fragment in k:
                            diffusers_key += suffix
                            break

                _convert(k, diffusers_key, state_dict, new_state_dict)

        remaining_all_unet = False
        if state_dict:
            remaining_all_unet = all(k.startswith("lora_unet_") for k in state_dict)
        if remaining_all_unet:
            keys = list(state_dict.keys())
            for k in keys:
                state_dict.pop(k)

        if len(state_dict) > 0:
            raise ValueError(
                f"Expected an empty state dict at this point but its has these keys which couldn't be parsed: {list(state_dict.keys())}."
            )

        transformer_state_dict = {
            f"transformer.{k}": v for k, v in new_state_dict.items() if not k.startswith("text_model.")
        }
        te_state_dict = {f"text_encoder.{k}": v for k, v in new_state_dict.items() if k.startswith("text_model.")}
        return {**transformer_state_dict, **te_state_dict}

    # This is  weird.
    # https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors
    # has both `peft` and non-peft state dict.
    has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
    if has_peft_state_dict:
        state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
        return state_dict

    # Another weird one.
    has_mixture = any(
        k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict
    )

    # ComfyUI.
    if not has_mixture:
        state_dict = {k.replace("diffusion_model.", "lora_unet_"): v for k, v in state_dict.items()}
        state_dict = {k.replace("text_encoders.clip_l.transformer.", "lora_te_"): v for k, v in state_dict.items()}

        has_position_embedding = any("position_embedding" in k for k in state_dict)
        if has_position_embedding:
            zero_status_pe = state_dict_all_zero(state_dict, "position_embedding")
            if zero_status_pe:
                logger.info(
                    "The `position_embedding` LoRA params are all zeros which make them ineffective. "
                    "So, we will purge them out of the current state dict to make loading possible."
                )

            else:
                logger.info(
                    "The state_dict has position_embedding LoRA params and we currently do not support them. "
                    "Open an issue if you need this supported - https://github.com/huggingface/diffusers/issues/new."
                )
            state_dict = {k: v for k, v in state_dict.items() if "position_embedding" not in k}

        has_t5xxl = any(k.startswith("text_encoders.t5xxl.transformer.") for k in state_dict)
        if has_t5xxl:
            zero_status_t5 = state_dict_all_zero(state_dict, "text_encoders.t5xxl")
            if zero_status_t5:
                logger.info(
                    "The `t5xxl` LoRA params are all zeros which make them ineffective. "
                    "So, we will purge them out of the current state dict to make loading possible."
                )
            else:
                logger.info(
                    "T5-xxl keys found in the state dict, which are currently unsupported. We will filter them out."
                    "Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new."
                )
            state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}

        has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict)
        if has_diffb:
            zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b")
            if zero_status_diff_b:
                logger.info(
                    "The `diff_b` LoRA params are all zeros which make them ineffective. "
                    "So, we will purge them out of the current state dict to make loading possible."
                )
            else:
                logger.info(
                    "`diff_b` keys found in the state dict which are currently unsupported. "
                    "So, we will filter out those keys. Open an issue if this is a problem - "
                    "https://github.com/huggingface/diffusers/issues/new."
                )
            state_dict = {k: v for k, v in state_dict.items() if ".diff_b" not in k}

        has_norm_diff = any(".norm" in k and ".diff" in k for k in state_dict)
        if has_norm_diff:
            zero_status_diff = state_dict_all_zero(state_dict, ".diff")
            if zero_status_diff:
                logger.info(
                    "The `diff` LoRA params are all zeros which make them ineffective. "
                    "So, we will purge them out of the current state dict to make loading possible."
                )
            else:
                logger.info(
                    "Normalization diff keys found in the state dict which are currently unsupported. "
                    "So, we will filter out those keys. Open an issue if this is a problem - "
                    "https://github.com/huggingface/diffusers/issues/new."
                )
            state_dict = {k: v for k, v in state_dict.items() if ".norm" not in k and ".diff" not in k}

        limit_substrings = ["lora_down", "lora_up"]
        if any("alpha" in k for k in state_dict):
            limit_substrings.append("alpha")

        state_dict = {
            _custom_replace(k, limit_substrings): v
            for k, v in state_dict.items()
            if k.startswith(("lora_unet_", "lora_te_"))
        }

        if any("text_projection" in k for k in state_dict):
            logger.info(
                "`text_projection` keys found in the `state_dict` which are unexpected. "
                "So, we will filter out those keys. Open an issue if this is a problem - "
                "https://github.com/huggingface/diffusers/issues/new."
            )
            state_dict = {k: v for k, v in state_dict.items() if "text_projection" not in k}

    if has_mixture:
        return _convert_mixture_state_dict_to_diffusers(state_dict)

    return _convert_sd_scripts_to_ai_toolkit(state_dict)