in ax/modelbridge/cross_validation.py [0:0]
def cross_validate_by_trial(model: ModelBridge, trial: int = -1) -> List[CVResult]:
"""Cross validation for model predictions on a particular trial.
Uses all of the data up until the specified trial to predict each of the
arms that was launched in that trial. Defaults to the last trial.
Args:
model: Fitted model (ModelBridge) to cross validate.
trial: Trial for which predictions are evaluated.
Returns:
A CVResult for each observation in the training data.
"""
# Get in-design training points
training_data = [
obs
for i, obs in enumerate(model.get_training_data())
if model.training_in_design[i]
]
all_trials = {
int(d.features.trial_index)
for d in training_data
if d.features.trial_index is not None
}
if len(all_trials) < 2:
raise ValueError(f"Training data has fewer than 2 trials ({all_trials})")
if trial < 0:
trial = max(all_trials)
elif trial not in all_trials:
raise ValueError(f"Trial {trial} not found in training data")
# Construct train/test data
cv_training_data = []
cv_test_data = []
cv_test_points = []
for obs in training_data:
if obs.features.trial_index is None:
continue
elif obs.features.trial_index < trial:
cv_training_data.append(obs)
elif obs.features.trial_index == trial:
cv_test_points.append(obs.features)
cv_test_data.append(obs)
# Make the prediction
cv_test_predictions = model.cross_validate(
cv_training_data=cv_training_data, cv_test_points=cv_test_points
)
# Form CVResult objects
result = [
CVResult(observed=obs, predicted=cv_test_predictions[i])
for i, obs in enumerate(cv_test_data)
]
return result