in domainbed_measures/experiment/experiment.py [0:0]
def __call__(self,
path: str,
measure_or_measure_list: list,
dataset: str,
save_path=None,
num_trials=10):
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S')
measures = measure_or_measure_list
if not isinstance(measure_or_measure_list, list) and (isinstance(
measure_or_measure_list, str)):
measures = [measure_or_measure_list]
elif not isinstance(measure_or_measure_list, list):
raise ValueError("Unexpected type for measure_or_measure_list")
np.random.shuffle(measures)
if not all(['c2st' in m or 'hdh' in m for m in measures]):
raise NotImplementedError("Variance Experiment only implemented"
"for c2st and hdh based measures.")
out_results, test_envs = read_jsonl_result(
os.path.join(path, _OUT_FILE_NAME))
for test_env_idx in test_envs:
if not isinstance(test_env_idx, int):
raise ValueError("Expect an integer test environment id.")
gen_measure_vals = defaultdict(list)
for idx, m in enumerate(measures):
if m not in MeasureRegistry._VALID_MEASURES:
raise ValueError("Invalid measure.")
algorithm, train_loaders, wd_eval_loaders, dirty_ood_eval_loaders, clean_ood_eval_loaders, num_classes = (
load_model_and_dataloaders(
os.path.join(path, _MODEL_FILE_NAME),
self._dirty_ood_split, test_env_idx,
**MeasureRegistry._KWARGS[m]["data_args"]))
MeasureClass = MeasureRegistry()[m]
measure_class_ood = MeasureClass(
algorithm,
train_loaders,
dirty_ood_eval_loaders,
num_classes,
**MeasureRegistry._KWARGS[m]["measure_args"])
for trial in range(num_trials):
gen_measure_val_ood, metadata_ood = measure_class_ood.compute(
stochastic_fraction_data=0.8)
gen_measure_vals[m].append(gen_measure_val_ood)
self.write_results(gen_measure_vals, save_path)