in dev-support/mini-submarine/submarine/mnist_distributed.py [0:0]
def main(_):
logging.getLogger().setLevel(logging.INFO)
cluster_spec_str = os.environ["CLUSTER_SPEC"]
cluster_spec = json.loads(cluster_spec_str)
ps_hosts = cluster_spec["ps"]
worker_hosts = cluster_spec["worker"]
# Create a cluster from the parameter server and worker hosts.
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
# Create and start a server for the local task.
job_name = os.environ["JOB_NAME"]
task_index = int(os.environ["TASK_INDEX"])
server = tf.train.Server(cluster, job_name=job_name, task_index=task_index)
if job_name == "ps":
server.join()
elif job_name == "worker":
# Create our model graph. Assigns ops to the local worker by default.
with tf.device(
tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % task_index, cluster=cluster)
):
features, labels, keep_prob, global_step, train_step, accuracy, merged = create_model()
if task_index == 0: # chief worker
tf.gfile.MakeDirs(FLAGS.working_dir)
start_tensorboard(FLAGS.working_dir)
# The StopAtStepHook handles stopping after running given steps.
hooks = [tf.train.StopAtStepHook(num_steps=FLAGS.steps)]
# Filter all connections except that between ps and this worker to
# avoid hanging issues when one worker finishes. We are using
# asynchronous training so there is no need for the workers to
# communicate.
config_proto = tf.ConfigProto(device_filters=["/job:ps", "/job:worker/task:%d" % task_index])
with tf.train.MonitoredTrainingSession(
master=server.target,
is_chief=(task_index == 0),
checkpoint_dir=FLAGS.working_dir,
hooks=hooks,
config=config_proto,
) as sess:
# Import data
logging.info("Extracting and loading input data...")
# Use a different data dir name to workaround "file already exists issue"
# when downloading dataset in the same single node
if FLAGS.mnist_data_url == "":
logging.info("Getting mnist data from default url")
mnist = input_data.read_data_sets(FLAGS.data_dir + str(task_index))
else:
logging.info("Getting mnist data from " + FLAGS.mnist_data_url)
mnist = input_data.read_data_sets(
FLAGS.data_dir + str(task_index), source_url=FLAGS.mnist_data_url
)
# Train
logging.info("Starting training")
i = 0
while not sess.should_stop():
# Before use submarine-sdk, start Mysql server first
# submarine.log_param("batch_size", FLAGS.batch_size)
batch = mnist.train.next_batch(FLAGS.batch_size)
if i % 100 == 0:
step, _, train_accuracy = sess.run(
[global_step, train_step, accuracy],
feed_dict={features: batch[0], labels: batch[1], keep_prob: 1.0},
)
logging.info("Step %d, training accuracy: %g" % (step, train_accuracy))
# Before use submarine-sdk, start Mysql server first
# submarine.log_metric("accuracy", train_accuracy, i)
else:
sess.run(
[global_step, train_step],
feed_dict={features: batch[0], labels: batch[1], keep_prob: 0.5},
)
i += 1
logging.info("Done training!")
sys.exit()