in domainbed_measures/experiment/experiment.py [0:0]
def __call__(self,
path: str,
measure_or_measure_list: list,
dataset: str,
save_path=None):
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S')
measures = measure_or_measure_list
if not isinstance(measure_or_measure_list, list) and (isinstance(
measure_or_measure_list, str)):
measures = [measure_or_measure_list]
elif not isinstance(measure_or_measure_list, list):
raise ValueError("Unexpected type for measure_or_measure_list")
np.random.shuffle(measures)
out_results, test_envs = read_jsonl_result(
os.path.join(path, _OUT_FILE_NAME))
for test_env_idx in test_envs:
if not isinstance(test_env_idx, int):
raise ValueError("Expect an integer test environment id.")
# Check if loaded and computed results match up
ood_gen_gap, wd_gen_gap, in_domain_perf, ood_out_domain_perf, wd_out_domain_perf = (
self.sanity_check(out_results, test_envs, test_env_idx,
self._dirty_ood_split, MODEL_SELECTION,
path))
for idx, m in enumerate(measures):
if m not in MeasureRegistry._VALID_MEASURES:
raise ValueError("Invalid measure.")
logging.info(
f"Computing measure {m} for {path}, test_env {test_env_idx} -- ({idx + 1}/{len(measures)})"
)
results = super(Experiment,
self).__call__(path, m, dataset, test_env_idx)
results["ood_gen_gap"] = ood_gen_gap
results["wd_gen_gap"] = wd_gen_gap
results["in_domain_perf"] = in_domain_perf
results["ood_out_domain_perf"] = ood_out_domain_perf
results["wd_out_domain_perf"] = wd_out_domain_perf
logging.info(f"Finished measure {m} for {path}")
self.write_results(results, save_path)