ax/metrics/torchx.py (43 lines of code) (raw):

# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from typing import Any, cast import pandas as pd from ax.core import Trial from ax.core.base_trial import BaseTrial from ax.core.data import Data from ax.core.metric import Metric from ax.runners.torchx import TORCHX_TRACKER_BASE from ax.utils.common.logger import get_logger from ax.utils.common.typeutils import not_none logger = get_logger(__name__) try: from torchx.runtime.tracking import FsspecResultTracker class TorchXMetric(Metric): """ Fetches AppMetric (the observation returned by the trial job/app) via the ``torchx.tracking`` module. Assumes that the app used the tracker in the following manner: .. code-block:: python tracker = torchx.runtime.tracking.FsspecResultTracker(tracker_base) tracker[str(trial_index)] = {metric_name: value} # -- or -- tracker[str(trial_index)] = {"metric_name/mean": mean_value, "metric_name/sem": sem_value} """ def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> Data: tracker_base = trial.run_metadata[TORCHX_TRACKER_BASE] tracker = FsspecResultTracker(tracker_base) res = tracker[trial.index] if self.name in res: mean = res[self.name] sem = None else: mean = res.get(f"{self.name}/mean") sem = res.get(f"{self.name}/sem") if mean is None and sem is None: raise KeyError( f"Observation for `{self.name}` not found in tracker at base " f"`{tracker_base}`. Ensure that the trial job is writing the " "results at the same tracker base." ) df_dict = { "arm_name": not_none(cast(Trial, trial).arm).name, "trial_index": trial.index, "metric_name": self.name, "mean": mean, "sem": sem, } return Data(df=pd.DataFrame.from_records([df_dict])) except ImportError: logger.warning( "torchx package not found. If you would like to use TorchXMetric, please " "install torchx." ) pass