def train_model()

in tools/train_net.py [0:0]


def train_model(writer_train=None, writer_eval=None, is_master=False):
    """Trains the model."""
    # Fit flops/params
    if cfg.TRAIN.AUTO_MATCH and cfg.RGRAPH.SEED_TRAIN == cfg.RGRAPH.SEED_TRAIN_START:
        mode = 'flops'  # flops or params
        if cfg.TRAIN.DATASET == 'cifar10':
            pre_repeat = 15
            if cfg.MODEL.TYPE == 'resnet':  # ResNet20
                stats_baseline = 40813184
            elif cfg.MODEL.TYPE == 'mlpnet':  # 5-layer MLP. cfg.MODEL.LAYERS exclude stem and head layers
                if cfg.MODEL.LAYERS == 3:
                    if cfg.RGRAPH.DIM_LIST[0] == 256:
                        stats_baseline = 985600
                    elif cfg.RGRAPH.DIM_LIST[0] == 512:
                        stats_baseline = 2364416
                    elif cfg.RGRAPH.DIM_LIST[0] == 1024:
                        stats_baseline = 6301696
            elif cfg.MODEL.TYPE == 'cnn':
                if cfg.MODEL.LAYERS == 3:
                    if cfg.RGRAPH.DIM_LIST[0] == 512:
                        stats_baseline = 806884352
                    elif cfg.RGRAPH.DIM_LIST[0] == 16:
                        stats_baseline = 1216672
                elif cfg.MODEL.LAYERS == 6:
                    if '64d' in cfg.OUT_DIR:
                        stats_baseline = 48957952
                    elif '16d' in cfg.OUT_DIR:
                        stats_baseline = 3392128
        elif cfg.TRAIN.DATASET == 'imagenet':
            pre_repeat = 9
            if cfg.MODEL.TYPE == 'resnet':
                if 'basic' in cfg.RESNET.TRANS_FUN:  # ResNet34
                    stats_baseline = 3663761408
                elif 'sep' in cfg.RESNET.TRANS_FUN:  # ResNet34-sep
                    stats_baseline = 553614592
                elif 'bottleneck' in cfg.RESNET.TRANS_FUN:  # ResNet50
                    stats_baseline = 4089184256
            elif cfg.MODEL.TYPE == 'efficientnet':  # EfficientNet
                stats_baseline = 385824092
            elif cfg.MODEL.TYPE == 'cnn':  # CNN
                if cfg.MODEL.LAYERS == 6:
                    if '64d' in cfg.OUT_DIR:
                        stats_baseline = 166438912
        cfg.defrost()
        stats = model_builder.build_model_stats(mode)
        if stats != stats_baseline:
            # 1st round: set first stage dim
            for i in range(pre_repeat):
                scale = round(math.sqrt(stats_baseline / stats), 2)
                first = cfg.RGRAPH.DIM_LIST[0]
                ratio_list = [dim / first for dim in cfg.RGRAPH.DIM_LIST]
                first = int(round(first * scale))
                cfg.RGRAPH.DIM_LIST = [int(round(first * ratio)) for ratio in ratio_list]
                stats = model_builder.build_model_stats(mode)
            flag_init = 1 if stats < stats_baseline else -1
            step = 1
            while True:
                first = cfg.RGRAPH.DIM_LIST[0]
                ratio_list = [dim / first for dim in cfg.RGRAPH.DIM_LIST]
                first += flag_init * step
                cfg.RGRAPH.DIM_LIST = [int(round(first * ratio)) for ratio in ratio_list]
                stats = model_builder.build_model_stats(mode)
                flag = 1 if stats < stats_baseline else -1
                if stats == stats_baseline:
                    break
                if flag != flag_init:
                    if cfg.RGRAPH.UPPER == False:  # make sure the stats is SMALLER than baseline
                        if flag < 0:
                            first = cfg.RGRAPH.DIM_LIST[0]
                            ratio_list = [dim / first for dim in cfg.RGRAPH.DIM_LIST]
                            first -= flag_init * step
                            cfg.RGRAPH.DIM_LIST = [int(round(first * ratio)) for ratio in ratio_list]
                        break
                    else:
                        if flag > 0:
                            first = cfg.RGRAPH.DIM_LIST[0]
                            ratio_list = [dim / first for dim in cfg.RGRAPH.DIM_LIST]
                            first -= flag_init * step
                            cfg.RGRAPH.DIM_LIST = [int(round(first * ratio)) for ratio in ratio_list]
                        break
            # 2nd round: set other stage dim
            first = cfg.RGRAPH.DIM_LIST[0]
            ratio_list = [int(round(dim / first)) for dim in cfg.RGRAPH.DIM_LIST]
            stats = model_builder.build_model_stats(mode)
            flag_init = 1 if stats < stats_baseline else -1
            if 'share' not in cfg.RESNET.TRANS_FUN:
                for i in range(1, len(cfg.RGRAPH.DIM_LIST)):
                    for j in range(ratio_list[i]):
                        cfg.RGRAPH.DIM_LIST[i] += flag_init
                        stats = model_builder.build_model_stats(mode)
                        flag = 1 if stats < stats_baseline else -1
                        if flag_init != flag:
                            cfg.RGRAPH.DIM_LIST[i] -= flag_init
                            break
        stats = model_builder.build_model_stats(mode)
        print('FINAL', cfg.RGRAPH.GROUP_NUM, cfg.RGRAPH.DIM_LIST, stats, stats_baseline, stats < stats_baseline)
    # Build the model (before the loaders to ease debugging)
    model = model_builder.build_model()
    params, flops = log_model_info(model, writer_eval)

    # Define the loss function
    loss_fun = losses.get_loss_fun()
    # Construct the optimizer
    optimizer = optim.construct_optimizer(model)

    # Load a checkpoint if applicable
    start_epoch = 0
    if cfg.TRAIN.AUTO_RESUME and cu.has_checkpoint():
        last_checkpoint = cu.get_checkpoint_last()
        checkpoint_epoch = cu.load_checkpoint(last_checkpoint, model, optimizer)
        logger.info('Loaded checkpoint from: {}'.format(last_checkpoint))
        if checkpoint_epoch == cfg.OPTIM.MAX_EPOCH:
            exit()
            start_epoch = checkpoint_epoch
        else:
            start_epoch = checkpoint_epoch + 1

    # Create data loaders
    train_loader = loader.construct_train_loader()
    test_loader = loader.construct_test_loader()

    # Create meters
    train_meter = TrainMeter(len(train_loader))
    test_meter = TestMeter(len(test_loader))

    if cfg.ONLINE_FLOPS:
        model_dummy = model_builder.build_model()

        IMAGE_SIZE = 224
        n_flops, n_params = mu.measure_model(model_dummy, IMAGE_SIZE, IMAGE_SIZE)

        logger.info('FLOPs: %.2fM, Params: %.2fM' % (n_flops / 1e6, n_params / 1e6))

        del (model_dummy)

    # Perform the training loop
    logger.info('Start epoch: {}'.format(start_epoch + 1))

    # do eval at initialization
    eval_epoch(test_loader, model, test_meter, -1,
               writer_eval, params, flops, is_master=is_master)

    if start_epoch == cfg.OPTIM.MAX_EPOCH:
        cur_epoch = start_epoch - 1
        eval_epoch(test_loader, model, test_meter, cur_epoch,
                   writer_eval, params, flops, is_master=is_master)
    else:
        for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
            # Train for one epoch
            train_epoch(
                train_loader, model, loss_fun, optimizer, train_meter, cur_epoch,
                writer_train, is_master=is_master
            )
            # Compute precise BN stats
            if cfg.BN.USE_PRECISE_STATS:
                nu.compute_precise_bn_stats(model, train_loader)
            # Save a checkpoint
            if cu.is_checkpoint_epoch(cur_epoch):
                checkpoint_file = cu.save_checkpoint(model, optimizer, cur_epoch)
                logger.info('Wrote checkpoint to: {}'.format(checkpoint_file))
            # Evaluate the model
            if is_eval_epoch(cur_epoch):
                eval_epoch(test_loader, model, test_meter, cur_epoch,
                           writer_eval, params, flops, is_master=is_master)