in src/python/tensorflow_cloud/core/experimental/models.py [0:0]
def classifier_trainer(dataset_name, model_name, batch_size, epochs,
train_split, validation_split, one_hot, model_dirs):
"""Training loop for image classifier from TF model Garden using TFDS."""
builder = tfds.builder(dataset_name)
num_classes = builder.info.features['label'].num_classes
model = get_model(model_name, batch_size, num_classes)
if model_name == 'resnet':
image_size = 224
width_ratio = 1
else: # Assumes model_name is an efficientnet version
image_size = model.config.resolution
width_ratio = model.config.width_coefficient
train_ds, validation_ds = load_data_from_builder(builder, train_split,
validation_split, image_size,
width_ratio, batch_size,
one_hot, num_classes)
callbacks = [
tf.keras.callbacks.TensorBoard(log_dir=model_dirs['tensorboard_logs']),
tf.keras.callbacks.ModelCheckpoint(
model_dirs['model_checkpoint'], save_best_only=True),
tf.keras.callbacks.EarlyStopping(
monitor='loss', min_delta=0.001, patience=3)
]
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=[tf.keras.metrics.CategoricalAccuracy(dtype=tf.float32)],
)
model.fit(
train_ds,
validation_data=validation_ds,
epochs=epochs,
callbacks=callbacks)
model.save(model_dirs['saved_model'])