in tensorflow_hub/tools/make_image_classifier/make_image_classifier.py [0:0]
def main(args):
"""Main function to be called by absl.app.run() after flag parsing."""
del args
_check_keras_dependencies()
hparams = _get_hparams_from_flags()
image_dir = FLAGS.image_dir or lib.get_default_image_dir()
if FLAGS.set_memory_growth:
_set_gpu_memory_growth()
use_tf_data_input = FLAGS.use_tf_data_input
# For tensorflow<2.5 TF preprocessing layers do not support distribution
# strategy. so default use_tf_data_input to True for TF >= 2.5.
if use_tf_data_input is True and (LooseVersion(tf.__version__) <
LooseVersion("2.5.0")):
raise ValueError("use_tf_data_input is not supported for tensorflow<2.5")
# For tensorflow>=2.5 default to using tf.data.Dataset and TF preprocessing
# layers.
if use_tf_data_input is None and (LooseVersion(tf.__version__) >=
LooseVersion("2.5.0")):
use_tf_data_input = True
model, labels, train_result = lib.make_image_classifier(
FLAGS.tfhub_module, image_dir, hparams,
lib.get_distribution_strategy(FLAGS.distribution_strategy),
FLAGS.image_size, FLAGS.summaries_dir, use_tf_data_input)
if FLAGS.assert_accuracy_at_least:
_assert_accuracy(train_result, FLAGS.assert_accuracy_at_least)
print("Done with training.")
if FLAGS.labels_output_file:
with tf.io.gfile.GFile(FLAGS.labels_output_file, "w") as f:
f.write("\n".join(labels + ("",)))
print("Labels written to", FLAGS.labels_output_file)
saved_model_dir = FLAGS.saved_model_dir
if FLAGS.tflite_output_file and not saved_model_dir:
# We need a SavedModel for conversion, even if the user did not request it.
saved_model_dir = tempfile.mkdtemp()
if saved_model_dir:
tf.saved_model.save(model, saved_model_dir)
print("SavedModel model exported to", saved_model_dir)
if FLAGS.tflite_output_file:
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
lite_model_content = converter.convert()
with tf.io.gfile.GFile(FLAGS.tflite_output_file, "wb") as f:
f.write(lite_model_content)
print("TFLite model exported to", FLAGS.tflite_output_file)