in models/official/resnet/benchmark/resnet_benchmark.py [0:0]
def main(unused_argv):
tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver(
FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
config = contrib_tpu.RunConfig(
cluster=tpu_cluster_resolver,
model_dir=FLAGS.model_dir,
save_checkpoints_steps=FLAGS.iterations_per_loop,
keep_checkpoint_max=None,
tpu_config=contrib_tpu.TPUConfig(
iterations_per_loop=FLAGS.iterations_per_loop,
num_shards=FLAGS.num_cores,
per_host_input_for_training=contrib_tpu.InputPipelineConfig.PER_HOST_V2)) # pylint: disable=line-too-long
# Input pipelines are slightly different (with regards to shuffling and
# preprocessing) between training and evaluation.
imagenet_train = imagenet_input.ImageNetInput(
is_training=True,
data_dir=FLAGS.data_dir,
use_bfloat16=True,
transpose_input=FLAGS.transpose_input)
imagenet_eval = imagenet_input.ImageNetInput(
is_training=False,
data_dir=FLAGS.data_dir,
use_bfloat16=True,
transpose_input=FLAGS.transpose_input)
if FLAGS.use_fast_lr:
resnet_main.LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(1.0, 4), (0.1, 21), (0.01, 35), (0.001, 43)
]
imagenet_train_small = imagenet_input.ImageNetInput(
is_training=True,
image_size=128,
data_dir=FLAGS.data_dir_small,
num_parallel_calls=FLAGS.num_parallel_calls,
use_bfloat16=True,
transpose_input=FLAGS.transpose_input,
cache=True)
imagenet_eval_small = imagenet_input.ImageNetInput(
is_training=False,
image_size=128,
data_dir=FLAGS.data_dir_small,
num_parallel_calls=FLAGS.num_parallel_calls,
use_bfloat16=True,
transpose_input=FLAGS.transpose_input,
cache=True)
imagenet_train_large = imagenet_input.ImageNetInput(
is_training=True,
image_size=288,
data_dir=FLAGS.data_dir,
num_parallel_calls=FLAGS.num_parallel_calls,
use_bfloat16=True,
transpose_input=FLAGS.transpose_input)
imagenet_eval_large = imagenet_input.ImageNetInput(
is_training=False,
image_size=288,
data_dir=FLAGS.data_dir,
num_parallel_calls=FLAGS.num_parallel_calls,
use_bfloat16=True,
transpose_input=FLAGS.transpose_input)
resnet_classifier = contrib_tpu.TPUEstimator(
use_tpu=FLAGS.use_tpu,
model_fn=resnet_main.resnet_model_fn,
config=config,
train_batch_size=FLAGS.train_batch_size,
eval_batch_size=FLAGS.eval_batch_size)
if FLAGS.mode == 'train':
current_step = estimator._load_global_step_from_checkpoint_dir(FLAGS.model_dir) # pylint: disable=protected-access,line-too-long
batches_per_epoch = NUM_TRAIN_IMAGES / FLAGS.train_batch_size
tf.logging.info('Training for %d steps (%.2f epochs in total). Current'
' step %d.' % (FLAGS.train_steps,
FLAGS.train_steps / batches_per_epoch,
current_step))
start_timestamp = time.time() # This time will include compilation time
# Write a dummy file at the start of training so that we can measure the
# runtime at each checkpoint from the file write time.
tf.gfile.MkDir(FLAGS.model_dir)
if not tf.gfile.Exists(os.path.join(FLAGS.model_dir, 'START')):
with tf.gfile.GFile(os.path.join(FLAGS.model_dir, 'START'), 'w') as f:
f.write(str(start_timestamp))
if FLAGS.use_fast_lr:
small_steps = int(18 * NUM_TRAIN_IMAGES / FLAGS.train_batch_size)
normal_steps = int(41 * NUM_TRAIN_IMAGES / FLAGS.train_batch_size)
large_steps = int(min(50 * NUM_TRAIN_IMAGES / FLAGS.train_batch_size,
FLAGS.train_steps))
resnet_classifier.train(
input_fn=imagenet_train_small.input_fn, max_steps=small_steps)
resnet_classifier.train(
input_fn=imagenet_train.input_fn, max_steps=normal_steps)
resnet_classifier.train(
input_fn=imagenet_train_large.input_fn,
max_steps=large_steps)
else:
resnet_classifier.train(
input_fn=imagenet_train.input_fn, max_steps=FLAGS.train_steps)
else:
assert FLAGS.mode == 'eval'
start_timestamp = tf.gfile.Stat(
os.path.join(FLAGS.model_dir, 'START')).mtime_nsec
results = []
eval_steps = NUM_EVAL_IMAGES // FLAGS.eval_batch_size
ckpt_steps = set()
all_files = tf.gfile.ListDirectory(FLAGS.model_dir)
for f in all_files:
mat = re.match(CKPT_PATTERN, f)
if mat is not None:
ckpt_steps.add(int(mat.group('gs')))
ckpt_steps = sorted(list(ckpt_steps))
tf.logging.info('Steps to be evaluated: %s' % str(ckpt_steps))
for step in ckpt_steps:
ckpt = os.path.join(FLAGS.model_dir, 'model.ckpt-%d' % step)
batches_per_epoch = NUM_TRAIN_IMAGES // FLAGS.train_batch_size
current_epoch = step // batches_per_epoch
if FLAGS.use_fast_lr:
if current_epoch < 18:
eval_input_fn = imagenet_eval_small.input_fn
if current_epoch >= 18 and current_epoch < 41:
eval_input_fn = imagenet_eval.input_fn
if current_epoch >= 41: # 41:
eval_input_fn = imagenet_eval_large.input_fn
else:
eval_input_fn = imagenet_eval.input_fn
end_timestamp = tf.gfile.Stat(ckpt + '.index').mtime_nsec
elapsed_hours = (end_timestamp - start_timestamp) / (1e9 * 3600.0)
tf.logging.info('Starting to evaluate.')
eval_start = time.time() # This time will include compilation time
eval_results = resnet_classifier.evaluate(
input_fn=eval_input_fn,
steps=eval_steps,
checkpoint_path=ckpt)
eval_time = int(time.time() - eval_start)
tf.logging.info('Eval results: %s. Elapsed seconds: %d' %
(eval_results, eval_time))
results.append([
current_epoch,
elapsed_hours,
'%.2f' % (eval_results['top_1_accuracy'] * 100),
'%.2f' % (eval_results['top_5_accuracy'] * 100),
])
time.sleep(60)
with tf.gfile.GFile(os.path.join(FLAGS.model_dir, 'results.tsv'), 'wb') as tsv_file: # pylint: disable=line-too-long
writer = csv.writer(tsv_file, delimiter='\t')
writer.writerow(['epoch', 'hours', 'top1Accuracy', 'top5Accuracy'])
writer.writerows(results)