def start_training_job()

in src/hyperpod_cli/commands/job.py [0:0]


def start_training_job(recipe, override_parameters, job_name, config_file, launcher_config_path=None, launcher_config_file_name=None,
                      pull_policy=None, restart_policy=None, namespace=None,
                      service_account_name=None, priority_class_name=None, volumes=None, persistent_volume_claims=None,
                      auto_resume=None, label_selector=None, max_retry=None, deep_health_check_passed_nodes_only=None):
    
    logger.info(f"recipe: {recipe}, override_parameters: {override_parameters}, job_name: {job_name}, config_file: {config_file}, launcher_config_path: {launcher_config_path}, launcher_config_file_name: {launcher_config_file_name}")
    env = os.environ.copy()
    env['HYDRA_FULL_ERROR'] = '1'

    if recipe is None:
        logger.debug(f"Starting job with config {launcher_config_path}{launcher_config_file_name}")
        cmd = [
            'python3',
            f'{SAGEMAKER_TRAINING_LAUNCHER_DIR}/main.py',
            f'--config-path={launcher_config_path}',
            f'--config-name={launcher_config_file_name}',
            f'base_results_dir={os.path.abspath(os.path.join(os.getcwd(), "results"))}',
            'cluster.cluster_type=k8s',
        ]
        execute_command(cmd, env)
    else:
        cmd = [
            'python3',
            f'{SAGEMAKER_TRAINING_LAUNCHER_DIR}/main.py',
            f'recipes={recipe}',
            'cluster_type=k8s',
            'cluster=k8s',
            f'base_results_dir={os.path.abspath(os.path.join(os.getcwd(), "results"))}',
        ]

        # Add pull policy if provided
        if pull_policy:
            cmd.append(f'cluster.pullPolicy="{pull_policy}"')

        # Add restart policy if provided
        if restart_policy:
            cmd.append(f'cluster.restartPolicy="{restart_policy}"')

        # Add namespace if provided
        if namespace:
            cmd.append(f'cluster.namespace="{namespace}"')

        # Add service account name if provided
        if service_account_name:
            cmd.append(f'cluster.service_account_name="{service_account_name}"')

        # Add priority class name if provided
        if priority_class_name:
            cmd.append(f'cluster.priority_class_name="{priority_class_name}"')

        # Add volumes if provided (expecting format: "volumeName1:hostPath1:mountPath1,volumeName2:hostPath2:mountPath2")
        if volumes:
            for idx, volume in enumerate(volumes.split(',')):
                vol_name, host_path, mount_path = volume.split(':')
                cmd.append(f'+cluster.volumes.{idx}.volumeName="{vol_name}"')
                cmd.append(f'+cluster.volumes.{idx}.hostPath="{host_path}"')
                cmd.append(f'+cluster.volumes.{idx}.mountPath="{mount_path}"')

        # Add persistent volume claims if provided (expecting format: "claimName1:mountPath1,claimName2:mountPath2")
        if persistent_volume_claims:
            for idx, pvc in enumerate(persistent_volume_claims.split(',')):
                claim_name, mount_path = pvc.split(':')
                cmd.append(f'+cluster.persistent_volume_claims.{idx}.claimName="{claim_name}"')
                cmd.append(f'+cluster.persistent_volume_claims.{idx}.mountPath="{mount_path}"')

        if label_selector:
            cmd.append(f'+cluster.label_selector={label_selector}')
        elif deep_health_check_passed_nodes_only:
            cmd.append(f'+cluster.label_selector={DEEP_HEALTH_CHECK_PASSED_ONLY_NODE_AFFINITY_DICT}')

        if auto_resume:
            # Set max_retry default to 1
            if max_retry is None:
                max_retry = 1
            annotations = {
                HYPERPOD_AUTO_RESUME_ANNOTATION_KEY: auto_resume,
                HYPERPOD_MAX_RETRY_ANNOTATION_KEY: max_retry,
            }
            cmd.append(f'+cluster.annotations="{annotations}"')
        
        logger.info(f"override_parameters: {override_parameters}")
        if override_parameters:
            try:
                # Parse the JSON string into a dictionary
                override_dict = json.loads(override_parameters)

                # Convert the dictionary into key=value pairs
                for key, value in override_dict.items():
                    if isinstance(value, str):
                        # Ensure strings are properly quoted
                        cmd.append(f'{key}="{value}"')
                    else:
                        cmd.append(f'{key}={value}')
            except json.JSONDecodeError as e:
                logger.error(f"Invalid JSON format: {e}")
                sys.exit(1)

        print(f"Final command: {' '.join(cmd)}")
        execute_command(cmd, env)

    if job_name is not None and config_file is None:
        file_to_delete = os.path.join(launcher_config_path, launcher_config_file_name)
        if os.path.exists(file_to_delete):
            os.remove(file_to_delete)