def classifier_trainer()

in src/python/tensorflow_cloud/core/experimental/models.py [0:0]


def classifier_trainer(dataset_name, model_name, batch_size, epochs,
                       train_split, validation_split, one_hot, model_dirs):
  """Training loop for image classifier from TF model Garden using TFDS."""
  builder = tfds.builder(dataset_name)

  num_classes = builder.info.features['label'].num_classes
  model = get_model(model_name, batch_size, num_classes)

  if model_name == 'resnet':
    image_size = 224
    width_ratio = 1
  else:  # Assumes model_name is an efficientnet version
    image_size = model.config.resolution
    width_ratio = model.config.width_coefficient

  train_ds, validation_ds = load_data_from_builder(builder, train_split,
                                                   validation_split, image_size,
                                                   width_ratio, batch_size,
                                                   one_hot, num_classes)
  callbacks = [
      tf.keras.callbacks.TensorBoard(log_dir=model_dirs['tensorboard_logs']),
      tf.keras.callbacks.ModelCheckpoint(
          model_dirs['model_checkpoint'], save_best_only=True),
      tf.keras.callbacks.EarlyStopping(
          monitor='loss', min_delta=0.001, patience=3)
  ]

  model.compile(
      optimizer=tf.keras.optimizers.Adam(),
      loss=tf.keras.losses.CategoricalCrossentropy(),
      metrics=[tf.keras.metrics.CategoricalAccuracy(dtype=tf.float32)],
  )

  model.fit(
      train_ds,
      validation_data=validation_ds,
      epochs=epochs,
      callbacks=callbacks)

  model.save(model_dirs['saved_model'])