in tools/train_net.py [0:0]
def train_model(writer_train=None, writer_eval=None, is_master=False):
"""Trains the model."""
# Fit flops/params
if cfg.TRAIN.AUTO_MATCH and cfg.RGRAPH.SEED_TRAIN == cfg.RGRAPH.SEED_TRAIN_START:
mode = 'flops' # flops or params
if cfg.TRAIN.DATASET == 'cifar10':
pre_repeat = 15
if cfg.MODEL.TYPE == 'resnet': # ResNet20
stats_baseline = 40813184
elif cfg.MODEL.TYPE == 'mlpnet': # 5-layer MLP. cfg.MODEL.LAYERS exclude stem and head layers
if cfg.MODEL.LAYERS == 3:
if cfg.RGRAPH.DIM_LIST[0] == 256:
stats_baseline = 985600
elif cfg.RGRAPH.DIM_LIST[0] == 512:
stats_baseline = 2364416
elif cfg.RGRAPH.DIM_LIST[0] == 1024:
stats_baseline = 6301696
elif cfg.MODEL.TYPE == 'cnn':
if cfg.MODEL.LAYERS == 3:
if cfg.RGRAPH.DIM_LIST[0] == 512:
stats_baseline = 806884352
elif cfg.RGRAPH.DIM_LIST[0] == 16:
stats_baseline = 1216672
elif cfg.MODEL.LAYERS == 6:
if '64d' in cfg.OUT_DIR:
stats_baseline = 48957952
elif '16d' in cfg.OUT_DIR:
stats_baseline = 3392128
elif cfg.TRAIN.DATASET == 'imagenet':
pre_repeat = 9
if cfg.MODEL.TYPE == 'resnet':
if 'basic' in cfg.RESNET.TRANS_FUN: # ResNet34
stats_baseline = 3663761408
elif 'sep' in cfg.RESNET.TRANS_FUN: # ResNet34-sep
stats_baseline = 553614592
elif 'bottleneck' in cfg.RESNET.TRANS_FUN: # ResNet50
stats_baseline = 4089184256
elif cfg.MODEL.TYPE == 'efficientnet': # EfficientNet
stats_baseline = 385824092
elif cfg.MODEL.TYPE == 'cnn': # CNN
if cfg.MODEL.LAYERS == 6:
if '64d' in cfg.OUT_DIR:
stats_baseline = 166438912
cfg.defrost()
stats = model_builder.build_model_stats(mode)
if stats != stats_baseline:
# 1st round: set first stage dim
for i in range(pre_repeat):
scale = round(math.sqrt(stats_baseline / stats), 2)
first = cfg.RGRAPH.DIM_LIST[0]
ratio_list = [dim / first for dim in cfg.RGRAPH.DIM_LIST]
first = int(round(first * scale))
cfg.RGRAPH.DIM_LIST = [int(round(first * ratio)) for ratio in ratio_list]
stats = model_builder.build_model_stats(mode)
flag_init = 1 if stats < stats_baseline else -1
step = 1
while True:
first = cfg.RGRAPH.DIM_LIST[0]
ratio_list = [dim / first for dim in cfg.RGRAPH.DIM_LIST]
first += flag_init * step
cfg.RGRAPH.DIM_LIST = [int(round(first * ratio)) for ratio in ratio_list]
stats = model_builder.build_model_stats(mode)
flag = 1 if stats < stats_baseline else -1
if stats == stats_baseline:
break
if flag != flag_init:
if cfg.RGRAPH.UPPER == False: # make sure the stats is SMALLER than baseline
if flag < 0:
first = cfg.RGRAPH.DIM_LIST[0]
ratio_list = [dim / first for dim in cfg.RGRAPH.DIM_LIST]
first -= flag_init * step
cfg.RGRAPH.DIM_LIST = [int(round(first * ratio)) for ratio in ratio_list]
break
else:
if flag > 0:
first = cfg.RGRAPH.DIM_LIST[0]
ratio_list = [dim / first for dim in cfg.RGRAPH.DIM_LIST]
first -= flag_init * step
cfg.RGRAPH.DIM_LIST = [int(round(first * ratio)) for ratio in ratio_list]
break
# 2nd round: set other stage dim
first = cfg.RGRAPH.DIM_LIST[0]
ratio_list = [int(round(dim / first)) for dim in cfg.RGRAPH.DIM_LIST]
stats = model_builder.build_model_stats(mode)
flag_init = 1 if stats < stats_baseline else -1
if 'share' not in cfg.RESNET.TRANS_FUN:
for i in range(1, len(cfg.RGRAPH.DIM_LIST)):
for j in range(ratio_list[i]):
cfg.RGRAPH.DIM_LIST[i] += flag_init
stats = model_builder.build_model_stats(mode)
flag = 1 if stats < stats_baseline else -1
if flag_init != flag:
cfg.RGRAPH.DIM_LIST[i] -= flag_init
break
stats = model_builder.build_model_stats(mode)
print('FINAL', cfg.RGRAPH.GROUP_NUM, cfg.RGRAPH.DIM_LIST, stats, stats_baseline, stats < stats_baseline)
# Build the model (before the loaders to ease debugging)
model = model_builder.build_model()
params, flops = log_model_info(model, writer_eval)
# Define the loss function
loss_fun = losses.get_loss_fun()
# Construct the optimizer
optimizer = optim.construct_optimizer(model)
# Load a checkpoint if applicable
start_epoch = 0
if cfg.TRAIN.AUTO_RESUME and cu.has_checkpoint():
last_checkpoint = cu.get_checkpoint_last()
checkpoint_epoch = cu.load_checkpoint(last_checkpoint, model, optimizer)
logger.info('Loaded checkpoint from: {}'.format(last_checkpoint))
if checkpoint_epoch == cfg.OPTIM.MAX_EPOCH:
exit()
start_epoch = checkpoint_epoch
else:
start_epoch = checkpoint_epoch + 1
# Create data loaders
train_loader = loader.construct_train_loader()
test_loader = loader.construct_test_loader()
# Create meters
train_meter = TrainMeter(len(train_loader))
test_meter = TestMeter(len(test_loader))
if cfg.ONLINE_FLOPS:
model_dummy = model_builder.build_model()
IMAGE_SIZE = 224
n_flops, n_params = mu.measure_model(model_dummy, IMAGE_SIZE, IMAGE_SIZE)
logger.info('FLOPs: %.2fM, Params: %.2fM' % (n_flops / 1e6, n_params / 1e6))
del (model_dummy)
# Perform the training loop
logger.info('Start epoch: {}'.format(start_epoch + 1))
# do eval at initialization
eval_epoch(test_loader, model, test_meter, -1,
writer_eval, params, flops, is_master=is_master)
if start_epoch == cfg.OPTIM.MAX_EPOCH:
cur_epoch = start_epoch - 1
eval_epoch(test_loader, model, test_meter, cur_epoch,
writer_eval, params, flops, is_master=is_master)
else:
for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
# Train for one epoch
train_epoch(
train_loader, model, loss_fun, optimizer, train_meter, cur_epoch,
writer_train, is_master=is_master
)
# Compute precise BN stats
if cfg.BN.USE_PRECISE_STATS:
nu.compute_precise_bn_stats(model, train_loader)
# Save a checkpoint
if cu.is_checkpoint_epoch(cur_epoch):
checkpoint_file = cu.save_checkpoint(model, optimizer, cur_epoch)
logger.info('Wrote checkpoint to: {}'.format(checkpoint_file))
# Evaluate the model
if is_eval_epoch(cur_epoch):
eval_epoch(test_loader, model, test_meter, cur_epoch,
writer_eval, params, flops, is_master=is_master)