in models/hific/model.py [0:0]
def _get_dataset(self, batch_size, crop_size,
images_glob, tfds_arguments: TFDSArguments):
"""Build TFDS dataset.
Args:
batch_size: int, batch_size.
crop_size: int, will random crop to this (crop_size, crop_size)
images_glob:
tfds_arguments: argument for TFDS.
Returns:
Instance of tf.data.Dataset.
"""
crop_size_float = tf.constant(crop_size, tf.float32) if crop_size else None
smallest_fac = tf.constant(0.75, tf.float32)
biggest_fac = tf.constant(0.95, tf.float32)
with tf.name_scope("tfds"):
if images_glob:
images = sorted(glob.glob(images_glob))
tf.logging.info(
f"Using images_glob={images_glob} ({len(images)} images)")
filenames = tf.data.Dataset.from_tensor_slices(images)
dataset = filenames.map(lambda x: tf.image.decode_png(tf.read_file(x)))
else:
tf.logging.info(f"Using TFDS={tfds_arguments}")
builder = tfds.builder(
tfds_arguments.dataset_name, data_dir=tfds_arguments.downloads_dir)
builder.download_and_prepare()
split = "train" if self.training else "validation"
dataset = builder.as_dataset(split=split)
def _preprocess(features):
if images_glob:
image = features
else:
image = features[tfds_arguments.features_key]
if not crop_size:
return image
tf.logging.info("Scaling down %s and cropping to %d x %d", image,
crop_size, crop_size)
with tf.name_scope("random_scale"):
# Scale down by at least `biggest_fac` and at most `smallest_fac` to
# remove JPG artifacts. This code also handles images that have one
# side shorter than crop_size. In this case, we always upscale such
# that this side becomes the same as `crop_size`. Overall, images
# returned will never be smaller than `crop_size`.
image_shape = tf.cast(tf.shape(image), tf.float32)
height, width = image_shape[0], image_shape[1]
smallest_side = tf.math.minimum(height, width)
# The smallest factor such that the downscaled image is still bigger
# than `crop_size`. Will be bigger than 1 for images smaller than
# `crop_size`.
image_smallest_fac = crop_size_float / smallest_side
min_fac = tf.math.maximum(smallest_fac, image_smallest_fac)
max_fac = tf.math.maximum(min_fac, biggest_fac)
scale = tf.random_uniform([],
minval=min_fac,
maxval=max_fac,
dtype=tf.float32,
seed=42,
name=None)
image = tf.image.resize_images(
image, [tf.ceil(scale * height),
tf.ceil(scale * width)])
with tf.name_scope("random_crop"):
image = tf.image.random_crop(image, [crop_size, crop_size, 3])
return image
dataset = dataset.map(
_preprocess, num_parallel_calls=DATASET_NUM_PARALLEL)
dataset = dataset.batch(batch_size, drop_remainder=True)
if not self.evaluation:
# Make sure we don't run out of data
dataset = dataset.repeat()
dataset = dataset.shuffle(buffer_size=DATASET_SHUFFLE_BUFFER)
dataset = dataset.prefetch(buffer_size=DATASET_PREFETCH_BUFFER)
return dataset