in domainbed_measures/experiment/io_utils.py [0:0]
def load_results(run_dir, experiment_path):
"""Load results from gen-ood runs.
Args:
run_dir (str): Base directory where we do all the runs and store results
for a given domainbed run
experiment_path (str): Format is _{in/out}_{algorithm}_{dataset}, name of
the particular dataset and run we want to analyze
Returns:
pd.DataFrame: A dataframe object with the results
"""
result_file_regex = os.path.join(run_dir, experiment_path, "slurm_files",
"*/results.jsonl")
result_files = glob.glob(result_file_regex)
all_results = []
for rfidx, rf in enumerate(result_files):
with open(rf, 'r') as file:
for line_idx, line in enumerate(file):
try:
res = json.loads(line)
except:
print(
f"Warning: Invalid line {line_idx} in {rf}. Skipping.."
)
continue
if res['gen_measure_val_ood'] > MAX_VAL:
print(f"Skpping value greater than {MAX_VAL}")
continue
with open(os.path.join(res['path'], 'out.txt')) as f:
test_env_str = f.readlines()[22]
if 'test_envs: ' not in test_env_str:
raise ValueError("Reading wrong line for test envs")
all_test_envs = sorted(
ast.literal_eval(test_env_str.lstrip().rstrip().lstrip(
"test_envs: ")))
lambda_wd = res['metadata_wd'].get('lambda_closeness')
lambda_ood = res['metadata_ood'].get('lambda_closeness')
lambda_wd = fix_lambda_closeness(
lambda_wd, res['dataset'], len(all_test_envs), "per_env"
in res["measure"]) if lambda_wd is not None else -1
lambda_ood = fix_lambda_closeness(
lambda_ood, res['dataset'], len(all_test_envs), "per_env"
in res["measure"]) if lambda_ood is not None else -1
if 'hdh' in res['measure'] or 'c2st' in res['measure']:
res['gen_measure_val_ood'] = fix_hdh_c2st_divergence_from_sum_to_mean(
res['gen_measure_val_ood'], res['dataset'],
len(all_test_envs), "hdh" in res['measure'], "per_env"
in res["measure"])
res['gen_measure_val_wd'] = fix_hdh_c2st_divergence_from_sum_to_mean(
res['gen_measure_val_wd'], res['dataset'],
len(all_test_envs), "hdh" in res['measure'], "per_env"
in res["measure"])
res['all_test_envs'] = all_test_envs
res['lambda_wd'] = lambda_wd
res['lambda_ood'] = lambda_ood
res['path_and_test_envs'] = "%s_%s" % (
res['path'], str(res['all_test_envs']))
all_results.append(res)
print(f"{rfidx}/{len(result_files)} done \r", end="\r")
all_results = pd.DataFrame(all_results)
all_results = all_results.replace([np.inf, -np.inf], np.nan)
all_results = all_results.dropna(axis=0)
print(f"Jobs per measure:")
for m in sorted(list(all_results['measure'].unique())):
print(
"%s: %d/%d" %
(m, len(all_results[all_results['measure'] == m]['path'].unique()),
len(all_results['path'].unique())))
return all_results, sorted(list(all_results['measure'].unique()))