def main()

in distilvit/quantize.py [0:0]


def main():
    parser = HfArgumentParser((ConversionArguments,))
    (conv_args,) = parser.parse_args_into_dataclasses()

    model_id = conv_args.model_id
    tokenizer_id = conv_args.tokenizer_id or model_id

    output_model_folder = os.path.join(conv_args.output_parent_dir, model_id)

    # Create output folder
    os.makedirs(output_model_folder, exist_ok=True)

    from_pretrained_kwargs = dict(
        trust_remote_code=conv_args.trust_remote_code,
    )

    # Saving the model config
    config = AutoConfig.from_pretrained(model_id, **from_pretrained_kwargs)

    custom_kwargs = {}
    if conv_args.custom_onnx_configs is not None:
        if conv_args.task == "auto":
            raise Exception(
                "`--task` must be set when exporting with `--custom_onnx_configs`"
            )
        custom_onnx_configs = json.loads(conv_args.custom_onnx_configs)

        for key in custom_onnx_configs:
            onnx_configs = TasksManager._SUPPORTED_MODEL_TYPE[custom_onnx_configs[key]][
                "onnx"
            ]
            mapping = onnx_configs[conv_args.task]
            custom_onnx_configs[key] = mapping.func(config, **mapping.keywords)

        custom_kwargs["custom_onnx_configs"] = custom_onnx_configs

    tokenizer = None
    try:
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_id, **from_pretrained_kwargs
        )

        # To avoid inserting all chat templates into tokenizers.js, we save the chat template
        # to the tokenizer_config.json file, and load it when the tokenizer is loaded.
        if getattr(tokenizer, "chat_template", None) is None and getattr(
            tokenizer, "use_default_system_prompt", False
        ):
            # No chat template specified, and we use the default
            setattr(tokenizer, "chat_template", tokenizer.default_chat_template)

    except KeyError:
        pass  # No Tokenizer

    except Exception as e:
        if config.model_type not in MODELS_WITHOUT_TOKENIZERS:
            raise e

    core_export_kwargs = dict(
        opset=conv_args.opset,
        device=conv_args.device,
        trust_remote_code=conv_args.trust_remote_code,
        **custom_kwargs,
    )

    export_kwargs = dict(
        model_name_or_path=model_id,
        output=output_model_folder,
        task=conv_args.task,
        do_validation=not conv_args.skip_validation,
        library_name="transformers",
        **core_export_kwargs,
    )

    # Handle special cases
    if config.model_type == "marian":
        from .extra.marian import generate_tokenizer_json

        tokenizer_json = generate_tokenizer_json(model_id, tokenizer)

        with open(
            os.path.join(output_model_folder, "tokenizer.json"), "w", encoding="utf-8"
        ) as fp:
            json.dump(tokenizer_json, fp, indent=4)

    elif config.model_type == "esm":
        from .extra.esm import generate_fast_tokenizer

        fast_tokenizer = generate_fast_tokenizer(tokenizer)
        fast_tokenizer.save(os.path.join(output_model_folder, "tokenizer.json"))

    elif config.model_type == "whisper":
        if conv_args.output_attentions:
            from .extra.whisper import get_main_export_kwargs

            export_kwargs.update(
                **get_main_export_kwargs(config, "automatic-speech-recognition")
            )

    elif config.model_type in (
        "wav2vec2",
        "wav2vec2-bert",
        "hubert",
        "unispeech",
        "unispeech-sat",
    ):
        if tokenizer is not None:
            from .extra.wav2vec2 import generate_tokenizer_json

            tokenizer_json = generate_tokenizer_json(tokenizer)

            with open(
                os.path.join(output_model_folder, "tokenizer.json"),
                "w",
                encoding="utf-8",
            ) as fp:
                json.dump(tokenizer_json, fp, indent=4)

    elif config.model_type == "vits":
        if tokenizer is not None:
            from .extra.vits import generate_tokenizer_json

            tokenizer_json = generate_tokenizer_json(tokenizer)

            with open(
                os.path.join(output_model_folder, "tokenizer.json"),
                "w",
                encoding="utf-8",
            ) as fp:
                json.dump(tokenizer_json, fp, indent=4)

    elif config.model_type == "speecht5":
        # TODO allow user to specify vocoder path
        export_kwargs["model_kwargs"] = {"vocoder": "microsoft/speecht5_hifigan"}

        if tokenizer is not None:
            from .extra.speecht5 import generate_tokenizer_json

            tokenizer_json = generate_tokenizer_json(tokenizer)

            with open(
                os.path.join(output_model_folder, "tokenizer.json"),
                "w",
                encoding="utf-8",
            ) as fp:
                json.dump(tokenizer_json, fp, indent=4)

    elif config.model_type in ("owlvit", "owlv2"):
        # Override default batch size to 1, needed because non-maximum suppression is performed for exporting.
        # For more information, see https://github.com/huggingface/optimum/blob/e3b7efb1257c011db907ef40ab340e795cc5684c/optimum/exporters/onnx/model_configs.py#L1028-L1032
        export_kwargs["batch_size"] = 1

    else:
        pass  # TODO

    # Step 1. convert huggingface model to onnx
    if not conv_args.split_modalities:
        main_export(**export_kwargs)
    else:
        custom_export_kwargs = dict(
            output_dir=output_model_folder,
            **core_export_kwargs,
        )

        if config.model_type == "clip":
            # Handle special case for exporting text and vision models separately
            from .extra.clip import (
                CLIPTextModelWithProjectionOnnxConfig,
                CLIPVisionModelWithProjectionOnnxConfig,
            )
            from transformers.models.clip import (
                CLIPTextModelWithProjection,
                CLIPVisionModelWithProjection,
            )

            text_model = CLIPTextModelWithProjection.from_pretrained(
                model_id, **from_pretrained_kwargs
            )
            vision_model = CLIPVisionModelWithProjection.from_pretrained(
                model_id, **from_pretrained_kwargs
            )

            export_models(
                models_and_onnx_configs={
                    "text_model": (
                        text_model,
                        CLIPTextModelWithProjectionOnnxConfig(text_model.config),
                    ),
                    "vision_model": (
                        vision_model,
                        CLIPVisionModelWithProjectionOnnxConfig(vision_model.config),
                    ),
                },
                **custom_export_kwargs,
            )

        elif config.model_type == "siglip":
            # Handle special case for exporting text and vision models separately
            from .extra.siglip import (
                SiglipTextModelOnnxConfig,
                SiglipVisionModelOnnxConfig,
            )
            from transformers.models.siglip import SiglipTextModel, SiglipVisionModel

            text_model = SiglipTextModel.from_pretrained(
                model_id, **from_pretrained_kwargs
            )
            vision_model = SiglipVisionModel.from_pretrained(
                model_id, **from_pretrained_kwargs
            )

            export_models(
                models_and_onnx_configs={
                    "text_model": (
                        text_model,
                        SiglipTextModelOnnxConfig(text_model.config),
                    ),
                    "vision_model": (
                        vision_model,
                        SiglipVisionModelOnnxConfig(vision_model.config),
                    ),
                },
                **custom_export_kwargs,
            )

        # TODO: Enable once https://github.com/huggingface/optimum/pull/1552 is merged
        # elif config.model_type == 'clap':
        #     # Handle special case for exporting text and audio models separately
        #     from .extra.clap import ClapTextModelWithProjectionOnnxConfig, ClapAudioModelWithProjectionOnnxConfig
        #     from transformers.models.clap import ClapTextModelWithProjection, ClapAudioModelWithProjection

        #     text_model = ClapTextModelWithProjection.from_pretrained(model_id, **from_pretrained_kwargs)
        #     audio_model = ClapAudioModelWithProjection.from_pretrained(model_id, **from_pretrained_kwargs)

        #     export_models(
        #         models_and_onnx_configs={
        #             "text_model": (text_model, ClapTextModelWithProjectionOnnxConfig(text_model.config)),
        #             "audio_model": (audio_model, ClapAudioModelWithProjectionOnnxConfig(audio_model.config)),
        #         },
        #         **custom_export_kwargs,
        #     )

        else:
            raise Exception(
                f"Unable to export {config.model_type} model with `--split_modalities`."
            )

    # Step 2. (optional, recommended) quantize the converted model for fast inference and to reduce model size.
    if conv_args.quantize:
        # Update quantize config with model specific defaults
        quantize_config = MODEL_SPECIFIC_QUANTIZE_PARAMS.get(
            config.model_type, DEFAULT_QUANTIZE_PARAMS
        )

        # Update if user specified values
        if conv_args.per_channel is not None:
            quantize_config["per_channel"] = conv_args.per_channel

        if conv_args.reduce_range is not None:
            quantize_config["reduce_range"] = conv_args.reduce_range

        quantize(
            [
                os.path.join(output_model_folder, x)
                for x in os.listdir(output_model_folder)
                if x.endswith(".onnx") and not x.endswith("_quantized.onnx")
            ],
            **quantize_config,
        )

    # Step 3. Move .onnx files to the 'onnx' subfolder
    os.makedirs(os.path.join(output_model_folder, "onnx"), exist_ok=True)
    for file in os.listdir(output_model_folder):
        if file.endswith((".onnx", ".onnx_data")):
            shutil.move(
                os.path.join(output_model_folder, file),
                os.path.join(output_model_folder, "onnx", file),
            )

    # Step 4. Update the generation config if necessary
    if config.model_type == "whisper":
        from transformers import GenerationConfig
        from .extra.whisper import get_alignment_heads

        generation_config = GenerationConfig.from_pretrained(
            model_id, **from_pretrained_kwargs
        )
        generation_config.alignment_heads = get_alignment_heads(config)
        generation_config.save_pretrained(output_model_folder)