def load_model_from_checkpoints()

in tensorflow_managed_spot_training_checkpointing/source_dir/cifar10_keras_main.py [0:0]


def load_model_from_checkpoints(checkpoint_path):
    checkpoint_files = [file for file in os.listdir(checkpoint_path) if file.endswith('.' + 'h5')]
    logging.info('------------------------------------------------------')
    logging.info("Available checkpoint files: {}".format(checkpoint_files))
    epoch_numbers = [re.search('(\.*[0-9])(?=\.)',file).group() for file in checkpoint_files]
      
    max_epoch_number = max(epoch_numbers)
    max_epoch_index = epoch_numbers.index(max_epoch_number)
    max_epoch_filename = checkpoint_files[max_epoch_index]

    logging.info('Latest epoch checkpoint file name: {}'.format(max_epoch_filename))
    logging.info('Resuming training from epoch: {}'.format(int(max_epoch_number)+1))
    logging.info('------------------------------------------------------')
    
    resumed_model_from_checkpoints = load_model(f'{checkpoint_path}/{max_epoch_filename}')
    return resumed_model_from_checkpoints, int(max_epoch_number)