in pycls/core/trainer.py [0:0]
def train_model():
"""Trains the model."""
# Setup training/testing environment
setup_env()
# Construct the model, ema, loss_fun, and optimizer
model = setup_model()
ema = deepcopy(model)
loss_fun = builders.build_loss_fun().cuda()
optimizer = optim.construct_optimizer(model)
# Load checkpoint or initial weights
start_epoch = 0
if cfg.TRAIN.AUTO_RESUME and cp.has_checkpoint():
file = cp.get_last_checkpoint()
epoch = cp.load_checkpoint(file, model, ema, optimizer)[0]
logger.info("Loaded checkpoint from: {}".format(file))
start_epoch = epoch + 1
elif cfg.TRAIN.WEIGHTS:
train_weights = get_weights_file(cfg.TRAIN.WEIGHTS)
cp.load_checkpoint(train_weights, model, ema)
logger.info("Loaded initial weights from: {}".format(train_weights))
# Create data loaders and meters
train_loader = data_loader.construct_train_loader()
test_loader = data_loader.construct_test_loader()
train_meter = meters.TrainMeter(len(train_loader))
test_meter = meters.TestMeter(len(test_loader))
ema_meter = meters.TestMeter(len(test_loader), "test_ema")
# Create a GradScaler for mixed precision training
scaler = amp.GradScaler(enabled=cfg.TRAIN.MIXED_PRECISION)
# Compute model and loader timings
if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
benchmark.compute_time_full(model, loss_fun, train_loader, test_loader)
# Perform the training loop
logger.info("Start epoch: {}".format(start_epoch + 1))
for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
# Train for one epoch
params = (train_loader, model, ema, loss_fun, optimizer, scaler, train_meter)
train_epoch(*params, cur_epoch)
# Compute precise BN stats
if cfg.BN.USE_PRECISE_STATS:
net.compute_precise_bn_stats(model, train_loader)
net.compute_precise_bn_stats(ema, train_loader)
# Evaluate the model
test_epoch(test_loader, model, test_meter, cur_epoch)
test_epoch(test_loader, ema, ema_meter, cur_epoch)
test_err = test_meter.get_epoch_stats(cur_epoch)["top1_err"]
ema_err = ema_meter.get_epoch_stats(cur_epoch)["top1_err"]
# Save a checkpoint
file = cp.save_checkpoint(model, ema, optimizer, cur_epoch, test_err, ema_err)
logger.info("Wrote checkpoint to: {}".format(file))