def main()

in workshop/lab_bring-your-own-model/tensorflow/cnn_mnist_after.py [0:0]


def main(unused_argv):
  args = parse_args()
  train_dir = args.train
  test_dir = args.test
  model_dir = args.model_dir
  sm_model_dir = args.sm_model_dir
  training_steps = args.training_steps
    
#def main(unused_argv):
  # Load training and eval data
  #mnist = tf.contrib.learn.datasets.load_dataset("mnist")
  #train_data = mnist.train.images  # Returns np.array
  #train_labels = np.asarray(mnist.train.labels, dtype=np.int32)
  #eval_data = mnist.test.images  # Returns np.array
  #eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)

  import os
  train_data = np.load(os.path.join(train_dir, 'image.npy')).astype(np.float32) * 1./255
  train_labels = np.load(os.path.join(train_dir, 'label.npy')).astype(np.int32)
  eval_data = np.load(os.path.join(test_dir, 'image.npy')).astype(np.float32) * 1./255
  eval_labels = np.load(os.path.join(test_dir, 'label.npy')).astype(np.int32)
  # Create the Estimator

  mnist_classifier = tf.estimator.Estimator(
      model_fn=cnn_model_fn, model_dir=model_dir)
    
  #mnist_classifier = tf.estimator.Estimator(
  #    model_fn=cnn_model_fn, model_dir="/tmp/mnist_convnet_model")

  # Set up logging for predictions
  # Log the values in the "Softmax" tensor with label "probabilities"
  tensors_to_log = {"probabilities": "softmax_tensor"}
  logging_hook = tf.train.LoggingTensorHook(
      tensors=tensors_to_log, every_n_iter=50)

  # Train the model
  train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
      x={"x": train_data},
      y=train_labels,
      batch_size=100,
      num_epochs=None,
      shuffle=True)
  mnist_classifier.train(
      input_fn=train_input_fn,
      steps=training_steps, #default:20000
      hooks=[logging_hook])

  # Evaluate the model and print results
  #eval_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
  #    x={"x": eval_data}, y=eval_labels, num_epochs=1, shuffle=False)
  #eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
  #print(eval_results)

  # Evaluate the model and print results
  eval_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
      x={"x": eval_data}, y=eval_labels, num_epochs=1, shuffle=False)
  eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
  print(eval_results)

  mnist_classifier.export_savedmodel(sm_model_dir, serving_input_fn)