def __call__()

in uimnet/workers/trainer.py [0:0]


  def __call__(self, cfg, Algorithm, dataset):
    elapsed_time = time.time()

    # First store an immutable copy of the arguments. Immutability is
    # surprisingly a hard property to attribute in python.
    self._cfg = copy.deepcopy(cfg)
    self.Algorithm = Algorithm
    self.dataset = dataset
    ## The keywords above will be accessed in n the checkpoint callback, in
    ## order to serialize the instantiation of this class and its arguments.

    self.setup(cfg)  # Setup modifies cfg. It needs a state on the worker.
    self.cfg = cfg
    if utils.is_not_distributed_or_is_rank0():
      utils.write_trace('train.running', dir_=cfg.output_dir)

    utils.message(cfg)

    self.datanode = datasets.SplitDataNode(
      dataset=dataset,
      transforms=datasets.TRANSFORMS,
      splits_props=cfg.dataset.splits_props,
      seed=cfg.dataset.seed)

    num_classes = self.datanode.splits['train'].num_classes
    self.algorithm = Algorithm(num_classes=num_classes,
                               arch=cfg.algorithm.arch,
                               device=cfg.experiment.device,
                               use_mixed_precision=cfg.algorithm.use_mixed_precision,
                               seed=cfg.algorithm.seed,
                               sn=cfg.algorithm.sn,
                               sn_coef=cfg.algorithm.sn_coef,
                               sn_bn=cfg.algorithm.sn_bn)

    self.algorithm.initialize(self.datanode.dataset)
    utils.message(self.algorithm)

    if utils.is_distributed():
      self.prediction_metrics = metrics.FusedPredictionMetrics()
    else:
      self.prediction_metrics = metrics.PredictionMetrics()

    self.maybe_load_checkpoint(cfg)
    utils.maybe_synchronize()

    utils.message('Starting mainloop.')

    for epoch in range(self.current_epoch + 1, cfg.experiment.num_epochs + 1):
      utils.message(f'Starting epoch {epoch}')
      self.current_epoch = epoch
      self.train_epoch(cfg)

      is_last_epoch = epoch == cfg.experiment.num_epochs
      if epoch % cfg.experiment.evaluate_every == 0 or is_last_epoch:
        self.evaluate(cfg)

      if epoch % cfg.experiment.checkpoint_every == 0 or is_last_epoch:
        self.save(cfg)

      utils.maybe_synchronize()

    utils.message('Training complete. Finalizing...')
    self.finalize(cfg)

    return {'data': self.records,
            'cfg': cfg,
            'elapsed_time': time.time() - elapsed_time,
            'status': 'done'}