def __init__()

in engine/training_engine.py [0:0]


    def __init__(self, opts,
                 model,
                 validation_loader,
                 training_loader,
                 criterion,
                 optimizer,
                 scheduler,
                 gradient_scalar,
                 start_epoch: int = 0,
                 start_iteration: int = 0,
                 best_metric: float = 0.0,
                 model_ema=None,
                 *args, **kwargs) -> None:
        super(Trainer, self).__init__()

        self.opts = opts

        self.model = model
        self.model_ema = model_ema
        self.criteria = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.gradient_scalar = gradient_scalar

        self.val_loader = validation_loader
        self.train_loader = training_loader

        self.device = getattr(opts, "dev.device", torch.device("cpu"))

        self.start_epoch = start_epoch
        self.best_metric = best_metric
        self.train_iterations = start_iteration

        self.is_master_node = is_master(opts)
        self.max_iterations_reached = False
        self.max_iterations = getattr(self.opts, "scheduler.max_iterations", DEFAULT_ITERATIONS)
        self.use_distributed = getattr(self.opts, "ddp.use_distributed", False)
        self.log_freq = getattr(self.opts, "common.log_freq", DEFAULT_LOG_FREQ)
        self.accum_freq = getattr(self.opts, "common.accum_freq", 1)
        self.accum_after_epoch = getattr(self.opts, "common.accum_after_epoch", 0)

        self.mixed_precision_training = getattr(opts, "common.mixed_precision", False)
        self.metric_names = getattr(opts, "stats.name", ['loss'])
        if isinstance(self.metric_names, str):
            self.metric_names = [self.metric_names]
        assert isinstance(self.metric_names, list), "Type of metric names should be list. Got: {}".format(
            type(self.metric_names))

        if 'loss' not in self.metric_names:
            self.metric_names.append(self.metric_names)

        self.ckpt_metric = getattr(self.opts, "stats.checkpoint_metric", "loss")
        assert self.ckpt_metric in self.metric_names, \
            "Checkpoint metric should be part of metric names. Metric names: {}, Checkpoint metric: {}".format(
                self.metric_names, self.ckpt_metric)
        self.ckpt_metric = self.ckpt_metric.lower()

        self.tb_log_writter = None
        if SummaryWriter is not None and self.is_master_node:
            self.setup_log_writer()

        if self.is_master_node:
            print_summary(opts=self.opts,
                          model=self.model,
                          criteria=self.criteria,
                          optimizer=self.optimizer,
                          scheduler=self.scheduler)

        self.adjust_norm_mom = None
        if getattr(opts, "adjust_bn_momentum.enable", True):
            from cvnets.layers import AdjustBatchNormMomentum
            self.adjust_norm_mom = AdjustBatchNormMomentum(opts=opts)
            if self.is_master_node:
                logger.log("Batch normalization momentum will be annealed during training.")
                print(self.adjust_norm_mom)