def load_model_from_checkpoints()

in tensorflow_2_managed_spot_training_checkpointing/mnist.py [0:0]


def load_model_from_checkpoints(checkpoint_path):
    checkpoint_files = [file for file in os.listdir(checkpoint_path) if file.endswith('.' + 'h5')]
    print('------------------------------------------------------')
    print("Available checkpoint files: {}".format(checkpoint_files))
    epoch_numbers = [re.search('(\.*[0-9])(?=\.)',file).group() for file in checkpoint_files]
      
    max_epoch_number = max(epoch_numbers)
    max_epoch_index = epoch_numbers.index(max_epoch_number)
    max_epoch_filename = checkpoint_files[max_epoch_index]

    print('Latest epoch checkpoint file name: {}'.format(max_epoch_filename))
    print('Resuming training from epoch: {}'.format(int(max_epoch_number)+1))
    print('------------------------------------------------------')
    
    resumed_model_from_checkpoints = tf.keras.models.load_model(f'{checkpoint_path}/{max_epoch_filename}')
    return resumed_model_from_checkpoints, int(max_epoch_number)