airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/plan.py (158 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import annotations
import json
import time
import os
import pydantic
from rich.progress import Progress
from .runtime import is_terminal_state
from .task import Task
import uuid
from .airavata import AiravataOperator
from .auth import context
class Plan(pydantic.BaseModel):
id: str | None = pydantic.Field(default=None)
tasks: list[Task] = []
@pydantic.field_validator("tasks", mode="before")
def default_tasks(cls, v):
if isinstance(v, list):
return [Task(**task) if isinstance(task, dict) else task for task in v]
return v
def __stage_prepare__(self) -> None:
print("Preparing to launch...")
def __stage_confirm__(self, silent: bool) -> None:
if not silent:
while True:
res = input("Ready to launch. Continue? (Y/n) ")
if res.upper() in ["N"]:
raise Exception("Launch aborted by user.")
elif res.upper() in ["Y", ""]:
break
else:
continue
def __stage_launch_task__(self) -> None:
print("Launching tasks...")
for task in self.tasks:
task.launch()
def __stage_status__(self) -> list:
statuses = []
for task in self.tasks:
statuses.append(task.status())
return statuses
def __stage_stop__(self) -> None:
print("Stopping task(s)...")
for task in self.tasks:
task.stop()
print("Task(s) stopped.")
def __stage_fetch__(self, local_dir: str) -> list[list[str]]:
print("Fetching results...")
fps = list[list[str]]()
for task in self.tasks:
fps.append(task.download_all(local_dir))
print("Results fetched.")
self.save_json(os.path.join(local_dir, "plan.json"))
return fps
def launch(self, silent: bool = True) -> None:
try:
self.__stage_prepare__()
self.__stage_confirm__(silent)
self.__stage_launch_task__()
self.save()
except Exception as e:
print(*e.args, sep="\n")
def status(self) -> None:
statuses = self.__stage_status__()
print(f"Plan {self.id} ({len(self.tasks)} tasks):")
for task, (task_id, status) in zip(self.tasks, statuses):
print(f"* {task.name}: {task_id}: {status}")
def wait_for_completion(self, check_every_n_mins: float = 0.1) -> None:
n = len(self.tasks)
try:
with Progress() as progress:
pbars = [progress.add_task(f"{task.name} ({i+1}/{n}): CHECKING", total=None) for i, task in enumerate(self.tasks)]
while True:
completed = [False] * n
statuses = self.__stage_status__()
for i, (task, (task_id, status), pbar) in enumerate(zip(self.tasks, statuses, pbars)):
completed[i] = is_terminal_state(status)
progress.update(pbar, description=f"{task.name} ({i+1}/{n}): {task_id}: {status}", completed=completed[i], refresh=True)
if all(completed):
break
sleep_time = check_every_n_mins * 60
time.sleep(sleep_time)
print("All tasks completed.")
except KeyboardInterrupt:
print("Interrupted by user.")
def download(self, local_dir: str):
assert os.path.isdir(local_dir)
self.__stage_fetch__(local_dir)
def stop(self) -> None:
self.__stage_stop__()
self.save()
def save_json(self, filename: str) -> None:
with open(filename, "w") as f:
json.dump(self.model_dump(), f, indent=2)
def save(self) -> None:
av = AiravataOperator(context.access_token)
az = av.__airavata_token__(av.access_token, av.default_gateway_id())
assert az.accessToken is not None
assert az.claimsMap is not None
headers = {
'Content-Type': 'application/json',
'Authorization': 'Bearer ' + az.accessToken,
'X-Claims': json.dumps(az.claimsMap)
}
import requests
if self.id is None:
self.id = str(uuid.uuid4())
response = requests.post("https://api.gateway.cybershuttle.org/api/v1/plan", headers=headers, json=self.model_dump())
print(f"Plan saved: {self.id}")
else:
response = requests.put(f"https://api.gateway.cybershuttle.org/api/v1/plan/{self.id}", headers=headers, json=self.model_dump())
print(f"Plan updated: {self.id}")
if response.status_code == 200:
body = response.json()
plan = json.loads(body["data"])
assert plan["id"] == self.id
else:
raise Exception(response)
def load_json(filename: str) -> Plan:
with open(filename, "r") as f:
model = json.load(f)
return Plan(**model)
def load(id: str | None) -> Plan:
assert id is not None
av = AiravataOperator(context.access_token)
az = av.__airavata_token__(av.access_token, av.default_gateway_id())
assert az.accessToken is not None
assert az.claimsMap is not None
headers = {
'Content-Type': 'application/json',
'Authorization': 'Bearer ' + az.accessToken,
'X-Claims': json.dumps(az.claimsMap)
}
import requests
response = requests.get(f"https://api.gateway.cybershuttle.org/api/v1/plan/{id}", headers=headers)
if response.status_code == 200:
body = response.json()
plan = json.loads(body["data"])
return Plan(**plan)
else:
raise Exception(response)
def query() -> list[Plan]:
av = AiravataOperator(context.access_token)
az = av.__airavata_token__(av.access_token, av.default_gateway_id())
assert az.accessToken is not None
assert az.claimsMap is not None
headers = {
'Content-Type': 'application/json',
'Authorization': 'Bearer ' + az.accessToken,
'X-Claims': json.dumps(az.claimsMap)
}
import requests
response = requests.get(f"https://api.gateway.cybershuttle.org/api/v1/plan/user", headers=headers)
if response.status_code == 200:
items: list = response.json()
plans = [json.loads(item["data"]) for item in items]
return [Plan(**plan) for plan in plans]
else:
raise Exception(response)