in domainbed_measures/experiment/io_utils.py [0:0]
def fix_hdh_c2st_divergence_from_sum_to_mean(hdh_or_c2st_divergence,
dataset,
num_env_test,
is_hdh,
per_env=True):
# Makes changes to the originally computed hdh divergence as:
# 2 * max_{h, h'} E_{source} p[h'(x) != h(x)] - E_{target} p[h'(x) != h(x)]
#
# for use in the generalization bounds, where we have a factor of 1/2 and we have
# mean of hdh divergences across different source and train envs when we are in
# the multi-source setting
if per_env == True:
num_env_train = DATASETS_TO_NUM_ENVS.get(dataset) - num_env_test
else:
num_env_train = 1
if is_hdh:
scaling_factor = 0.5
else:
scaling_factor = 1.0
new_hdh_or_c2st_divergence = scaling_factor * hdh_or_c2st_divergence / float(
num_env_train)
return new_hdh_or_c2st_divergence