in lib/utils/checkpoints.py [0:0]
def load_model_from_params_file(model):
"""
case 1: CHECKPOINT.RESUME = False and TRAIN.PARAMS_FILE is not none:
load params_file
case 2: CHECKPOINT.RESUME = True and TRAIN.PARAMS_FILE is not none:
case 2a: if checkpoint exist: use checkpoint
case 2b: if checkpoint not exist: use params_file
case 3: CHECKPOINT.RESUME = True and TRAIN.PARAMS_FILE is none:
case 3a: if checkpoint exist: use checkpoint
case 3b: if checkpoint not exist: set start_model_iter = 0
"""
use_checkpoint = cfg.CHECKPOINT.RESUME and find_checkpoint()
if cfg.TRAIN.PARAMS_FILE and not use_checkpoint:
logger.info('Initializing from pre-trained file...')
start_model_iter, prev_lr = initialize_params_from_file(
model=model, weights_file=cfg.TRAIN.PARAMS_FILE,
load_momentum=False, # not load momentum if it is pretrained
)
logger.info(('Loaded: start_model_iter: {}; prev_lr: {:.8f}').format(
start_model_iter, prev_lr))
model.current_lr = prev_lr
# correct start_model_iter if pretraining uses a different batch size
# (mainly used for 1-node warmup)
if cfg.TRAIN.RESUME_FROM_BATCH_SIZE > 0:
start_model_iter = misc.resume_from(start_model_iter)
# if we only want the weights
if cfg.TRAIN.RESET_START_ITER:
start_model_iter = 0
elif use_checkpoint:
logger.info('Initializing from checkpoints...')
start_model_iter, prev_lr = initialize_params_from_file(
model=model, weights_file=get_checkpoint_resume_file())
logger.info(('Loaded: start_model_iter: {}; prev_lr: {:.8f}').format(
start_model_iter, prev_lr))
model.current_lr = prev_lr
else: # no checkpoint, no params_file
# Do nothing and return 0
start_model_iter = 0
logger.info('No checkpoint found; training from scratch...')
return start_model_iter