in tensorflow_graphics/projects/points_to_3Dobjects/train_multi_objects/train.py [0:0]
def val(gpu_ids=None, record_losses=False, split='val', part_id=-2):
"""Val function."""
FLAGS.batch_size = 1
strategy = tf.distribute.MirroredStrategy(tf_utils.get_devices(gpu_ids))
logging.info('Number of devices: %d', strategy.num_replicas_in_sync)
shape_centers, shape_sdfs, shape_pointclouds, dict_clusters = \
get_shapes('scannet' in FLAGS.tfrecords_dir)
soft_shape_labels = get_soft_shape_labels(shape_sdfs)
part = '*.tfrecord' if part_id == -2 else \
'-'+str(part_id).zfill(5)+'-of-00100.tfrecord'
dataset = get_dataset(split+part, soft_shape_labels, shape_pointclouds)
# for sample in dataset:
# plt.imshow(sample['image'])
# plt.savefig('/usr/local/google/home/engelmann/res/'+sample['scene_filename'].numpy().decode()+'.png')
val_evaluator = get_evaluator()
with strategy.scope():
name = 'eval_'+str(split)
work_unit = None
logging_dir = os.path.join(FLAGS.logdir, 'logging')
logger = logger_util.Logger(logging_dir, name, work_unit,
FLAGS.xmanager_metric,
save_loss_tensorboard_frequency=10,
print_loss_frequency=1000)
epoch = tf.Variable(0, trainable=False)
latest_epoch = tf.Variable(-1, trainable=False)
num_epochs = tf.Variable(-1, trainable=False)
number_of_steps_previous_epochs = \
tf.Variable(0, trainable=False, dtype=tf.int64)
model = get_model(shape_centers,
shape_sdfs,
shape_pointclouds,
dict_clusters)
transforms = {'name': 'centernet_preprocessing',
'params': {'image_size': (FLAGS.image_height,
FLAGS.image_width),
'transform_gt_annotations': True,
'random': False}}
train_targets = {'name': 'centernet_train_targets',
'params': {'num_classes': FLAGS.num_classes,
'image_size': (FLAGS.image_height,
FLAGS.image_width),
'stride': model.output_stride}}
transform_fn = transforms_factory.TransformsFactory.get_transform_group(
**transforms)
train_targets_fn = transforms_factory.TransformsFactory.get_transform_group(
**train_targets)
input_image_size = transforms['params']['image_size']
dataset = dataset.map(transform_fn, num_parallel_calls=FLAGS.num_workers)
dataset = dataset.batch(FLAGS.batch_size)
# for k in ['name', 'scene_filename', 'mesh_names', 'classes', 'image',
# 'image_data', 'original_image_spatial_shape', 'num_boxes',
# 'center2d', 'groundtruth_boxes', 'dot', 'sizes_3d',
# 'translations_3d', 'rotations_3d', 'rt', 'k',
# 'groundtruth_valid_classes', 'shapes']:
# print('---', k)
# for i, sample in enumerate(dataset.take(7)):
# print(sample[k].shape)
# train_targets_fn(sample)
# for i, sample in enumerate(dataset):
# print(i)
# train_targets_fn(sample)
if train_targets_fn is not None:
dataset = dataset.map(train_targets_fn,
num_parallel_calls=FLAGS.num_workers)
if FLAGS.debug and False:
for d in dataset.take(1):
image = tf.io.decode_image(d['image_data'][0]).numpy()
heatmaps = d['centers'][0]
plot.plot_gt_heatmaps(image, heatmaps)
if tf.distribute.has_strategy():
strategy = tf.distribute.get_strategy()
dataset = strategy.experimental_distribute_dataset(dataset)
if transforms is not None and input_image_size is None:
if FLAGS.run_graph:
FLAGS.run_graph = False
logging.info('Graph mode has been disable because the input does'
'not have constant size.')
if FLAGS.batch_size > strategy.num_replicas_in_sync:
raise ValueError('Batch size cannot be bigger than the number of GPUs'
' when the input does not have constant size')
val_checkpoint_dir = os.path.join(FLAGS.logdir, f'{name}_ckpts')
val_checkpoint = tf.train.Checkpoint(
epoch=latest_epoch,
number_of_steps_previous_epochs=number_of_steps_previous_epochs)
val_manager = tf.train.CheckpointManager(
val_checkpoint, val_checkpoint_dir, max_to_keep=1)
if val_manager.latest_checkpoint:
val_checkpoint.restore(val_manager.latest_checkpoint)
train_checkpoint_dir = os.path.join(FLAGS.logdir, 'training_ckpts')
if FLAGS.replication:
train_checkpoint_dir = os.path.join(train_checkpoint_dir, 'r=30')
train_checkpoint = tf.train.Checkpoint(epoch=epoch, model=model.network,
num_epochs=num_epochs)
latest_checkpoint = ''
if FLAGS.master == 'local' or FLAGS.plot:
local_dump = os.path.join(FLAGS.logdir, 'images')
if not tf.io.gfile.exists(local_dump):
tf.io.gfile.makedirs(local_dump)
with logger.summary_writer.as_default():
while True:
curr_latest_checkpoint = \
tf.train.latest_checkpoint(train_checkpoint_dir)
if (curr_latest_checkpoint is not None and
latest_checkpoint != curr_latest_checkpoint):
latest_checkpoint = curr_latest_checkpoint
train_checkpoint.restore(curr_latest_checkpoint)
if epoch != latest_epoch or FLAGS.eval_only:
FLAGS.eval_only = False
logging.info('Evaluating checkpoint in %s: %s.',
name, latest_checkpoint)
n_steps = _val_epoch(name, model, dataset, input_image_size,
epoch.numpy(), logger,
number_of_steps_previous_epochs,
val_evaluator, record_losses)
number_of_steps_previous_epochs.assign_add(n_steps)
latest_epoch.assign(epoch.numpy())
if part_id < -1:
val_manager.save()
else:
return
if epoch == num_epochs:
break
time.sleep(1)