def eval_corr_gen()

in utils/evaluate.py [0:0]


def eval_corr_gen(trainer, dataset, mode="train"):
    """
    Evaluates a classifier for correlation with generalization using some of the best predictors
    of generalization from each section of "FANTASTIC GENERALIZATION MEASURES AND WHERE TO FIND THEM.
    Also does the usual classifier evaluation
    """
    # measure usual classification (i.e. how well generalizes)
    out_eval_clf = eval_clf(trainer, dataset)

    if mode == "test":
        # only do correlation measure if "train mode"
        return out_eval_clf

    trainer = clone_trainer(trainer)
    logger.info(f"len(dataset)={len(dataset)}")

    # Variance of gradients (for classifier and transformer)
    logger.info("var_grad")
    var_grad = get_var_grad(trainer, dataset)
    logger.info(logger)

    logger.info("d_H_Q_xCz")
    # before freezing the net
    d_H_Q_xCz = get_H_Q_xCz(
        trainer, dataset, "d_H_Q_xCz", conditional="H_Q[X|Z]-H_Q[Y|Z]"
    )

    logger.info("H_Q_xCz")
    # H_Q[X|Z]
    H_Q_xCz = get_H_Q_xCz(trainer, dataset, "H_Q_xCz")

    # H_Q+[X|Z]
    logger.info("d_H_Q+_xCz")
    d_H_Qp_xCz = get_H_Q_xCz(
        trainer,
        dataset,
        "H_Q_xCz",
        Q_zx=partial(
            MLP, hidden_size=2048, n_hidden_layers=trainer.module_.clf.n_hidden_layers
        ),
    )

    # H_Q-[X|Z]
    logger.info("d_H_Q-_xCz")
    d_H_Qm_xCz = get_H_Q_xCz(
        trainer,
        dataset,
        "H_Q_xCz",
        Q_zx=partial(
            MLP, hidden_size=2, n_hidden_layers=trainer.module_.clf.n_hidden_layers
        ),
    )

    # freezes all batchnorm layers by converting them to convolutions
    trainer.module_.eval()
    batchnorms2convs_(trainer.module_)

    # Entropy of the logits
    logger.info("entropy")
    y_pred_proba = trainer.predict_proba(dataset)
    y_pred_ent = scipy.stats.entropy(
        y_pred_proba, axis=1, base=BASE_LOG).mean()

    # Path Norm (for classifier and transformer)
    logger.info("path_norm")
    path_norm = get_path_norm(trainer, dataset)

    # Sharpness magnitude => max (relative) change in weights that cause less than 1 diff in log like
    logger.info("sharp_mag")
    sharp_mag = get_sharp_mag(trainer, dataset)

    return dict(
        y_pred_ent=y_pred_ent,
        path_norm=path_norm,
        var_grad=var_grad,
        sharp_mag=sharp_mag,
        H_Q_xCz=H_Q_xCz,
        d_H_Qp_xCz=d_H_Qp_xCz,
        d_H_Qm_xCz=d_H_Qm_xCz,
        d_H_Q_xCz=d_H_Q_xCz,
        **out_eval_clf,
    )