def main()

in legacy/models/resnet/tensorflow/train_imagenet_resnet_hvd.py [0:0]


def main():
    gpu_thread_count = 2
    os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private'
    os.environ['TF_GPU_THREAD_COUNT'] = str(gpu_thread_count)
    os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'
    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
    hvd.init()

    # random.seed(5 * (1 + hvd.rank()))
    # np.random.seed(7 * (1 + hvd.rank()))
    # tf.set_random_seed(31 * (1 + hvd.rank()))

    cmdline = add_cli_args()
    FLAGS, unknown_args = cmdline.parse_known_args()
    if len(unknown_args) > 0:
        for bad_arg in unknown_args:
            print("ERROR: Unknown command line arg: %s" % bad_arg)
        raise ValueError("Invalid command line arg(s)")


    config = tf.ConfigProto()
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    config.gpu_options.force_gpu_compatible = True  # Force pinned memory
    config.intra_op_parallelism_threads = FLAGS.intra_op_parallelism_threads
    config.inter_op_parallelism_threads = FLAGS.inter_op_parallelism_threads

    FLAGS.data_dir = None if FLAGS.data_dir == "" else FLAGS.data_dir
    FLAGS.log_dir = None if FLAGS.log_dir == "" else FLAGS.log_dir

    if FLAGS.eval:
        FLAGS.log_name = 'eval_' + FLAGS.log_name
    if FLAGS.local_ckpt:
        do_checkpoint = hvd.local_rank() == 0
    else:
        do_checkpoint = hvd.rank() == 0
    if hvd.local_rank() == 0 and FLAGS.clear_log and os.path.isdir(FLAGS.log_dir):
        shutil.rmtree(FLAGS.log_dir)
    barrier = hvd.allreduce(tf.constant(0, dtype=tf.float32))
    tf.Session(config=config).run(barrier)

    if hvd.local_rank() == 0 and not os.path.isdir(FLAGS.log_dir):
        os.makedirs(FLAGS.log_dir)
    barrier = hvd.allreduce(tf.constant(0, dtype=tf.float32))
    tf.Session(config=config).run(barrier)
    
    logger = logging.getLogger(FLAGS.log_name)
    logger.setLevel(logging.INFO)  # INFO, ERROR
    # file handler which logs debug messages
    # console handler
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    # add formatter to the handlers
    # formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    formatter = logging.Formatter('%(message)s')
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    if not hvd.rank():
        fh = logging.FileHandler(os.path.join(FLAGS.log_dir, FLAGS.log_name))
        fh.setLevel(logging.DEBUG)
        fh.setFormatter(formatter)
        # add handlers to logger
        logger.addHandler(fh)
    
    height, width = 224, 224
    global_batch_size = FLAGS.batch_size * hvd.size()
    rank0log(logger, 'PY' + str(sys.version) + 'TF' + str(tf.__version__))
    rank0log(logger, "Horovod size: ", hvd.size())

    if FLAGS.data_dir:
        filename_pattern = os.path.join(FLAGS.data_dir, '%s-*')
        train_filenames = sorted(tf.gfile.Glob(filename_pattern % 'train'))
        eval_filenames = sorted(tf.gfile.Glob(filename_pattern % 'validation'))
        num_training_samples = get_num_records(train_filenames)
        rank0log(logger, "Using data from: ", FLAGS.data_dir)
        if not FLAGS.eval:
            rank0log(logger, 'Found ', num_training_samples, ' training samples')
    else:
        if not FLAGS.synthetic:
            raise ValueError('data_dir missing. Please pass --synthetic if you want to run on synthetic data. Else please pass --data_dir')
        train_filenames = eval_filenames = []
        num_training_samples = 1281167
    training_samples_per_rank = num_training_samples // hvd.size()

    if FLAGS.num_epochs:
        nstep = num_training_samples * FLAGS.num_epochs // global_batch_size
    elif FLAGS.num_batches:
        nstep = FLAGS.num_batches
        FLAGS.num_epochs = max(nstep * global_batch_size // num_training_samples, 1)
    else:
        raise ValueError("Either num_epochs or num_batches has to be passed")
    nstep_per_epoch = num_training_samples // global_batch_size
    decay_steps = nstep

    if FLAGS.lr_decay_mode == 'steps':
        steps = [int(x) * nstep_per_epoch for x in FLAGS.lr_decay_steps.split(',')]
        lr_steps = [float(x) for x in FLAGS.lr_decay_lrs.split(',')]
    else:
        steps = []
        lr_steps = []

    if not FLAGS.lr:
        if FLAGS.use_larc:
            FLAGS.lr = 3.7
        else:
            FLAGS.lr = (hvd.size() * FLAGS.batch_size * 0.1) / 256
    if not FLAGS.save_checkpoints_steps:
        # default to save one checkpoint per epoch
        FLAGS.save_checkpoints_steps = nstep_per_epoch
    if not FLAGS.save_summary_steps:
        # default to save one checkpoint per epoch
        FLAGS.save_summary_steps = nstep_per_epoch
    
    if not FLAGS.eval:
        rank0log(logger, 'Using a learning rate of ', FLAGS.lr)
        rank0log(logger, 'Checkpointing every ' + str(FLAGS.save_checkpoints_steps) + ' steps')
        rank0log(logger, 'Saving summary every ' + str(FLAGS.save_summary_steps) + ' steps')

    warmup_it = nstep_per_epoch * FLAGS.warmup_epochs

    classifier = tf.estimator.Estimator(
        model_fn=cnn_model_function,
        model_dir=FLAGS.log_dir,
        params={
            'model': FLAGS.model,
            'decay_steps': decay_steps,
            'n_classes': 1000,
            'dtype': tf.float16 if FLAGS.fp16 else tf.float32,
            'format': 'channels_first',
            'device': '/gpu:0',
            'lr': FLAGS.lr,
            'mom': FLAGS.mom,
            'wdecay': FLAGS.wdecay,
            'use_larc': FLAGS.use_larc,
            'leta': FLAGS.leta,
            'steps': steps,
            'lr_steps': lr_steps,
            'lr_decay_mode': FLAGS.lr_decay_mode,
            'warmup_it': warmup_it,
            'warmup_lr': FLAGS.warmup_lr,
            'cdr_first_decay_ratio': FLAGS.cdr_first_decay_ratio,
            'cdr_t_mul': FLAGS.cdr_t_mul,
            'cdr_m_mul': FLAGS.cdr_m_mul,
            'cdr_alpha': FLAGS.cdr_alpha,
            'lc_periods': FLAGS.lc_periods,
            'lc_alpha': FLAGS.lc_alpha,
            'lc_beta': FLAGS.lc_beta,
            'loss_scale': FLAGS.loss_scale,
            'adv_bn_init': FLAGS.adv_bn_init,
            'conv_init': tf.variance_scaling_initializer() if FLAGS.adv_conv_init else None
        },
        config=tf.estimator.RunConfig(
            # tf_random_seed=31 * (1 + hvd.rank()),
            session_config=config,
            save_summary_steps=FLAGS.save_summary_steps if do_checkpoint else None,
            save_checkpoints_steps=FLAGS.save_checkpoints_steps if do_checkpoint else None,
            keep_checkpoint_max=None))

    if not FLAGS.eval:
        num_preproc_threads = FLAGS.num_parallel_calls
        rank0log(logger, "Using preprocessing threads per GPU: ", num_preproc_threads)
        training_hooks = [hvd.BroadcastGlobalVariablesHook(0),
                          PrefillStagingAreasHook()]
        if hvd.rank() == 0:
            training_hooks.append(
                LogSessionRunHook(global_batch_size,
                                  num_training_samples,
                                  FLAGS.display_every, logger))
        try:
            start_time = time.time()
            classifier.train(
                input_fn=lambda: make_dataset(
                    train_filenames,
                    training_samples_per_rank,
                    FLAGS.batch_size, height, width, 
                    FLAGS.brightness, FLAGS.contrast, FLAGS.saturation, FLAGS.hue, 
                    training=True, num_threads=num_preproc_threads, 
                    shard=True, synthetic=FLAGS.synthetic, increased_aug=FLAGS.increased_aug),
                max_steps=nstep,
                hooks=training_hooks)
            rank0log(logger, "Finished in ", time.time() - start_time)
        except KeyboardInterrupt:
            print("Keyboard interrupt")
    elif FLAGS.eval and not FLAGS.synthetic:
        rank0log(logger, "Evaluating")
        rank0log(logger, "Validation dataset size: {}".format(get_num_records(eval_filenames)))
        barrier = hvd.allreduce(tf.constant(0, dtype=tf.float32))
        tf.Session(config=config).run(barrier)
        time.sleep(5)  # a little extra margin...
        if FLAGS.num_gpus == 1:
            rank0log(logger, """If you are evaluating checkpoints of a multi-GPU run on a single GPU,