in tools/train_net.py [0:0]
def train(opts):
"""Train a model."""
workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
logging.getLogger(__name__)
# Generate seed.
misc.generate_random_seed(opts)
# Create checkpoint dir.
checkpoint_dir = checkpoints.create_and_get_checkpoint_directory()
logger.info('Checkpoint directory created: {}'.format(checkpoint_dir))
# Setting training-time-specific configurations.
cfg.AVA.FULL_EVAL = cfg.AVA.FULL_EVAL_DURING_TRAINING
cfg.AVA.DETECTION_SCORE_THRESH = cfg.AVA.DETECTION_SCORE_THRESH_TRAIN
cfg.CHARADES.NUM_TEST_CLIPS = cfg.CHARADES.NUM_TEST_CLIPS_DURING_TRAINING
test_lfb, train_lfb = None, None
if cfg.LFB.ENABLED:
test_lfb = get_lfb(cfg.LFB.MODEL_PARAMS_FILE, is_train=False)
train_lfb = get_lfb(cfg.LFB.MODEL_PARAMS_FILE, is_train=True)
# Build test_model.
# We build test_model first, so that we don't overwrite init.
test_model, test_timer, test_meter = create_wrapper(
is_train=False,
lfb=test_lfb,
)
total_test_iters = misc.get_total_test_iters(test_model)
logger.info('Test iters: {}'.format(total_test_iters))
# Build train_model.
train_model, train_timer, train_meter = create_wrapper(
is_train=True,
lfb=train_lfb,
)
# Bould BN auxilary model.
if cfg.TRAIN.COMPUTE_PRECISE_BN:
bn_aux = bn_helper.BatchNormHelper()
bn_aux.create_bn_aux_model(node_id=opts.node_id)
# Load checkpoint or pre-trained weight.
# See checkpoints.load_model_from_params_file for more details.
start_model_iter = 0
if cfg.CHECKPOINT.RESUME or cfg.TRAIN.PARAMS_FILE:
start_model_iter = checkpoints.load_model_from_params_file(train_model)
logger.info("------------- Training model... -------------")
train_meter.reset()
last_checkpoint = checkpoints.get_checkpoint_resume_file()
for curr_iter in range(start_model_iter, cfg.SOLVER.MAX_ITER):
train_model.UpdateWorkspaceLr(curr_iter)
train_timer.tic()
# SGD step.
workspace.RunNet(train_model.net.Proto().name)
train_timer.toc()
if curr_iter == start_model_iter:
misc.print_net(train_model)
os.system('nvidia-smi')
misc.show_flops_params(train_model)
misc.check_nan_losses()
# Checkpoint.
if (curr_iter + 1) % cfg.CHECKPOINT.CHECKPOINT_PERIOD == 0 \
or curr_iter + 1 == cfg.SOLVER.MAX_ITER:
if cfg.TRAIN.COMPUTE_PRECISE_BN:
bn_aux.compute_and_update_bn_stats(curr_iter)
last_checkpoint = os.path.join(
checkpoint_dir,
'c2_model_iter{}.pkl'.format(curr_iter + 1))
checkpoints.save_model_params(
model=train_model,
params_file=last_checkpoint,
model_iter=curr_iter)
train_meter.calculate_and_log_all_metrics_train(
curr_iter, train_timer, suffix='_train')
# Evaluation.
if (curr_iter + 1) % cfg.TRAIN.EVAL_PERIOD == 0:
if cfg.TRAIN.COMPUTE_PRECISE_BN:
bn_aux.compute_and_update_bn_stats(curr_iter)
test_meter.reset()
logger.info("=> Testing model")
for test_iter in range(0, total_test_iters):
test_timer.tic()
workspace.RunNet(test_model.net.Proto().name)
test_timer.toc()
test_meter.calculate_and_log_all_metrics_test(
test_iter, test_timer, total_test_iters, suffix='_test')
test_meter.finalize_metrics()
test_meter.compute_and_log_best()
test_meter.log_final_metrics(curr_iter)
# Finalize and reset train_meter after test.
train_meter.finalize_metrics(is_train=True)
json_stats = metrics.get_json_stats_dict(
train_meter, test_meter, curr_iter)
misc.log_json_stats(json_stats)
train_meter.reset()
train_model.shutdown_data_loader()
test_model.shutdown_data_loader()
if cfg.TRAIN.TEST_AFTER_TRAIN:
cfg.TEST.PARAMS_FILE = last_checkpoint
test_net(test_lfb)