def prepare_for_metric()

in mlebench/competitions/icecube-neutrinos-in-deep-ice/grade.py [0:0]


def prepare_for_metric(submission: pd.DataFrame, answers: pd.DataFrame) -> dict:
    # submission
    if set(submission.columns) != {"azimuth", "zenith", "event_id"}:
        raise InvalidSubmissionError(
            "Submission must contain columns 'azimuth','zenith' and 'event_id'"
        )

    if not np.issubdtype(submission["azimuth"].dtype, np.number):
        raise InvalidSubmissionError("Azimuth must be a number")
    if not np.issubdtype(submission["zenith"].dtype, np.number):
        raise InvalidSubmissionError("Zenith must be a number")
    if not np.all(np.isfinite(submission["azimuth"])):
        raise InvalidSubmissionError("Azimuth must not be infinite")
    if not np.all(np.isfinite(submission["zenith"])):
        raise InvalidSubmissionError("Zenith must not be infinite")
    if submission["azimuth"].isnull().any():
        raise InvalidSubmissionError("Azimuth must not be NaN")
    if submission["zenith"].isnull().any():
        raise InvalidSubmissionError("Zenith must not be NaN")

    # answers
    assert set(answers.columns) == {
        "azimuth",
        "zenith",
        "event_id",
    }, "Answers must contain columns 'azimuth','zenith' and 'event_id'"
    assert np.issubdtype(answers["azimuth"].dtype, np.number), "Azimuth must be a number"
    assert np.issubdtype(answers["zenith"].dtype, np.number), "Zenith must be a number"
    assert np.all(np.isfinite(answers["azimuth"])), "Azimuth must not be infinite"
    assert np.all(np.isfinite(answers["zenith"])), "Zenith must not be infinite"
    assert not answers["azimuth"].isnull().any(), "Azimuth must not be NaN"
    assert not answers["zenith"].isnull().any(), "Zenith must not be NaN"

    # both
    if len(submission) != len(answers):
        raise InvalidSubmissionError("Submission and answers must have the same length")
    if set(submission["event_id"]) != set(answers["event_id"]):
        raise InvalidSubmissionError("Submission and answers must have the same event_ids")

    # sort values by id so that the order is correct
    submission = submission.sort_values("event_id")
    answers = answers.sort_values("event_id")

    return {
        "az_true": answers["azimuth"].to_numpy(),
        "zen_true": answers["zenith"].to_numpy(),
        "az_pred": submission["azimuth"].to_numpy(),
        "zen_pred": submission["zenith"].to_numpy(),
    }