in ax/models/torch/botorch_modular/surrogate.py [0:0]
def construct(self, training_data: TrainingData, **kwargs: Any) -> None:
"""Constructs the underlying BoTorch ``Model`` using the training data.
Args:
training_data: Training data for the model (for one outcome for
the default `Surrogate`, with the exception of batched
multi-output case, where training data is formatted with just
one X and concatenated Ys).
**kwargs: Optional keyword arguments, expects any of:
- "fidelity_features": Indices of columns in X that represent
fidelity.
"""
if self._constructed_manually:
logger.warning("Reconstructing a manually constructed `Model`.")
if not isinstance(training_data, TrainingData):
raise ValueError( # pragma: no cover
"Base `Surrogate` expects training data for single outcome."
)
input_constructor_kwargs = {**self.model_options, **(kwargs or {})}
self._training_data = training_data
formatted_model_inputs = self.botorch_model_class.construct_inputs(
training_data=self.training_data, **input_constructor_kwargs
)
# TODO: We currently only pass in `covar_module` and `likelihood` if they are
# inputs to the BoTorch model. This interface will need to be expanded to a
# ModelFactory, see D22457664, to accommodate different models in the future.
botorch_model_class_args = inspect.getfullargspec(self.botorch_model_class).args
if "covar_module" in botorch_model_class_args and self.covar_module_class:
# pyre-ignore [45]
formatted_model_inputs["covar_module"] = self.covar_module_class(
**self.covar_module_options
)
if "likelihood" in botorch_model_class_args and self.likelihood_class:
# pyre-ignore [45]
formatted_model_inputs["likelihood"] = self.likelihood_class(
**self.likelihood_options
)
# pyre-ignore [45]
self._model = self.botorch_model_class(**formatted_model_inputs)