in tensorflow_graphics/projects/points_to_3Dobjects/train_multi_objects/train.py [0:0]
def get_dataset(split, shape_soft_labels, shape_pointclouds=None):
"""Get dataset."""
if shape_pointclouds:
print(shape_pointclouds)
tfrecord_path = os.path.join(FLAGS.tfrecords_dir, split)
buffer_size, shuffle, cycles = 10000, True, 10000
if FLAGS.debug:
buffer_size, shuffle, cycles = 1, False, 1
if FLAGS.val:
buffer_size, shuffle, cycles = 100, False, 1
tfrecords_pattern = io.expand_rio_pattern(tfrecord_path)
dataset = tf.data.Dataset.list_files(tfrecords_pattern, shuffle=shuffle)
dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=cycles)
if shuffle:
dataset = dataset.shuffle(buffer_size=buffer_size)
if 'scannet' in tfrecord_path:
dataset = dataset.map(extract_protos.decode_bytes_multiple_scannet)
dataset = dataset.filter(lambda sample: sample['num_boxes'] == 3)
else:
dataset = dataset.map(extract_protos.decode_bytes_multiple)
dataset = \
dataset.filter(lambda sample: tf.reduce_min(sample['shapes']) > -1)
def augment(sample):
image = sample['image']
if tf.random.uniform([1], 0, 1.0) < 0.8:
image = tf.image.random_saturation(image, 1.0, 10.0)
image = tf.image.random_contrast(image, 0.05, 5.0)
image = tf.image.random_hue(image, 0.5)
image = tf.image.random_brightness(image, 0.8)
sample['image'] = image
if tf.random.uniform([1], 0, 1.0) < 0.5:
sample['image'] = tf.image.flip_left_right(sample['image'])
sample['translations_3d'] *= [[-1.0, 1.0, 1.0]]
sample['rotations_3d'] = tf.reshape(sample['rotations_3d'], [-1, 3, 3])
sample['rotations_3d'] = tf.transpose(sample['rotations_3d'],
perm=[0, 2, 1])
sample['rotations_3d'] = tf.reshape(sample['rotations_3d'], [-1, 9])
bbox = sample['groundtruth_boxes']
bbox = tf.stack([bbox[:, 0], 1 - bbox[:, 3], bbox[:, 2], 1 - bbox[:, 1]],
axis=-1)
sample['groundtruth_boxes'] = bbox
if FLAGS.gaussian_augmentation:
if tf.random.uniform([1], 0, 1.0) < 0.15:
sample['image'] = tf_utils.gaussian_blur(sample['image'], sigma=1)
elif tf.random.uniform([1], 0, 1.0) < 0.30:
sample['image'] = tf_utils.gaussian_blur(sample['image'], sigma=2)
elif tf.random.uniform([1], 0, 1.0) < 0.45:
sample['image'] = tf_utils.gaussian_blur(sample['image'], sigma=3)
return sample
if FLAGS.train: # and not FLAGS.debug:
dataset = dataset.map(augment, num_parallel_calls=FLAGS.num_workers)
def add_soft_shape_labels(sample):
sample['shapes_soft'] = tf.map_fn(
fn=lambda t: tf.cast(shape_soft_labels[t], tf.float32),
elems=tf.cast(sample['shapes'], tf.int32),
fn_output_signature=tf.float32)
return sample
if FLAGS.soft_shape_labels:
dataset = dataset.map(add_soft_shape_labels,
num_parallel_calls=FLAGS.num_workers)
# Create dataset for overfitting when debugging
if FLAGS.debug:
t = FLAGS.num_overfitting_samples
dataset = dataset.take(t)
mult = 1 if FLAGS.val else 1000
dataset = dataset.repeat(mult)
if t > 1:
dataset = dataset.shuffle(buffer_size)
return dataset