def load_model_from_checkpoints()

in mxnet_managed_spot_training_checkpointing/source_dir/mnist.py [0:0]


def load_model_from_checkpoints(checkpoint_path):
    checkpoint_files = [file for file in os.listdir(checkpoint_path) if file.endswith('.' + 'params')]
    logging.info('------------------------------------------------------')
    logging.info("Available checkpoint files: {}".format(checkpoint_files))
    epoch_numbers = [re.search('(\.*[0-9])(?=\.)',file).group() for file in checkpoint_files]
      
    max_epoch_number = max(epoch_numbers)
    max_epoch_index = epoch_numbers.index(max_epoch_number)
    max_epoch_filename = checkpoint_files[max_epoch_index]

    logging.info('Latest epoch checkpoint file name: {}'.format(max_epoch_filename))
    logging.info('Resuming training from epoch: {}'.format(max_epoch_number))
    logging.info('------------------------------------------------------')
    
    sym, arg_params, aux_params = mx.model.load_checkpoint(checkpoint_path + "/mnist", int(max_epoch_number))
    mlp_model = mx.mod.Module(symbol=sym)
    return mlp_model, int(max_epoch_number)