def __call__()

in uimnet/workers/evaluator2.py [0:0]


  def __call__(self, model_dir, eval_cfg, train_cfg, Algorithm, Measure, partitions):


    self.model_dir = model_dir
    self.eval_cfg = copy.deepcopy(eval_cfg)
    self.train_cfg = copy.deepcopy(train_cfg)
    self.Algorithm = Algorithm
    self.datasets  = datasets

    self.setup(eval_cfg)
    eval_cfg, train_cfg = [guard(el) for el in (eval_cfg, train_cfg)]

    self.trace = f'ood_{Measure.__name__}'

    # Check if training was completed
    path = Path(self.model_dir)
    if not utils.trace_exists('train.done', dir_=str(path)):
      utils.message(f'Train completion tracer missing!')
      return dict(status='missing', records=None)

    utils.message(f'Train completion tracer found.')
    # utils.write_trace('ood_evaluation.running', dir_=str(path))

    ##############
    ## Datasets ##
    ##############

    loaders_kwargs = dict(
      batch_size=train_cfg.dataset.batch_size,
      shuffle=False,
      pin_memory=True if 'cuda' in eval_cfg.experiment.device else False,
      num_workers=eval_cfg.experiment.num_workers)
    loaders, datanodes = get_loaders_datanodes(partitions, train_cfg,
                                               loaders_kwargs=loaders_kwargs,
                                               seed=eval_cfg.dataset.seed)
    for datanode in datanodes.values():
      datanode.eval()

    num_classes = partitions[('train', 'in')].num_classes


    ###############
    ## Algorithm ##
    ###############
    self.algorithm = Algorithm(num_classes=num_classes,
                               arch=train_cfg.algorithm.arch,
                               device=eval_cfg.experiment.device,
                               use_mixed_precision=train_cfg.algorithm.use_mixed_precision,
                               seed=train_cfg.algorithm.seed,
                               sn=train_cfg.algorithm.sn,
                               sn_coef=train_cfg.algorithm.sn_coef,
                               sn_bn=train_cfg.algorithm.sn_bn
                               )


    utils.message(eval_cfg)
    self.algorithm.initialize()
    self.algorithm.load_state(train_cfg.output_dir,
                              map_location=eval_cfg.experiment.device)

    records = []
    self.algorithm.eval()
    with torch.no_grad():
      for temperature_mode in ['initial', 'learned']:

        self.algorithm.set_temperature(temperature_mode)
        measure = Measure(algorithm=self.algorithm)
        measure.estimate(loaders[('train', 'in')])

        measurements = collections.defaultdict(list)
        for (partition, split) in [('eval', 'in'),
                                   ('val', 'in'), ('val', OUT)]:
          key = (partition, split)
          for i, batch in enumerate(loaders[key]):
            x, y = batch['x'].cuda(), batch['y'].cuda()
            _measurement = measure(x)
            if utils.is_distributed():
              ranks = list(range(dist.get_world_size()))

              N = torch.as_tensor(_measurement.size(0)).long().cuda()
              all_N = [torch.zeros_like(N) for _ in ranks]
              dist.all_gather(all_N, tensor=N)

              all_size = [(Ni, ) + _measurement.shape[1:] for Ni in all_N]
              all_measurement = [torch.zeros(size=size).float().cuda() for size in all_size]
              dist.all_gather(all_measurement, _measurement)

              _measurement = torch.cat(all_measurement, dim=0)

            measurements[key] += [_measurement.detach().cpu()]
            if __DEBUG__ and i > 1:
              break

        measurements = {k: torch.cat(l) for k, l in dict(measurements).items()}
        # Evaluating metrics
        for Metric in METRICS:

          metric = Metric(measurements[('eval', 'in')])
          value = metric(measurements[('val', 'in')],
                        measurements[('val', OUT)])

          record = dict(metric=metric.__class__.__name__,
                        measure=Measure.__name__,
                        value=value,
                        temperature_mode=temperature_mode)

          record.update(utils.flatten_nested_dicts(train_cfg))
          records += [record]

    # Saving records
    if utils.is_not_distributed_or_is_rank0():
      save_path = path / f'{Measure.__name__}_results.pkl'
      shutil.rmtree(save_path)
      with open(save_path, 'wb') as fp:
        pickle.dump(records, fp, protocol=pickle.HIGHEST_PROTOCOL)
      utils.message(pd.DataFrame.from_records(utils.apply_fun(utils.to_scalar, records)).round(4))
      utils.write_trace(f'{self.trace}.done', dir_=str(path))

    return records