in domainbed_measures/experiment/experiment.py [0:0]
def __call__(self, path, measure, dataset, test_env_idx):
results = []
######## Compute the generalization measure #####################
# Need to reload the data since the config of the dataloader also depends on which
# measure we use.
algorithm, train_loaders, wd_eval_loaders, dirty_ood_eval_loaders, clean_ood_eval_loaders, num_classes = (
load_model_and_dataloaders(
os.path.join(path, _MODEL_FILE_NAME), self._dirty_ood_split,
test_env_idx, **MeasureRegistry._KWARGS[measure]["data_args"]))
logging.info("Computing WD generalization gap")
MeasureClass = MeasureRegistry()[measure]
# Optional file to store temporary results and runs at for caching
if 'fisher' in measure:
if "MNIST" in dataset:
logging.info(
f"Increasing number of examples for MNIST to {_MNIST_NUM_EXAMPLES_FISHER}."
)
MeasureRegistry._KWARGS[measure]["measure_args"][
"max_num_examples"] = _MNIST_NUM_EXAMPLES_FISHER
measure_class_wd = MeasureClass(
algorithm,
train_loaders,
wd_eval_loaders,
num_classes,
**MeasureRegistry._KWARGS[measure]["measure_args"])
gen_measure_val_wd, metadata_wd = measure_class_wd.compute()
# We only need to compute the generalization measure for out of distribution
# if the measure uses out of distribution data, if not we need not compute
if MeasureRegistry._KWARGS[measure]["measure_args"].get(
'use_eval_data') == True:
logging.info("Computing OOD generalization gap")
measure_class_ood = MeasureClass(
algorithm,
train_loaders,
dirty_ood_eval_loaders,
num_classes,
**MeasureRegistry._KWARGS[measure]["measure_args"])
gen_measure_val_ood, metadata_ood = measure_class_ood.compute()
else:
gen_measure_val_ood, metadata_ood = gen_measure_val_wd, metadata_wd
return {
"gen_measure_val_wd": float(gen_measure_val_wd),
"gen_measure_val_ood": float(gen_measure_val_ood),
"metadata_wd": metadata_wd,
"metadata_ood": metadata_ood,
"measure": measure,
"dataset": dataset,
"path": path,
"test_env": test_env_idx,
}