in tensorflow_hub/tools/make_image_classifier/make_image_classifier_lib.py [0:0]
def make_image_classifier(tfhub_module,
image_dir,
hparams,
distribution_strategy=None,
requested_image_size=None,
log_dir=None,
use_tf_data_input=False):
"""Builds and trains a TensorFLow model for image classification.
Args:
tfhub_module: A Python string with the handle of the Hub module.
image_dir: A Python string naming a directory with subdirectories of images,
one per class.
hparams: A HParams object with hyperparameters controlling the training.
distribution_strategy: The DistributionStrategy make_image_classifier is
running with.
requested_image_size: A Python integer controlling the size of images to
feed into the Hub module. If the module has a fixed input size, this must
be omitted or set to that same value.
log_dir: A directory to write logs for TensorBoard into (defaults to None,
no logs will then be written).
use_tf_data_input: Whether to read input with a tf.data.Dataset and use TF
ops for preprocessing.
"""
augmentation_params = dict(
rotation_range=hparams.rotation_range,
horizontal_flip=hparams.horizontal_flip,
width_shift_range=hparams.width_shift_range,
height_shift_range=hparams.height_shift_range,
shear_range=hparams.shear_range,
zoom_range=hparams.zoom_range)
with distribution_strategy.scope():
module_layer = hub.KerasLayer(
tfhub_module, trainable=hparams.do_fine_tuning)
image_size = _image_size_for_module(module_layer, requested_image_size)
print("Using module {} with image size {}".format(tfhub_module, image_size))
if use_tf_data_input:
train_data_and_size, valid_data_and_size, labels = _get_data_as_datasets(
image_dir, image_size, hparams)
else:
train_data_and_size, valid_data_and_size, labels = _get_data_with_keras(
image_dir, image_size, hparams.batch_size, hparams.validation_split,
hparams.do_data_augmentation, augmentation_params)
print("Found", len(labels), "classes:", ", ".join(labels))
model = build_model(module_layer, hparams, image_size, len(labels))
train_result = train_model(model, hparams, train_data_and_size,
valid_data_and_size, log_dir)
return model, labels, train_result