in tensorflow_examples/lite/model_maker/core/task/image_classifier.py [0:0]
def create(cls,
train_data,
model_spec='efficientnet_lite0',
validation_data=None,
batch_size=None,
epochs=None,
steps_per_epoch=None,
train_whole_model=None,
dropout_rate=None,
learning_rate=None,
momentum=None,
shuffle=False,
use_augmentation=False,
use_hub_library=True,
warmup_steps=None,
model_dir=None,
do_train=True):
"""Loads data and retrains the model based on data for image classification.
Args:
train_data: Training data.
model_spec: Specification for the model.
validation_data: Validation data. If None, skips validation process.
batch_size: Number of samples per training step. If `use_hub_library` is
False, it represents the base learning rate when train batch size is 256
and it's linear to the batch size.
epochs: Number of epochs for training.
steps_per_epoch: Integer or None. Total number of steps (batches of
samples) before declaring one epoch finished and starting the next
epoch. If `steps_per_epoch` is None, the epoch will run until the input
dataset is exhausted.
train_whole_model: If true, the Hub module is trained together with the
classification layer on top. Otherwise, only train the top
classification layer.
dropout_rate: The rate for dropout.
learning_rate: Base learning rate when train batch size is 256. Linear to
the batch size.
momentum: a Python float forwarded to the optimizer. Only used when
`use_hub_library` is True.
shuffle: Whether the data should be shuffled.
use_augmentation: Use data augmentation for preprocessing.
use_hub_library: Use `make_image_classifier_lib` from tensorflow hub to
retrain the model.
warmup_steps: Number of warmup steps for warmup schedule on learning rate.
If None, the default warmup_steps is used which is the total training
steps in two epochs. Only used when `use_hub_library` is False.
model_dir: The location of the model checkpoint files. Only used when
`use_hub_library` is False.
do_train: Whether to run training.
Returns:
An instance based on ImageClassifier.
"""
model_spec = ms.get(model_spec)
if compat.get_tf_behavior() not in model_spec.compat_tf_versions:
raise ValueError('Incompatible versions. Expect {}, but got {}.'.format(
model_spec.compat_tf_versions, compat.get_tf_behavior()))
if use_hub_library:
hparams = get_hub_lib_hparams(
batch_size=batch_size,
train_epochs=epochs,
do_fine_tuning=train_whole_model,
dropout_rate=dropout_rate,
learning_rate=learning_rate,
momentum=momentum)
else:
hparams = train_image_classifier_lib.HParams.get_hparams(
batch_size=batch_size,
train_epochs=epochs,
do_fine_tuning=train_whole_model,
dropout_rate=dropout_rate,
learning_rate=learning_rate,
warmup_steps=warmup_steps,
model_dir=model_dir)
image_classifier = cls(
model_spec,
train_data.index_to_label,
shuffle=shuffle,
hparams=hparams,
use_augmentation=use_augmentation,
representative_data=train_data)
if do_train:
tf.compat.v1.logging.info('Retraining the models...')
image_classifier.train(train_data, validation_data, steps_per_epoch)
else:
# Used in evaluation.
image_classifier.create_model(with_loss_and_metrics=True)
return image_classifier