in src/sagemaker/remote_function/client.py [0:0]
def _remote(func):
job_settings = _JobSettings(
dependencies=dependencies,
pre_execution_commands=pre_execution_commands,
pre_execution_script=pre_execution_script,
environment_variables=environment_variables,
image_uri=image_uri,
include_local_workdir=include_local_workdir,
custom_file_filter=custom_file_filter,
instance_count=instance_count,
instance_type=instance_type,
job_conda_env=job_conda_env,
job_name_prefix=job_name_prefix,
keep_alive_period_in_seconds=keep_alive_period_in_seconds,
max_retry_attempts=max_retry_attempts,
max_runtime_in_seconds=max_runtime_in_seconds,
role=role,
s3_kms_key=s3_kms_key,
s3_root_uri=s3_root_uri,
sagemaker_session=sagemaker_session,
security_group_ids=security_group_ids,
subnets=subnets,
tags=tags,
volume_kms_key=volume_kms_key,
volume_size=volume_size,
encrypt_inter_container_traffic=encrypt_inter_container_traffic,
spark_config=spark_config,
use_spot_instances=use_spot_instances,
max_wait_time_in_seconds=max_wait_time_in_seconds,
disable_output_compression=disable_output_compression,
use_torchrun=use_torchrun,
use_mpirun=use_mpirun,
nproc_per_node=nproc_per_node,
)
@functools.wraps(func)
def wrapper(*args, **kwargs):
if instance_count > 1 and not (
(spark_config is not None and not use_torchrun and not use_mpirun)
or (spark_config is None and use_torchrun and not use_mpirun)
or (spark_config is None and not use_torchrun and use_mpirun)
):
raise ValueError(
"Remote function do not support training on multi instances "
+ "without spark_config or use_torchrun or use_mpirun. "
+ "Please provide instance_count = 1"
)
RemoteExecutor._validate_submit_args(func, *args, **kwargs)
job = _Job.start(job_settings, func, args, kwargs)
try:
job.wait()
except UnexpectedStatusException as usex:
if usex.actual_status == "Failed":
try:
exception = serialization.deserialize_exception_from_s3(
sagemaker_session=job_settings.sagemaker_session,
s3_uri=s3_path_join(
job_settings.s3_root_uri, job.job_name, EXCEPTION_FOLDER
),
hmac_key=job.hmac_key,
)
except ServiceError as serr:
chained_e = serr.__cause__
if (
isinstance(chained_e, ClientError)
and chained_e.response["Error"]["Code"] # pylint: disable=no-member
== "404"
and chained_e.response["Error"]["Message"] # pylint: disable=no-member
== "Not Found"
):
describe_result = job.describe()
if (
"FailureReason" in describe_result
and describe_result["FailureReason"]
and "RuntimeEnvironmentError: " in describe_result["FailureReason"]
):
failure_msg = describe_result["FailureReason"].replace(
"RuntimeEnvironmentError: ", ""
)
raise RuntimeEnvironmentError(failure_msg)
raise RemoteFunctionError(
"Failed to execute remote function. "
+ "Check corresponding job for details."
)
raise serr
raise exception
raise TimeoutError(
"Job for remote function timed out before reaching a termination status."
)
if job.describe()["TrainingJobStatus"] == "Completed":
return serialization.deserialize_obj_from_s3(
sagemaker_session=job_settings.sagemaker_session,
s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, RESULTS_FOLDER),
hmac_key=job.hmac_key,
)
if job.describe()["TrainingJobStatus"] == "Stopped":
raise RemoteFunctionError("Job for remote function has been aborted.")
return None
wrapper.job_settings = job_settings
wrapper.wrapped_func = func
return wrapper