def _get_data_as_datasets()

in tensorflow_hub/tools/make_image_classifier/make_image_classifier_lib.py [0:0]


def _get_data_as_datasets(image_dir, image_size, hparams):
  """Gets training and validation data via tf.data.Dataset.

  Args:
    image_dir: A Python string with the name of a directory that contains
      subdirectories of images, one per class.
    image_size: A list or tuple with 2 Python integers specifying the fixed
      height and width to which input images are resized.
    hparams: A HParams object with hyperparameters controlling the training.

  Returns:
    A nested tuple ((train_data, train_size),
                    (valid_data, valid_size), labels) where:
    train_data, valid_data: tf.data.Dataset for use with Model.fit, each
      yielding batch of tuples (images, labels) where
        images is a float32 Tensor of shape [batch_size, height, width, 3]
          with pixel values in range [0,1],
        labels is a float32 Tensor of shape [batch_size, num_classes]
          with one-hot encoded classes.
    train_size, valid_size: Python integers with the numbers of training
      and validation examples, respectively.
    labels: A tuple of strings with the class labels (subdirectory names).
      The index of a label in this tuple is the numeric class id.
  """
  # Check if hparam.shear_range is set. If yes, throw an error since shear is
  # not supported when using preprocessing layers.
  if hparams.shear_range != 0:
    raise ValueError("Found non-zero value for shear_range. Shear is not "
                     "supported when using reading input with tf.data.Dataset "
                     "and using preprocessing layers.")

  train_ds = tf.keras.preprocessing.image_dataset_from_directory(
      image_dir,
      validation_split=hparams.validation_split,
      subset="training",
      label_mode="categorical",
      # Seed needs to provided when using validation_split and shuffle = True.
      # A fixed seed is used so that the validation set is stable across runs.
      seed=123,
      image_size=image_size,
      batch_size=1)
  class_names = tuple(train_ds.class_names)
  train_size = train_ds.cardinality().numpy()
  train_ds = train_ds.unbatch().batch(hparams.batch_size)
  train_ds = train_ds.repeat()

  normalization_layer = tf.keras.layers.experimental.preprocessing.Rescaling(
      1. / 255)
  preprocessing_model = tf.keras.Sequential([normalization_layer])
  if hparams.do_data_augmentation:
    preprocessing_model.add(
        tf.keras.layers.experimental.preprocessing.RandomRotation(
            hparams.rotation_range))
    preprocessing_model.add(
        tf.keras.layers.experimental.preprocessing.RandomTranslation(
            0, hparams.width_shift_range))
    preprocessing_model.add(
        tf.keras.layers.experimental.preprocessing.RandomTranslation(
            hparams.height_shift_range, 0))
    # Like the old tf.keras.preprocessing.image.ImageDataGenerator(),
    # image sizes are fixed when reading, and then a random zoom is applied.
    # If all training inputs are larger than image_size, one could also use
    # RandomCrop with a batch size of 1 and rebatch later.
    preprocessing_model.add(
        tf.keras.layers.experimental.preprocessing.RandomZoom(
            hparams.zoom_range, hparams.zoom_range))
    if hparams.horizontal_flip:
      preprocessing_model.add(
          tf.keras.layers.experimental.preprocessing.RandomFlip(
              mode="horizontal"))
  train_ds = train_ds.map(lambda images, labels:
                          (preprocessing_model(images), labels))

  val_ds = tf.keras.preprocessing.image_dataset_from_directory(
      image_dir,
      validation_split=hparams.validation_split,
      subset="validation",
      label_mode="categorical",
      seed=123,
      shuffle=False,
      image_size=image_size,
      batch_size=1)
  valid_size = val_ds.cardinality().numpy()
  val_ds = val_ds.unbatch().batch(hparams.batch_size)
  val_ds = val_ds.map(lambda images, labels:
                      (normalization_layer(images), labels))

  return ((train_ds, train_size), (val_ds, valid_size), class_names)