def __call__()

in uimnet/workers/calibrator.py [0:0]


  def __call__(self, calibration_cfg, train_cfg, Algorithm, dataset):

    elapsed_time = time.time()
    self.train_cfg = train_cfg
    self.calibration_cfg = calibration_cfg
    self.Algorithm = Algorithm
    self.dataset = dataset

    self.setup(calibration_cfg)  # Setup modifies cfg. It needs a state on the worker.
    utils.message(calibration_cfg)

    if utils.is_not_distributed_or_is_rank0():
      if not utils.trace_exists('train.done', dir_=train_cfg.output_dir):
        err_msg = f'Training not finished'
        raise RuntimeError(err_msg)
      utils.write_trace('calibration.running', dir_=train_cfg.output_dir)

    self.datanode = datasets.SplitDataNode(
      dataset=dataset,
      transforms=datasets.TRANSFORMS,
      splits_props=train_cfg.dataset.splits_props,
      seed=train_cfg.dataset.seed)
    num_classes = self.datanode.splits['train'].num_classes

    self.algorithm = Algorithm(num_classes=num_classes,
                               arch=train_cfg.algorithm.arch,
                               device=calibration_cfg.experiment.device,
                               use_mixed_precision=train_cfg.algorithm.use_mixed_precision,
                               seed=train_cfg.algorithm.seed,
                               sn=train_cfg.algorithm.sn,
                               sn_coef=train_cfg.algorithm.sn_coef,
                               sn_bn=train_cfg.algorithm.sn_bn
                               )

    self.algorithm.initialize()
    utils.message(self.algorithm)

    adapt_state = train_cfg.experiment.distributed and not utils.is_distributed()
    self.algorithm.load_state(train_cfg.output_dir, map_location=calibration_cfg.experiment.device,
                              adapt_state=adapt_state)


    # Preparing for calibration
    self.datanode.eval()
    self.algorithm.eval()
    eval_loader = self.datanode.get_loader('eval',
                                            batch_size=calibration_cfg.dataset.batch_size,
                                            shuffle=False,
                                            pin_memory=True if 'cuda' in calibration_cfg.experiment.device else False,
                                            num_workers=calibration_cfg.experiment.num_workers)

    collected = collections.defaultdict(list)
    utils.message('Collecting logits and targets')
    with torch.no_grad():
      for i, batch in enumerate(eval_loader):
        batch = utils.apply_fun(functools.partial(utils.to_device, device=calibration_cfg.experiment.device), batch)
        _logits = self.algorithm(batch['x'])
        _y = batch['y']
        if utils.is_distributed():
          _logits = torch.cat(utils.all_gather(_logits), dim=0)
          _y = torch.cat(utils.all_gather(_y), dim=0)
        collected['logits'] += [_logits.cpu()]
        collected['y'] += [_y.cpu()]
        if __DEBUG__ and i > 2:
          break
      collected = dict(collected)
      collected = {k: torch.cat(v, dim=0) for k, v in collected.items()}
      #utils.message(collected['y'].unique())
      #collected = {k: utils.all_cat(v, dim=0) for k, v in collected.items()}
    utils.message(f'logits.shape: {collected["logits"].shape}, y.shape: {collected["y"].shape}')
    self.algorithm.cpu()


    utils.message(f'Temperature before reinitialization {self.algorithm.temperature.tau}')
    utils.message(f'Reinitializing temperature')
    self.algorithm.temperature.reinitialize_temperature()
    utils.message(f"Temperature after reinitialization {self.algorithm.temperature.tau}")
    with torch.no_grad():
      self.algorithm.temperature.tau.fill_(1.5)
    utils.message(f'Tau before calibration: {self.algorithm.temperature.tau}')
    self.algorithm.temperature.tau.requires_grad = True

    tau_initial = self.algorithm.temperature.tau.data.clone()
    utils.message(f'Temperature before calibration {tau_initial}')
    optimizer = torch.optim.LBFGS([self.algorithm.temperature.tau], **self.OPTIM_KWARGS)

    def _closure():
      loss_value = F.cross_entropy(collected['logits'] / self.algorithm.temperature.tau, collected['y'])
      utils.message(f'Calibration loss value {loss_value}')
      loss_value.backward()
      utils.message(f'Temperature gradient after backward: {self.algorithm.temperature.tau.grad}')
      return loss_value

    optimizer.step(_closure)
    self.algorithm.temperature.tau.requires_grad = False
    utils.message(f'Temperature after calibration {self.algorithm.temperature.tau}')

    utils.message('Finalizing calibration')
    # self.algorithm.cuda()
    utils.message('serializing model.')
    if utils.is_not_distributed_or_is_rank0():
      self.algorithm.save_state(train_cfg.output_dir)
    self.finalize(train_cfg)

    return {'data': dict(tau_star=self.algorithm.temperature.tau.detach().cpu(), tau_initial=tau_initial.detach().cpu()),
            'calibration_cfg': calibration_cfg,
            'train_cfg': train_cfg,
            'elapsed_time': time.time() - elapsed_time,
            'status': 'done'}