def _collect_oodomain_records()

in uimnet/evaluation/oodomain.py [0:0]


def _collect_oodomain_records(cfg, all_results):
  results_by_train_cfg = collections.defaultdict(list)
  for result in all_results:
    key = result['train_cfg']
    results_by_train_cfg[key] += [result]
  results_by_train_cfg = dict(results_by_train_cfg)
  all_records = []
  Metrics = [metrics.__dict__[el] for el in cfg.metrics]
  for train_cfg, results in results_by_train_cfg.items():

    ids = dict()
    for i, result in enumerate(results):
      #utils.message(f"{result['eval_cfg'].dataset.partition}")
      ids[result['eval_cfg'].dataset.partition] = i
    #utils.message(ids)
    # Computing in vs easy
    for Metric in Metrics:
      in_results = results[ids['in']]
      # TODO: replace by results[ids['in_eval']]
      # for ood_partition in ['easy', 'difficult']:
      for ood_partition in ['easy']:
        oo_results = results[ids[ood_partition]]
        for measure_name in in_results['test_measurements']:
          metric = Metric(measurement_in_val=in_results['valid_measurements'][measure_name])
          value = metric(in_results['test_measurements'][measure_name],
                        oo_results['test_measurements'][measure_name],
                        in_results['test_tables'],
                        oo_results['test_tables'])

          record = dict(metric=metric.__class__.__name__,
                        value=value,
                        measurement=measure_name,
                        ood_partition=ood_partition,
                        temperature_mode=in_results['eval_cfg'].temperature_mode
                        )
          record.update(utils.flatten_nested_dicts(train_cfg))
          all_records += [record]

  return all_records