def main()

in benchmarks/horovod-resnet/train_imagenet_resnet_hvd.py [0:0]


def main():
    gpu_thread_count = 2
    os.environ["TF_GPU_THREAD_MODE"] = "gpu_private"
    os.environ["TF_GPU_THREAD_COUNT"] = str(gpu_thread_count)
    os.environ["TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT"] = "1"
    os.environ["TF_ENABLE_WINOGRAD_NONFUSED"] = "1"
    hvd.init()

    config = tf.ConfigProto()
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    config.gpu_options.force_gpu_compatible = True  # Force pinned memory
    config.intra_op_parallelism_threads = 1  # Avoid pool of Eigen threads
    config.inter_op_parallelism_threads = 5

    # random.seed(5 * (1 + hvd.rank()))
    # np.random.seed(7 * (1 + hvd.rank()))
    # tf.set_random_seed(31 * (1 + hvd.rank()))

    cmdline = add_cli_args()
    FLAGS, unknown_args = cmdline.parse_known_args()
    if len(unknown_args) > 0:
        for bad_arg in unknown_args:
            print("ERROR: Unknown command line arg: %s" % bad_arg)
        raise ValueError("Invalid command line arg(s)")

    FLAGS.data_dir = None if FLAGS.data_dir == "" else FLAGS.data_dir
    FLAGS.log_dir = None if FLAGS.log_dir == "" else FLAGS.log_dir

    if FLAGS.eval:
        FLAGS.log_name = "eval_" + FLAGS.log_name
        if hvd.rank() != 0:
            return
    if FLAGS.local_ckpt:
        do_checkpoint = hvd.local_rank() == 0
    else:
        do_checkpoint = hvd.rank() == 0
    if hvd.local_rank() == 0 and FLAGS.clear_log and os.path.isdir(FLAGS.log_dir):
        shutil.rmtree(FLAGS.log_dir)
    barrier = hvd.allreduce(tf.constant(0, dtype=tf.float32))
    tf.Session(config=config).run(barrier)

    if hvd.local_rank() == 0 and not os.path.isdir(FLAGS.log_dir):
        os.makedirs(FLAGS.log_dir)
    barrier = hvd.allreduce(tf.constant(0, dtype=tf.float32))
    tf.Session(config=config).run(barrier)

    logger = logging.getLogger(FLAGS.log_name)
    logger.setLevel(logging.INFO)  # INFO, ERROR
    # file handler which logs debug messages
    # console handler
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    # add formatter to the handlers
    # formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    formatter = logging.Formatter("%(message)s")
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    if not hvd.rank():
        fh = logging.FileHandler(os.path.join(FLAGS.log_dir, FLAGS.log_name))
        fh.setLevel(logging.DEBUG)
        fh.setFormatter(formatter)
        # add handlers to logger
        logger.addHandler(fh)

    height, width = 224, 224
    global_batch_size = FLAGS.batch_size * hvd.size()
    rank0log(logger, "PY" + str(sys.version) + "TF" + str(tf.__version__))
    rank0log(logger, "Horovod size: ", hvd.size())

    if FLAGS.data_dir:
        filename_pattern = os.path.join(FLAGS.data_dir, "%s-*")
        train_filenames = sorted(tf.gfile.Glob(filename_pattern % "train"))
        eval_filenames = sorted(tf.gfile.Glob(filename_pattern % "validation"))
        num_training_samples = get_num_records(train_filenames)
        rank0log(logger, "Using data from: ", FLAGS.data_dir)
        if not FLAGS.eval:
            rank0log(logger, "Found ", num_training_samples, " training samples")
    else:
        if not FLAGS.synthetic:
            raise ValueError(
                "data_dir missing. Please pass --synthetic if you want to run on synthetic data. Else please pass --data_dir"
            )
        train_filenames = eval_filenames = []
        num_training_samples = 1281167
    training_samples_per_rank = num_training_samples // hvd.size()

    if FLAGS.num_epochs:
        nstep = num_training_samples * FLAGS.num_epochs // global_batch_size
    elif FLAGS.num_batches:
        nstep = FLAGS.num_batches
        FLAGS.num_epochs = max(nstep * global_batch_size // num_training_samples, 1)
    else:
        raise ValueError("Either num_epochs or num_batches has to be passed")
    nstep_per_epoch = num_training_samples // global_batch_size
    decay_steps = nstep

    if FLAGS.lr_decay_mode == "steps":
        steps = [int(x) * nstep_per_epoch for x in FLAGS.lr_decay_steps.split(",")]
        lr_steps = [float(x) for x in FLAGS.lr_decay_lrs.split(",")]
    else:
        steps = []
        lr_steps = []

    if not FLAGS.lr:
        if FLAGS.use_larc:
            FLAGS.lr = 3.7
        else:
            FLAGS.lr = (hvd.size() * FLAGS.batch_size * 0.1) / 256
    if not FLAGS.save_checkpoints_steps:
        # default to save one checkpoint per epoch
        FLAGS.save_checkpoints_steps = nstep_per_epoch
    if not FLAGS.save_summary_steps:
        # default to save one checkpoint per epoch
        FLAGS.save_summary_steps = nstep_per_epoch

    if not FLAGS.eval:
        rank0log(logger, "Using a learning rate of ", FLAGS.lr)
        rank0log(logger, "Checkpointing every " + str(FLAGS.save_checkpoints_steps) + " steps")
        rank0log(logger, "Saving summary every " + str(FLAGS.save_summary_steps) + " steps")

    warmup_it = nstep_per_epoch * FLAGS.warmup_epochs

    classifier = tf.estimator.Estimator(
        model_fn=cnn_model_function,
        model_dir=FLAGS.log_dir,
        params={
            "model": FLAGS.model,
            "decay_steps": decay_steps,
            "n_classes": 1000,
            "dtype": tf.float16 if FLAGS.fp16 else tf.float32,
            "format": "channels_first",
            "device": "/gpu:0",
            "lr": FLAGS.lr,
            "mom": FLAGS.mom,
            "wdecay": FLAGS.wdecay,
            "use_larc": FLAGS.use_larc,
            "leta": FLAGS.leta,
            "steps": steps,
            "lr_steps": lr_steps,
            "lr_decay_mode": FLAGS.lr_decay_mode,
            "warmup_it": warmup_it,
            "warmup_lr": FLAGS.warmup_lr,
            "cdr_first_decay_ratio": FLAGS.cdr_first_decay_ratio,
            "cdr_t_mul": FLAGS.cdr_t_mul,
            "cdr_m_mul": FLAGS.cdr_m_mul,
            "cdr_alpha": FLAGS.cdr_alpha,
            "lc_periods": FLAGS.lc_periods,
            "lc_alpha": FLAGS.lc_alpha,
            "lc_beta": FLAGS.lc_beta,
            "loss_scale": FLAGS.loss_scale,
            "adv_bn_init": FLAGS.adv_bn_init,
            "conv_init": tf.variance_scaling_initializer() if FLAGS.adv_conv_init else None,
        },
        config=tf.estimator.RunConfig(
            # tf_random_seed=31 * (1 + hvd.rank()),
            session_config=config,
            save_summary_steps=FLAGS.save_summary_steps if do_checkpoint else None,
            save_checkpoints_steps=FLAGS.save_checkpoints_steps if do_checkpoint else None,
            keep_checkpoint_max=None,
        ),
    )

    if not FLAGS.eval:
        num_preproc_threads = 5
        rank0log(logger, "Using preprocessing threads per GPU: ", num_preproc_threads)
        training_hooks = [hvd.BroadcastGlobalVariablesHook(0), PrefillStagingAreasHook()]
        if hvd.rank() == 0:
            training_hooks.append(
                LogSessionRunHook(
                    global_batch_size, num_training_samples, FLAGS.display_every, logger
                )
            )
        try:
            start_time = time.time()
            classifier.train(
                input_fn=lambda: make_dataset(
                    train_filenames,
                    training_samples_per_rank,
                    FLAGS.batch_size,
                    height,
                    width,
                    FLAGS.brightness,
                    FLAGS.contrast,
                    FLAGS.saturation,
                    FLAGS.hue,
                    training=True,
                    num_threads=num_preproc_threads,
                    shard=True,
                    synthetic=FLAGS.synthetic,
                    increased_aug=FLAGS.increased_aug,
                ),
                max_steps=nstep,
                hooks=training_hooks,
            )
            rank0log(logger, "Finished in ", time.time() - start_time)
        except KeyboardInterrupt:
            print("Keyboard interrupt")
    elif FLAGS.eval and not FLAGS.synthetic:
        rank0log(logger, "Evaluating")
        rank0log(logger, "Validation dataset size: {}".format(get_num_records(eval_filenames)))
        barrier = hvd.allreduce(tf.constant(0, dtype=tf.float32))
        tf.Session(config=config).run(barrier)
        time.sleep(5)  # a little extra margin...
        if FLAGS.num_gpus == 1:
            rank0log(
                logger,