template/v2/dirs/etc/sagemaker-ui/workflows/workflow_client.py (85 lines of code) (raw):
import argparse
from datetime import datetime, timezone
from typing import Optional
import boto3
import requests
JUPYTERLAB_URL = "http://default:8888/jupyterlab/default/"
WORKFLOWS_API_ENDPOINT = "api/sagemaker/workflows"
TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S.%f%z"
def _validate_response(function_name: str, response: requests.Response):
if response.status_code == 200:
return response
else:
raise RuntimeError(f"{function_name} returned {response.status_code}: {str(response.content)}")
def update_local_runner_status(session: requests.Session, status: str, detailed_status: Optional[str] = None, **kwargs):
response = session.post(
url=JUPYTERLAB_URL + WORKFLOWS_API_ENDPOINT + "/update-local-runner-status",
headers={"X-Xsrftoken": session.cookies.get_dict()["_xsrf"]},
json={
"timestamp": datetime.now(timezone.utc).strftime(TIMESTAMP_FORMAT),
"status": status,
"detailed_status": detailed_status,
},
)
return _validate_response("UpdateLocalRunner", response)
def start_local_runner(session: requests.Session, **kwargs):
response = session.post(
url=JUPYTERLAB_URL + WORKFLOWS_API_ENDPOINT + "/start-local-runner",
headers={"X-Xsrftoken": session.cookies.get_dict()["_xsrf"]},
json={},
)
return _validate_response("StartLocalRunner", response)
def stop_local_runner(session: requests.Session, **kwargs):
response = session.post(
url=JUPYTERLAB_URL + WORKFLOWS_API_ENDPOINT + "/stop-local-runner",
headers={"X-Xsrftoken": session.cookies.get_dict()["_xsrf"]},
json={},
)
return _validate_response("StopLocalRunner", response)
def check_blueprint(region: str, domain_id: str, endpoint: str, **kwargs):
DZ_CLIENT = boto3.client("datazone")
# add correct endpoint for gamma env
if endpoint != "":
DZ_CLIENT = boto3.client("datazone", endpoint_url=endpoint)
try:
blueprint_id = DZ_CLIENT.list_environment_blueprints(
managed=True, domainIdentifier=domain_id, name="Workflows"
)["items"][0]["id"]
blueprint_config = DZ_CLIENT.get_environment_blueprint_configuration(
domainIdentifier=domain_id, environmentBlueprintIdentifier=blueprint_id
)
enabled_regions = blueprint_config["enabledRegions"]
print(str(region in enabled_regions))
except:
print("False")
COMMAND_REGISTRY = {
"update-local-runner-status": update_local_runner_status,
"start-local-runner": start_local_runner,
"stop-local-runner": stop_local_runner,
"check-blueprint": check_blueprint,
}
def main():
parser = argparse.ArgumentParser(description="Workflow local runner client")
subparsers = parser.add_subparsers(dest="command", help="Available commands")
update_status_parser = subparsers.add_parser("update-local-runner-status", help="Update status of local runner")
update_status_parser.add_argument("--status", type=str, required=True, help="Status to update")
update_status_parser.add_argument("--detailed-status", type=str, required=False, help="Detailed status text")
start_parser = subparsers.add_parser("start-local-runner", help="Start local runner")
stop_parser = subparsers.add_parser("stop-local-runner", help="Stop local runner")
check_blueprint_parser = subparsers.add_parser("check-blueprint", help="Check Workflows blueprint")
check_blueprint_parser.add_argument(
"--domain-id", type=str, required=True, help="Datazone Domain ID for blueprint check"
)
check_blueprint_parser.add_argument("--region", type=str, required=True, help="Datazone Domain region")
check_blueprint_parser.add_argument(
"--endpoint", type=str, required=True, help="Datazone endpoint for blueprint check"
)
args = parser.parse_args()
# create the request session
session = requests.Session()
# populate XSRF cookie
session.get(JUPYTERLAB_URL)
kwargs = vars(args) | {"session": session}
if args.command in COMMAND_REGISTRY:
COMMAND_REGISTRY[args.command](**kwargs)
else:
parser.print_help()
if __name__ == "__main__":
main()