def main()

in assets/training/finetune_acft_image/src/finetune/finetune.py [0:0]


def main():
    """Driver function."""
    parser = get_parser()
    args, _ = parser.parse_known_args()
    # Validate argument types without modifying args
    for action in parser._actions:
        arg_name = action.dest
        expected_type = action.type
        arg_value = getattr(args, arg_name, None)

        if expected_type and arg_value is not None:
            try:
                # Attempt to cast the argument to the expected type without modifying args
                expected_type(arg_value)
            except (ValueError, TypeError):
                error_msg = f"Argument '{arg_name}' expects type {expected_type.__name__}, but got value '{arg_value}'"
                raise ACFTValidationException._with_error(
                    AzureMLError.create(ACFTUserError, pii_safe_message=error_msg)
                )
    print(f"Deepspeed Version: {deepspeed.__version__}")

    if args.task_name in [Tasks.HF_SD_TEXT_TO_IMAGE]:
        parser = add_sd_args_to_parser(parser)
        args, _ = parser.parse_known_args()
        # saving is needed as lora adapters are picked from checkpoints.
        # args.save_strategy = "no"
        # args.evaluation_strategy = "no"
        # args.eval_steps = None
        # args.save_steps = None

        # Todo: Remove if not required
        if args.with_prior_preservation and args.class_data_dir and not os.path.exists(args.class_data_dir):
            os.makedirs(args.class_data_dir, exist_ok=True)

    # step learning rate scheduler can only come from sweep component. The only other option that's available in
    # sweep components is warmup_cosine, hence we are raising following exception.
    if args.lr_scheduler_type == IncomingLearingScheduler.STEP:
        error_string = (
            f"Step scheduler is not supported by Huggingface and MMdetection trainer. Please choose "
            f"{IncomingLearingScheduler.WARMUP_COSINE} as the learning rate scheduler."
        )
        raise ACFTValidationException._with_error(AzureMLError.create(ACFTUserError, pii_safe_message=error_string))

    set_logging_parameters(
        task_type=args.task_name,
        acft_custom_dimensions={
            LoggingLiterals.PROJECT_NAME: PROJECT_NAME,
            LoggingLiterals.PROJECT_VERSION_NUMBER: VERSION,
        },
        azureml_pkg_denylist_logging_patterns=LOGS_TO_BE_FILTERED_IN_APPINSIGHTS,
    )

    # Read the preprocess component args
    # Preprocess Component + Model Selector Component ---> Finetune Component
    # Since all Model Selector Component args are saved via Preprocess Component, loading the Preprocess args
    # suffices
    model_selector_args_save_path = os.path.join(args.model_path, ModelSelectorDefaults.MODEL_SELECTOR_ARGS_SAVE_PATH)
    with open(model_selector_args_save_path, "r") as rptr:
        preprocess_args = json.load(rptr)
        for key, value in preprocess_args.items():
            if not hasattr(args, key):  # update the values that don't already exist
                logger.info(f"{key}, {value}")
                setattr(args, key, value)

    # continual_finetuning is when an existing model which is already registered in the workspace
    # is used for finetuning. In that case, model_name would be None and either of the
    # pytorch_model_path or mlflow_model_path would be present.
    # In case of continual finetuning, We need to set model_name_or_path to the pyroch/mlflow model path.
    if hasattr(args, "pytorch_model_path") and args.pytorch_model_path:
        args.model_name_or_path = os.path.join(args.model_path, args.pytorch_model_path)
        args.is_continual_finetuning = True
        if args.model_family in (MODEL_FAMILY_CLS.MMDETECTION_IMAGE, MODEL_FAMILY_CLS.MMTRACKING_VIDEO):
            args.model_weights_path_or_url = os.path.join(args.model_path, args.model_weights_path_or_url)
    elif hasattr(args, "mlflow_model_path") and args.mlflow_model_path:
        args.model_name_or_path = os.path.join(args.model_path, args.mlflow_model_path)
        args.is_continual_finetuning = True
        if args.model_family in (MODEL_FAMILY_CLS.MMDETECTION_IMAGE, MODEL_FAMILY_CLS.MMTRACKING_VIDEO):
            args.model_weights_path_or_url = os.path.join(args.model_path, args.model_weights_path_or_url)
    elif hasattr(args, "model_name") and args.model_name:
        args.model_name_or_path = args.model_name
        args.is_continual_finetuning = False
        if args.model_family in (MODEL_FAMILY_CLS.MMDETECTION_IMAGE, MODEL_FAMILY_CLS.MMTRACKING_VIDEO):
            args.model_weights_path_or_url = args.model_weights_path_or_url
    else:
        raise ACFTValidationException._with_error(
            AzureMLError.create(ModelInputEmpty, argument_name="Model ports and model_name")
        )

    # Map learing rate scheduler to as expected by the Trainer class
    if args.lr_scheduler_type is not None:
        args.lr_scheduler_type = Mapper.LR_SCHEDULER_MAP[args.lr_scheduler_type]

    # Map optimizer to as expected by the Trainer class
    if args.optim is not None:
        args.optim = Mapper.OPTIMIZER_MAP[args.optim]

    # Update 'args' namespace with defaults based on task type and model selected
    # Doing this before any assignment to 'args' namespace
    if args.task_name in [
        Tasks.MM_OBJECT_DETECTION,
        Tasks.MM_INSTANCE_SEGMENTATION,
        Tasks.HF_MULTI_CLASS_IMAGE_CLASSIFICATION,
        Tasks.HF_MULTI_LABEL_IMAGE_CLASSIFICATION,
        Tasks.MM_MULTI_OBJECT_TRACKING,
        Tasks.HF_SD_TEXT_TO_IMAGE
    ]:
        training_defaults = TrainingDefaults(
            task=args.task_name,
            model_name_or_path=args.model_name_or_path,
        )
        # Update the namespace object with values from the dictionary
        # Only update the values that don't already exist or are None
        for key, value in training_defaults.defaults_dict.items():
            if not hasattr(args, key) or getattr(args, key) is None:
                setattr(args, key, value)

    logger.info(f"Using learning rate scheduler - {args.lr_scheduler_type}")
    logger.info(f"Using optimizer - {args.optim}")

    if (
        args.task_name == Tasks.HF_MULTI_LABEL_IMAGE_CLASSIFICATION
        and args.label_smoothing_factor is not None
        and args.label_smoothing_factor > 0.0
    ):
        args.label_smoothing_factor = 0.0
        msg = (
            f"Label smoothing is not supported for multi-label image classification. "
            f"Setting label_smoothing_factor to 0.0 from {args.label_smoothing_factor}"
        )
        logger.warning(msg)

    # We don't support DS & ORT training for OD and IS tasks.
    if args.task_name in [Tasks.MM_OBJECT_DETECTION, Tasks.MM_INSTANCE_SEGMENTATION] and (
       args.apply_deepspeed is True or args.apply_ort is True):
        err_msg = (
            f"apply_deepspeed or apply_ort is not yet supported for {args.task_name}. "
            "Please disable ds and ort training."
        )
        raise ACFTValidationException._with_error(AzureMLError.create(ACFTUserError, pii_safe_message=err_msg))

    if args.task_name in [
        Tasks.MM_OBJECT_DETECTION, Tasks.MM_INSTANCE_SEGMENTATION, Tasks.MM_MULTI_OBJECT_TRACKING,
        Tasks.HF_SD_TEXT_TO_IMAGE
    ]:
        # Note: This is temporary check to disable deepspeed and ORT for MM tasks, until they are working.
        deepspeed_ort_arg_names = []
        if args.apply_deepspeed is True:
            deepspeed_ort_arg_names.append("apply_deepspeed")
        if args.apply_ort is True:
            deepspeed_ort_arg_names.append("apply_ort")
        if len(deepspeed_ort_arg_names) >= 1:
            deepspeed_ort_arg_names = ",".join(deepspeed_ort_arg_names)
            err_msg = f"{deepspeed_ort_arg_names} not yet supported for {args.task_name}, will be enabled in future."
            raise ACFTValidationException._with_error(
                AzureMLError.create(ArgumentInvalid, argument_name=f"{deepspeed_ort_arg_names}", expected_type=err_msg)
            )

    if args.apply_ort is False and args.optim == IncomingOptimizerNames.ADAMW_ORT_FUSED:
        error_string = (
            f"ORT fused AdamW ({IncomingOptimizerNames.ADAMW_ORT_FUSED}) optimizer should only be used with ORT "
            f"training."
        )
        raise ACFTValidationException._with_error(AzureMLError.create(ACFTUserError, pii_safe_message=error_string))

    if args.apply_ort is True and args.optim != IncomingOptimizerNames.ADAMW_ORT_FUSED:
        logger.warning(
            f"ORT training is enabled but optimizer is not set to {IncomingOptimizerNames.ADAMW_ORT_FUSED}, "
            f"setting optimizer to {IncomingOptimizerNames.ADAMW_ORT_FUSED}"
        )
        args.optim = IncomingOptimizerNames.ADAMW_ORT_FUSED

    if args.task_name != Tasks.HF_SD_TEXT_TO_IMAGE:
        # Metrics is not supported for text-to-image task yet
        if args.task_name != Tasks.HF_MULTI_LABEL_IMAGE_CLASSIFICATION and "iou" in args.metric_for_best_model:
            err_msg = (
                f"{args.metric_for_best_model} metric supported only for {Tasks.HF_MULTI_LABEL_IMAGE_CLASSIFICATION}"
            )
            raise ACFTValidationException._with_error(
                AzureMLError.create(ArgumentInvalid, argument_name="metric_for_best_model", expected_type=err_msg)
            )

    # Prepare args as per the TrainingArguments class+
    use_fp16 = bool(args.precision == 16)

    # Read the default deepspeed config if the apply_deepspeed is set to true without providing config file
    if args.apply_deepspeed and args.deepspeed_config is None:
        args.deepspeed_config = os.path.join(os.path.dirname(os.path.abspath(__file__)), "zero1.json")
        with open(args.deepspeed_config) as fp:
            ds_dict = json.load(fp)
            use_fp16 = "fp16" in ds_dict and "enabled" in ds_dict["fp16"] and ds_dict["fp16"]["enabled"]

    if args.apply_deepspeed and args.deepspeed_config is not None:
        with open(args.deepspeed_config) as fp:
            try:
                _ = json.load(fp)
            except json.JSONDecodeError as e:
                raise ACFTValidationException._with_error(
                    AzureMLError.create(
                        ACFTUserError,
                        pii_safe_message=f"Invalid JSON in deepspeed config file: {str(e)}"
                    )
                )

    args.fp16 = use_fp16
    args.deepspeed = args.deepspeed_config if args.apply_deepspeed else None
    if args.metric_for_best_model in MetricConstants.METRIC_LESSER_IS_BETTER:
        args.metric_greater_is_better = False
    else:
        args.metric_greater_is_better = True
    args.load_best_model_at_end = True
    args.report_to = None
    args.save_safetensors = False
    logger.info(f"metric_for_best_model - {args.metric_for_best_model}")
    logger.info(f"metric_greater_is_better - {args.metric_greater_is_better}")
    logger.info(f"save_safetensors - {args.save_safetensors}")

    # setting arguments as needed for the core
    args.model_selector_output = args.model_path
    args.output_dir = SettingParameters.DEFAULT_OUTPUT_DIR

    # Empty the output directory in master process.
    # Todo: Ideally, in case of preemption, we should handle start finetuning
    # from the last checkpoint. This logic will be implemented in the future.
    # For now, we are deleting the output directory and starting the training
    # from scratch.
    master_process = os.environ.get("RANK") == "0"
    if os.path.exists(args.output_dir) and master_process:
        shutil.rmtree(args.output_dir)
    os.makedirs(args.output_dir, exist_ok=True)

    # TODO: overwriting the save_as_mlflow_model flag to True. Otherwise, it will fail the pipeline service since it
    #  expects the mlflow model folder to create model asset. It can be modified if outputs of the component can be
    #  optional.
    args.save_as_mlflow_model = True

    # Disable adding prefixes to logger.
    args.set_log_prefix = False
    logger.info(f"Using log prefix - {args.set_log_prefix}")

    if args.apply_lora:
        args.lora_algo = "peft"
        args.label_names = ["labels"]

    logger.info(args)

    # Saving the args is done in `finetune_runner` to handle the distributed training
    finetune_runner(args)