in ax/models/torch/botorch_modular/list_surrogate.py [0:0]
def construct(self, training_data: TrainingData, **kwargs: Any) -> None:
"""Constructs the underlying BoTorch ``Model`` using the training data.
Args:
training_data: List of ``TrainingData`` for the submodels of
``ModelListGP``. Each training data is for one outcome,
and the order of outcomes should match the order of metrics
in ``metric_names`` argument.
**kwargs: Keyword arguments, accepts:
- ``metric_names`` (required): Names of metrics, in the same order
as training data (so if training data is ``[tr_A, tr_B]``, the
metrics are ``["A" and "B"]``). These are used to match training data
with correct submodels of ``ModelListGP``,
- ``fidelity_features``: Indices of columns in X that represent
fidelity,
- ``task_features``: Indices of columns in X that represent tasks.
"""
metric_names = kwargs.get(Keys.METRIC_NAMES)
fidelity_features = kwargs.get(Keys.FIDELITY_FEATURES, [])
task_features = kwargs.get(Keys.TASK_FEATURES, [])
if metric_names is None:
raise ValueError("Metric names are required.")
self._training_data = training_data
self._training_data_per_outcome = {
metric_name: TrainingData.from_block_design(X=X, Y=Y, Yvar=Yvar)
for metric_name, X, Y, Yvar in zip(
metric_names,
training_data.Xs,
training_data.Ys,
# `TrainingData.Yvars` can be none, in which case each per-outcome
# training data should have null `Yvar`.
training_data.Yvars or [None] * len(metric_names),
)
}
submodels = []
for m in metric_names:
model_cls = self.botorch_submodel_class_per_outcome.get(
m, self.botorch_submodel_class
)
if not model_cls:
raise ValueError(f"No model class specified for outcome {m}.")
if m not in self.training_data_per_outcome: # pragma: no cover
logger.info(f"Metric {m} not in training data.")
continue
# NOTE: here we do a shallow copy of `self.submodel_options`, to
# protect from accidental modification of shared options. As it is
# a shallow copy, it does not protect the objects in the dictionary,
# just the dictionary itself.
submodel_options = {
**self.submodel_options,
**self.submodel_options_per_outcome.get(m, {}),
}
formatted_model_inputs = model_cls.construct_inputs(
training_data=self.training_data_per_outcome[m],
fidelity_features=fidelity_features,
task_features=task_features,
**submodel_options,
)
# pyre-ignore[45]: Py raises informative error if model is abstract.
submodels.append(model_cls(**formatted_model_inputs))
self._model = ModelListGP(*submodels)