in auto-installer/lambda/main.py [0:0]
def run_commands(domain_id: str, user_profile_name: str):
logger.info(f"[{domain_id}/{user_profile_name}] Generating presigned URL")
# (This will only work for IAM-authenticated Studio domains)
presigned_resp = smclient.create_presigned_domain_url(
DomainId=domain_id,
UserProfileName=user_profile_name,
)
sagemaker_login_url = presigned_resp["AuthorizedUrl"]
# Login URL like https://d-....studio.{AWSRegion}.sagemaker.aws/auth?token=...
# API relative to https://d-....studio.{AWSRegion}.sagemaker.aws/jupyter/default
api_base_url = sagemaker_login_url.partition("?")[0].rpartition("/")[0] + "/jupyter/default"
# Need to make our requests via a session so cookies/etc persist:
reqsess = requests.Session()
logger.info(f"[{domain_id}/{user_profile_name}] Logging in")
login_resp = reqsess.get(sagemaker_login_url)
# If JupyterServer app only just started up, it may not be ready yet: In which case we need to wait for
# it to start. Here we'll use the same 2sec /app polling logic as implemented by the SMStudio front-end
# at the time of writing:
if "_xsrf" not in reqsess.cookies:
logger.info(f"[{domain_id}/{user_profile_name}] Waiting for JupyterServer start-up...")
app_status = "Unknown"
base_url = sagemaker_login_url.partition("?")[0].rpartition("/")[0]
while app_status not in {"InService", "Terminated"}:
time.sleep(2)
app_status = reqsess.get(f"{base_url}/app?appType=JupyterServer&appName=default").text
logger.debug(f"Got app_status {app_status}")
if app_status == "InService":
logger.info(f"[{domain_id}/{user_profile_name}] JupyterServer app ready")
ready_resp = reqsess.get(api_base_url)
else:
raise ValueError(f"JupyterServer app in unusable status '{app_status}'")
logger.info(f"[{domain_id}/{user_profile_name}] Creating terminal")
terminal_resp = reqsess.post(
f"{api_base_url}/api/terminals",
# The XSRF token is required on any state-changing request types e.g. POST/DELETE/etc, but seems
# that it can be put either in header or query string.
# For more information on this pattern, you can see e.g:
# https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html
params={ "_xsrf": reqsess.cookies["_xsrf"] },
)
terminal = terminal_resp.json()
terminal_name = terminal["name"] # Typically e.g. '1'.
# Actually using a terminal (or notebook kernel) is done via websocket channels
ws_base_url = "wss://" + api_base_url.partition("://")[2] + "/terminals/websocket"
cookies = reqsess.cookies.get_dict()
logger.info(f"[{domain_id}/{user_profile_name}] Connecting to:\n{ws_base_url}/{terminal_name}")
ws = websocket.create_connection(
f"{ws_base_url}/{terminal_name}",
cookie="; ".join(["%s=%s" %(i, j) for i, j in cookies.items()]),
)
try:
logger.info(f"[{domain_id}/{user_profile_name}] Waiting for setup message")
setup = None
while setup is not None:
res = json.loads(ws.recv())
if res[0] == "setup":
setup = res[1] # Just get {} in all my tests
# Send commands one by one, waiting for each to complete and re-show prompt:
prompt_exp = re.compile(PROMPT_REGEX, re.MULTILINE)
for ix, c in enumerate(COMMAND_SCRIPT):
ws.send(json.dumps(["stdin", c + "\n"]))
# Assuming echo is on, stdin messages will be echoed to stdout anyway so no need to log
while True:
res = json.loads(ws.recv())
# res[0] is the stream so will be e.g. 'stdout', 'stderr'
# res[1] is the content
logger.info(f"[{domain_id}/{user_profile_name}] {res[0]}: {res[1]}")
# You may want to apply some more RegExs here to log a little less verbosely, or actively
# check for failure/success of your particular script.
if res[0] == "stdout" and prompt_exp.search(res[1]):
break
# No need to push too hard in recv()ing the tiniest possible chunks!
time.sleep(0.1)
logger.info(f"[{domain_id}/{user_profile_name}] Complete")
finally:
ws.close()