in engine/training_engine.py [0:0]
def __init__(self, opts,
model,
validation_loader,
training_loader,
criterion,
optimizer,
scheduler,
gradient_scalar,
start_epoch: int = 0,
start_iteration: int = 0,
best_metric: float = 0.0,
model_ema=None,
*args, **kwargs) -> None:
super(Trainer, self).__init__()
self.opts = opts
self.model = model
self.model_ema = model_ema
self.criteria = criterion
self.optimizer = optimizer
self.scheduler = scheduler
self.gradient_scalar = gradient_scalar
self.val_loader = validation_loader
self.train_loader = training_loader
self.device = getattr(opts, "dev.device", torch.device("cpu"))
self.start_epoch = start_epoch
self.best_metric = best_metric
self.train_iterations = start_iteration
self.is_master_node = is_master(opts)
self.max_iterations_reached = False
self.max_iterations = getattr(self.opts, "scheduler.max_iterations", DEFAULT_ITERATIONS)
self.use_distributed = getattr(self.opts, "ddp.use_distributed", False)
self.log_freq = getattr(self.opts, "common.log_freq", DEFAULT_LOG_FREQ)
self.accum_freq = getattr(self.opts, "common.accum_freq", 1)
self.accum_after_epoch = getattr(self.opts, "common.accum_after_epoch", 0)
self.mixed_precision_training = getattr(opts, "common.mixed_precision", False)
self.metric_names = getattr(opts, "stats.name", ['loss'])
if isinstance(self.metric_names, str):
self.metric_names = [self.metric_names]
assert isinstance(self.metric_names, list), "Type of metric names should be list. Got: {}".format(
type(self.metric_names))
if 'loss' not in self.metric_names:
self.metric_names.append(self.metric_names)
self.ckpt_metric = getattr(self.opts, "stats.checkpoint_metric", "loss")
assert self.ckpt_metric in self.metric_names, \
"Checkpoint metric should be part of metric names. Metric names: {}, Checkpoint metric: {}".format(
self.metric_names, self.ckpt_metric)
self.ckpt_metric = self.ckpt_metric.lower()
self.tb_log_writter = None
if SummaryWriter is not None and self.is_master_node:
self.setup_log_writer()
if self.is_master_node:
print_summary(opts=self.opts,
model=self.model,
criteria=self.criteria,
optimizer=self.optimizer,
scheduler=self.scheduler)
self.adjust_norm_mom = None
if getattr(opts, "adjust_bn_momentum.enable", True):
from cvnets.layers import AdjustBatchNormMomentum
self.adjust_norm_mom = AdjustBatchNormMomentum(opts=opts)
if self.is_master_node:
logger.log("Batch normalization momentum will be annealed during training.")
print(self.adjust_norm_mom)