def fix_hdh_c2st_divergence_from_sum_to_mean()

in domainbed_measures/experiment/io_utils.py [0:0]


def fix_hdh_c2st_divergence_from_sum_to_mean(hdh_or_c2st_divergence,
                                             dataset,
                                             num_env_test,
                                             is_hdh,
                                             per_env=True):
    # Makes changes to the originally computed hdh divergence as:
    # 2 * max_{h, h'} E_{source} p[h'(x) != h(x)] - E_{target} p[h'(x) != h(x)]
    #
    # for use in the generalization bounds, where we have a factor of 1/2 and we have
    # mean of hdh divergences across different source and train envs when we are in
    # the multi-source setting
    if per_env == True:
        num_env_train = DATASETS_TO_NUM_ENVS.get(dataset) - num_env_test
    else:
        num_env_train = 1
    if is_hdh:
        scaling_factor = 0.5
    else:
        scaling_factor = 1.0
    new_hdh_or_c2st_divergence = scaling_factor * hdh_or_c2st_divergence / float(
        num_env_train)
    return new_hdh_or_c2st_divergence