in models/official/efficientnet/main.py [0:0]
def main(unused_argv):
input_image_size = FLAGS.input_image_size
if not input_image_size:
input_image_size = model_builder_factory.get_model_input_size(
FLAGS.model_name)
if FLAGS.holdout_shards:
holdout_images = int(FLAGS.num_train_images * FLAGS.holdout_shards / 1024.0)
FLAGS.num_train_images -= holdout_images
if FLAGS.eval_name and 'test' in FLAGS.eval_name:
FLAGS.holdout_shards = None # do not use holdout if eval test set.
else:
FLAGS.num_eval_images = holdout_images
# For imagenet dataset, include background label if number of output classes
# is 1001
include_background_label = (FLAGS.num_label_classes == 1001)
if FLAGS.tpu or FLAGS.use_tpu:
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
FLAGS.tpu,
zone=FLAGS.tpu_zone,
project=FLAGS.gcp_project)
else:
tpu_cluster_resolver = None
if FLAGS.use_async_checkpointing:
save_checkpoints_steps = None
else:
save_checkpoints_steps = max(100, FLAGS.iterations_per_loop)
config = tf.estimator.tpu.RunConfig(
cluster=tpu_cluster_resolver,
model_dir=FLAGS.model_dir,
save_checkpoints_steps=save_checkpoints_steps,
log_step_count_steps=FLAGS.log_step_count_steps,
session_config=tf.ConfigProto(
graph_options=tf.GraphOptions(
rewrite_options=rewriter_config_pb2.RewriterConfig(
disable_meta_optimizer=True))),
tpu_config=tf.estimator.tpu.TPUConfig(
iterations_per_loop=FLAGS.iterations_per_loop,
tpu_job_name=FLAGS.tpu_job_name,
per_host_input_for_training=tf.estimator.tpu.InputPipelineConfig
.PER_HOST_V2)) # pylint: disable=line-too-long
# Initializes model parameters.
params = dict(
steps_per_epoch=FLAGS.num_train_images / FLAGS.train_batch_size,
use_bfloat16=FLAGS.use_bfloat16)
est = tf.estimator.tpu.TPUEstimator(
use_tpu=FLAGS.use_tpu,
model_fn=model_fn,
config=config,
train_batch_size=FLAGS.train_batch_size,
eval_batch_size=FLAGS.eval_batch_size,
export_to_tpu=FLAGS.export_to_tpu,
params=params)
if (FLAGS.model_name.startswith('efficientnet-lite') or
FLAGS.model_name.startswith('efficientnet-edgetpu')):
# lite or edgetpu use binlinear for easier post-quantization.
resize_method = tf.image.ResizeMethod.BILINEAR
else:
resize_method = None
# Input pipelines are slightly different (with regards to shuffling and
# preprocessing) between training and evaluation.
def build_imagenet_input(is_training):
"""Generate ImageNetInput for training and eval."""
if FLAGS.bigtable_instance:
logging.info('Using Bigtable dataset, table %s', FLAGS.bigtable_table)
select_train, select_eval = _select_tables_from_flags()
return imagenet_input.ImageNetBigtableInput(
is_training=is_training,
use_bfloat16=FLAGS.use_bfloat16,
transpose_input=FLAGS.transpose_input,
selection=select_train if is_training else select_eval,
num_label_classes=FLAGS.num_label_classes,
include_background_label=include_background_label,
augment_name=FLAGS.augment_name,
mixup_alpha=FLAGS.mixup_alpha,
randaug_num_layers=FLAGS.randaug_num_layers,
randaug_magnitude=FLAGS.randaug_magnitude,
resize_method=resize_method)
else:
if FLAGS.data_dir == FAKE_DATA_DIR:
logging.info('Using fake dataset.')
else:
logging.info('Using dataset: %s', FLAGS.data_dir)
return imagenet_input.ImageNetInput(
is_training=is_training,
data_dir=FLAGS.data_dir,
transpose_input=FLAGS.transpose_input,
cache=FLAGS.use_cache and is_training,
image_size=input_image_size,
num_parallel_calls=FLAGS.num_parallel_calls,
use_bfloat16=FLAGS.use_bfloat16,
num_label_classes=FLAGS.num_label_classes,
include_background_label=include_background_label,
augment_name=FLAGS.augment_name,
mixup_alpha=FLAGS.mixup_alpha,
randaug_num_layers=FLAGS.randaug_num_layers,
randaug_magnitude=FLAGS.randaug_magnitude,
resize_method=resize_method,
holdout_shards=FLAGS.holdout_shards)
imagenet_train = build_imagenet_input(is_training=True)
imagenet_eval = build_imagenet_input(is_training=False)
if FLAGS.mode == 'eval':
eval_steps = FLAGS.num_eval_images // FLAGS.eval_batch_size
# Run evaluation when there's a new checkpoint
for ckpt in tf.train.checkpoints_iterator(
FLAGS.model_dir, timeout=FLAGS.eval_timeout):
logging.info('Starting to evaluate.')
try:
start_timestamp = time.time() # This time will include compilation time
eval_results = est.evaluate(
input_fn=imagenet_eval.input_fn,
steps=eval_steps,
checkpoint_path=ckpt,
name=FLAGS.eval_name)
elapsed_time = int(time.time() - start_timestamp)
logging.info('Eval results: %s. Elapsed seconds: %d',
eval_results, elapsed_time)
if FLAGS.archive_ckpt:
utils.archive_ckpt(eval_results, eval_results['top_1_accuracy'], ckpt)
# Terminate eval job when final checkpoint is reached
try:
current_step = int(os.path.basename(ckpt).split('-')[1])
except IndexError:
logging.info('%s has no global step info: stop!', ckpt)
break
if current_step >= FLAGS.train_steps:
logging.info(
'Evaluation finished after training step %d', current_step)
break
except tf.errors.NotFoundError:
# Since the coordinator is on a different job than the TPU worker,
# sometimes the TPU worker does not finish initializing until long after
# the CPU job tells it to start evaluating. In this case, the checkpoint
# file could have been deleted already.
logging.info(
'Checkpoint %s no longer exists, skipping checkpoint', ckpt)
else: # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
current_step = estimator._load_global_step_from_checkpoint_dir(FLAGS.model_dir) # pylint: disable=protected-access,line-too-long
logging.info(
'Training for %d steps (%.2f epochs in total). Current'
' step %d.', FLAGS.train_steps,
FLAGS.train_steps / params['steps_per_epoch'], current_step)
start_timestamp = time.time() # This time will include compilation time
if FLAGS.mode == 'train':
hooks = []
if FLAGS.use_async_checkpointing:
try:
from tensorflow.contrib.tpu.python.tpu import async_checkpoint # pylint: disable=g-import-not-at-top
except ImportError as e:
logging.exception(
'Async checkpointing is not supported in TensorFlow 2.x')
raise e
hooks.append(
async_checkpoint.AsyncCheckpointSaverHook(
checkpoint_dir=FLAGS.model_dir,
save_steps=max(100, FLAGS.iterations_per_loop)))
est.train(
input_fn=imagenet_train.input_fn,
max_steps=FLAGS.train_steps,
hooks=hooks)
else:
assert FLAGS.mode == 'train_and_eval'
while current_step < FLAGS.train_steps:
# Train for up to steps_per_eval number of steps.
# At the end of training, a checkpoint will be written to --model_dir.
next_checkpoint = min(current_step + FLAGS.steps_per_eval,
FLAGS.train_steps)
est.train(input_fn=imagenet_train.input_fn, max_steps=next_checkpoint)
current_step = next_checkpoint
logging.info('Finished training up to step %d. Elapsed seconds %d.',
next_checkpoint, int(time.time() - start_timestamp))
# Evaluate the model on the most recent model in --model_dir.
# Since evaluation happens in batches of --eval_batch_size, some images
# may be excluded modulo the batch size. As long as the batch size is
# consistent, the evaluated images are also consistent.
logging.info('Starting to evaluate.')
eval_results = est.evaluate(
input_fn=imagenet_eval.input_fn,
steps=FLAGS.num_eval_images // FLAGS.eval_batch_size,
name=FLAGS.eval_name)
logging.info('Eval results at step %d: %s',
next_checkpoint, eval_results)
ckpt = tf.train.latest_checkpoint(FLAGS.model_dir)
if FLAGS.archive_ckpt:
utils.archive_ckpt(eval_results, eval_results['top_1_accuracy'], ckpt)
elapsed_time = int(time.time() - start_timestamp)
logging.info('Finished training up to step %d. Elapsed seconds %d.',
FLAGS.train_steps, elapsed_time)
if FLAGS.export_dir:
export(est, FLAGS.export_dir, input_image_size)