def train_and_checkpoint()

in ec2-spot-tensorflow-checkpoint/tensorflow_checkpoint.py [0:0]


def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint).expect_partial()
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")
  for _ in range(5000):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
        save_path = manager.save()
        list_of_files = glob.glob('tf_ckpts/*.index')
        latest_file = max(list_of_files, key=os.path.getctime)
        upload_file(latest_file, 'pythontfckpt', object_name=None)
        list_of_files = glob.glob('tf_ckpts/*.data*')
        latest_file = max(list_of_files, key=os.path.getctime)
        upload_file(latest_file, 'pythontfckpt', object_name=None)
        upload_file('tf_ckpts/checkpoint', 'pythontfckpt', object_name=None)
        print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
        print("loss {:1.2f}".format(loss.numpy()))