modules/agent-framework/airavata-jupyter-magic/airavata_jupyter_magic.py (472 lines of code) (raw):
import base64
import binascii
import json
import os
import time
from argparse import ArgumentParser
from dataclasses import dataclass
from enum import IntEnum
from pathlib import Path
from typing import NamedTuple
import jwt
import requests
from device_auth import DeviceFlowAuthenticator
from IPython.core.getipython import get_ipython
from IPython.core.interactiveshell import ExecutionInfo, ExecutionResult
from IPython.core.magic import register_cell_magic, register_line_magic
from IPython.display import HTML, Image, display
# ========================================================================
# DATA STRUCTURES
class RequestedRuntime:
cluster: str
cpus: int
memory: int
walltime: int
queue: str
group: str
class ProcessState(IntEnum):
CREATED = 0
VALIDATED = 1
STARTED = 2
PRE_PROCESSING = 3
CONFIGURING_WORKSPACE = 4
INPUT_DATA_STAGING = 5
EXECUTING = 6
MONITORING = 7
OUTPUT_DATA_STAGING = 8
POST_PROCESSING = 9
COMPLETED = 10
FAILED = 11
CANCELLING = 12
CANCELED = 13
QUEUED = 14
DEQUEUING = 15
REQUEUED = 16
RuntimeInfo = NamedTuple('RuntimeInfo', [
('agentId', str),
('experimentId', str),
('processId', str),
('cluster', str),
('queue', str),
('cpus', int),
('memory', int),
('walltime', int),
('gateway_id', str),
('group', str),
])
PENDING_STATES = [
ProcessState.CREATED,
ProcessState.VALIDATED,
ProcessState.STARTED,
ProcessState.PRE_PROCESSING,
ProcessState.CONFIGURING_WORKSPACE,
ProcessState.INPUT_DATA_STAGING,
ProcessState.EXECUTING,
ProcessState.QUEUED,
ProcessState.REQUEUED,
]
TERMINAL_STATES = [
ProcessState.DEQUEUING,
ProcessState.CANCELLING,
ProcessState.COMPLETED,
ProcessState.FAILED,
ProcessState.CANCELED,
]
@dataclass
class State:
current_runtime: str # none => local
all_runtimes: dict[str, RuntimeInfo] # user-defined runtime dict
# END OF DATA STRUCTURES
# ========================================================================
# HELPER FUNCTIONS
def get_access_token(envar_name: str = "CS_ACCESS_TOKEN", state_path: str = "/tmp/av.json") -> str | None:
"""
Get access token from environment or file
@param None:
@returns: access token if present, None otherwise
"""
token = os.getenv(envar_name)
if not token:
try:
token = json.load(Path(state_path).open("r")).get("access_token")
except (FileNotFoundError, json.JSONDecodeError):
pass
return token
def is_runtime_ready(agent_id: str) -> bool:
"""
Check if the runtime (i.e., agent job) is ready to receive requests
@param agent_id: the agent id
@returns: True if ready, False otherwise
"""
url = f"{api_base_url}/api/v1/agent/{agent_id}"
res = requests.get(url)
code = res.status_code
if code == 202:
data: dict = res.json()
return bool(data.get("agentUp", None) or False)
else:
print(f"[{code}] Runtime status check failed: {res.text}")
return False
def get_process_state(experiment_id: str, headers: dict) -> tuple[str, ProcessState]:
"""
Get process state by experiment id
@param experiment_id: the experiment id
@param headers: the headers
@returns: process id and state
"""
url = f"{api_base_url}/api/v1/exp/{experiment_id}/process"
pid, pstate = "", ProcessState.QUEUED
while not pid:
res = requests.get(url, headers=headers)
code = res.status_code
if code == 200:
data: dict = res.json()
pid = data.get("processId")
pstates = data.get("processState")
if pstates and len(pstates):
pstate = ProcessState(pstates[0].get("state"))
else:
time.sleep(5)
return pid, pstate
def generate_headers(access_token: str, gateway_id: str) -> dict:
"""
Generate headers for the request
@param access_token: the access token
@param gateway_id: the gateway id
@returns: the headers
"""
decode = jwt.decode(access_token, options={"verify_signature": False})
user_id = decode['preferred_username']
claimsMap = {
"userName": user_id,
"gatewayID": gateway_id
}
return {
'Content-Type': 'application/json',
'Authorization': 'Bearer ' + access_token,
'X-Claims': json.dumps(claimsMap)
}
def submit_agent_job(
rt_name: str,
access_token: str,
app_name: str,
cluster: str,
cpus: int,
memory: int,
walltime: int,
queue: str,
group: str,
gateway_id: str = 'default',
) -> None:
"""
Submit an agent job to the given runtime
@param rt_name: the runtime name
@param access_token: the access token
@param app_name: the application name
@param cluster: the cluster
@param cpus: the number of cpus
@param memory: the memory
@param walltime: the walltime
@param queue: the queue
@param group: the group
@param gateway_id: the gateway id
@returns: None
"""
# URL to which the POST request will be sent
url = api_base_url + '/api/v1/exp/launch'
# Data to be sent in the POST request
data = {
'experimentName': app_name,
'remoteCluster': cluster,
'cpuCount': cpus,
'nodeCount': 1,
'memory': memory,
'wallTime': walltime,
'queue': queue,
'group': group,
}
# Convert the data to JSON format
json_data = json.dumps(data)
# Send the POST request
headers = generate_headers(access_token, gateway_id)
res = requests.post(url, headers=headers, data=json_data)
code = res.status_code
# Check if the request was successful
if code == 200:
obj = res.json()
pid, pstate = get_process_state(obj['experimentId'], headers=headers)
rt = RuntimeInfo(
agentId=obj['agentId'],
experimentId=obj['experimentId'],
processId=pid,
cluster=cluster,
queue=queue,
cpus=cpus,
memory=memory,
walltime=walltime,
gateway_id=gateway_id,
group=group,
)
state.all_runtimes[rt_name] = rt
print(f'Requested runtime={rt_name}. state={pstate.value}')
else:
print(f'[{code}] Failed to request runtime={rt_name}. error={res.text}')
def wait_until_runtime_ready(rt_name: str):
"""
Block execution until the runtime is ready.
@param rt_name: the runtime name
@returns: None when ready
"""
rt = state.all_runtimes.get(rt_name, None)
if rt is None:
return print(f"Runtime {rt_name} not found.")
if rt_name == "local":
return
if not is_runtime_ready(rt.agentId):
print(f"Waiting for runtime={rt_name} to be ready...")
time.sleep(5)
while not is_runtime_ready(rt.agentId):
time.sleep(5)
else:
print(f"Runtime={rt_name} is ready!")
return True
def stop_agent_job(access_token: str, runtime_name: str, runtime: RuntimeInfo):
"""
Stop the agent job on the given runtime.
@param access_token: the access token
@param runtime_name: the runtime name
@param runtime: the runtime info
@returns: None
"""
url = api_base_url + '/api/v1/exp/terminate/' + runtime.experimentId
decode = jwt.decode(access_token, options={"verify_signature": False})
user_id = decode['preferred_username']
claimsMap = {
"userName": user_id,
"gatewayID": runtime.gateway_id
}
# Headers
headers = {
'Content-Type': 'application/json',
'Authorization': 'Bearer ' + access_token,
'X-Claims': json.dumps(claimsMap)
}
# Send the POST request
res = requests.get(url, headers=headers)
# Check if the request was successful
if res.status_code == 200:
data = res.json()
print(f"Terminated runtime={runtime_name}. state={data}")
state.all_runtimes.pop(runtime_name, None)
else:
print(
f'[{res.status_code}] Failed to terminate runtime={runtime_name}: error={res.text}')
def run_on_runtime(rt_name: str, cell: str, store_history=False, silent=False, shell_futures=True, cell_id=None):
info = ExecutionInfo(cell, store_history, silent, shell_futures, cell_id)
excResult = ExecutionResult(info)
rt = state.all_runtimes.get(rt_name, None)
if rt is None:
excResult.error_in_exec = Exception(f"Runtime {rt_name} not found.")
return excResult
url = api_base_url + '/api/v1/agent/executejupyterrequest'
data = {
"sessionId": "session1",
"keepAlive": True,
"code": cell,
"agentId": rt.agentId
}
json_data = json.dumps(data)
response = requests.post(
url, headers={'Content-Type': 'application/json'}, data=json_data)
execution_resp = response.json()
execution_id = execution_resp.get("executionId")
if not execution_id:
excResult.error_in_exec = Exception("Failed to start cell execution")
return excResult
error = execution_resp.get("error")
if error:
excResult.error_in_exec = Exception(
"Cell execution failed. Error: " + error)
return excResult
while True:
url = api_base_url + "/api/v1/agent/executejupyterresponse/" + execution_id
response = requests.get(url, headers={'Accept': 'application/json'})
json_response = response.json()
if json_response.get('available'):
break
time.sleep(1)
result_str = json_response.get('responseString')
try:
result = json.loads(result_str)
except json.JSONDecodeError as e:
excResult.error_in_exec = Exception(
f"Failed to decode response from runtime={rt_name}: {e.msg}")
return excResult
if 'outputs' in result:
for output in result['outputs']:
output_type = output.get('output_type')
if output_type == 'display_data':
data_obj = output.get('data', {})
if 'image/png' in data_obj:
image_data = data_obj['image/png']
try:
image_bytes = base64.b64decode(image_data)
display(Image(data=image_bytes, format='png'))
except binascii.Error as e:
excResult.error_in_exec = Exception(
f"Failed to decode image data: {e}")
return excResult
elif output_type == 'stream':
stream_name = output.get('name', 'stdout')
stream_text = output.get('text', '').strip()
if stream_name == 'stderr':
error_html = f"""
<div style="
color: #a71d5d;
background-color: #fdd;
border: 1px solid #a71d5d;
padding: 5px;
border-radius: 5px;
font-family: Consolas, 'Courier New', monospace;
white-space: pre-wrap;
">
{stream_text}
</div>
"""
display(HTML(error_html))
excResult.error_in_exec = Exception(stream_text)
return excResult
else:
print(stream_text)
elif output_type == 'error':
ename = output.get('ename', 'Error')
evalue = output.get('evalue', '')
traceback = output.get('traceback', [])
error_html = f"""
<div style="
color: #a71d5d;
background-color: #fdd;
border: 1px solid #a71d5d;
padding: 5px;
border-radius: 5px;
font-family: Consolas, 'Courier New', monospace;
">
<pre><strong>{ename}: {evalue}</strong>
"""
for line in traceback:
error_html += f"{line}\n"
error_html += "</pre></div>"
display(HTML(error_html))
excResult.error_in_exec = Exception(f"{ename}: {evalue}")
return excResult
elif output_type == 'execute_result':
data_obj = output.get('data', {})
if 'text/plain' in data_obj:
print(data_obj['text/plain'])
else:
if 'result' in result:
print(result['result'])
elif 'error' in result:
print(result['error']['ename'])
print(result['error']['evalue'])
print(result['error']['traceback'])
elif 'display' in result:
data_obj = result['display'].get('data', {})
if 'image/png' in data_obj:
image_data = data_obj['image/png']
try:
image_bytes = base64.b64decode(image_data)
display(Image(data=image_bytes, format='png'))
except binascii.Error as e:
excResult.error_in_exec = Exception(
f"Failed to decode image data: {e}")
return excResult
else:
# Mark as failed execution if no recognized output format is found
error_html = """
<div style="
color: #a71d5d;
background-color: #fdd;
border: 1px solid #a71d5d;
padding: 5px;
border-radius: 5px;
font-family: Consolas, 'Courier New', monospace;
">
<strong>Error:</strong> Execution failed with unrecognized output format from remote runtime.
<pre>{}</pre>
</div>
""".format(result_str)
display(HTML(error_html))
excResult.error_in_exec = Exception(
"Execution failed with unrecognized output format from remote runtime.")
return excResult
return excResult
def push_remote(local_path: str, remot_rt: str, remot_path: str) -> None:
"""
Push a local file to a remote runtime
@param local_path: the local file path
@param remot_rt: the remote runtime name
@param remot_path: the remote file path
@returns: None
"""
if not state.all_runtimes.get(remot_rt, None):
return print(MSG_NOT_INITIALIZED)
# validate paths
if not remot_path or not local_path:
return print("Please provide paths for both source and target")
# upload file
print(f"Pushing local:{local_path} to remote:{remot_path}")
url = f"{file_server_url}/upload/live/{state.all_runtimes[state.current_runtime].processId}/{remot_path}"
with open(local_path, "rb") as file:
files = {"file": file}
response = requests.post(url, files=files)
print(
f"[{response.status_code}] Uploaded local:{local_path} to remote:{remot_path}")
def pull_remote(local_path: str, remot_rt: str, remot_path: str) -> None:
"""
Pull a remote file to a local runtime
@param local_path: the local file path
@param remot_rt: the remote runtime name
@param remot_path: the remote file path
@returns: None
"""
if not state.all_runtimes.get(remot_rt, None):
return print(MSG_NOT_INITIALIZED)
# validate paths
if not remot_path or not local_path:
return print("Please provide paths for both source and target")
# download file
print(f"Pulling remote:{remot_path} to local:{local_path}")
url = f"{file_server_url}/download/live/{state.all_runtimes[state.current_runtime].processId}/{remot_path}"
response = requests.get(url)
with open(local_path, "wb") as file:
file.write(response.content)
print(
f"[{response.status_code}] Downloaded remote:{remot_path} to local:{local_path}")
# END OF HELPER FUNCTIONS
# ========================================================================
# MAGIC FUNCTIONS
@register_cell_magic
def run_on(line: str, cell: str):
"""
Run the cell on the given runtime
"""
assert ipython is not None
cell_runtime = line.strip()
orig_runtime = state.current_runtime
try:
if cell_runtime in ["local", *state.all_runtimes]:
state.current_runtime = cell_runtime
ipython.run_cell(cell, silent=True)
else:
raise Exception(f"Runtime {cell_runtime} not found.")
finally:
state.current_runtime = orig_runtime
@register_line_magic
def switch_runtime(line: str):
"""
Switch the active runtime
"""
cell_runtime = line.strip()
try:
if cell_runtime not in ["local", *state.all_runtimes]:
raise Exception(f"Runtime {cell_runtime} not found.")
except Exception as e:
raise Exception(
f"Could not switch to runtime={cell_runtime}. error={e}")
else:
state.current_runtime = cell_runtime
print(f"Switched to runtime={cell_runtime}.")
@register_line_magic
def authenticate(line: str):
"""
Authenticate to access high-performance runtimes
"""
try:
authenticator = DeviceFlowAuthenticator()
authenticator.login()
except ValueError as e:
print(f"Configuration error: {e}")
@register_line_magic
def request_runtime(line: str):
"""
Request a runtime with given capabilities
"""
access_token = get_access_token()
assert access_token is not None
[rt_name, *cmd_args] = line.strip().split()
# validate runtime name
if rt_name == "local":
return print(f"Runtime={rt_name} already exists!")
rt = state.all_runtimes.get(rt_name, None)
if rt is not None:
status = is_runtime_ready(rt.agentId)
if status:
return print(f"Runtime={rt_name} already exists!")
headers = generate_headers(access_token, rt.gateway_id)
_, pstate = get_process_state(rt.experimentId, headers)
if pstate in PENDING_STATES:
return print(f"Runtime={rt_name} is in state={pstate}. Please wait, or run '%stop_runtime {rt_name}' to stop it.")
if pstate in TERMINAL_STATES:
state.all_runtimes.pop(rt_name, None)
# parse cli args
p = ArgumentParser(
prog="request_runtime",
description="Request a runtime with given capabilities",
)
p.add_argument("--cluster", type=str, help="cluster", required=True)
p.add_argument("--cpus", type=int, help="CPU cores", required=True)
p.add_argument("--memory", type=int, help="memory (MB)", required=True)
p.add_argument("--walltime", type=int, help="time (mins)", required=True)
p.add_argument("--queue", type=str, help="resource queue", required=True)
p.add_argument("--group", type=str, help="resource group", required=True)
args = p.parse_args(cmd_args, namespace=RequestedRuntime())
submit_agent_job(
rt_name=rt_name,
access_token=access_token,
app_name='CS_Agent',
cluster=args.cluster,
cpus=args.cpus,
memory=args.memory,
walltime=args.walltime,
queue=args.queue,
group=args.group,
)
@register_line_magic
def stat_runtime(line: str):
"""
Show the status of the runtime
"""
access_token = get_access_token()
assert access_token is not None
runtime_name = line.strip()
if runtime_name in ["local", None]:
return print("Runtime=local is always available")
rt = state.all_runtimes.get(runtime_name, None)
if rt is None:
return print(f"Runtime {runtime_name} not found.")
status = is_runtime_ready(rt.agentId)
if status:
print(f"Runtime {runtime_name} is ready!")
else:
print(f"Runtime {runtime_name} is still preparing. Please wait")
@register_line_magic
def stop_runtime(runtime_name: str):
"""
Stop the runtime
"""
access_token = get_access_token()
assert access_token is not None
rt = state.all_runtimes.get(runtime_name, None)
if rt is None:
return print(f"Runtime {runtime_name} not found.")
stop_agent_job(access_token, runtime_name, rt)
@register_line_magic
def copy_data(line: str):
"""
Copy data between runtimes
"""
parts = line.strip().split()
args = {}
for part in parts:
if "=" in part:
k, v = part.split("=", 1)
args[k] = v
source = args.get("source")
target = args.get("target")
if not source or not target:
return print("Usage: %copy_data source=<runtime>:<path> target=<runtime>:<path>")
source_runtime, source_path = source.split(":")
target_runtime, target_path = target.split(":")
print(
f"Copying from {source_runtime}:{source_path} to {target_runtime}:{target_path}")
if source_runtime == "local":
push_remote(source_path, target_runtime, target_path)
elif target_runtime == "local":
pull_remote(target_path, source_runtime, source_path)
else:
print("remote-to-remote copy is not supported yet")
# END OF MAGIC FUNCTIONS
# ========================================================================
# AUTORUN
ipython = get_ipython()
if ipython is None:
raise RuntimeError("airavata_jupyter_magic requires an ipython session")
api_base_url = "https://api.gateway.cybershuttle.org"
file_server_url = "http://3.142.234.94:8050"
MSG_NOT_INITIALIZED = r"Runtime not found. Please run %request_runtime name=<name> cluster=<cluster> cpu=<cpu> memory=<memory mb> queue=<queue> walltime=<walltime minutes> group=<group> to request one."
state = State(current_runtime="local", all_runtimes={})
orig_run_cell = ipython.run_cell
def cell_has_magic(raw_cell: str) -> bool:
lines = raw_cell.strip().splitlines()
magics = (r"%switch_runtime", r"%%run_on", r"%authenticate",
r"%request_runtime", r"%stop_runtime", r"%stat_runtime", r"%copy_data")
return any(line.strip().startswith(magics) for line in lines)
def run_cell(raw_cell, store_history=False, silent=False, shell_futures=True, cell_id=None):
rt = state.current_runtime
if rt == "local" or cell_has_magic(raw_cell):
return orig_run_cell(raw_cell, store_history, silent, shell_futures, cell_id)
else:
wait_until_runtime_ready(rt)
return run_on_runtime(rt, raw_cell, store_history, silent, shell_futures, cell_id)
ipython.run_cell = run_cell
print(r"""
Loaded airavata_jupyter_magic
(current runtime = local)
%authenticate -- Authenticate to access high-performance runtimes.
%request_runtime <rt> [args] -- Request a runtime named <rt> with configuration <args>. Call multiple times to request multiple runtimes.
%stop_runtime <rt> -- Stop runtime <rt> when no longer needed.
%switch_runtime <rt> -- Switch active runtime to <rt>. All subsequent executions will use this runtime.
%%run_on <rt> -- Force a cell to always execute on <rt>, regardless of the active runtime.
%copy_data <r1:file1> <r2:file2> -- Copy <file1> in <r1> to <file2> in <r2>.
""")
# END OF AUTORUN
# ========================================================================