in src/sagemaker/workflow/function_step.py [0:0]
def _step(func):
if dependencies == "auto_capture":
raise ValueError("Auto Capture of dependencies is not supported for pipeline steps.")
# avoid circular import
from sagemaker.remote_function.client import RemoteExecutor
@wraps(func)
def wrapper(*args, **kwargs):
# TODO: Move _validate_submit_args function out of RemoteExecutor class
RemoteExecutor._validate_submit_args(func, *args, **kwargs)
for arg in list(args) + list(kwargs.values()):
if isinstance(arg, (Join, JsonGet)):
raise ValueError(f"{type(arg)} is not supported for function arguments.")
depends_on = {}
for arg in list(args) + list(kwargs.values()):
if isinstance(arg, DelayedReturn):
depends_on[id(arg._step)] = arg._step
# setup default values for name, display_name and description if not provided
_name = unique_name_from_base_uuid4(func.__name__) if not name else name
_display_name = (
f"{func.__module__}.{func.__name__}" if not display_name else display_name
)
_description = description
if not _description:
_description = func.__doc__ if func.__doc__ else func.__code__.co_filename
function_step = _FunctionStep(
name=_name,
display_name=_display_name,
description=_description,
retry_policies=retry_policies,
func=func,
func_args=args,
func_kwargs=kwargs,
depends_on=list(depends_on.values()),
dependencies=dependencies,
pre_execution_commands=pre_execution_commands,
pre_execution_script=pre_execution_script,
environment_variables=environment_variables,
image_uri=image_uri,
instance_count=instance_count,
instance_type=instance_type,
job_conda_env=job_conda_env,
job_name_prefix=job_name_prefix,
keep_alive_period_in_seconds=keep_alive_period_in_seconds,
max_retry_attempts=max_retry_attempts,
max_runtime_in_seconds=max_runtime_in_seconds,
role=role,
security_group_ids=security_group_ids,
subnets=subnets,
tags=format_tags(tags),
volume_kms_key=volume_kms_key,
volume_size=volume_size,
encrypt_inter_container_traffic=encrypt_inter_container_traffic,
spark_config=spark_config,
use_spot_instances=use_spot_instances,
max_wait_time_in_seconds=max_wait_time_in_seconds,
)
return _generate_delayed_return(
function_step, type_hint=get_type_hints(func).get("return")
)
return wrapper