in uimnet/workers/trainer.py [0:0]
def __call__(self, cfg, Algorithm, dataset):
elapsed_time = time.time()
# First store an immutable copy of the arguments. Immutability is
# surprisingly a hard property to attribute in python.
self._cfg = copy.deepcopy(cfg)
self.Algorithm = Algorithm
self.dataset = dataset
## The keywords above will be accessed in n the checkpoint callback, in
## order to serialize the instantiation of this class and its arguments.
self.setup(cfg) # Setup modifies cfg. It needs a state on the worker.
self.cfg = cfg
if utils.is_not_distributed_or_is_rank0():
utils.write_trace('train.running', dir_=cfg.output_dir)
utils.message(cfg)
self.datanode = datasets.SplitDataNode(
dataset=dataset,
transforms=datasets.TRANSFORMS,
splits_props=cfg.dataset.splits_props,
seed=cfg.dataset.seed)
num_classes = self.datanode.splits['train'].num_classes
self.algorithm = Algorithm(num_classes=num_classes,
arch=cfg.algorithm.arch,
device=cfg.experiment.device,
use_mixed_precision=cfg.algorithm.use_mixed_precision,
seed=cfg.algorithm.seed,
sn=cfg.algorithm.sn,
sn_coef=cfg.algorithm.sn_coef,
sn_bn=cfg.algorithm.sn_bn)
self.algorithm.initialize(self.datanode.dataset)
utils.message(self.algorithm)
if utils.is_distributed():
self.prediction_metrics = metrics.FusedPredictionMetrics()
else:
self.prediction_metrics = metrics.PredictionMetrics()
self.maybe_load_checkpoint(cfg)
utils.maybe_synchronize()
utils.message('Starting mainloop.')
for epoch in range(self.current_epoch + 1, cfg.experiment.num_epochs + 1):
utils.message(f'Starting epoch {epoch}')
self.current_epoch = epoch
self.train_epoch(cfg)
is_last_epoch = epoch == cfg.experiment.num_epochs
if epoch % cfg.experiment.evaluate_every == 0 or is_last_epoch:
self.evaluate(cfg)
if epoch % cfg.experiment.checkpoint_every == 0 or is_last_epoch:
self.save(cfg)
utils.maybe_synchronize()
utils.message('Training complete. Finalizing...')
self.finalize(cfg)
return {'data': self.records,
'cfg': cfg,
'elapsed_time': time.time() - elapsed_time,
'status': 'done'}