in benchmarks/horovod-resnet/train_imagenet_resnet_hvd.py [0:0]
def main():
gpu_thread_count = 2
os.environ["TF_GPU_THREAD_MODE"] = "gpu_private"
os.environ["TF_GPU_THREAD_COUNT"] = str(gpu_thread_count)
os.environ["TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT"] = "1"
os.environ["TF_ENABLE_WINOGRAD_NONFUSED"] = "1"
hvd.init()
config = tf.ConfigProto()
config.gpu_options.visible_device_list = str(hvd.local_rank())
config.gpu_options.force_gpu_compatible = True # Force pinned memory
config.intra_op_parallelism_threads = 1 # Avoid pool of Eigen threads
config.inter_op_parallelism_threads = 5
# random.seed(5 * (1 + hvd.rank()))
# np.random.seed(7 * (1 + hvd.rank()))
# tf.set_random_seed(31 * (1 + hvd.rank()))
cmdline = add_cli_args()
FLAGS, unknown_args = cmdline.parse_known_args()
if len(unknown_args) > 0:
for bad_arg in unknown_args:
print("ERROR: Unknown command line arg: %s" % bad_arg)
raise ValueError("Invalid command line arg(s)")
FLAGS.data_dir = None if FLAGS.data_dir == "" else FLAGS.data_dir
FLAGS.log_dir = None if FLAGS.log_dir == "" else FLAGS.log_dir
if FLAGS.eval:
FLAGS.log_name = "eval_" + FLAGS.log_name
if hvd.rank() != 0:
return
if FLAGS.local_ckpt:
do_checkpoint = hvd.local_rank() == 0
else:
do_checkpoint = hvd.rank() == 0
if hvd.local_rank() == 0 and FLAGS.clear_log and os.path.isdir(FLAGS.log_dir):
shutil.rmtree(FLAGS.log_dir)
barrier = hvd.allreduce(tf.constant(0, dtype=tf.float32))
tf.Session(config=config).run(barrier)
if hvd.local_rank() == 0 and not os.path.isdir(FLAGS.log_dir):
os.makedirs(FLAGS.log_dir)
barrier = hvd.allreduce(tf.constant(0, dtype=tf.float32))
tf.Session(config=config).run(barrier)
logger = logging.getLogger(FLAGS.log_name)
logger.setLevel(logging.INFO) # INFO, ERROR
# file handler which logs debug messages
# console handler
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
# add formatter to the handlers
# formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
formatter = logging.Formatter("%(message)s")
ch.setFormatter(formatter)
logger.addHandler(ch)
if not hvd.rank():
fh = logging.FileHandler(os.path.join(FLAGS.log_dir, FLAGS.log_name))
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
# add handlers to logger
logger.addHandler(fh)
height, width = 224, 224
global_batch_size = FLAGS.batch_size * hvd.size()
rank0log(logger, "PY" + str(sys.version) + "TF" + str(tf.__version__))
rank0log(logger, "Horovod size: ", hvd.size())
if FLAGS.data_dir:
filename_pattern = os.path.join(FLAGS.data_dir, "%s-*")
train_filenames = sorted(tf.gfile.Glob(filename_pattern % "train"))
eval_filenames = sorted(tf.gfile.Glob(filename_pattern % "validation"))
num_training_samples = get_num_records(train_filenames)
rank0log(logger, "Using data from: ", FLAGS.data_dir)
if not FLAGS.eval:
rank0log(logger, "Found ", num_training_samples, " training samples")
else:
if not FLAGS.synthetic:
raise ValueError(
"data_dir missing. Please pass --synthetic if you want to run on synthetic data. Else please pass --data_dir"
)
train_filenames = eval_filenames = []
num_training_samples = 1281167
training_samples_per_rank = num_training_samples // hvd.size()
if FLAGS.num_epochs:
nstep = num_training_samples * FLAGS.num_epochs // global_batch_size
elif FLAGS.num_batches:
nstep = FLAGS.num_batches
FLAGS.num_epochs = max(nstep * global_batch_size // num_training_samples, 1)
else:
raise ValueError("Either num_epochs or num_batches has to be passed")
nstep_per_epoch = num_training_samples // global_batch_size
decay_steps = nstep
if FLAGS.lr_decay_mode == "steps":
steps = [int(x) * nstep_per_epoch for x in FLAGS.lr_decay_steps.split(",")]
lr_steps = [float(x) for x in FLAGS.lr_decay_lrs.split(",")]
else:
steps = []
lr_steps = []
if not FLAGS.lr:
if FLAGS.use_larc:
FLAGS.lr = 3.7
else:
FLAGS.lr = (hvd.size() * FLAGS.batch_size * 0.1) / 256
if not FLAGS.save_checkpoints_steps:
# default to save one checkpoint per epoch
FLAGS.save_checkpoints_steps = nstep_per_epoch
if not FLAGS.save_summary_steps:
# default to save one checkpoint per epoch
FLAGS.save_summary_steps = nstep_per_epoch
if not FLAGS.eval:
rank0log(logger, "Using a learning rate of ", FLAGS.lr)
rank0log(logger, "Checkpointing every " + str(FLAGS.save_checkpoints_steps) + " steps")
rank0log(logger, "Saving summary every " + str(FLAGS.save_summary_steps) + " steps")
warmup_it = nstep_per_epoch * FLAGS.warmup_epochs
classifier = tf.estimator.Estimator(
model_fn=cnn_model_function,
model_dir=FLAGS.log_dir,
params={
"model": FLAGS.model,
"decay_steps": decay_steps,
"n_classes": 1000,
"dtype": tf.float16 if FLAGS.fp16 else tf.float32,
"format": "channels_first",
"device": "/gpu:0",
"lr": FLAGS.lr,
"mom": FLAGS.mom,
"wdecay": FLAGS.wdecay,
"use_larc": FLAGS.use_larc,
"leta": FLAGS.leta,
"steps": steps,
"lr_steps": lr_steps,
"lr_decay_mode": FLAGS.lr_decay_mode,
"warmup_it": warmup_it,
"warmup_lr": FLAGS.warmup_lr,
"cdr_first_decay_ratio": FLAGS.cdr_first_decay_ratio,
"cdr_t_mul": FLAGS.cdr_t_mul,
"cdr_m_mul": FLAGS.cdr_m_mul,
"cdr_alpha": FLAGS.cdr_alpha,
"lc_periods": FLAGS.lc_periods,
"lc_alpha": FLAGS.lc_alpha,
"lc_beta": FLAGS.lc_beta,
"loss_scale": FLAGS.loss_scale,
"adv_bn_init": FLAGS.adv_bn_init,
"conv_init": tf.variance_scaling_initializer() if FLAGS.adv_conv_init else None,
},
config=tf.estimator.RunConfig(
# tf_random_seed=31 * (1 + hvd.rank()),
session_config=config,
save_summary_steps=FLAGS.save_summary_steps if do_checkpoint else None,
save_checkpoints_steps=FLAGS.save_checkpoints_steps if do_checkpoint else None,
keep_checkpoint_max=None,
),
)
if not FLAGS.eval:
num_preproc_threads = 5
rank0log(logger, "Using preprocessing threads per GPU: ", num_preproc_threads)
training_hooks = [hvd.BroadcastGlobalVariablesHook(0), PrefillStagingAreasHook()]
if hvd.rank() == 0:
training_hooks.append(
LogSessionRunHook(
global_batch_size, num_training_samples, FLAGS.display_every, logger
)
)
try:
start_time = time.time()
classifier.train(
input_fn=lambda: make_dataset(
train_filenames,
training_samples_per_rank,
FLAGS.batch_size,
height,
width,
FLAGS.brightness,
FLAGS.contrast,
FLAGS.saturation,
FLAGS.hue,
training=True,
num_threads=num_preproc_threads,
shard=True,
synthetic=FLAGS.synthetic,
increased_aug=FLAGS.increased_aug,
),
max_steps=nstep,
hooks=training_hooks,
)
rank0log(logger, "Finished in ", time.time() - start_time)
except KeyboardInterrupt:
print("Keyboard interrupt")
elif FLAGS.eval and not FLAGS.synthetic:
rank0log(logger, "Evaluating")
rank0log(logger, "Validation dataset size: {}".format(get_num_records(eval_filenames)))
barrier = hvd.allreduce(tf.constant(0, dtype=tf.float32))
tf.Session(config=config).run(barrier)
time.sleep(5) # a little extra margin...
if FLAGS.num_gpus == 1:
rank0log(
logger,