in assets/training/finetune_acft_hf_nlp/src/finetune/finetune.py [0:0]
def finetune(args: Namespace):
"""Finetune."""
logger.info(f"full_determinism is set to {args.enable_full_determinism}")
enable_full_determinism(args.seed) if args.enable_full_determinism else set_seed(args.seed)
# Update the model name or path
model_name_or_path = Path(args.model_selector_output, args.model_name)
if model_name_or_path.is_dir():
args.model_name_or_path = model_name_or_path
else:
args.model_name_or_path = args.model_name
# fetch model asset id
model_asset_id = getattr(args, "model_asset_id", None) or ""
# additional logging
logger.info(f"Model name: {getattr(args, 'model_name', None)}")
logger.info(f"Task name: {getattr(args, 'task_name', None)}")
logger.info(f"Model asset id: {model_asset_id}")
logger.info(f"enable LoRA: {getattr(args, 'apply_lora', None)}")
logger.info(f"enable DeepSpeed: {getattr(args, 'apply_deepspeed', None)}")
logger.info(f"enable ORT: {getattr(args, 'apply_ort', None)}")
logger.info(f"Precision: {getattr(args, 'precision', None)}")
# set `ignore_mismatched_sizes` to `false` by default
if (
hasattr(args, "model_type")
and args.model_type in IGNORE_MISMATCHED_SIZES_FALSE_MODELS
):
logger.info(
f"Identified model type: {args.model_type}. Forcing `ignore_mismatched_sizes` to False."
)
setattr(args, "ignore_mismatched_sizes", False)
# set eval_accumulation_steps to None if passed a non-positive value
eval_accumulation_steps = getattr(args, "eval_accumulation_steps", -1)
if eval_accumulation_steps and eval_accumulation_steps <= 0:
setattr(args, "eval_accumulation_steps", None)
logger.info(f"eval_accumulation_steps: {getattr(args, 'eval_accumulation_steps', None)}")
# read FT config
ft_config_path = Path(args.model_selector_output, SaveFileConstants.ACFT_CONFIG_SAVE_PATH)
if ft_config_path.is_file():
with open(ft_config_path, "r") as rptr:
ft_config = json.load(rptr)
setattr(args, "finetune_config", ft_config)
logger.info("Added finetune config to `component_args`")
# Read the lora parameters from finetune config
if "lora_algo" in ft_config:
logger.info(f'Setting lora_algo to: {ft_config.get("lora_algo")}')
setattr(args, "lora_algo", ft_config.get("lora_algo"))
if "lora_target_modules" in ft_config:
logger.info(f'Setting lora_target_modules to: {ft_config.get("lora_target_modules")}')
setattr(args, "lora_target_modules", ft_config.get("lora_target_modules"))
# Read leaf modules for MoE models from finetune config
if "leaf_modules_of_moe_models" in ft_config:
logger.info(f'Setting leaf_modules_of_moe_models to: {ft_config.get("leaf_modules_of_moe_models")}')
setattr(args, "leaf_modules_of_moe_models", ft_config.get("leaf_modules_of_moe_models"))
# Reading hf trainer args from finetune config
_set_hf_trainer_args_from_finetune_config(args, ft_config)
else:
logger.info(f"{SaveFileConstants.ACFT_CONFIG_SAVE_PATH} does not exist")
setattr(args, "finetune_config", {})
# `mlflow_ft_conf` - contains all mlflow related properties
mlflow_ft_conf = {
"mlflow_model_signature": {},
"mlflow_hftransformers_misc_conf": {},
"mlflow_flavor": None,
}
mlmodel_data = _load_mlflow_model(args.model_selector_output)
mlflow_flavor = None
if mlmodel_data is not None:
mlflow_flavors = [
MLFLOW_FLAVORS.TRANSFORMERS,
MLFLOW_FLAVORS.HFTRANSFORMERS,
MLFLOW_FLAVORS.HFTRANSFORMERSV2,
]
mlflow_flavor = _get_model_flavor(mlflow_flavors, mlmodel_data)
mlflow_ft_conf["mlflow_flavor"] = mlflow_flavor
# set task based mlflow_model_signature
if getattr(args, "task_name", None) is not None:
if mlflow_flavor is not None and mlflow_flavor in MLFLOW_MODEL_SIGNATURES_FOR_FLAVOR.keys():
if args.task_name in MLFLOW_MODEL_SIGNATURES_FOR_FLAVOR[mlflow_flavor]:
mlflow_ft_conf["mlflow_model_signature"] = deep_update(
mlflow_ft_conf["mlflow_model_signature"],
MLFLOW_MODEL_SIGNATURES_FOR_FLAVOR[mlflow_flavor][args.task_name],
)
logger.info(
f"Adding mlflow model signature for task {args.task_name} - "
f"{MLFLOW_MODEL_SIGNATURES_FOR_FLAVOR[mlflow_flavor][args.task_name]}"
)
# set `mlflow_flavor` in finetune args
setattr(args, "mlflow_flavor", mlflow_flavor)
# remove mlflow_model_signature if empty
if "mlflow_model_signature" in mlflow_ft_conf \
and len(mlflow_ft_conf["mlflow_model_signature"]) == 0:
del mlflow_ft_conf["mlflow_model_signature"]
model_name_or_type = None
# pass `mlflow_hftransformers_misc_conf` to be set in mlflow model
if hasattr(args, "model_name") and args.model_name in MLFLOW_HFTRANSFORMERS_MISC_CONF:
model_name_or_type = args.model_name
if hasattr(args, "model_type") and args.model_type in MLFLOW_HFTRANSFORMERS_MISC_CONF:
model_name_or_type = args.model_type
if model_name_or_type is not None:
mlflow_hftransformers_misc_conf = MLFLOW_HFTRANSFORMERS_MISC_CONF[model_name_or_type]
logger.info(
f"Forcing `mlflow_hftransformers_misc_conf` to set to {mlflow_hftransformers_misc_conf} "
f"for {model_name_or_type}"
)
mlflow_ft_conf["mlflow_hftransformers_misc_conf"] = deep_update(
mlflow_ft_conf["mlflow_hftransformers_misc_conf"],
mlflow_hftransformers_misc_conf,
)
metadata = {}
# if MLmodel file exists pass to finetuned model as `base_model_mlmodel`
mlflow_config_file = Path(args.model_selector_output, MLFlowHFFlavourConstants.MISC_CONFIG_FILE)
if mlflow_config_file.is_file():
import yaml
mlflow_data = None
try:
with open(mlflow_config_file, "r") as rptr:
mlflow_data = yaml.safe_load(rptr)
metadata = mlflow_data.get("metadata", {})
except Exception as e:
logger.info(f"Unable to load MLmodel file - {e}")
if mlflow_data is not None:
# pass base model MLmodel file data if available
mlflow_hftransformers_misc_conf = mlflow_ft_conf.get("mlflow_hftransformers_misc_conf", {})
mlflow_hftransformers_misc_conf.update({"base_model_mlmodel": mlflow_data})
mlflow_ft_conf["mlflow_hftransformers_misc_conf"] = deep_update(
mlflow_ft_conf["mlflow_hftransformers_misc_conf"],
mlflow_hftransformers_misc_conf,
)
logger.info(f"Setting `base_model_mlmodel` in finetuned mlflow model - {mlflow_hftransformers_misc_conf}")
else:
logger.info("MLmodel file is empty")
else:
logger.info("MLmodel file does not exist")
if mlmodel_data is not None:
# pass base model MLmodel file data if available
mlflow_hftransformers_misc_conf = mlflow_ft_conf.get("mlflow_hftransformers_misc_conf", {})
mlflow_hftransformers_misc_conf.update({"base_model_mlmodel": mlmodel_data})
mlflow_ft_conf["mlflow_hftransformers_misc_conf"] = deep_update(
mlflow_ft_conf["mlflow_hftransformers_misc_conf"],
mlflow_hftransformers_misc_conf,
)
logger.info(f"Setting `base_model_mlmodel` in finetuned mlflow model - {mlflow_hftransformers_misc_conf}")
else:
logger.info("MLmodel file is empty")
# if input is pytorch model, read metadata if the metadata.json exists.
if not metadata:
metadatapath = os.path.join(model_name_or_path, ModelSelectorDefaults.MODEL_DEFAULTS_PATH)
if os.path.isfile(metadatapath):
with open(metadatapath, "r") as rptr:
metadata = json.load(rptr)
logger.info(f"FT MLFlow config - {mlflow_ft_conf}")
mlflow_ft_conf = deep_update(mlflow_ft_conf, args.finetune_config.get("mlflow_ft_conf", {}))
args.finetune_config["mlflow_ft_conf"] = deepcopy(mlflow_ft_conf)
logger.info(f"Updated FT MLFlow config - {args.finetune_config['mlflow_ft_conf']}")
# Below arguments are needed for HF training args
args.output_dir = args.pytorch_model_folder
Path(args.output_dir).mkdir(exist_ok=True, parents=True)
if args.precision == 16:
set_16bit_precision(args)
args.finetune_in_8bit = bool(args.precision == 8) # 8 bit finetune
args.finetune_in_4bit = bool(args.precision == 4) # 4 bit finetune
# set flash-attention
set_flash_attention(args)
# set gradient-checkpointing
set_gradient_checkpointing(args)
validate_learning_rate(args)
if args.finetune_in_8bit or args.finetune_in_4bit:
if hasattr(args, "model_type") and args.model_type not in QLORA_SUPPORTED_MODEL_TYPES:
raise ACFTValidationException._with_error(
AzureMLError.create(
ACFTUserError,
pii_safe_message=(
f"Quantized finetune is not supported for model family: {args.model_type}."
)
)
)
logger.info("Enabling QLoRA finetuning")
if not args.apply_lora:
logger.info("Lora is not enabled. Setting it to true.")
setattr(args, "apply_lora", True)
if args.apply_deepspeed:
logger.info(
"Deepspeed is enabled which is not compatible with QLoRA. "
"Resetting Deepspeed to false."
)
setattr(args, "apply_deepspeed", False)
if args.gradient_checkpointing:
logger.info(
"Gradient checkpointing is enabled which is not compatible with QLoRA. "
"Resetting Gradient checkpointing to false."
)
setattr(args, "gradient_checkpointing", False)
setattr(args, "apply_ort", can_apply_ort(args, logger))
# Deepspeed enabled
if args.apply_deepspeed:
setup_and_validate_deepspeed(args)
else:
# do not use deepspeed config if provided when apply_deepspeed is set to false
args.deepspeed = None
if (
not isinstance(args.evaluation_steps_interval, float) or
args.evaluation_steps_interval < 0.0 or
args.evaluation_steps_interval > 1.0
):
args.evaluation_steps_interval = 0.0
else:
logger.info(f"evaluation_steps_interval: {args.evaluation_steps_interval}")
if args.save_strategy == SaveStrategy.EVALUATION_STRATEGY:
logger.info(f"Setting save strategy to evaluation strategy: {args.evaluation_strategy}, {args.eval_steps}")
args.save_strategy = args.evaluation_strategy
args.save_steps = args.eval_steps
# setup vllm for finetuned model inference
metadata = setup_vllm(args.task_name, args.finetune_config, metadata)
args.model_metadata = update_acft_metadata(metadata=metadata,
finetuning_task=args.task_name,
base_model_asset_id=model_asset_id)
setup_automl_nlp(args)
# Saving the args is done in `run_finetune` to handle the distributed training
hf_task_runner = get_task_runner(task_name=args.task_name)()
hf_task_runner.run_finetune(args)
# post-training execute any code on main-process only to avoid race conditions.
if is_main_process():
# copy conda file
conda_file_path = Path(args.model_selector_output, MLFlowHFFlavourConstants.CONDA_YAML_FILE)
if conda_file_path.is_file():
shutil.copy(str(conda_file_path), args.output_dir)
logger.info(f"Copied {MLFlowHFFlavourConstants.CONDA_YAML_FILE} file to output dir.")
# copy inference config files
mlflow_ml_configs_dir = Path(args.model_selector_output, "ml_configs")
ml_config_dir = Path(args.output_dir, "ml_configs")
if mlflow_ml_configs_dir.is_dir():
shutil.copytree(
mlflow_ml_configs_dir,
ml_config_dir
)
logger.info("Copied ml_configs folder to output dir.")