def make_image_classifier()

in tensorflow_hub/tools/make_image_classifier/make_image_classifier_lib.py [0:0]


def make_image_classifier(tfhub_module,
                          image_dir,
                          hparams,
                          distribution_strategy=None,
                          requested_image_size=None,
                          log_dir=None,
                          use_tf_data_input=False):
  """Builds and trains a TensorFLow model for image classification.

  Args:
    tfhub_module: A Python string with the handle of the Hub module.
    image_dir: A Python string naming a directory with subdirectories of images,
      one per class.
    hparams: A HParams object with hyperparameters controlling the training.
    distribution_strategy: The DistributionStrategy make_image_classifier is
      running with.
    requested_image_size: A Python integer controlling the size of images to
      feed into the Hub module. If the module has a fixed input size, this must
      be omitted or set to that same value.
    log_dir: A directory to write logs for TensorBoard into (defaults to None,
      no logs will then be written).
    use_tf_data_input: Whether to read input with a tf.data.Dataset and use TF
      ops for preprocessing.
  """
  augmentation_params = dict(
      rotation_range=hparams.rotation_range,
      horizontal_flip=hparams.horizontal_flip,
      width_shift_range=hparams.width_shift_range,
      height_shift_range=hparams.height_shift_range,
      shear_range=hparams.shear_range,
      zoom_range=hparams.zoom_range)

  with distribution_strategy.scope():
    module_layer = hub.KerasLayer(
        tfhub_module, trainable=hparams.do_fine_tuning)
    image_size = _image_size_for_module(module_layer, requested_image_size)
    print("Using module {} with image size {}".format(tfhub_module, image_size))
    if use_tf_data_input:
      train_data_and_size, valid_data_and_size, labels = _get_data_as_datasets(
          image_dir, image_size, hparams)
    else:
      train_data_and_size, valid_data_and_size, labels = _get_data_with_keras(
          image_dir, image_size, hparams.batch_size, hparams.validation_split,
          hparams.do_data_augmentation, augmentation_params)
    print("Found", len(labels), "classes:", ", ".join(labels))
    model = build_model(module_layer, hparams, image_size, len(labels))
    train_result = train_model(model, hparams, train_data_and_size,
                                valid_data_and_size, log_dir)
  return model, labels, train_result