def __call__()

in domainbed_measures/experiment/experiment.py [0:0]


    def __call__(self,
                 path: str,
                 measure_or_measure_list: list,
                 dataset: str,
                 save_path=None):
        logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
                            level=logging.INFO,
                            datefmt='%Y-%m-%d %H:%M:%S')
        measures = measure_or_measure_list

        if not isinstance(measure_or_measure_list, list) and (isinstance(
                measure_or_measure_list, str)):
            measures = [measure_or_measure_list]
        elif not isinstance(measure_or_measure_list, list):
            raise ValueError("Unexpected type for measure_or_measure_list")
        np.random.shuffle(measures)

        out_results, test_envs = read_jsonl_result(
            os.path.join(path, _OUT_FILE_NAME))

        for test_env_idx in test_envs:
            if not isinstance(test_env_idx, int):
                raise ValueError("Expect an integer test environment id.")

            # Check if loaded and computed results match up
            ood_gen_gap, wd_gen_gap, in_domain_perf, ood_out_domain_perf, wd_out_domain_perf = (
                self.sanity_check(out_results, test_envs, test_env_idx,
                                  self._dirty_ood_split, MODEL_SELECTION,
                                  path))
            for idx, m in enumerate(measures):
                if m not in MeasureRegistry._VALID_MEASURES:
                    raise ValueError("Invalid measure.")

                logging.info(
                    f"Computing measure {m} for {path}, test_env {test_env_idx} -- ({idx + 1}/{len(measures)})"
                )
                results = super(Experiment,
                                self).__call__(path, m, dataset, test_env_idx)
                results["ood_gen_gap"] = ood_gen_gap
                results["wd_gen_gap"] = wd_gen_gap
                results["in_domain_perf"] = in_domain_perf
                results["ood_out_domain_perf"] = ood_out_domain_perf
                results["wd_out_domain_perf"] = wd_out_domain_perf

                logging.info(f"Finished measure {m} for {path}")
                self.write_results(results, save_path)