in tools/handobj/train_net.py [0:0]
def train(cfg):
"""
Train a video model for many epochs on train set and evaluate it on val set.
Args:
cfg (CfgNode): configs. Details can be found in
slowfast/config/defaults.py
"""
# Set up environment.
du.init_distributed_training(cfg)
# Set random seed from configs.
np.random.seed(cfg.RNG_SEED)
torch.manual_seed(cfg.RNG_SEED)
# Setup logging format.
logging.setup_logging(cfg.OUTPUT_DIR)
# Init multigrid.
multigrid = None
if cfg.MULTIGRID.LONG_CYCLE or cfg.MULTIGRID.SHORT_CYCLE:
multigrid = MultigridSchedule()
cfg = multigrid.init_multigrid(cfg)
if cfg.MULTIGRID.LONG_CYCLE:
cfg, _ = multigrid.update_long_cycle(cfg, cur_epoch=0)
# Print config.
logger.info("Train with config:")
logger.info(pprint.pformat(cfg))
# Build the video model and print model statistics.
model = build_model(cfg)
if du.is_master_proc() and cfg.LOG_MODEL_INFO:
try:
misc.log_model_info(model, cfg, use_train_input=True)
except Exception as e:
logger.info(e)
#pass
# Construct the optimizer.
optimizer = optim.construct_optimizer(model, cfg)
# Load a checkpoint to resume training if applicable.
start_epoch = cu.load_train_checkpoint(cfg, model, optimizer)
# Create the video train and val loaders.
train_loader = loader.construct_loader(cfg, "train")
val_loader = loader.construct_loader(cfg, "val")
precise_bn_loader = loader.construct_loader(
cfg, "train", is_precise_bn=True
)
# Create meters.
if cfg.DETECTION.ENABLE:
train_meter = AVAMeter(len(train_loader), cfg, mode="train")
val_meter = AVAMeter(len(val_loader), cfg, mode="val")
else:
train_meter = TrainMeter(len(train_loader), cfg)
val_meter = ValMeter(len(val_loader), cfg)
# set up writer for logging to Tensorboard format.
if cfg.TENSORBOARD.ENABLE and du.is_master_proc(
cfg.NUM_GPUS * cfg.NUM_SHARDS
):
writer = tb.TensorboardWriter(cfg)
else:
writer = None
# Perform the training loop.
logger.info("Start epoch: {}".format(start_epoch + 1))
for cur_epoch in range(start_epoch, cfg.SOLVER.MAX_EPOCH):
if cfg.MULTIGRID.LONG_CYCLE:
cfg, changed = multigrid.update_long_cycle(cfg, cur_epoch)
if changed:
(
model,
optimizer,
train_loader,
val_loader,
precise_bn_loader,
train_meter,
val_meter,
) = build_trainer(cfg)
# Load checkpoint.
if cu.has_checkpoint(cfg.OUTPUT_DIR):
last_checkpoint = cu.get_last_checkpoint(cfg.OUTPUT_DIR)
assert "{:05d}.pyth".format(cur_epoch) in last_checkpoint
else:
last_checkpoint = cfg.TRAIN.CHECKPOINT_FILE_PATH
logger.info("Load from {}".format(last_checkpoint))
cu.load_checkpoint(
last_checkpoint, model, cfg.NUM_GPUS > 1, optimizer
)
# Shuffle the dataset.
loader.shuffle_dataset(train_loader, cur_epoch)
# Train for one epoch.
train_epoch(
train_loader, model, optimizer, train_meter, cur_epoch, cfg, writer
)
# Compute precise BN stats.
if cfg.BN.USE_PRECISE_STATS and len(get_bn_modules(model)) > 0:
calculate_and_update_precise_bn(
precise_bn_loader,
model,
min(cfg.BN.NUM_BATCHES_PRECISE, len(precise_bn_loader)),
)
_ = misc.aggregate_sub_bn_stats(model)
# Save a checkpoint.
if cu.is_checkpoint_epoch(
cfg, cur_epoch, None if multigrid is None else multigrid.schedule
):
cu.save_checkpoint(cfg.OUTPUT_DIR, model, optimizer, cur_epoch, cfg)
# Evaluate the model on validation set.
if misc.is_eval_epoch(
cfg, cur_epoch, None if multigrid is None else multigrid.schedule
):
eval_epoch(val_loader, model, val_meter, cur_epoch, cfg, writer)
if writer is not None:
writer.close()