def main()

in assets/training/finetune_acft_hf_nlp/src/model_selector/model_selector.py [0:0]


def main():
    """Parse args and import model."""
    # args
    parser = get_parser()
    args, _ = parser.parse_known_args()

    set_logging_parameters(
        task_type=args.task_name,
        acft_custom_dimensions={
            LoggingLiterals.PROJECT_NAME: PROJECT_NAME,
            LoggingLiterals.PROJECT_VERSION_NUMBER: VERSION,
            LoggingLiterals.COMPONENT_NAME: COMPONENT_NAME
        },
        azureml_pkg_denylist_logging_patterns=LOGS_TO_BE_FILTERED_IN_APPINSIGHTS,
        log_level=logging.INFO,
    )

    # Validated custom model type
    if args.mlflow_model_path and \
       not Path(args.mlflow_model_path, MLFlowHFFlavourConstants.MISC_CONFIG_FILE).is_file():
        raise ACFTValidationException._with_error(
                    AzureMLError.create(
                        ACFTUserError,
                        pii_safe_message=(
                            "MLmodel file is not found, If this is a custom model "
                            "it needs to be connected to pytorch_model_path"
                        )
                    )
            )

    # Adding flavor map to args
    setattr(args, "flavor_map", FLAVOR_MAP)

    # run model selector
    model_selector_args = model_selector(args)
    model_name = model_selector_args.get("model_name", ModelSelectorConstants.MODEL_NAME_NOT_FOUND)
    logger.info(f"Model name - {model_name}")
    logger.info(f"Task name: {getattr(args, 'task_name', None)}")
    # Validate port for right model type
    if args.pytorch_model_path and Path(args.pytorch_model_path, MLFlowHFFlavourConstants.MISC_CONFIG_FILE).is_file():
        raise ACFTValidationException._with_error(
                AzureMLError.create(
                    ACFTUserError,
                    pii_safe_message=(
                        "MLFLOW model is connected to pytorch_model_path, "
                        "it needs to be connected to mlflow_model_path"
                    )
                )
            )

    # load ft config and update ACFT config
    # finetune_config_dict = load_finetune_config(args)
    ft_config_obj = FinetuneConfig(
        task_name=args.task_name,
        model_name=model_name,
        model_type=fetch_model_type(str(Path(args.output_dir, model_name))),
        artifacts_finetune_config_path=str(
            Path(
                args.pytorch_model_path or args.mlflow_model_path or "",
                SaveFileConstants.ACFT_CONFIG_SAVE_PATH
            )
        ),
        io_finetune_config_path=args.finetune_config_path
    )
    finetune_config = ft_config_obj.get_finetune_config()

    # read finetune config from base mlmodel file
    # Priority order: io_finetune_config > artifacts_finetune_config > base_model_finetune_config
    updated_finetune_config = deep_update(
        read_base_model_finetune_config(
            args.mlflow_model_path,
            args.task_name
        ),
        finetune_config
    )

    # Copy Mlmodel generator config so that FTed model also uses same generator config while evaluation.
    # (Settings like `max_new_tokens` can help us reduce inference time.)
    # We are updating generation_config.json so that no conflicts will be present between
    # model's config and model's generator_config. (If there is conflict we get warning in logs
    # and from transformers>=4.41.0 exceptions will be raised if `_from_model_config` key is present.)
    if "update_generator_config" in updated_finetune_config:
        generator_config = updated_finetune_config.pop("update_generator_config")
        base_model_generation_config_file = Path(
            args.output_dir, model_selector_args["model_name"], GENERATION_CONFIG_NAME
        )
        if base_model_generation_config_file.is_file():
            update_json_file_and_overwrite(str(base_model_generation_config_file), generator_config)
            logger.info(f"Updated {GENERATION_CONFIG_NAME} with {generator_config}")
        else:
            logger.info(f"Could not update {GENERATION_CONFIG_NAME} as not present.")
    else:
        logger.info(f"{MLFlowHFFlavourConstants.MISC_CONFIG_FILE} does not have any generation config parameters.")

    logger.info(f"Updated finetune config with base model config: {updated_finetune_config}")
    # save FT config
    with open(str(Path(args.output_dir, SaveFileConstants.ACFT_CONFIG_SAVE_PATH)), "w") as rptr:
        json.dump(updated_finetune_config, rptr, indent=2)
    logger.info(f"Saved {SaveFileConstants.ACFT_CONFIG_SAVE_PATH}")

    # copy the mlmodel file to output dir. This is only applicable for mlflow model
    if args.mlflow_model_path is not None:
        mlflow_config_file = Path(args.mlflow_model_path, MLFlowHFFlavourConstants.MISC_CONFIG_FILE)
        if mlflow_config_file.is_file():
            shutil.copy(str(mlflow_config_file), args.output_dir)
            logger.info(f"Copied {MLFlowHFFlavourConstants.MISC_CONFIG_FILE} file to output dir.")

        # copy conda file
        conda_file_path = Path(args.mlflow_model_path, MLFlowHFFlavourConstants.CONDA_YAML_FILE)
        if conda_file_path.is_file():
            shutil.copy(str(conda_file_path), args.output_dir)
            logger.info(f"Copied {MLFlowHFFlavourConstants.CONDA_YAML_FILE} file to output dir.")

        # copy inference config files
        mlflow_ml_configs_dir = Path(args.mlflow_model_path, "ml_configs")
        ml_config_dir = Path(args.output_dir, "ml_configs")
        if mlflow_ml_configs_dir.is_dir():
            shutil.copytree(
                mlflow_ml_configs_dir,
                ml_config_dir
            )
            logger.info("Copied ml_configs folder to output dir.")