in tools/train_net.py [0:0]
def train(cfg):
"""
Train a video model for many epochs on train set and evaluate it on val set.
Args:
cfg (CfgNode): configs. Details can be found in
slowfast/config/defaults.py
"""
# Set up environment.
du.init_distributed_training(cfg)
# Set random seed from configs.
np.random.seed(cfg.RNG_SEED)
torch.manual_seed(cfg.RNG_SEED)
# Setup logging format.
logging.setup_logging(cfg.OUTPUT_DIR)
# Init multigrid.
multigrid = None
if cfg.MULTIGRID.LONG_CYCLE or cfg.MULTIGRID.SHORT_CYCLE:
multigrid = MultigridSchedule()
cfg = multigrid.init_multigrid(cfg)
if cfg.MULTIGRID.LONG_CYCLE:
cfg, _ = multigrid.update_long_cycle(cfg, cur_epoch=0)
# Print config.
logger.info("Train with config:")
logger.info(pprint.pformat(cfg))
# Build the video model and print model statistics.
model = build_model(cfg)
if du.is_master_proc() and cfg.LOG_MODEL_INFO:
misc.log_model_info(model, cfg, use_train_input=True)
# Construct the optimizer.
optimizer = optim.construct_optimizer(model, cfg)
# Create a GradScaler for mixed precision training
scaler = torch.cuda.amp.GradScaler(enabled=cfg.TRAIN.MIXED_PRECISION)
# Load a checkpoint to resume training if applicable.
start_epoch = cu.load_train_checkpoint(
cfg, model, optimizer, scaler if cfg.TRAIN.MIXED_PRECISION else None
)
# Create the video train and val loaders.
train_loader = loader.construct_loader(cfg, "train")
val_loader = loader.construct_loader(cfg, "val")
precise_bn_loader = (
loader.construct_loader(cfg, "train", is_precise_bn=True)
if cfg.BN.USE_PRECISE_STATS
else None
)
# Create meters.
if cfg.DETECTION.ENABLE:
train_meter = AVAMeter(len(train_loader), cfg, mode="train")
val_meter = AVAMeter(len(val_loader), cfg, mode="val")
else:
train_meter = TrainMeter(len(train_loader), cfg)
val_meter = ValMeter(len(val_loader), cfg)
# set up writer for logging to Tensorboard format.
if cfg.TENSORBOARD.ENABLE and du.is_master_proc(
cfg.NUM_GPUS * cfg.NUM_SHARDS
):
writer = tb.TensorboardWriter(cfg)
else:
writer = None
# Perform the training loop.
logger.info("Start epoch: {}".format(start_epoch + 1))
epoch_timer = EpochTimer()
for cur_epoch in range(start_epoch, cfg.SOLVER.MAX_EPOCH):
if cfg.MULTIGRID.LONG_CYCLE:
cfg, changed = multigrid.update_long_cycle(cfg, cur_epoch)
if changed:
(
model,
optimizer,
train_loader,
val_loader,
precise_bn_loader,
train_meter,
val_meter,
) = build_trainer(cfg)
# Load checkpoint.
if cu.has_checkpoint(cfg.OUTPUT_DIR):
last_checkpoint = cu.get_last_checkpoint(cfg.OUTPUT_DIR)
assert "{:05d}.pyth".format(cur_epoch) in last_checkpoint
else:
last_checkpoint = cfg.TRAIN.CHECKPOINT_FILE_PATH
logger.info("Load from {}".format(last_checkpoint))
cu.load_checkpoint(
last_checkpoint, model, cfg.NUM_GPUS > 1, optimizer
)
# Shuffle the dataset.
loader.shuffle_dataset(train_loader, cur_epoch)
# Train for one epoch.
epoch_timer.epoch_tic()
train_epoch(
train_loader,
model,
optimizer,
scaler,
train_meter,
cur_epoch,
cfg,
writer,
)
epoch_timer.epoch_toc()
logger.info(
f"Epoch {cur_epoch} takes {epoch_timer.last_epoch_time():.2f}s. Epochs "
f"from {start_epoch} to {cur_epoch} take "
f"{epoch_timer.avg_epoch_time():.2f}s in average and "
f"{epoch_timer.median_epoch_time():.2f}s in median."
)
logger.info(
f"For epoch {cur_epoch}, each iteraction takes "
f"{epoch_timer.last_epoch_time()/len(train_loader):.2f}s in average. "
f"From epoch {start_epoch} to {cur_epoch}, each iteraction takes "
f"{epoch_timer.avg_epoch_time()/len(train_loader):.2f}s in average."
)
is_checkp_epoch = cu.is_checkpoint_epoch(
cfg,
cur_epoch,
None if multigrid is None else multigrid.schedule,
)
is_eval_epoch = misc.is_eval_epoch(
cfg, cur_epoch, None if multigrid is None else multigrid.schedule
)
# Compute precise BN stats.
if (
(is_checkp_epoch or is_eval_epoch)
and cfg.BN.USE_PRECISE_STATS
and len(get_bn_modules(model)) > 0
):
calculate_and_update_precise_bn(
precise_bn_loader,
model,
min(cfg.BN.NUM_BATCHES_PRECISE, len(precise_bn_loader)),
cfg.NUM_GPUS > 0,
)
_ = misc.aggregate_sub_bn_stats(model)
# Save a checkpoint.
if is_checkp_epoch:
cu.save_checkpoint(
cfg.OUTPUT_DIR,
model,
optimizer,
cur_epoch,
cfg,
scaler if cfg.TRAIN.MIXED_PRECISION else None,
)
# Evaluate the model on validation set.
if is_eval_epoch:
eval_epoch(val_loader, model, val_meter, cur_epoch, cfg, writer)
if writer is not None:
writer.close()