in ec2-spot-tensorflow-checkpoint/tensorflow_checkpoint.py [0:0]
def train_and_checkpoint(net, manager):
ckpt.restore(manager.latest_checkpoint).expect_partial()
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")
for _ in range(5000):
example = next(iterator)
loss = train_step(net, example, opt)
ckpt.step.assign_add(1)
if int(ckpt.step) % 10 == 0:
save_path = manager.save()
list_of_files = glob.glob('tf_ckpts/*.index')
latest_file = max(list_of_files, key=os.path.getctime)
upload_file(latest_file, 'pythontfckpt', object_name=None)
list_of_files = glob.glob('tf_ckpts/*.data*')
latest_file = max(list_of_files, key=os.path.getctime)
upload_file(latest_file, 'pythontfckpt', object_name=None)
upload_file('tf_ckpts/checkpoint', 'pythontfckpt', object_name=None)
print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
print("loss {:1.2f}".format(loss.numpy()))