def get_dataset_creator()

in mesh_tensorflow/experimental/unet.py [0:0]


def get_dataset_creator(dataset_str):
  """Returns a function that creates an unbatched dataset."""
  if dataset_str == 'train':
    data_file_pattern = FLAGS.train_file_pattern.format(FLAGS.ct_resolution)
    shuffle = True
    interleave = True
  else:
    assert dataset_str == 'eval'
    data_file_pattern = FLAGS.eval_file_pattern.format(FLAGS.ct_resolution)
    shuffle = False
    interleave = False

  def _dataset_creator():
    """Returns an unbatch dataset."""

    def _get_stacked_2d_slices(image_3d, label_3d):
      """Return 2d slices of the 3d scan."""
      image_stack = []
      label_stack = []

      for begin_idx in range(0, FLAGS.ct_resolution - FLAGS.image_c + 1):
        slice_begin = [0, 0, begin_idx]
        slice_size = [FLAGS.ct_resolution, FLAGS.ct_resolution, FLAGS.image_c]
        image = tf.slice(image_3d, slice_begin, slice_size)

        slice_begin = [0, 0, begin_idx + FLAGS.image_c // 2]
        slice_size = [FLAGS.ct_resolution, FLAGS.ct_resolution, 1]
        label = tf.slice(label_3d, slice_begin, slice_size)

        spatial_dims_w_blocks = [FLAGS.image_nx_block,
                                 FLAGS.ct_resolution // FLAGS.image_nx_block,
                                 FLAGS.image_ny_block,
                                 FLAGS.ct_resolution // FLAGS.image_ny_block]

        image = tf.reshape(image, spatial_dims_w_blocks + [FLAGS.image_c])
        label = tf.reshape(label, spatial_dims_w_blocks)

        label = tf.cast(label, tf.int32)
        label = tf.one_hot(label, FLAGS.label_c)

        data_dtype = tf.as_dtype(FLAGS.mtf_dtype)
        image = tf.cast(image, data_dtype)
        label = tf.cast(label, data_dtype)

        image_stack.append(image)
        label_stack.append(label)

      return tf.stack(image_stack), tf.stack(label_stack)

    def _parser_fn(serialized_example):
      """Parses a single tf.Example into image and label tensors."""
      features = {}
      features['image/ct_image'] = tf.FixedLenFeature([], tf.string)
      features['image/label'] = tf.FixedLenFeature([], tf.string)
      parsed = tf.parse_single_example(serialized_example, features=features)

      spatial_dims = [FLAGS.ct_resolution] * 3
      if FLAGS.sampled_2d_slices:
        noise_shape = [FLAGS.ct_resolution] * 2 + [FLAGS.image_c]
      else:
        noise_shape = [FLAGS.ct_resolution] * 3

      image = tf.decode_raw(parsed['image/ct_image'], tf.float32)
      label = tf.decode_raw(parsed['image/label'], tf.float32)

      if dataset_str != 'train':
        # Preprocess intensity, clip to 0 ~ 1.
        # The training set is already preprocessed.
        image = tf.clip_by_value(image / 1024.0 + 0.5, 0, 1)

      image = tf.reshape(image, spatial_dims)
      label = tf.reshape(label, spatial_dims)

      if dataset_str == 'eval' and FLAGS.sampled_2d_slices:
        return _get_stacked_2d_slices(image, label)

      if FLAGS.sampled_2d_slices:
        # Take random slices of images and label
        begin_idx = tf.random_uniform(
            shape=[], minval=0,
            maxval=FLAGS.ct_resolution - FLAGS.image_c + 1, dtype=tf.int32)
        slice_begin = [0, 0, begin_idx]
        slice_size = [FLAGS.ct_resolution, FLAGS.ct_resolution, FLAGS.image_c]

        image = tf.slice(image, slice_begin, slice_size)
        label = tf.slice(label, slice_begin, slice_size)

      if dataset_str == 'train':
        for flip_axis in [0, 1, 2]:
          image, label = data_aug_lib.maybe_flip(image, label, flip_axis)
        image, label = data_aug_lib.maybe_rot180(image, label, static_axis=2)
        image = data_aug_lib.intensity_shift(
            image, label,
            FLAGS.per_class_intensity_scale, FLAGS.per_class_intensity_shift)
        image = data_aug_lib.image_corruption(
            image, label, FLAGS.ct_resolution,
            FLAGS.image_corrupt_ratio_mean, FLAGS.image_corrupt_ratio_stddev)
        image = data_aug_lib.maybe_add_noise(
            image, noise_shape, 1, 4,
            FLAGS.image_noise_probability, FLAGS.image_noise_ratio)
        image, label = data_aug_lib.projective_transform(
            image, label, FLAGS.ct_resolution,
            FLAGS.image_translate_ratio, FLAGS.image_transform_ratio,
            FLAGS.sampled_2d_slices)

      if FLAGS.sampled_2d_slices:
        # Only get the center slice of label.
        label = tf.slice(label, [0, 0, FLAGS.image_c // 2],
                         [FLAGS.ct_resolution, FLAGS.ct_resolution, 1])

      spatial_dims_w_blocks = [FLAGS.image_nx_block,
                               FLAGS.ct_resolution // FLAGS.image_nx_block,
                               FLAGS.image_ny_block,
                               FLAGS.ct_resolution // FLAGS.image_ny_block]
      if not FLAGS.sampled_2d_slices:
        spatial_dims_w_blocks += [FLAGS.ct_resolution]

      image = tf.reshape(image, spatial_dims_w_blocks + [FLAGS.image_c])
      label = tf.reshape(label, spatial_dims_w_blocks)

      label = tf.cast(label, tf.int32)
      label = tf.one_hot(label, FLAGS.label_c)

      data_dtype = tf.as_dtype(FLAGS.mtf_dtype)
      image = tf.cast(image, data_dtype)
      label = tf.cast(label, data_dtype)
      return image, label

    dataset_fn = functools.partial(tf.data.TFRecordDataset,
                                   compression_type='GZIP')
    dataset = tf.data.Dataset.list_files(data_file_pattern,
                                         shuffle=shuffle).repeat()

    if interleave:
      dataset = dataset.apply(
          tf.data.experimental.parallel_interleave(
              lambda file_name: dataset_fn(file_name).prefetch(1),
              cycle_length=FLAGS.n_dataset_read_interleave,
              sloppy=True))
    else:
      dataset = dataset.apply(
          tf.data.experimental.parallel_interleave(
              lambda file_name: dataset_fn(file_name).prefetch(1),
              cycle_length=1,
              sloppy=False))

    if shuffle:
      dataset = dataset.shuffle(FLAGS.n_dataset_processes).map(
          _parser_fn, num_parallel_calls=FLAGS.n_dataset_processes)
    else:
      dataset = dataset.map(_parser_fn)

    if dataset_str == 'eval' and FLAGS.sampled_2d_slices:
      # When evaluating on slices, unbatch slices that belong to one CT scan.
      dataset = dataset.unbatch()

    return dataset

  return _dataset_creator