def get_dataset()

in tensorflow_graphics/projects/points_to_3Dobjects/train_multi_objects/train.py [0:0]


def get_dataset(split, shape_soft_labels, shape_pointclouds=None):
  """Get dataset."""
  if shape_pointclouds:
    print(shape_pointclouds)
  tfrecord_path = os.path.join(FLAGS.tfrecords_dir, split)
  buffer_size, shuffle, cycles = 10000, True, 10000
  if FLAGS.debug:
    buffer_size, shuffle, cycles = 1, False, 1
  if FLAGS.val:
    buffer_size, shuffle, cycles = 100, False, 1

  tfrecords_pattern = io.expand_rio_pattern(tfrecord_path)
  dataset = tf.data.Dataset.list_files(tfrecords_pattern, shuffle=shuffle)
  dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=cycles)

  if shuffle:
    dataset = dataset.shuffle(buffer_size=buffer_size)

  if 'scannet' in tfrecord_path:
    dataset = dataset.map(extract_protos.decode_bytes_multiple_scannet)
    dataset = dataset.filter(lambda sample: sample['num_boxes'] == 3)
  else:
    dataset = dataset.map(extract_protos.decode_bytes_multiple)
    dataset = \
        dataset.filter(lambda sample: tf.reduce_min(sample['shapes']) > -1)

  def augment(sample):
    image = sample['image']
    if tf.random.uniform([1], 0, 1.0) < 0.8:
      image = tf.image.random_saturation(image, 1.0, 10.0)
      image = tf.image.random_contrast(image, 0.05, 5.0)
      image = tf.image.random_hue(image, 0.5)
      image = tf.image.random_brightness(image, 0.8)
      sample['image'] = image
    if tf.random.uniform([1], 0, 1.0) < 0.5:
      sample['image'] = tf.image.flip_left_right(sample['image'])
      sample['translations_3d'] *= [[-1.0, 1.0, 1.0]]
      sample['rotations_3d'] = tf.reshape(sample['rotations_3d'], [-1, 3, 3])
      sample['rotations_3d'] = tf.transpose(sample['rotations_3d'],
                                            perm=[0, 2, 1])
      sample['rotations_3d'] = tf.reshape(sample['rotations_3d'], [-1, 9])
      bbox = sample['groundtruth_boxes']
      bbox = tf.stack([bbox[:, 0], 1 - bbox[:, 3], bbox[:, 2], 1 - bbox[:, 1]],
                      axis=-1)
      sample['groundtruth_boxes'] = bbox
    if FLAGS.gaussian_augmentation:
      if tf.random.uniform([1], 0, 1.0) < 0.15:
        sample['image'] = tf_utils.gaussian_blur(sample['image'], sigma=1)
      elif tf.random.uniform([1], 0, 1.0) < 0.30:
        sample['image'] = tf_utils.gaussian_blur(sample['image'], sigma=2)
      elif tf.random.uniform([1], 0, 1.0) < 0.45:
        sample['image'] = tf_utils.gaussian_blur(sample['image'], sigma=3)
    return sample

  if FLAGS.train:  # and not FLAGS.debug:
    dataset = dataset.map(augment, num_parallel_calls=FLAGS.num_workers)

  def add_soft_shape_labels(sample):
    sample['shapes_soft'] = tf.map_fn(
        fn=lambda t: tf.cast(shape_soft_labels[t], tf.float32),
        elems=tf.cast(sample['shapes'], tf.int32),
        fn_output_signature=tf.float32)
    return sample

  if FLAGS.soft_shape_labels:
    dataset = dataset.map(add_soft_shape_labels,
                          num_parallel_calls=FLAGS.num_workers)

  # Create dataset for overfitting when debugging
  if FLAGS.debug:
    t = FLAGS.num_overfitting_samples
    dataset = dataset.take(t)
    mult = 1 if FLAGS.val else 1000
    dataset = dataset.repeat(mult)
    if t > 1:
      dataset = dataset.shuffle(buffer_size)

  return dataset