in tensorflow_benchmark/tf_cnn_benchmarks/preprocessing.py [0:0]
def minibatch(self, dataset, subset, use_datasets, cache_data,
shift_ratio=-1):
if shift_ratio < 0:
shift_ratio = self.shift_ratio
with tf.name_scope('batch_processing'):
# Build final results per split.
images = [[] for _ in range(self.num_splits)]
labels = [[] for _ in range(self.num_splits)]
if use_datasets:
glob_pattern = dataset.tf_record_pattern(subset)
file_names = gfile.Glob(glob_pattern)
if not file_names:
raise ValueError('Found no files in --data_dir matching: {}'
.format(glob_pattern))
ds = tf.data.TFRecordDataset.list_files(file_names)
ds = ds.apply(
interleave_ops.parallel_interleave(
tf.data.TFRecordDataset, cycle_length=10))
if cache_data:
ds = ds.take(1).cache().repeat()
counter = tf.data.Dataset.range(self.batch_size)
counter = counter.repeat()
ds = tf.data.Dataset.zip((ds, counter))
ds = ds.prefetch(buffer_size=self.batch_size)
ds = ds.shuffle(buffer_size=10000)
ds = ds.repeat()
ds = ds.apply(
batching.map_and_batch(
map_func=self.parse_and_preprocess,
batch_size=self.batch_size_per_split,
num_parallel_batches=self.num_splits))
ds = ds.prefetch(buffer_size=self.num_splits)
ds_iterator = ds.make_one_shot_iterator()
for d in xrange(self.num_splits):
labels[d], images[d] = ds_iterator.get_next()
else:
record_input = data_flow_ops.RecordInput(
file_pattern=dataset.tf_record_pattern(subset),
seed=301,
parallelism=64,
buffer_size=10000,
batch_size=self.batch_size,
shift_ratio=shift_ratio,
name='record_input')
records = record_input.get_yield_op()
records = tf.split(records, self.batch_size, 0)
records = [tf.reshape(record, []) for record in records]
for idx in xrange(self.batch_size):
value = records[idx]
(label, image) = self.parse_and_preprocess(value, idx)
split_index = idx % self.num_splits
labels[split_index].append(label)
images[split_index].append(image)
for split_index in xrange(self.num_splits):
if not use_datasets:
images[split_index] = tf.parallel_stack(images[split_index])
labels[split_index] = tf.concat(labels[split_index], 0)
images[split_index] = tf.cast(images[split_index], self.dtype)
depth = 3
images[split_index] = tf.reshape(
images[split_index],
shape=[self.batch_size_per_split, self.height, self.width, depth])
labels[split_index] = tf.reshape(labels[split_index],
[self.batch_size_per_split])
return images, labels