in assets/training/finetune_acft_hf_nlp/src/model_selector/model_selector.py [0:0]
def main():
"""Parse args and import model."""
# args
parser = get_parser()
args, _ = parser.parse_known_args()
set_logging_parameters(
task_type=args.task_name,
acft_custom_dimensions={
LoggingLiterals.PROJECT_NAME: PROJECT_NAME,
LoggingLiterals.PROJECT_VERSION_NUMBER: VERSION,
LoggingLiterals.COMPONENT_NAME: COMPONENT_NAME
},
azureml_pkg_denylist_logging_patterns=LOGS_TO_BE_FILTERED_IN_APPINSIGHTS,
log_level=logging.INFO,
)
# Validated custom model type
if args.mlflow_model_path and \
not Path(args.mlflow_model_path, MLFlowHFFlavourConstants.MISC_CONFIG_FILE).is_file():
raise ACFTValidationException._with_error(
AzureMLError.create(
ACFTUserError,
pii_safe_message=(
"MLmodel file is not found, If this is a custom model "
"it needs to be connected to pytorch_model_path"
)
)
)
# Adding flavor map to args
setattr(args, "flavor_map", FLAVOR_MAP)
# run model selector
model_selector_args = model_selector(args)
model_name = model_selector_args.get("model_name", ModelSelectorConstants.MODEL_NAME_NOT_FOUND)
logger.info(f"Model name - {model_name}")
logger.info(f"Task name: {getattr(args, 'task_name', None)}")
# Validate port for right model type
if args.pytorch_model_path and Path(args.pytorch_model_path, MLFlowHFFlavourConstants.MISC_CONFIG_FILE).is_file():
raise ACFTValidationException._with_error(
AzureMLError.create(
ACFTUserError,
pii_safe_message=(
"MLFLOW model is connected to pytorch_model_path, "
"it needs to be connected to mlflow_model_path"
)
)
)
# load ft config and update ACFT config
# finetune_config_dict = load_finetune_config(args)
ft_config_obj = FinetuneConfig(
task_name=args.task_name,
model_name=model_name,
model_type=fetch_model_type(str(Path(args.output_dir, model_name))),
artifacts_finetune_config_path=str(
Path(
args.pytorch_model_path or args.mlflow_model_path or "",
SaveFileConstants.ACFT_CONFIG_SAVE_PATH
)
),
io_finetune_config_path=args.finetune_config_path
)
finetune_config = ft_config_obj.get_finetune_config()
# read finetune config from base mlmodel file
# Priority order: io_finetune_config > artifacts_finetune_config > base_model_finetune_config
updated_finetune_config = deep_update(
read_base_model_finetune_config(
args.mlflow_model_path,
args.task_name
),
finetune_config
)
# Copy Mlmodel generator config so that FTed model also uses same generator config while evaluation.
# (Settings like `max_new_tokens` can help us reduce inference time.)
# We are updating generation_config.json so that no conflicts will be present between
# model's config and model's generator_config. (If there is conflict we get warning in logs
# and from transformers>=4.41.0 exceptions will be raised if `_from_model_config` key is present.)
if "update_generator_config" in updated_finetune_config:
generator_config = updated_finetune_config.pop("update_generator_config")
base_model_generation_config_file = Path(
args.output_dir, model_selector_args["model_name"], GENERATION_CONFIG_NAME
)
if base_model_generation_config_file.is_file():
update_json_file_and_overwrite(str(base_model_generation_config_file), generator_config)
logger.info(f"Updated {GENERATION_CONFIG_NAME} with {generator_config}")
else:
logger.info(f"Could not update {GENERATION_CONFIG_NAME} as not present.")
else:
logger.info(f"{MLFlowHFFlavourConstants.MISC_CONFIG_FILE} does not have any generation config parameters.")
logger.info(f"Updated finetune config with base model config: {updated_finetune_config}")
# save FT config
with open(str(Path(args.output_dir, SaveFileConstants.ACFT_CONFIG_SAVE_PATH)), "w") as rptr:
json.dump(updated_finetune_config, rptr, indent=2)
logger.info(f"Saved {SaveFileConstants.ACFT_CONFIG_SAVE_PATH}")
# copy the mlmodel file to output dir. This is only applicable for mlflow model
if args.mlflow_model_path is not None:
mlflow_config_file = Path(args.mlflow_model_path, MLFlowHFFlavourConstants.MISC_CONFIG_FILE)
if mlflow_config_file.is_file():
shutil.copy(str(mlflow_config_file), args.output_dir)
logger.info(f"Copied {MLFlowHFFlavourConstants.MISC_CONFIG_FILE} file to output dir.")
# copy conda file
conda_file_path = Path(args.mlflow_model_path, MLFlowHFFlavourConstants.CONDA_YAML_FILE)
if conda_file_path.is_file():
shutil.copy(str(conda_file_path), args.output_dir)
logger.info(f"Copied {MLFlowHFFlavourConstants.CONDA_YAML_FILE} file to output dir.")
# copy inference config files
mlflow_ml_configs_dir = Path(args.mlflow_model_path, "ml_configs")
ml_config_dir = Path(args.output_dir, "ml_configs")
if mlflow_ml_configs_dir.is_dir():
shutil.copytree(
mlflow_ml_configs_dir,
ml_config_dir
)
logger.info("Copied ml_configs folder to output dir.")