in src/hyperpod_cli/commands/job.py [0:0]
def start_training_job(recipe, override_parameters, job_name, config_file, launcher_config_path=None, launcher_config_file_name=None,
pull_policy=None, restart_policy=None, namespace=None,
service_account_name=None, priority_class_name=None, volumes=None, persistent_volume_claims=None,
auto_resume=None, label_selector=None, max_retry=None, deep_health_check_passed_nodes_only=None):
logger.info(f"recipe: {recipe}, override_parameters: {override_parameters}, job_name: {job_name}, config_file: {config_file}, launcher_config_path: {launcher_config_path}, launcher_config_file_name: {launcher_config_file_name}")
env = os.environ.copy()
env['HYDRA_FULL_ERROR'] = '1'
if recipe is None:
logger.debug(f"Starting job with config {launcher_config_path}{launcher_config_file_name}")
cmd = [
'python3',
f'{SAGEMAKER_TRAINING_LAUNCHER_DIR}/main.py',
f'--config-path={launcher_config_path}',
f'--config-name={launcher_config_file_name}',
f'base_results_dir={os.path.abspath(os.path.join(os.getcwd(), "results"))}',
'cluster.cluster_type=k8s',
]
execute_command(cmd, env)
else:
cmd = [
'python3',
f'{SAGEMAKER_TRAINING_LAUNCHER_DIR}/main.py',
f'recipes={recipe}',
'cluster_type=k8s',
'cluster=k8s',
f'base_results_dir={os.path.abspath(os.path.join(os.getcwd(), "results"))}',
]
# Add pull policy if provided
if pull_policy:
cmd.append(f'cluster.pullPolicy="{pull_policy}"')
# Add restart policy if provided
if restart_policy:
cmd.append(f'cluster.restartPolicy="{restart_policy}"')
# Add namespace if provided
if namespace:
cmd.append(f'cluster.namespace="{namespace}"')
# Add service account name if provided
if service_account_name:
cmd.append(f'cluster.service_account_name="{service_account_name}"')
# Add priority class name if provided
if priority_class_name:
cmd.append(f'cluster.priority_class_name="{priority_class_name}"')
# Add volumes if provided (expecting format: "volumeName1:hostPath1:mountPath1,volumeName2:hostPath2:mountPath2")
if volumes:
for idx, volume in enumerate(volumes.split(',')):
vol_name, host_path, mount_path = volume.split(':')
cmd.append(f'+cluster.volumes.{idx}.volumeName="{vol_name}"')
cmd.append(f'+cluster.volumes.{idx}.hostPath="{host_path}"')
cmd.append(f'+cluster.volumes.{idx}.mountPath="{mount_path}"')
# Add persistent volume claims if provided (expecting format: "claimName1:mountPath1,claimName2:mountPath2")
if persistent_volume_claims:
for idx, pvc in enumerate(persistent_volume_claims.split(',')):
claim_name, mount_path = pvc.split(':')
cmd.append(f'+cluster.persistent_volume_claims.{idx}.claimName="{claim_name}"')
cmd.append(f'+cluster.persistent_volume_claims.{idx}.mountPath="{mount_path}"')
if label_selector:
cmd.append(f'+cluster.label_selector={label_selector}')
elif deep_health_check_passed_nodes_only:
cmd.append(f'+cluster.label_selector={DEEP_HEALTH_CHECK_PASSED_ONLY_NODE_AFFINITY_DICT}')
if auto_resume:
# Set max_retry default to 1
if max_retry is None:
max_retry = 1
annotations = {
HYPERPOD_AUTO_RESUME_ANNOTATION_KEY: auto_resume,
HYPERPOD_MAX_RETRY_ANNOTATION_KEY: max_retry,
}
cmd.append(f'+cluster.annotations="{annotations}"')
logger.info(f"override_parameters: {override_parameters}")
if override_parameters:
try:
# Parse the JSON string into a dictionary
override_dict = json.loads(override_parameters)
# Convert the dictionary into key=value pairs
for key, value in override_dict.items():
if isinstance(value, str):
# Ensure strings are properly quoted
cmd.append(f'{key}="{value}"')
else:
cmd.append(f'{key}={value}')
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON format: {e}")
sys.exit(1)
print(f"Final command: {' '.join(cmd)}")
execute_command(cmd, env)
if job_name is not None and config_file is None:
file_to_delete = os.path.join(launcher_config_path, launcher_config_file_name)
if os.path.exists(file_to_delete):
os.remove(file_to_delete)