def cross_validate_by_trial()

in ax/modelbridge/cross_validation.py [0:0]


def cross_validate_by_trial(model: ModelBridge, trial: int = -1) -> List[CVResult]:
    """Cross validation for model predictions on a particular trial.

    Uses all of the data up until the specified trial to predict each of the
    arms that was launched in that trial. Defaults to the last trial.

    Args:
        model: Fitted model (ModelBridge) to cross validate.
        trial: Trial for which predictions are evaluated.

    Returns:
        A CVResult for each observation in the training data.
    """
    # Get in-design training points
    training_data = [
        obs
        for i, obs in enumerate(model.get_training_data())
        if model.training_in_design[i]
    ]
    all_trials = {
        int(d.features.trial_index)
        for d in training_data
        if d.features.trial_index is not None
    }
    if len(all_trials) < 2:
        raise ValueError(f"Training data has fewer than 2 trials ({all_trials})")
    if trial < 0:
        trial = max(all_trials)
    elif trial not in all_trials:
        raise ValueError(f"Trial {trial} not found in training data")
    # Construct train/test data
    cv_training_data = []
    cv_test_data = []
    cv_test_points = []
    for obs in training_data:
        if obs.features.trial_index is None:
            continue
        elif obs.features.trial_index < trial:
            cv_training_data.append(obs)
        elif obs.features.trial_index == trial:
            cv_test_points.append(obs.features)
            cv_test_data.append(obs)
    # Make the prediction
    cv_test_predictions = model.cross_validate(
        cv_training_data=cv_training_data, cv_test_points=cv_test_points
    )
    # Form CVResult objects
    result = [
        CVResult(observed=obs, predicted=cv_test_predictions[i])
        for i, obs in enumerate(cv_test_data)
    ]
    return result