def _get_dataset()

in models/hific/model.py [0:0]


  def _get_dataset(self, batch_size, crop_size,
                   images_glob, tfds_arguments: TFDSArguments):
    """Build TFDS dataset.

    Args:
      batch_size: int, batch_size.
      crop_size: int, will random crop to this (crop_size, crop_size)
      images_glob:
      tfds_arguments: argument for TFDS.

    Returns:
      Instance of tf.data.Dataset.
    """

    crop_size_float = tf.constant(crop_size, tf.float32) if crop_size else None
    smallest_fac = tf.constant(0.75, tf.float32)
    biggest_fac = tf.constant(0.95, tf.float32)

    with tf.name_scope("tfds"):
      if images_glob:
        images = sorted(glob.glob(images_glob))
        tf.logging.info(
            f"Using images_glob={images_glob} ({len(images)} images)")
        filenames = tf.data.Dataset.from_tensor_slices(images)
        dataset = filenames.map(lambda x: tf.image.decode_png(tf.read_file(x)))
      else:
        tf.logging.info(f"Using TFDS={tfds_arguments}")
        builder = tfds.builder(
            tfds_arguments.dataset_name, data_dir=tfds_arguments.downloads_dir)
        builder.download_and_prepare()
        split = "train" if self.training else "validation"
        dataset = builder.as_dataset(split=split)

      def _preprocess(features):
        if images_glob:
          image = features
        else:
          image = features[tfds_arguments.features_key]
        if not crop_size:
          return image
        tf.logging.info("Scaling down %s and cropping to %d x %d", image,
                        crop_size, crop_size)
        with tf.name_scope("random_scale"):
          # Scale down by at least `biggest_fac` and at most `smallest_fac` to
          # remove JPG artifacts. This code also handles images that have one
          # side  shorter than crop_size. In this case, we always upscale such
          # that this side becomes the same as `crop_size`. Overall, images
          # returned will never be smaller than `crop_size`.
          image_shape = tf.cast(tf.shape(image), tf.float32)
          height, width = image_shape[0], image_shape[1]
          smallest_side = tf.math.minimum(height, width)
          # The smallest factor such that the downscaled image is still bigger
          # than `crop_size`. Will be bigger than 1 for images smaller than
          # `crop_size`.
          image_smallest_fac = crop_size_float / smallest_side
          min_fac = tf.math.maximum(smallest_fac, image_smallest_fac)
          max_fac = tf.math.maximum(min_fac, biggest_fac)
          scale = tf.random_uniform([],
                                    minval=min_fac,
                                    maxval=max_fac,
                                    dtype=tf.float32,
                                    seed=42,
                                    name=None)
          image = tf.image.resize_images(
              image, [tf.ceil(scale * height),
                      tf.ceil(scale * width)])
        with tf.name_scope("random_crop"):
          image = tf.image.random_crop(image, [crop_size, crop_size, 3])
        return image

      dataset = dataset.map(
          _preprocess, num_parallel_calls=DATASET_NUM_PARALLEL)
      dataset = dataset.batch(batch_size, drop_remainder=True)

      if not self.evaluation:
        # Make sure we don't run out of data
        dataset = dataset.repeat()
        dataset = dataset.shuffle(buffer_size=DATASET_SHUFFLE_BUFFER)
      dataset = dataset.prefetch(buffer_size=DATASET_PREFETCH_BUFFER)

      return dataset