in ubteacher/engine/trainer.py [0:0]
def __init__(self, cfg):
"""
Args:
cfg (CfgNode):
Use the custom checkpointer, which loads other backbone models
with matching heuristics.
"""
cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())
data_loader = self.build_train_loader(cfg)
# create an student model
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
# create an teacher model
model_teacher = self.build_model(cfg)
self.model_teacher = model_teacher
# For training, wrap with DDP. But don't need this for inference.
if comm.get_world_size() > 1:
model = DistributedDataParallel(
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
)
TrainerBase.__init__(self)
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
model, data_loader, optimizer
)
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
# Ensemble teacher and student model is for model saving and loading
ensem_ts_model = EnsembleTSModel(model_teacher, model)
self.checkpointer = DetectionTSCheckpointer(
ensem_ts_model,
cfg.OUTPUT_DIR,
optimizer=optimizer,
scheduler=self.scheduler,
)
self.start_iter = 0
self.max_iter = cfg.SOLVER.MAX_ITER
self.cfg = cfg
self.register_hooks(self.build_hooks())