in tensorflow_graphics/projects/points_to_3Dobjects/train_multi_objects/train.py [0:0]
def train(max_num_steps_epoch=None,
save_initial_checkpoint=False,
gpu_ids=None):
"""Train function."""
strategy = tf.distribute.MirroredStrategy(tf_utils.get_devices(gpu_ids))
logging.info('Number of devices: %d', strategy.num_replicas_in_sync)
shape_centers, shape_sdfs, shape_pointclouds, dict_clusters = \
get_shapes('scannet' in FLAGS.tfrecords_dir)
soft_shape_labels = get_soft_shape_labels(shape_sdfs)
dataset = get_dataset('train*.tfrecord', soft_shape_labels, shape_pointclouds)
for sample in dataset.take(1):
plt.imshow(sample['image'])
if FLAGS.debug:
FLAGS.num_epochs = 50
if FLAGS.continue_from_checkpoint:
FLAGS.num_epochs *= 2
latest_epoch = tf.Variable(0, trainable=False)
num_epochs_var = tf.Variable(FLAGS.num_epochs, trainable=False)
number_of_steps_previous_epochs = tf.Variable(0, trainable=False,
dtype=tf.int64)
with strategy.scope():
work_unit = None
logging_dir = os.path.join(FLAGS.logdir, 'logging')
logger = logger_util.Logger(logging_dir, 'train', work_unit, '',
save_loss_tensorboard_frequency=100,
print_loss_frequency=1000)
optimizer = tf.keras.optimizers.Adam(learning_rate=get_learning_rate_fn())
model = get_model(shape_centers,
shape_sdfs,
shape_pointclouds,
dict_clusters)
model.optimizer = optimizer
transforms = {'name': 'centernet_preprocessing',
'params': {'image_size': (FLAGS.image_height,
FLAGS.image_width),
'transform_gt_annotations': True,
'random': False}}
train_targets = {'name': 'centernet_train_targets',
'params': {'num_classes': FLAGS.num_classes,
'image_size': (FLAGS.image_height,
FLAGS.image_width),
'stride': model.output_stride}}
transform_fn = transforms_factory.TransformsFactory.get_transform_group(
**transforms)
train_targets_fn = transforms_factory.TransformsFactory.get_transform_group(
**train_targets)
input_image_size = transforms['params']['image_size']
dataset = dataset.map(transform_fn, num_parallel_calls=FLAGS.num_workers)
dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True)
dataset = dataset.map(train_targets_fn,
num_parallel_calls=FLAGS.num_workers)
if FLAGS.batch_size > 1:
dataset.prefetch(int(FLAGS.batch_size * 1.5))
# for sample in dataset:
# print(sample['name'])
dataset = strategy.experimental_distribute_dataset(dataset)
checkpoint_dir = os.path.join(FLAGS.logdir, 'training_ckpts')
if FLAGS.replication:
checkpoint_dir = os.path.join(checkpoint_dir, 'r=30')
checkpoint = tf.train.Checkpoint(
epoch=latest_epoch,
model=model.network,
optimizer=optimizer,
number_of_steps_previous_epochs=number_of_steps_previous_epochs,
num_epochs=num_epochs_var)
manager = tf.train.CheckpointManager(checkpoint,
checkpoint_dir,
max_to_keep=5)
# Restore latest checkpoint
if manager.latest_checkpoint:
logging.info('Restoring from %s', manager.latest_checkpoint)
checkpoint.restore(manager.latest_checkpoint)
elif FLAGS.continue_from_checkpoint:
init_checkpoint_dir = os.path.join(
FLAGS.continue_from_checkpoint, 'training_ckpts')
init_manager = tf.train.CheckpointManager(checkpoint,
init_checkpoint_dir,
None)
logging.info('Restoring from pretrained %s',
init_manager.latest_checkpoint)
checkpoint.restore(init_manager.latest_checkpoint)
else:
logging.info('Not restoring any previous training checkpoint.')
if save_initial_checkpoint and not manager.latest_checkpoint:
# Create a new checkpoint to avoid internal ckpt counter to increment
tmp_ckpt = tf.train.Checkpoint(epoch=latest_epoch, model=model.network)
tmp_manager = tf.train.CheckpointManager(tmp_ckpt, checkpoint_dir, None)
save_path = tmp_manager.save(0)
logging.info('Saved checkpoint for epoch %d: %s',
int(latest_epoch.numpy()), save_path)
latest_epoch.assign_add(1)
with logger.summary_writer.as_default():
for epoch in range(int(latest_epoch.numpy()), FLAGS.num_epochs + 1):
latest_epoch.assign(epoch)
n_steps = _train_epoch(epoch, model, dataset, logger,
number_of_steps_previous_epochs,
max_num_steps_epoch, input_image_size)
number_of_steps_previous_epochs.assign_add(n_steps)
save_path = manager.save()
logging.info('Saved checkpoint for epoch %d: %s',
int(latest_epoch.numpy()), save_path)