def main()

in optimum/exporters/tflite/__main__.py [0:0]


def main():
    parser = ArgumentParser("Hugging Face Optimum TensorFlow Lite exporter")

    parse_args_tflite(parser)

    # Retrieve CLI arguments
    args = parser.parse_args()
    args.output = args.output.joinpath("model.tflite")

    if not args.output.parent.exists():
        args.output.parent.mkdir(parents=True)

    # Infer the task
    task = args.task
    if task == "auto":
        try:
            task = TasksManager.infer_task_from_model(args.model, library_name="transformers")
        except KeyError as e:
            raise KeyError(
                "The task could not be automatically inferred. Please provide the argument --task with the task "
                f"from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
            )
        except RequestsConnectionError as e:
            raise RequestsConnectionError(
                f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
            )

    model = TasksManager.get_model_from_task(
        task,
        args.model,
        framework="tf",
        cache_dir=args.cache_dir,
        trust_remote_code=args.trust_remote_code,
        library_name="transformers",
    )

    tflite_config_constructor = TasksManager.get_exporter_config_constructor(
        model=model, exporter="tflite", task=task, library_name="transformers"
    )
    # TODO: find a cleaner way to do this.
    shapes = {name: getattr(args, name) for name in tflite_config_constructor.func.get_mandatory_axes_for_task(task)}
    tflite_config = tflite_config_constructor(model.config, **shapes)

    if args.atol is None:
        args.atol = tflite_config.ATOL_FOR_VALIDATION
        if isinstance(args.atol, dict):
            args.atol = args.atol[task.replace("-with-past", "")]

    # Saving the model config and preprocessor as this is needed sometimes.
    model.config.save_pretrained(args.output.parent)
    maybe_save_preprocessors(args.model, args.output.parent)

    preprocessor = maybe_load_preprocessors(args.output.parent)
    if preprocessor:
        preprocessor = preprocessor[0]
    else:
        preprocessor = None

    quantization_config = None
    if args.quantize:
        quantization_config = TFLiteQuantizationConfig(
            approach=args.quantize,
            fallback_to_float=args.fallback_to_float,
            inputs_dtype=args.inputs_type,
            outputs_dtype=args.outputs_type,
            calibration_dataset_name_or_path=args.calibration_dataset,
            calibration_dataset_config_name=args.calibration_dataset_config_name,
            num_calibration_samples=args.num_calibration_samples,
            calibration_split=args.calibration_split,
            primary_key=args.primary_key,
            secondary_key=args.secondary_key,
            question_key=args.question_key,
            context_key=args.context_key,
            image_key=args.image_key,
        )

    tflite_inputs, tflite_outputs = export(
        model=model,
        config=tflite_config,
        output=args.output,
        task=task,
        preprocessor=preprocessor,
        quantization_config=quantization_config,
    )

    if args.quantize is None:
        try:
            validate_model_outputs(
                config=tflite_config,
                reference_model=model,
                tflite_model_path=args.output,
                tflite_named_outputs=tflite_config.outputs,
                atol=args.atol,
            )

            logger.info(
                "The TensorFlow Lite export succeeded and the exported model was saved at: "
                f"{args.output.parent.as_posix()}"
            )
        except ShapeError as e:
            raise e
        except AtolError as e:
            logger.warning(
                f"The TensorFlow Lite export succeeded with the warning: {e}.\n The exported model was saved at: "
                f"{args.output.parent.as_posix()}"
            )
        except OutputMatchError as e:
            logger.warning(
                f"The TensorFlow Lite export succeeded with the warning: {e}.\n The exported model was saved at: "
                f"{args.output.parent.as_posix()}"
            )
        except Exception as e:
            logger.error(
                f"An error occured with the error message: {e}.\n The exported model was saved at: "
                f"{args.output.parent.as_posix()}"
            )