in ax/modelbridge/generation_strategy.py [0:0]
def _fit_current_model(self, data: Data) -> None:
"""Instantiate the current model with all available data."""
# If last generator run's index matches the current step, extract
# model state from last generator run and pass it to the model
# being instantiated in this function.
lgr = self.last_generator_run
# NOTE: This will not be easily compatible with `GenerationNode`;
# will likely need to find last generator run per model. Not a problem
# for now though as GS only allows `GenerationStep`-s for now.
# Potential solution: store generator runs on `GenerationStep`-s and
# split them per-model there.
model_state_on_lgr = {}
if (
lgr is not None
and lgr._generation_step_index == self._curr.index
and lgr._model_state_after_gen
and self.model
):
# TODO[drfreund]: Consider moving this to `GenerationStep` or
# `GenerationNode`.
model_state_on_lgr = _extract_model_state_after_gen(
generator_run=lgr,
model_class=not_none(self.model).model.__class__,
)
if not data.df.empty:
trial_indices_in_data = sorted(data.df["trial_index"].unique())
logger.debug(f"Fitting model with data for trials: {trial_indices_in_data}")
self._curr.fit(experiment=self.experiment, data=data, **model_state_on_lgr)
self._model = self._curr.model_spec.fitted_model