airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/runtime.py (203 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import annotations
from .auth import context
import abc
from typing import Any
from pathlib import Path
import pydantic
# from .task import Task
Task = Any
class Runtime(abc.ABC, pydantic.BaseModel):
id: str
args: dict[str, str | int | float] = pydantic.Field(default={})
@abc.abstractmethod
def execute(self, task: Task) -> None: ...
@abc.abstractmethod
def execute_py(self, libraries: list[str], code: str, task: Task) -> None: ...
@abc.abstractmethod
def status(self, task: Task) -> tuple[str, str]: ...
@abc.abstractmethod
def signal(self, signal: str, task: Task) -> None: ...
@abc.abstractmethod
def ls(self, task: Task) -> list[str]: ...
@abc.abstractmethod
def upload(self, file: Path, task: Task) -> str: ...
@abc.abstractmethod
def download(self, file: str, local_dir: str, task: Task) -> str: ...
@abc.abstractmethod
def cat(self, file: str, task: Task) -> bytes: ...
def __str__(self) -> str:
return f"{self.__class__.__name__}(args={self.args})"
@staticmethod
def default():
return Remote.default()
@staticmethod
def create(id: str, args: dict[str, Any]) -> Runtime:
if id == "mock":
return Mock(**args)
elif id == "remote":
return Remote(**args)
else:
raise ValueError(f"Unknown runtime id: {id}")
@staticmethod
def Remote(**kwargs):
return Remote(**kwargs)
@staticmethod
def Local(**kwargs):
return Mock(**kwargs)
class Mock(Runtime):
_state: int = 0
def __init__(self) -> None:
super().__init__(id="mock")
def execute(self, task: Task) -> None:
import uuid
task.agent_ref = str(uuid.uuid4())
task.ref = str(uuid.uuid4())
def execute_py(self, libraries: list[str], code: str, task: Task) -> None:
pass
def status(self, task: Task) -> tuple[str, str]:
import random
self._state += random.randint(0, 5)
if self._state > 10:
return "N/A", "COMPLETED"
return "N/A", "RUNNING"
def signal(self, signal: str, task: Task) -> None:
pass
def ls(self, task: Task) -> list[str]:
return [""]
def upload(self, file: Path, task: Task) -> str:
return ""
def download(self, file: str, local_dir: str, task: Task) -> str:
return ""
def cat(self, file: str, task: Task) -> bytes:
return b""
@staticmethod
def default():
return Mock()
class Remote(Runtime):
def __init__(self, cluster: str, category: str, queue_name: str, node_count: int, cpu_count: int, walltime: int, gpu_count: int = 0, group: str = "Default") -> None:
super().__init__(id="remote", args=dict(
cluster=cluster,
category=category,
queue_name=queue_name,
node_count=node_count,
cpu_count=cpu_count,
gpu_count=gpu_count,
walltime=walltime,
group=group,
))
def execute(self, task: Task) -> None:
assert task.ref is None
assert task.agent_ref is None
assert {"cluster", "group", "queue_name", "node_count", "cpu_count", "gpu_count", "walltime"}.issubset(self.args.keys())
print(f"[Remote] Creating Experiment: name={task.name}")
from .airavata import AiravataOperator
av = AiravataOperator(context.access_token)
try:
launch_state = av.launch_experiment(
experiment_name=task.name,
app_name=task.app_id,
project=task.project,
inputs=task.inputs,
computation_resource_name=str(self.args["cluster"]),
queue_name=str(self.args["queue_name"]),
node_count=int(self.args["node_count"]),
cpu_count=int(self.args["cpu_count"]),
walltime=int(self.args["walltime"]),
group=str(self.args["group"]),
)
task.agent_ref = launch_state.agent_ref
task.pid = launch_state.process_id
task.ref = launch_state.experiment_id
task.workdir = launch_state.experiment_dir
task.sr_host = launch_state.sr_host
print(f"[Remote] Experiment Launched: id={task.ref}")
except Exception as e:
print(f"[Remote] Failed to launch experiment: {e}")
raise e
def execute_py(self, libraries: list[str], code: str, task: Task) -> None:
assert task.ref is not None
assert task.agent_ref is not None
assert task.pid is not None
from .airavata import AiravataOperator
av = AiravataOperator(context.access_token)
result = av.execute_py(task.project, libraries, code, task.agent_ref, task.pid, task.runtime.args)
print(result)
def status(self, task: Task) -> tuple[str, str]:
assert task.ref is not None
assert task.agent_ref is not None
from .airavata import AiravataOperator
av = AiravataOperator(context.access_token)
# prioritize job state, fallback to experiment state
job_id, job_state = av.get_task_status(task.ref)
if not job_state or job_state == "UN_SUBMITTED":
return job_id, av.get_experiment_status(task.ref)
else:
return job_id, job_state
def signal(self, signal: str, task: Task) -> None:
assert task.ref is not None
assert task.agent_ref is not None
from .airavata import AiravataOperator
av = AiravataOperator(context.access_token)
av.stop_experiment(task.ref)
def ls(self, task: Task) -> list[str]:
assert task.ref is not None
assert task.pid is not None
assert task.agent_ref is not None
assert task.sr_host is not None
assert task.workdir is not None
from .airavata import AiravataOperator
av = AiravataOperator(context.access_token)
files = av.list_files(task.pid, task.agent_ref, task.sr_host, task.workdir)
return files
def upload(self, file: Path, task: Task) -> str:
assert task.ref is not None
assert task.pid is not None
assert task.agent_ref is not None
assert task.sr_host is not None
assert task.workdir is not None
from .airavata import AiravataOperator
av = AiravataOperator(context.access_token)
result = av.upload_files(task.pid, task.agent_ref, task.sr_host, [file], task.workdir).pop()
return result
def download(self, file: str, local_dir: str, task: Task) -> str:
assert task.ref is not None
assert task.pid is not None
assert task.agent_ref is not None
assert task.sr_host is not None
assert task.workdir is not None
from .airavata import AiravataOperator
av = AiravataOperator(context.access_token)
result = av.download_file(task.pid, task.agent_ref, task.sr_host, file, task.workdir, local_dir)
return result
def cat(self, file: str, task: Task) -> bytes:
assert task.ref is not None
assert task.pid is not None
assert task.agent_ref is not None
assert task.sr_host is not None
assert task.workdir is not None
from .airavata import AiravataOperator
av = AiravataOperator(context.access_token)
content = av.cat_file(task.pid, task.agent_ref, task.sr_host, file, task.workdir)
return content
@staticmethod
def default():
return list_runtimes(cluster="login.expanse.sdsc.edu", category="gpu").pop()
def list_runtimes(
cluster: str | None = None,
category: str | None = None,
group: str | None = None,
node_count: int | None = None,
cpu_count: int | None = None,
walltime: int | None = None,
) -> list[Runtime]:
from .airavata import AiravataOperator
av = AiravataOperator(context.access_token)
all_runtimes = av.get_available_runtimes()
out_runtimes = []
for r in all_runtimes:
if (cluster in [None, r.args["cluster"]]) and (category in [None, r.args["category"]]) and (group in [None, r.args["group"]]):
r.args["node_count"] = node_count or r.args["node_count"]
r.args["cpu_count"] = cpu_count or r.args["cpu_count"]
r.args["walltime"] = walltime or r.args["walltime"]
out_runtimes.append(r)
return out_runtimes
def is_terminal_state(x):
return x in ["CANCELED", "COMPLETED", "FAILED"]