def train_net()

in tools/train_net.py [0:0]


def train_net(opts):
    workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
    logging.getLogger(__name__)
    np.random.seed(cfg.RNG_SEED)

    prefix, device = helpers.get_prefix_and_device()
    device_prefix = '{}{}'.format(prefix, 0)

    ############################################################################
    total_test_iters, test_metrics_calculator = 0, None
    if cfg.MODEL.TEST_MODEL:
        # Build test_model: we do this first so we don't overwrite init (if any)
        test_model, test_metrics_calculator, test_timer = build_wrapper(
            is_train=False, prefix=device_prefix
        )
        total_test_iters = helpers.get_num_test_iter(test_model.input_db)
        logger.info('Test epoch iters: {}'.format(total_test_iters))

    ############################################################################
    # create training model and metrics
    train_model, train_metrics_calculator, train_timer = build_wrapper(
        is_train=True, prefix=device_prefix
    )
    # save proto for debugging
    helpers.save_model_proto(train_model)

    ############################################################################
    # setup the checkpoint directory and load model from checkpoint
    if cfg.CHECKPOINT.CHECKPOINT_ON:
        checkpoint_dir = checkpoints.get_checkpoint_directory()
        logger.info('Checkpoint directory: {}'.format(checkpoint_dir))

    # checkpoint_exists variable is used to track whether this is a model
    # resume from a failed training or not. It will look up in the checkpoint_dir
    # whether there are any model checkpoints available. If yes, then it is set
    # to True.
    start_model_iter, prev_checkpointed_lr, checkpoint_exists = 0, None, False
    if cfg.CHECKPOINT.RESUME or cfg.TRAIN.PARAMS_FILE:
        start_model_iter, prev_checkpointed_lr, checkpoint_exists = (
            checkpoints.load_model_from_params_file(
                train_model, params_file=cfg.TRAIN.PARAMS_FILE,
                checkpoint_dir=checkpoint_dir
            )
        )
        # if we are fine-tuning, it's possible that the training might get
        # stopped/killed
        if cfg.MODEL.FINE_TUNE and not checkpoint_exists:
            start_model_iter = 0

    ############################################################################
    logger.info("=> Training model...")
    model_flops, model_params = 0, 0
    lr_iters = lr_utils.get_lr_steps()
    for curr_iter in range(start_model_iter, cfg.SOLVER.NUM_ITERATIONS):
        # set LR
        lr_utils.add_variable_stepsize_lr(
            curr_iter + 1, cfg.NUM_DEVICES, lr_iters, start_model_iter + 1,
            cfg.TRAIN.EVALUATION_FREQUENCY, train_model, prev_checkpointed_lr
        )

        # run the model training iteration
        train_timer.tic()
        workspace.RunNet(train_model.net.Proto().name)
        train_timer.toc(average=False)

        # logging after 1st iteration
        if curr_iter == start_model_iter:
            helpers.print_net(train_model)
            os.system('nvidia-smi')
            model_flops, model_params = helpers.get_flops_params(train_model)

        # check nan loses
        helpers.check_nan_losses(cfg.NUM_DEVICES)

        # log metrics at the cfg.LOGGER_FREQUENCY
        rem_train_iters = cfg.SOLVER.NUM_ITERATIONS - curr_iter - 1
        train_metrics_calculator.calculate_and_log_train_iter_metrics(
            curr_iter, train_timer, rem_train_iters, cfg.SOLVER.NUM_ITERATIONS,
            train_model.data_loader.minibatch_queue_size()
        )

        # checkpoint model at CHECKPOINT_PERIOD
        if (
            cfg.CHECKPOINT.CHECKPOINT_ON
            and (curr_iter + 1) % cfg.CHECKPOINT.CHECKPOINT_PERIOD == 0
        ):
            params_file = os.path.join(
                checkpoint_dir, 'c2_model_iter{}.pkl'.format(curr_iter + 1)
            )
            checkpoints.save_model_params(
                model=train_model, params_file=params_file,
                model_iter=curr_iter, checkpoint_dir=checkpoint_dir
            )

        if (curr_iter + 1) % cfg.TRAIN.EVALUATION_FREQUENCY == 0:
            train_metrics_calculator.finalize_metrics()
            # test model if the testing is ON
            if cfg.MODEL.TEST_MODEL:
                test_metrics_calculator.reset()
                logger.info("=> Testing model...")
                for test_iter in range(0, total_test_iters):
                    # run a test iteration
                    test_timer.tic()
                    workspace.RunNet(test_model.net.Proto().name)
                    test_timer.toc()
                    rem_test_iters = (total_test_iters - test_iter - 1)
                    num_rem_iter = (cfg.SOLVER.NUM_ITERATIONS - curr_iter - 1)
                    num_rem_ep = num_rem_iter / cfg.TRAIN.EVALUATION_FREQUENCY
                    if (test_iter + 1) % cfg.LOGGER_FREQUENCY == 0:
                        rem_test_iters += int(total_test_iters * num_rem_ep)
                    test_metrics_calculator.calculate_and_log_test_iter_metrics(
                        test_iter, test_timer, rem_test_iters, total_test_iters
                    )
                test_metrics_calculator.finalize_metrics()
                test_metrics_calculator.compute_and_log_epoch_best_metric(
                    model_iter=curr_iter
                )
            json_stats = metrics_helper.get_json_stats_dict(
                train_metrics_calculator, test_metrics_calculator, curr_iter,
                model_flops, model_params,
            )

            json_stats['average_time'] = round(
                train_timer.average_time + test_timer.average_time, 3
            )
            metrics_helper.print_json_stats(json_stats)
            train_metrics_calculator.reset()

    if test_metrics_calculator is not None:
        test_metrics_calculator.log_best_model_metrics(
            model_iter=curr_iter, total_iters=cfg.SOLVER.NUM_ITERATIONS,
        )

    train_model.data_loader.shutdown_dataloader()
    if cfg.MODEL.TEST_MODEL:
        test_model.data_loader.shutdown_dataloader()

    logger.info('Training has successfully finished...exiting!')
    os._exit(0)