in mesh_tensorflow/experimental/unet.py [0:0]
def get_dataset_creator(dataset_str):
"""Returns a function that creates an unbatched dataset."""
if dataset_str == 'train':
data_file_pattern = FLAGS.train_file_pattern.format(FLAGS.ct_resolution)
shuffle = True
interleave = True
else:
assert dataset_str == 'eval'
data_file_pattern = FLAGS.eval_file_pattern.format(FLAGS.ct_resolution)
shuffle = False
interleave = False
def _dataset_creator():
"""Returns an unbatch dataset."""
def _get_stacked_2d_slices(image_3d, label_3d):
"""Return 2d slices of the 3d scan."""
image_stack = []
label_stack = []
for begin_idx in range(0, FLAGS.ct_resolution - FLAGS.image_c + 1):
slice_begin = [0, 0, begin_idx]
slice_size = [FLAGS.ct_resolution, FLAGS.ct_resolution, FLAGS.image_c]
image = tf.slice(image_3d, slice_begin, slice_size)
slice_begin = [0, 0, begin_idx + FLAGS.image_c // 2]
slice_size = [FLAGS.ct_resolution, FLAGS.ct_resolution, 1]
label = tf.slice(label_3d, slice_begin, slice_size)
spatial_dims_w_blocks = [FLAGS.image_nx_block,
FLAGS.ct_resolution // FLAGS.image_nx_block,
FLAGS.image_ny_block,
FLAGS.ct_resolution // FLAGS.image_ny_block]
image = tf.reshape(image, spatial_dims_w_blocks + [FLAGS.image_c])
label = tf.reshape(label, spatial_dims_w_blocks)
label = tf.cast(label, tf.int32)
label = tf.one_hot(label, FLAGS.label_c)
data_dtype = tf.as_dtype(FLAGS.mtf_dtype)
image = tf.cast(image, data_dtype)
label = tf.cast(label, data_dtype)
image_stack.append(image)
label_stack.append(label)
return tf.stack(image_stack), tf.stack(label_stack)
def _parser_fn(serialized_example):
"""Parses a single tf.Example into image and label tensors."""
features = {}
features['image/ct_image'] = tf.FixedLenFeature([], tf.string)
features['image/label'] = tf.FixedLenFeature([], tf.string)
parsed = tf.parse_single_example(serialized_example, features=features)
spatial_dims = [FLAGS.ct_resolution] * 3
if FLAGS.sampled_2d_slices:
noise_shape = [FLAGS.ct_resolution] * 2 + [FLAGS.image_c]
else:
noise_shape = [FLAGS.ct_resolution] * 3
image = tf.decode_raw(parsed['image/ct_image'], tf.float32)
label = tf.decode_raw(parsed['image/label'], tf.float32)
if dataset_str != 'train':
# Preprocess intensity, clip to 0 ~ 1.
# The training set is already preprocessed.
image = tf.clip_by_value(image / 1024.0 + 0.5, 0, 1)
image = tf.reshape(image, spatial_dims)
label = tf.reshape(label, spatial_dims)
if dataset_str == 'eval' and FLAGS.sampled_2d_slices:
return _get_stacked_2d_slices(image, label)
if FLAGS.sampled_2d_slices:
# Take random slices of images and label
begin_idx = tf.random_uniform(
shape=[], minval=0,
maxval=FLAGS.ct_resolution - FLAGS.image_c + 1, dtype=tf.int32)
slice_begin = [0, 0, begin_idx]
slice_size = [FLAGS.ct_resolution, FLAGS.ct_resolution, FLAGS.image_c]
image = tf.slice(image, slice_begin, slice_size)
label = tf.slice(label, slice_begin, slice_size)
if dataset_str == 'train':
for flip_axis in [0, 1, 2]:
image, label = data_aug_lib.maybe_flip(image, label, flip_axis)
image, label = data_aug_lib.maybe_rot180(image, label, static_axis=2)
image = data_aug_lib.intensity_shift(
image, label,
FLAGS.per_class_intensity_scale, FLAGS.per_class_intensity_shift)
image = data_aug_lib.image_corruption(
image, label, FLAGS.ct_resolution,
FLAGS.image_corrupt_ratio_mean, FLAGS.image_corrupt_ratio_stddev)
image = data_aug_lib.maybe_add_noise(
image, noise_shape, 1, 4,
FLAGS.image_noise_probability, FLAGS.image_noise_ratio)
image, label = data_aug_lib.projective_transform(
image, label, FLAGS.ct_resolution,
FLAGS.image_translate_ratio, FLAGS.image_transform_ratio,
FLAGS.sampled_2d_slices)
if FLAGS.sampled_2d_slices:
# Only get the center slice of label.
label = tf.slice(label, [0, 0, FLAGS.image_c // 2],
[FLAGS.ct_resolution, FLAGS.ct_resolution, 1])
spatial_dims_w_blocks = [FLAGS.image_nx_block,
FLAGS.ct_resolution // FLAGS.image_nx_block,
FLAGS.image_ny_block,
FLAGS.ct_resolution // FLAGS.image_ny_block]
if not FLAGS.sampled_2d_slices:
spatial_dims_w_blocks += [FLAGS.ct_resolution]
image = tf.reshape(image, spatial_dims_w_blocks + [FLAGS.image_c])
label = tf.reshape(label, spatial_dims_w_blocks)
label = tf.cast(label, tf.int32)
label = tf.one_hot(label, FLAGS.label_c)
data_dtype = tf.as_dtype(FLAGS.mtf_dtype)
image = tf.cast(image, data_dtype)
label = tf.cast(label, data_dtype)
return image, label
dataset_fn = functools.partial(tf.data.TFRecordDataset,
compression_type='GZIP')
dataset = tf.data.Dataset.list_files(data_file_pattern,
shuffle=shuffle).repeat()
if interleave:
dataset = dataset.apply(
tf.data.experimental.parallel_interleave(
lambda file_name: dataset_fn(file_name).prefetch(1),
cycle_length=FLAGS.n_dataset_read_interleave,
sloppy=True))
else:
dataset = dataset.apply(
tf.data.experimental.parallel_interleave(
lambda file_name: dataset_fn(file_name).prefetch(1),
cycle_length=1,
sloppy=False))
if shuffle:
dataset = dataset.shuffle(FLAGS.n_dataset_processes).map(
_parser_fn, num_parallel_calls=FLAGS.n_dataset_processes)
else:
dataset = dataset.map(_parser_fn)
if dataset_str == 'eval' and FLAGS.sampled_2d_slices:
# When evaluating on slices, unbatch slices that belong to one CT scan.
dataset = dataset.unbatch()
return dataset
return _dataset_creator