in utils/evaluate.py [0:0]
def eval_corr_gen(trainer, dataset, mode="train"):
"""
Evaluates a classifier for correlation with generalization using some of the best predictors
of generalization from each section of "FANTASTIC GENERALIZATION MEASURES AND WHERE TO FIND THEM.
Also does the usual classifier evaluation
"""
# measure usual classification (i.e. how well generalizes)
out_eval_clf = eval_clf(trainer, dataset)
if mode == "test":
# only do correlation measure if "train mode"
return out_eval_clf
trainer = clone_trainer(trainer)
logger.info(f"len(dataset)={len(dataset)}")
# Variance of gradients (for classifier and transformer)
logger.info("var_grad")
var_grad = get_var_grad(trainer, dataset)
logger.info(logger)
logger.info("d_H_Q_xCz")
# before freezing the net
d_H_Q_xCz = get_H_Q_xCz(
trainer, dataset, "d_H_Q_xCz", conditional="H_Q[X|Z]-H_Q[Y|Z]"
)
logger.info("H_Q_xCz")
# H_Q[X|Z]
H_Q_xCz = get_H_Q_xCz(trainer, dataset, "H_Q_xCz")
# H_Q+[X|Z]
logger.info("d_H_Q+_xCz")
d_H_Qp_xCz = get_H_Q_xCz(
trainer,
dataset,
"H_Q_xCz",
Q_zx=partial(
MLP, hidden_size=2048, n_hidden_layers=trainer.module_.clf.n_hidden_layers
),
)
# H_Q-[X|Z]
logger.info("d_H_Q-_xCz")
d_H_Qm_xCz = get_H_Q_xCz(
trainer,
dataset,
"H_Q_xCz",
Q_zx=partial(
MLP, hidden_size=2, n_hidden_layers=trainer.module_.clf.n_hidden_layers
),
)
# freezes all batchnorm layers by converting them to convolutions
trainer.module_.eval()
batchnorms2convs_(trainer.module_)
# Entropy of the logits
logger.info("entropy")
y_pred_proba = trainer.predict_proba(dataset)
y_pred_ent = scipy.stats.entropy(
y_pred_proba, axis=1, base=BASE_LOG).mean()
# Path Norm (for classifier and transformer)
logger.info("path_norm")
path_norm = get_path_norm(trainer, dataset)
# Sharpness magnitude => max (relative) change in weights that cause less than 1 diff in log like
logger.info("sharp_mag")
sharp_mag = get_sharp_mag(trainer, dataset)
return dict(
y_pred_ent=y_pred_ent,
path_norm=path_norm,
var_grad=var_grad,
sharp_mag=sharp_mag,
H_Q_xCz=H_Q_xCz,
d_H_Qp_xCz=d_H_Qp_xCz,
d_H_Qm_xCz=d_H_Qm_xCz,
d_H_Q_xCz=d_H_Q_xCz,
**out_eval_clf,
)