in src/braket_container.py [0:0]
def get_code_setup_parameters() -> Tuple[str, str, str]:
"""
Returns the code setup parameters:
s3_uri: the S3 location where the code is stored.
entry_point: the entrypoint into the code.
compression_type: the compression used to archive the code (optional)
These values are stored in environment variables, however, we also allow the storing of
these values in the hyperparameters to facilitate testing in local mode.
If the s3_uri or entry_point can not be found, the script will exit with an error.
Returns:
str, str, str: the code setup parameters as described above.
"""
s3_uri = os.getenv('AMZN_BRAKET_SCRIPT_S3_URI')
entry_point = os.getenv('AMZN_BRAKET_SCRIPT_ENTRY_POINT')
compression_type = os.getenv('AMZN_BRAKET_SCRIPT_COMPRESSION_TYPE')
if s3_uri and entry_point:
return s3_uri, entry_point, compression_type
hyperparameters_env = os.getenv('SM_HPS')
if hyperparameters_env:
try:
hyperparameters = json.loads(hyperparameters_env)
if not s3_uri:
s3_uri = hyperparameters.get("AMZN_BRAKET_SCRIPT_S3_URI")
if not entry_point:
entry_point = hyperparameters.get("AMZN_BRAKET_SCRIPT_ENTRY_POINT")
if not compression_type:
compression_type = hyperparameters.get("AMZN_BRAKET_SCRIPT_COMPRESSION_TYPE")
except Exception as e:
log_failure_and_exit("Hyperparameters not specified in env")
if not s3_uri:
log_failure_and_exit("No customer script specified")
if not entry_point:
log_failure_and_exit("No customer entry point specified")
return s3_uri, entry_point, compression_type