agents/registry.py (78 lines of code) (raw):

from dataclasses import dataclass from pathlib import Path from typing import Optional import yaml from agents.utils import parse_env_var_values from mlebench.utils import get_logger logger = get_logger(__name__) @dataclass(frozen=True) class Agent: id: str name: str agents_dir: Path start: Path dockerfile: Path kwargs: dict env_vars: dict privileged: bool = False kwargs_type: Optional[str] = None def __post_init__(self): assert isinstance(self.start, Path), "Agent start script must be a pathlib.Path object." assert isinstance(self.dockerfile, Path), "Agent dockerfile must be a pathlib.Path object." assert isinstance(self.kwargs, dict), "Agent kwargs must be a dictionary." assert isinstance(self.privileged, bool), "Agent privileged must be a boolean." if self.kwargs_type is not None: assert isinstance(self.kwargs_type, str), "Agent kwargs_type must be a string." else: # i.e., self.kwargs_type is None assert self.kwargs == {}, "Agent kwargs_type must be set if kwargs are provided." assert isinstance(self.env_vars, dict), "Agent env_vars must be a dictionary." assert self.start.exists(), f"start script {self.start} does not exist." assert self.dockerfile.exists(), f"dockerfile {self.dockerfile} does not exist." @staticmethod def from_dict(data: dict) -> "Agent": agents_dir = Path(data["agents_dir"]) try: return Agent( id=data["id"], name=data["name"], agents_dir=agents_dir, start=agents_dir / data["start"], dockerfile=agents_dir / data["dockerfile"], kwargs=data.get("kwargs", {}), kwargs_type=data.get("kwargs_type", None), env_vars=data.get("env_vars", {}), privileged=data.get("privileged", False), ) except KeyError as e: raise ValueError(f"Missing key {e} in agent config!") class Registry: def get_agents_dir(self) -> Path: """Retrieves the agents directory within the registry.""" return Path(__file__).parent def get_agent(self, agent_id: str) -> Agent: """Fetch the agent from the registry.""" agents_dir = self.get_agents_dir() for fpath in agents_dir.glob("**/config.yaml"): with open(fpath, "r") as f: contents = yaml.safe_load(f) if agent_id not in contents: continue logger.debug(f"Fetching {fpath}") kwargs = contents[agent_id].get("kwargs", {}) kwargs_type = contents[agent_id].get("kwargs_type", None) env_vars = contents[agent_id].get("env_vars", {}) privileged = contents[agent_id].get("privileged", False) # env vars can be used both in kwargs and env_vars kwargs = parse_env_var_values(kwargs) env_vars = parse_env_var_values(env_vars) return Agent.from_dict( { **contents[agent_id], "id": agent_id, "name": fpath.parent.name, "agents_dir": agents_dir, "kwargs": kwargs, "kwargs_type": kwargs_type, "env_vars": env_vars, "privileged": privileged, } ) raise ValueError(f"Agent with id {agent_id} not found") registry = Registry()