07_training/serverlessml/flowers/ingest/tfrecords.py (61 lines of code) (raw):
#!/usr/bin/env python
# Copyright 2020 Google Inc. Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
# OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.
import numpy as np
import tensorflow as tf
from tensorflow.data.experimental import AUTOTUNE
class _Preprocessor:
def __init__(self, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS):
self.IMG_HEIGHT = IMG_HEIGHT
self.IMG_WIDTH = IMG_WIDTH
self.IMG_CHANNELS = IMG_CHANNELS
def read_from_tfr(self, proto):
feature_description = {
'image': tf.io.VarLenFeature(tf.float32),
'shape': tf.io.VarLenFeature(tf.int64),
'label': tf.io.FixedLenFeature([], tf.string, default_value=''),
'label_int': tf.io.FixedLenFeature([], tf.int64, default_value=0),
}
rec = tf.io.parse_single_example(
proto, feature_description
)
shape = tf.sparse.to_dense(rec['shape'])
img = tf.reshape(tf.sparse.to_dense(rec['image']), shape)
label_int = rec['label_int']
return img, label_int
def read_from_jpegfile(self, filename):
# same code as in 05_create_dataset/jpeg_to_tfrecord.py
img = tf.io.read_file(filename)
img = tf.image.decode_jpeg(img, channels=self.IMG_CHANNELS)
img = tf.image.convert_image_dtype(img, tf.float32)
return img
def preprocess(self, img):
return tf.image.resize_with_pad(img, self.IMG_HEIGHT, self.IMG_WIDTH)
# most efficient way to read the data
# as determined in 07a_ingest.ipynb
# splits the files into two halves and interleaves datasets
def create_preproc_dataset(pattern, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS):
"""
Does interleaving, parallel calls, prefetch, batching
Caching is not a good idea on large datasets.
"""
preproc = _Preprocessor(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)
files = [filename for filename
in tf.random.shuffle(tf.io.gfile.glob(pattern))]
if len(files) > 1:
print("Interleaving the reading of {} files.".format(len(files)))
def _create_half_ds(x):
if x == 0:
half = files[:(len(files)//2)]
else:
half = files[(len(files)//2):]
return tf.data.TFRecordDataset(half,
compression_type='GZIP')
trainds = tf.data.Dataset.range(2).interleave(
_create_half_ds, num_parallel_calls=AUTOTUNE)
else:
print("WARNING! Only {} files match {}".format(len(files), pattern))
trainds = tf.data.TFRecordDataset(files,
compression_type='GZIP')
def _preproc_img_label(img, label):
return (preproc.preprocess(img), label)
trainds = (trainds
.map(preproc.read_from_tfr, num_parallel_calls=AUTOTUNE)
.map(_preproc_img_label, num_parallel_calls=AUTOTUNE)
.shuffle(200)
.prefetch(AUTOTUNE)
)
return trainds
def create_preproc_image(filename, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS):
preproc = _Preprocessor(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)
img = preproc.read_from_jpegfile(filename)
return preproc.preprocess(img)