in tools/train_net.py [0:0]
def train(cfg):
"""
Train a video model for many epochs on train set and evaluate it on val set.
Args:
cfg (CfgNode): configs. Details can be found in
slowfast/config/defaults.py
"""
# Set up environment.
du.init_distributed_training(cfg)
# Set random seed from configs.
np.random.seed(cfg.RNG_SEED)
torch.manual_seed(cfg.RNG_SEED)
# Setup logging format.
logging.setup_logging(cfg.OUTPUT_DIR)
# Init multigrid.
multigrid = None
if cfg.MULTIGRID.LONG_CYCLE or cfg.MULTIGRID.SHORT_CYCLE:
multigrid = MultigridSchedule()
cfg = multigrid.init_multigrid(cfg)
if cfg.MULTIGRID.LONG_CYCLE:
cfg, _ = multigrid.update_long_cycle(cfg, cur_epoch=0)
# Print config.
logger.info("Train with config:")
logger.info(pprint.pformat(cfg))
# Build the video model and print model statistics.
model = build_model(cfg)
if du.is_master_proc() and cfg.LOG_MODEL_INFO:
misc.log_model_info(model, cfg, use_train_input=True)
# Construct the optimizer.
optimizer = optim.construct_optimizer(model, cfg)
# Mixed Precision Training Scaler
if cfg.SOLVER.USE_MIXED_PRECISION:
loss_scaler = NativeScaler()
else:
loss_scaler = None
# Load a checkpoint to resume training if applicable.
start_epoch = cu.load_train_checkpoint(
cfg, model, optimizer, loss_scaler=loss_scaler)
# Create the video train and val loaders.
train_loader = loader.construct_loader(cfg, "train")
val_loader = loader.construct_loader(cfg, "val")
precise_bn_loader = (
loader.construct_loader(cfg, "train", is_precise_bn=True)
if cfg.BN.USE_PRECISE_STATS
else None
)
# Create meters.
if cfg.TRAIN.DATASET == 'Epickitchens':
train_meter = EPICTrainMeter(len(train_loader), cfg)
val_meter = EPICValMeter(len(val_loader), cfg)
else:
train_meter = TrainMeter(len(train_loader), cfg)
val_meter = ValMeter(len(val_loader), cfg)
# set up writer for logging to Tensorboard format.
if cfg.TENSORBOARD.ENABLE and du.is_master_proc(
cfg.NUM_GPUS * cfg.NUM_SHARDS
):
writer = tb.TensorboardWriter(cfg)
else:
writer = None
# Perform the training loop.
logger.info("Start epoch: {}".format(start_epoch + 1))
mixup_fn = None
mixup_active = cfg.MIXUP.MIXUP_ALPHA > 0 or cfg.MIXUP.CUTMIX_ALPHA > 0 or cfg.MIXUP.CUTMIX_MINMAX is not None
if mixup_active:
mixup_fn = Mixup(
mixup_alpha=cfg.MIXUP.MIXUP_ALPHA,
cutmix_alpha=cfg.MIXUP.CUTMIX_ALPHA,
cutmix_minmax=cfg.MIXUP.CUTMIX_MINMAX,
prob=cfg.MIXUP.MIXUP_PROB,
switch_prob=cfg.MIXUP.MIXUP_SWITCH_PROB,
mode=cfg.MIXUP.MIXUP_MODE,
label_smoothing=cfg.SOLVER.SMOOTHING,
num_classes=cfg.MODEL.NUM_CLASSES
)
# Explicitly declare reduction to mean.
if cfg.MIXUP.MIXUP_ALPHA > 0.:
# smoothing is handled with mixup label transform
loss_fun = losses.get_loss_func("soft_target_cross_entropy")()
elif cfg.SOLVER.SMOOTHING > 0.0:
loss_fun = losses.get_loss_func("label_smoothing_cross_entropy")(
smoothing=cfg.SOLVER.SMOOTHING)
else:
loss_fun = losses.get_loss_func(cfg.MODEL.LOSS_FUNC)(reduction="mean")
for cur_epoch in range(start_epoch, cfg.SOLVER.MAX_EPOCH):
if cfg.MULTIGRID.LONG_CYCLE:
cfg, changed = multigrid.update_long_cycle(cfg, cur_epoch)
if changed:
(
model,
optimizer,
train_loader,
val_loader,
precise_bn_loader,
train_meter,
val_meter,
) = build_trainer(cfg)
# Load checkpoint.
if cu.has_checkpoint(cfg.OUTPUT_DIR):
last_checkpoint = cu.get_last_checkpoint(cfg.OUTPUT_DIR)
assert "{:05d}.pyth".format(cur_epoch) in last_checkpoint
else:
last_checkpoint = cfg.TRAIN.CHECKPOINT_FILE_PATH
logger.info("Load from {}".format(last_checkpoint))
cu.load_checkpoint(
last_checkpoint, model, cfg.NUM_GPUS > 1, optimizer
)
# Shuffle the dataset.
loader.shuffle_dataset(train_loader, cur_epoch)
# Train for one epoch.
train_epoch(
train_loader, model, optimizer, train_meter, cur_epoch, cfg, writer,
loss_scaler=loss_scaler, loss_fun=loss_fun, mixup_fn=mixup_fn)
is_checkp_epoch = cu.is_checkpoint_epoch(
cfg,
cur_epoch,
None if multigrid is None else multigrid.schedule,
)
is_eval_epoch = misc.is_eval_epoch(
cfg, cur_epoch, None if multigrid is None else multigrid.schedule
)
# Compute precise BN stats.
if (
(is_checkp_epoch or is_eval_epoch)
and cfg.BN.USE_PRECISE_STATS
and len(get_bn_modules(model)) > 0
):
calculate_and_update_precise_bn(
precise_bn_loader,
model,
min(cfg.BN.NUM_BATCHES_PRECISE, len(precise_bn_loader)),
cfg.NUM_GPUS > 0,
)
_ = misc.aggregate_sub_bn_stats(model)
# Save a checkpoint.
if is_checkp_epoch:
cu.save_checkpoint(cfg.OUTPUT_DIR, model, optimizer, cur_epoch, cfg,
loss_scaler=loss_scaler)
# Evaluate the model on validation set.
if is_eval_epoch:
eval_epoch(val_loader, model, val_meter, cur_epoch, cfg, writer)
if writer is not None:
writer.close()