in assets/training/finetune_acft_image/src/finetune/finetune.py [0:0]
def main():
"""Driver function."""
parser = get_parser()
args, _ = parser.parse_known_args()
# Validate argument types without modifying args
for action in parser._actions:
arg_name = action.dest
expected_type = action.type
arg_value = getattr(args, arg_name, None)
if expected_type and arg_value is not None:
try:
# Attempt to cast the argument to the expected type without modifying args
expected_type(arg_value)
except (ValueError, TypeError):
error_msg = f"Argument '{arg_name}' expects type {expected_type.__name__}, but got value '{arg_value}'"
raise ACFTValidationException._with_error(
AzureMLError.create(ACFTUserError, pii_safe_message=error_msg)
)
print(f"Deepspeed Version: {deepspeed.__version__}")
if args.task_name in [Tasks.HF_SD_TEXT_TO_IMAGE]:
parser = add_sd_args_to_parser(parser)
args, _ = parser.parse_known_args()
# saving is needed as lora adapters are picked from checkpoints.
# args.save_strategy = "no"
# args.evaluation_strategy = "no"
# args.eval_steps = None
# args.save_steps = None
# Todo: Remove if not required
if args.with_prior_preservation and args.class_data_dir and not os.path.exists(args.class_data_dir):
os.makedirs(args.class_data_dir, exist_ok=True)
# step learning rate scheduler can only come from sweep component. The only other option that's available in
# sweep components is warmup_cosine, hence we are raising following exception.
if args.lr_scheduler_type == IncomingLearingScheduler.STEP:
error_string = (
f"Step scheduler is not supported by Huggingface and MMdetection trainer. Please choose "
f"{IncomingLearingScheduler.WARMUP_COSINE} as the learning rate scheduler."
)
raise ACFTValidationException._with_error(AzureMLError.create(ACFTUserError, pii_safe_message=error_string))
set_logging_parameters(
task_type=args.task_name,
acft_custom_dimensions={
LoggingLiterals.PROJECT_NAME: PROJECT_NAME,
LoggingLiterals.PROJECT_VERSION_NUMBER: VERSION,
},
azureml_pkg_denylist_logging_patterns=LOGS_TO_BE_FILTERED_IN_APPINSIGHTS,
)
# Read the preprocess component args
# Preprocess Component + Model Selector Component ---> Finetune Component
# Since all Model Selector Component args are saved via Preprocess Component, loading the Preprocess args
# suffices
model_selector_args_save_path = os.path.join(args.model_path, ModelSelectorDefaults.MODEL_SELECTOR_ARGS_SAVE_PATH)
with open(model_selector_args_save_path, "r") as rptr:
preprocess_args = json.load(rptr)
for key, value in preprocess_args.items():
if not hasattr(args, key): # update the values that don't already exist
logger.info(f"{key}, {value}")
setattr(args, key, value)
# continual_finetuning is when an existing model which is already registered in the workspace
# is used for finetuning. In that case, model_name would be None and either of the
# pytorch_model_path or mlflow_model_path would be present.
# In case of continual finetuning, We need to set model_name_or_path to the pyroch/mlflow model path.
if hasattr(args, "pytorch_model_path") and args.pytorch_model_path:
args.model_name_or_path = os.path.join(args.model_path, args.pytorch_model_path)
args.is_continual_finetuning = True
if args.model_family in (MODEL_FAMILY_CLS.MMDETECTION_IMAGE, MODEL_FAMILY_CLS.MMTRACKING_VIDEO):
args.model_weights_path_or_url = os.path.join(args.model_path, args.model_weights_path_or_url)
elif hasattr(args, "mlflow_model_path") and args.mlflow_model_path:
args.model_name_or_path = os.path.join(args.model_path, args.mlflow_model_path)
args.is_continual_finetuning = True
if args.model_family in (MODEL_FAMILY_CLS.MMDETECTION_IMAGE, MODEL_FAMILY_CLS.MMTRACKING_VIDEO):
args.model_weights_path_or_url = os.path.join(args.model_path, args.model_weights_path_or_url)
elif hasattr(args, "model_name") and args.model_name:
args.model_name_or_path = args.model_name
args.is_continual_finetuning = False
if args.model_family in (MODEL_FAMILY_CLS.MMDETECTION_IMAGE, MODEL_FAMILY_CLS.MMTRACKING_VIDEO):
args.model_weights_path_or_url = args.model_weights_path_or_url
else:
raise ACFTValidationException._with_error(
AzureMLError.create(ModelInputEmpty, argument_name="Model ports and model_name")
)
# Map learing rate scheduler to as expected by the Trainer class
if args.lr_scheduler_type is not None:
args.lr_scheduler_type = Mapper.LR_SCHEDULER_MAP[args.lr_scheduler_type]
# Map optimizer to as expected by the Trainer class
if args.optim is not None:
args.optim = Mapper.OPTIMIZER_MAP[args.optim]
# Update 'args' namespace with defaults based on task type and model selected
# Doing this before any assignment to 'args' namespace
if args.task_name in [
Tasks.MM_OBJECT_DETECTION,
Tasks.MM_INSTANCE_SEGMENTATION,
Tasks.HF_MULTI_CLASS_IMAGE_CLASSIFICATION,
Tasks.HF_MULTI_LABEL_IMAGE_CLASSIFICATION,
Tasks.MM_MULTI_OBJECT_TRACKING,
Tasks.HF_SD_TEXT_TO_IMAGE
]:
training_defaults = TrainingDefaults(
task=args.task_name,
model_name_or_path=args.model_name_or_path,
)
# Update the namespace object with values from the dictionary
# Only update the values that don't already exist or are None
for key, value in training_defaults.defaults_dict.items():
if not hasattr(args, key) or getattr(args, key) is None:
setattr(args, key, value)
logger.info(f"Using learning rate scheduler - {args.lr_scheduler_type}")
logger.info(f"Using optimizer - {args.optim}")
if (
args.task_name == Tasks.HF_MULTI_LABEL_IMAGE_CLASSIFICATION
and args.label_smoothing_factor is not None
and args.label_smoothing_factor > 0.0
):
args.label_smoothing_factor = 0.0
msg = (
f"Label smoothing is not supported for multi-label image classification. "
f"Setting label_smoothing_factor to 0.0 from {args.label_smoothing_factor}"
)
logger.warning(msg)
# We don't support DS & ORT training for OD and IS tasks.
if args.task_name in [Tasks.MM_OBJECT_DETECTION, Tasks.MM_INSTANCE_SEGMENTATION] and (
args.apply_deepspeed is True or args.apply_ort is True):
err_msg = (
f"apply_deepspeed or apply_ort is not yet supported for {args.task_name}. "
"Please disable ds and ort training."
)
raise ACFTValidationException._with_error(AzureMLError.create(ACFTUserError, pii_safe_message=err_msg))
if args.task_name in [
Tasks.MM_OBJECT_DETECTION, Tasks.MM_INSTANCE_SEGMENTATION, Tasks.MM_MULTI_OBJECT_TRACKING,
Tasks.HF_SD_TEXT_TO_IMAGE
]:
# Note: This is temporary check to disable deepspeed and ORT for MM tasks, until they are working.
deepspeed_ort_arg_names = []
if args.apply_deepspeed is True:
deepspeed_ort_arg_names.append("apply_deepspeed")
if args.apply_ort is True:
deepspeed_ort_arg_names.append("apply_ort")
if len(deepspeed_ort_arg_names) >= 1:
deepspeed_ort_arg_names = ",".join(deepspeed_ort_arg_names)
err_msg = f"{deepspeed_ort_arg_names} not yet supported for {args.task_name}, will be enabled in future."
raise ACFTValidationException._with_error(
AzureMLError.create(ArgumentInvalid, argument_name=f"{deepspeed_ort_arg_names}", expected_type=err_msg)
)
if args.apply_ort is False and args.optim == IncomingOptimizerNames.ADAMW_ORT_FUSED:
error_string = (
f"ORT fused AdamW ({IncomingOptimizerNames.ADAMW_ORT_FUSED}) optimizer should only be used with ORT "
f"training."
)
raise ACFTValidationException._with_error(AzureMLError.create(ACFTUserError, pii_safe_message=error_string))
if args.apply_ort is True and args.optim != IncomingOptimizerNames.ADAMW_ORT_FUSED:
logger.warning(
f"ORT training is enabled but optimizer is not set to {IncomingOptimizerNames.ADAMW_ORT_FUSED}, "
f"setting optimizer to {IncomingOptimizerNames.ADAMW_ORT_FUSED}"
)
args.optim = IncomingOptimizerNames.ADAMW_ORT_FUSED
if args.task_name != Tasks.HF_SD_TEXT_TO_IMAGE:
# Metrics is not supported for text-to-image task yet
if args.task_name != Tasks.HF_MULTI_LABEL_IMAGE_CLASSIFICATION and "iou" in args.metric_for_best_model:
err_msg = (
f"{args.metric_for_best_model} metric supported only for {Tasks.HF_MULTI_LABEL_IMAGE_CLASSIFICATION}"
)
raise ACFTValidationException._with_error(
AzureMLError.create(ArgumentInvalid, argument_name="metric_for_best_model", expected_type=err_msg)
)
# Prepare args as per the TrainingArguments class+
use_fp16 = bool(args.precision == 16)
# Read the default deepspeed config if the apply_deepspeed is set to true without providing config file
if args.apply_deepspeed and args.deepspeed_config is None:
args.deepspeed_config = os.path.join(os.path.dirname(os.path.abspath(__file__)), "zero1.json")
with open(args.deepspeed_config) as fp:
ds_dict = json.load(fp)
use_fp16 = "fp16" in ds_dict and "enabled" in ds_dict["fp16"] and ds_dict["fp16"]["enabled"]
if args.apply_deepspeed and args.deepspeed_config is not None:
with open(args.deepspeed_config) as fp:
try:
_ = json.load(fp)
except json.JSONDecodeError as e:
raise ACFTValidationException._with_error(
AzureMLError.create(
ACFTUserError,
pii_safe_message=f"Invalid JSON in deepspeed config file: {str(e)}"
)
)
args.fp16 = use_fp16
args.deepspeed = args.deepspeed_config if args.apply_deepspeed else None
if args.metric_for_best_model in MetricConstants.METRIC_LESSER_IS_BETTER:
args.metric_greater_is_better = False
else:
args.metric_greater_is_better = True
args.load_best_model_at_end = True
args.report_to = None
args.save_safetensors = False
logger.info(f"metric_for_best_model - {args.metric_for_best_model}")
logger.info(f"metric_greater_is_better - {args.metric_greater_is_better}")
logger.info(f"save_safetensors - {args.save_safetensors}")
# setting arguments as needed for the core
args.model_selector_output = args.model_path
args.output_dir = SettingParameters.DEFAULT_OUTPUT_DIR
# Empty the output directory in master process.
# Todo: Ideally, in case of preemption, we should handle start finetuning
# from the last checkpoint. This logic will be implemented in the future.
# For now, we are deleting the output directory and starting the training
# from scratch.
master_process = os.environ.get("RANK") == "0"
if os.path.exists(args.output_dir) and master_process:
shutil.rmtree(args.output_dir)
os.makedirs(args.output_dir, exist_ok=True)
# TODO: overwriting the save_as_mlflow_model flag to True. Otherwise, it will fail the pipeline service since it
# expects the mlflow model folder to create model asset. It can be modified if outputs of the component can be
# optional.
args.save_as_mlflow_model = True
# Disable adding prefixes to logger.
args.set_log_prefix = False
logger.info(f"Using log prefix - {args.set_log_prefix}")
if args.apply_lora:
args.lora_algo = "peft"
args.label_names = ["labels"]
logger.info(args)
# Saving the args is done in `finetune_runner` to handle the distributed training
finetune_runner(args)