in mxnet_managed_spot_training_checkpointing/source_dir/mnist.py [0:0]
def load_model_from_checkpoints(checkpoint_path):
checkpoint_files = [file for file in os.listdir(checkpoint_path) if file.endswith('.' + 'params')]
logging.info('------------------------------------------------------')
logging.info("Available checkpoint files: {}".format(checkpoint_files))
epoch_numbers = [re.search('(\.*[0-9])(?=\.)',file).group() for file in checkpoint_files]
max_epoch_number = max(epoch_numbers)
max_epoch_index = epoch_numbers.index(max_epoch_number)
max_epoch_filename = checkpoint_files[max_epoch_index]
logging.info('Latest epoch checkpoint file name: {}'.format(max_epoch_filename))
logging.info('Resuming training from epoch: {}'.format(max_epoch_number))
logging.info('------------------------------------------------------')
sym, arg_params, aux_params = mx.model.load_checkpoint(checkpoint_path + "/mnist", int(max_epoch_number))
mlp_model = mx.mod.Module(symbol=sym)
return mlp_model, int(max_epoch_number)