def __call__()

in domainbed_measures/experiment/experiment.py [0:0]


    def __call__(self,
                 path: str,
                 measure_or_measure_list: list,
                 dataset: str,
                 save_path=None,
                 num_trials=10):
        logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
                            level=logging.INFO,
                            datefmt='%Y-%m-%d %H:%M:%S')
        measures = measure_or_measure_list

        if not isinstance(measure_or_measure_list, list) and (isinstance(
                measure_or_measure_list, str)):
            measures = [measure_or_measure_list]
        elif not isinstance(measure_or_measure_list, list):
            raise ValueError("Unexpected type for measure_or_measure_list")
        np.random.shuffle(measures)

        if not all(['c2st' in m or 'hdh' in m for m in measures]):
            raise NotImplementedError("Variance Experiment only implemented"
                                      "for c2st and hdh based measures.")

        out_results, test_envs = read_jsonl_result(
            os.path.join(path, _OUT_FILE_NAME))

        for test_env_idx in test_envs:
            if not isinstance(test_env_idx, int):
                raise ValueError("Expect an integer test environment id.")

            gen_measure_vals = defaultdict(list)

            for idx, m in enumerate(measures):
                if m not in MeasureRegistry._VALID_MEASURES:
                    raise ValueError("Invalid measure.")

                algorithm, train_loaders, wd_eval_loaders, dirty_ood_eval_loaders, clean_ood_eval_loaders, num_classes = (
                    load_model_and_dataloaders(
                        os.path.join(path, _MODEL_FILE_NAME),
                        self._dirty_ood_split, test_env_idx,
                        **MeasureRegistry._KWARGS[m]["data_args"]))
                MeasureClass = MeasureRegistry()[m]
                measure_class_ood = MeasureClass(
                    algorithm,
                    train_loaders,
                    dirty_ood_eval_loaders,
                    num_classes,
                    **MeasureRegistry._KWARGS[m]["measure_args"])

                for trial in range(num_trials):
                    gen_measure_val_ood, metadata_ood = measure_class_ood.compute(
                        stochastic_fraction_data=0.8)
                    gen_measure_vals[m].append(gen_measure_val_ood)

            self.write_results(gen_measure_vals, save_path)