airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/task.py (78 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from typing import Any import pydantic from .runtime import Runtime from rich.progress import Progress class Task(pydantic.BaseModel): name: str app_id: str project: str = pydantic.Field(default="Default Project") inputs: dict[str, Any] runtime: Runtime ref: str | None = pydantic.Field(default=None) pid: str | None = pydantic.Field(default=None) agent_ref: str | None = pydantic.Field(default=None) workdir: str | None = pydantic.Field(default=None) sr_host: str | None = pydantic.Field(default=None) @pydantic.field_validator("runtime", mode="before") def set_runtime(cls, v): if isinstance(v, dict) and "id" in v: id = v.pop("id") args = v.pop("args", {}) return Runtime.create(id=id, args=args) return v def __str__(self) -> str: return f"Task(\nname={self.name}\napp_id={self.app_id}\ninputs={self.inputs}\nruntime={self.runtime}\nref={self.ref}\nagent_ref={self.agent_ref}\nfile_path={self.sr_host}:{self.workdir}\n)" def launch(self, force=True) -> None: if not force and self.ref is not None: print(f"[Task] Task {self.name} has already launched: ref={self.ref}") return if self.ref is not None: input("[NOTE] Past runs will be overwritten! Hit Enter to continue...") self.ref = None self.agent_ref = None print(f"[Task] Executing {self.name} on {self.runtime}") self.runtime.execute(self) def status(self) -> tuple[str, str]: assert self.ref is not None return self.runtime.status(self) def ls(self) -> list[str]: assert self.ref is not None return self.runtime.ls(self) def upload(self, file: str) -> str: assert self.ref is not None from pathlib import Path return self.runtime.upload(Path(file), self) def download(self, file: str, local_dir: str) -> str: assert self.ref is not None from pathlib import Path Path(local_dir).mkdir(parents=True, exist_ok=True) return self.runtime.download(file, local_dir, self) def download_all(self, local_dir: str) -> list[str]: assert self.ref is not None import os os.makedirs(local_dir, exist_ok=True) fps_task = list[str]() files = self.ls() with Progress() as progress: pbar = progress.add_task(f"Downloading: ...", total=len(files)) for remote_fp in self.ls(): fp = self.runtime.download(remote_fp, local_dir, self) progress.update(pbar, description=f"Downloading: {remote_fp}", advance=1) fps_task.append(fp) progress.update(pbar, description=f"Downloading: DONE", refresh=True) return fps_task def cat(self, file: str) -> bytes: assert self.ref is not None return self.runtime.cat(file, self) def stop(self) -> None: assert self.ref is not None return self.runtime.signal("SIGTERM", self) def context(self, packages: list[str]) -> Any: def decorator(func): def wrapper(*args, **kwargs): from .scripter import scriptize make_script = scriptize(func) return self.runtime.execute_py(packages, make_script(*args, **kwargs), self) return wrapper return decorator