def infer_diffusers_model_type()

in src/diffusers/loaders/single_file_utils.py [0:0]


def infer_diffusers_model_type(checkpoint):
    if (
        CHECKPOINT_KEY_NAMES["inpainting"] in checkpoint
        and checkpoint[CHECKPOINT_KEY_NAMES["inpainting"]].shape[1] == 9
    ):
        if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
            model_type = "inpainting_v2"
        elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
            model_type = "xl_inpaint"
        else:
            model_type = "inpainting"

    elif CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
        model_type = "v2"

    elif CHECKPOINT_KEY_NAMES["playground-v2-5"] in checkpoint:
        model_type = "playground-v2-5"

    elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
        model_type = "xl_base"

    elif CHECKPOINT_KEY_NAMES["xl_refiner"] in checkpoint:
        model_type = "xl_refiner"

    elif CHECKPOINT_KEY_NAMES["upscale"] in checkpoint:
        model_type = "upscale"

    elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["controlnet"]):
        if CHECKPOINT_KEY_NAMES["controlnet_xl"] in checkpoint:
            if CHECKPOINT_KEY_NAMES["controlnet_xl_large"] in checkpoint:
                model_type = "controlnet_xl_large"
            elif CHECKPOINT_KEY_NAMES["controlnet_xl_mid"] in checkpoint:
                model_type = "controlnet_xl_mid"
            else:
                model_type = "controlnet_xl_small"
        else:
            model_type = "controlnet"

    elif (
        CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
        and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 1536
    ):
        model_type = "stable_cascade_stage_c_lite"

    elif (
        CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
        and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 2048
    ):
        model_type = "stable_cascade_stage_c"

    elif (
        CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint
        and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 576
    ):
        model_type = "stable_cascade_stage_b_lite"

    elif (
        CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint
        and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 640
    ):
        model_type = "stable_cascade_stage_b"

    elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd3"]) and any(
        checkpoint[key].shape[-1] == 9216 if key in checkpoint else False for key in CHECKPOINT_KEY_NAMES["sd3"]
    ):
        if "model.diffusion_model.pos_embed" in checkpoint:
            key = "model.diffusion_model.pos_embed"
        else:
            key = "pos_embed"

        if checkpoint[key].shape[1] == 36864:
            model_type = "sd3"
        elif checkpoint[key].shape[1] == 147456:
            model_type = "sd35_medium"

    elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd35_large"]):
        model_type = "sd35_large"

    elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
        if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint:
            model_type = "animatediff_scribble"

        elif CHECKPOINT_KEY_NAMES["animatediff_rgb"] in checkpoint:
            model_type = "animatediff_rgb"

        elif CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint:
            model_type = "animatediff_v2"

        elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"]].shape[-1] == 320:
            model_type = "animatediff_sdxl_beta"

        elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff"]].shape[1] == 24:
            model_type = "animatediff_v1"

        else:
            model_type = "animatediff_v3"

    elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
        if any(
            g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
        ):
            if "model.diffusion_model.img_in.weight" in checkpoint:
                key = "model.diffusion_model.img_in.weight"
            else:
                key = "img_in.weight"

            if checkpoint[key].shape[1] == 384:
                model_type = "flux-fill"
            elif checkpoint[key].shape[1] == 128:
                model_type = "flux-depth"
            else:
                model_type = "flux-dev"
        else:
            model_type = "flux-schnell"

    elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
        has_vae = "vae.encoder.conv_in.conv.bias" in checkpoint
        if any(key.endswith("transformer_blocks.47.scale_shift_table") for key in checkpoint):
            model_type = "ltx-video-0.9.7"
        elif has_vae and checkpoint["vae.encoder.conv_out.conv.weight"].shape[1] == 2048:
            model_type = "ltx-video-0.9.5"
        elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
            model_type = "ltx-video-0.9.1"
        else:
            model_type = "ltx-video"

    elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
        encoder_key = "encoder.project_in.conv.conv.bias"
        decoder_key = "decoder.project_in.main.conv.weight"

        if CHECKPOINT_KEY_NAMES["autoencoder-dc-sana"] in checkpoint:
            model_type = "autoencoder-dc-f32c32-sana"

        elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 32:
            model_type = "autoencoder-dc-f32c32"

        elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 128:
            model_type = "autoencoder-dc-f64c128"

        else:
            model_type = "autoencoder-dc-f128c512"

    elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
        model_type = "mochi-1-preview"

    elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
        model_type = "hunyuan-video"

    elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["auraflow"]):
        model_type = "auraflow"

    elif (
        CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint
        and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8
    ):
        model_type = "instruct-pix2pix"

    elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
        model_type = "lumina2"

    elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sana"]):
        model_type = "sana"

    elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["wan"]):
        if "model.diffusion_model.patch_embedding.weight" in checkpoint:
            target_key = "model.diffusion_model.patch_embedding.weight"
        else:
            target_key = "patch_embedding.weight"

        if checkpoint[target_key].shape[0] == 1536:
            model_type = "wan-t2v-1.3B"
        elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
            model_type = "wan-t2v-14B"
        else:
            model_type = "wan-i2v-14B"
    elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
        # All Wan models use the same VAE so we can use the same default model repo to fetch the config
        model_type = "wan-t2v-14B"
    elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint:
        model_type = "hidream"
    else:
        model_type = "v1"

    return model_type