def main()

in tensorflow_hub/tools/make_image_classifier/make_image_classifier.py [0:0]


def main(args):
  """Main function to be called by absl.app.run() after flag parsing."""
  del args
  _check_keras_dependencies()
  hparams = _get_hparams_from_flags()

  image_dir = FLAGS.image_dir or lib.get_default_image_dir()

  if FLAGS.set_memory_growth:
    _set_gpu_memory_growth()

  use_tf_data_input = FLAGS.use_tf_data_input
  # For tensorflow<2.5 TF preprocessing layers do not support distribution
  # strategy. so default use_tf_data_input to True for TF >= 2.5.
  if use_tf_data_input is True and (LooseVersion(tf.__version__) <
                                    LooseVersion("2.5.0")):
    raise ValueError("use_tf_data_input is not supported for tensorflow<2.5")
  # For tensorflow>=2.5 default to using tf.data.Dataset and TF preprocessing
  # layers.
  if use_tf_data_input is None and (LooseVersion(tf.__version__) >=
                                    LooseVersion("2.5.0")):
    use_tf_data_input = True

  model, labels, train_result = lib.make_image_classifier(
      FLAGS.tfhub_module, image_dir, hparams,
      lib.get_distribution_strategy(FLAGS.distribution_strategy),
      FLAGS.image_size, FLAGS.summaries_dir, use_tf_data_input)
  if FLAGS.assert_accuracy_at_least:
    _assert_accuracy(train_result, FLAGS.assert_accuracy_at_least)
  print("Done with training.")

  if FLAGS.labels_output_file:
    with tf.io.gfile.GFile(FLAGS.labels_output_file, "w") as f:
      f.write("\n".join(labels + ("",)))
    print("Labels written to", FLAGS.labels_output_file)

  saved_model_dir = FLAGS.saved_model_dir
  if FLAGS.tflite_output_file and not saved_model_dir:
    # We need a SavedModel for conversion, even if the user did not request it.
    saved_model_dir = tempfile.mkdtemp()
  if saved_model_dir:
    tf.saved_model.save(model, saved_model_dir)
    print("SavedModel model exported to", saved_model_dir)

  if FLAGS.tflite_output_file:
    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
    lite_model_content = converter.convert()
    with tf.io.gfile.GFile(FLAGS.tflite_output_file, "wb") as f:
      f.write(lite_model_content)
    print("TFLite model exported to", FLAGS.tflite_output_file)