def main()

in dev-support/mini-submarine/submarine/mnist_distributed.py [0:0]


def main(_):
    logging.getLogger().setLevel(logging.INFO)

    cluster_spec_str = os.environ["CLUSTER_SPEC"]
    cluster_spec = json.loads(cluster_spec_str)
    ps_hosts = cluster_spec["ps"]
    worker_hosts = cluster_spec["worker"]

    # Create a cluster from the parameter server and worker hosts.
    cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})

    # Create and start a server for the local task.
    job_name = os.environ["JOB_NAME"]
    task_index = int(os.environ["TASK_INDEX"])
    server = tf.train.Server(cluster, job_name=job_name, task_index=task_index)

    if job_name == "ps":
        server.join()
    elif job_name == "worker":
        # Create our model graph. Assigns ops to the local worker by default.
        with tf.device(
            tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % task_index, cluster=cluster)
        ):
            features, labels, keep_prob, global_step, train_step, accuracy, merged = create_model()

        if task_index == 0:  # chief worker
            tf.gfile.MakeDirs(FLAGS.working_dir)
            start_tensorboard(FLAGS.working_dir)

        # The StopAtStepHook handles stopping after running given steps.
        hooks = [tf.train.StopAtStepHook(num_steps=FLAGS.steps)]

        # Filter all connections except that between ps and this worker to
        # avoid hanging issues when one worker finishes. We are using
        # asynchronous training so there is no need for the workers to
        # communicate.
        config_proto = tf.ConfigProto(device_filters=["/job:ps", "/job:worker/task:%d" % task_index])

        with tf.train.MonitoredTrainingSession(
            master=server.target,
            is_chief=(task_index == 0),
            checkpoint_dir=FLAGS.working_dir,
            hooks=hooks,
            config=config_proto,
        ) as sess:
            # Import data
            logging.info("Extracting and loading input data...")
            # Use a different data dir name to workaround "file already exists issue"
            # when downloading dataset in the same single node
            if FLAGS.mnist_data_url == "":
                logging.info("Getting mnist data from default url")
                mnist = input_data.read_data_sets(FLAGS.data_dir + str(task_index))
            else:
                logging.info("Getting mnist data from " + FLAGS.mnist_data_url)
                mnist = input_data.read_data_sets(
                    FLAGS.data_dir + str(task_index), source_url=FLAGS.mnist_data_url
                )

            # Train
            logging.info("Starting training")
            i = 0
            while not sess.should_stop():
                # Before use submarine-sdk, start Mysql server first
                # submarine.log_param("batch_size", FLAGS.batch_size)
                batch = mnist.train.next_batch(FLAGS.batch_size)
                if i % 100 == 0:
                    step, _, train_accuracy = sess.run(
                        [global_step, train_step, accuracy],
                        feed_dict={features: batch[0], labels: batch[1], keep_prob: 1.0},
                    )
                    logging.info("Step %d, training accuracy: %g" % (step, train_accuracy))
                    # Before use submarine-sdk, start Mysql server first
                    # submarine.log_metric("accuracy", train_accuracy, i)
                else:
                    sess.run(
                        [global_step, train_step],
                        feed_dict={features: batch[0], labels: batch[1], keep_prob: 0.5},
                    )
                i += 1

        logging.info("Done training!")
        sys.exit()