def record()

in src/sagemaker_xgboost_container/prediction_utils.py [0:0]


    def record(self, indices: np.ndarray, predictions: np.ndarray) -> None:
        """Record predictions on a single validation fold in-memory.

        :param indices: indicates for which rows the predictions were made.
        :param predictions: predictions for rows specified in `indices` variable.
        """
        if self.pred_ndim_ is None:
            self.pred_ndim_ = predictions.ndim
        if self.pred_ndim_ != predictions.ndim:
            raise exc.AlgorithmError(f"Expected predictions with ndim={self.pred_ndim_}, got ndim={predictions.ndim}.")

        cv_repeat_idx = self.cv_repeat_counter[indices]
        if np.any(cv_repeat_idx == self.num_cv_round):
            sample_rows = cv_repeat_idx[cv_repeat_idx == self.num_cv_round]
            sample_rows = sample_rows[:EXAMPLE_ROWS_EXCEPTION_COUNT]
            raise exc.AlgorithmError(
                f"More than {self.num_cv_round} repeated predictions for same row were provided. "
                f"Example row indices where this is the case: {sample_rows}."
            )

        if self.classification:
            if predictions.ndim > 1:
                labels = np.argmax(predictions, axis=-1)
                proba = predictions[np.arange(len(labels)), labels]
            else:
                labels = 1 * (predictions > 0.5)
                proba = predictions
            self.y_pred[indices, cv_repeat_idx] = labels
            self.y_prob[indices, cv_repeat_idx] = proba
        else:
            self.y_pred[indices, cv_repeat_idx] = predictions
        self.cv_repeat_counter[indices] += 1