in tensorflow_2_managed_spot_training_checkpointing/mnist.py [0:0]
def load_model_from_checkpoints(checkpoint_path):
checkpoint_files = [file for file in os.listdir(checkpoint_path) if file.endswith('.' + 'h5')]
print('------------------------------------------------------')
print("Available checkpoint files: {}".format(checkpoint_files))
epoch_numbers = [re.search('(\.*[0-9])(?=\.)',file).group() for file in checkpoint_files]
max_epoch_number = max(epoch_numbers)
max_epoch_index = epoch_numbers.index(max_epoch_number)
max_epoch_filename = checkpoint_files[max_epoch_index]
print('Latest epoch checkpoint file name: {}'.format(max_epoch_filename))
print('Resuming training from epoch: {}'.format(int(max_epoch_number)+1))
print('------------------------------------------------------')
resumed_model_from_checkpoints = tf.keras.models.load_model(f'{checkpoint_path}/{max_epoch_filename}')
return resumed_model_from_checkpoints, int(max_epoch_number)