def train()

in tensorflow_graphics/projects/points_to_3Dobjects/train_multi_objects/train.py [0:0]


def train(max_num_steps_epoch=None,
          save_initial_checkpoint=False,
          gpu_ids=None):
  """Train function."""

  strategy = tf.distribute.MirroredStrategy(tf_utils.get_devices(gpu_ids))
  logging.info('Number of devices: %d', strategy.num_replicas_in_sync)
  shape_centers, shape_sdfs, shape_pointclouds, dict_clusters = \
      get_shapes('scannet' in FLAGS.tfrecords_dir)
  soft_shape_labels = get_soft_shape_labels(shape_sdfs)
  dataset = get_dataset('train*.tfrecord', soft_shape_labels, shape_pointclouds)

  for sample in dataset.take(1):
    plt.imshow(sample['image'])

  if FLAGS.debug:
    FLAGS.num_epochs = 50
  if FLAGS.continue_from_checkpoint:
    FLAGS.num_epochs *= 2

  latest_epoch = tf.Variable(0, trainable=False)
  num_epochs_var = tf.Variable(FLAGS.num_epochs, trainable=False)
  number_of_steps_previous_epochs = tf.Variable(0, trainable=False,
                                                dtype=tf.int64)
  with strategy.scope():
    work_unit = None
    logging_dir = os.path.join(FLAGS.logdir, 'logging')
    logger = logger_util.Logger(logging_dir, 'train', work_unit, '',
                                save_loss_tensorboard_frequency=100,
                                print_loss_frequency=1000)

    optimizer = tf.keras.optimizers.Adam(learning_rate=get_learning_rate_fn())
    model = get_model(shape_centers,
                      shape_sdfs,
                      shape_pointclouds,
                      dict_clusters)
    model.optimizer = optimizer

    transforms = {'name': 'centernet_preprocessing',
                  'params': {'image_size': (FLAGS.image_height,
                                            FLAGS.image_width),
                             'transform_gt_annotations': True,
                             'random': False}}
    train_targets = {'name': 'centernet_train_targets',
                     'params': {'num_classes': FLAGS.num_classes,
                                'image_size': (FLAGS.image_height,
                                               FLAGS.image_width),
                                'stride': model.output_stride}}
    transform_fn = transforms_factory.TransformsFactory.get_transform_group(
        **transforms)
    train_targets_fn = transforms_factory.TransformsFactory.get_transform_group(
        **train_targets)
    input_image_size = transforms['params']['image_size']

    dataset = dataset.map(transform_fn, num_parallel_calls=FLAGS.num_workers)
    dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True)
    dataset = dataset.map(train_targets_fn,
                          num_parallel_calls=FLAGS.num_workers)

    if FLAGS.batch_size > 1:
      dataset.prefetch(int(FLAGS.batch_size * 1.5))

    # for sample in dataset:
    #   print(sample['name'])

    dataset = strategy.experimental_distribute_dataset(dataset)

    checkpoint_dir = os.path.join(FLAGS.logdir, 'training_ckpts')

    if FLAGS.replication:
      checkpoint_dir = os.path.join(checkpoint_dir, 'r=30')

    checkpoint = tf.train.Checkpoint(
        epoch=latest_epoch,
        model=model.network,
        optimizer=optimizer,
        number_of_steps_previous_epochs=number_of_steps_previous_epochs,
        num_epochs=num_epochs_var)

    manager = tf.train.CheckpointManager(checkpoint,
                                         checkpoint_dir,
                                         max_to_keep=5)

    # Restore latest checkpoint
    if manager.latest_checkpoint:
      logging.info('Restoring from %s', manager.latest_checkpoint)
      checkpoint.restore(manager.latest_checkpoint)
    elif FLAGS.continue_from_checkpoint:
      init_checkpoint_dir = os.path.join(
          FLAGS.continue_from_checkpoint, 'training_ckpts')
      init_manager = tf.train.CheckpointManager(checkpoint,
                                                init_checkpoint_dir,
                                                None)
      logging.info('Restoring from pretrained %s',
                   init_manager.latest_checkpoint)
      checkpoint.restore(init_manager.latest_checkpoint)
    else:
      logging.info('Not restoring any previous training checkpoint.')

    if save_initial_checkpoint and not manager.latest_checkpoint:
      # Create a new checkpoint to avoid internal ckpt counter to increment
      tmp_ckpt = tf.train.Checkpoint(epoch=latest_epoch, model=model.network)
      tmp_manager = tf.train.CheckpointManager(tmp_ckpt, checkpoint_dir, None)
      save_path = tmp_manager.save(0)
      logging.info('Saved checkpoint for epoch %d: %s',
                   int(latest_epoch.numpy()), save_path)
    latest_epoch.assign_add(1)

    with logger.summary_writer.as_default():
      for epoch in range(int(latest_epoch.numpy()), FLAGS.num_epochs + 1):
        latest_epoch.assign(epoch)
        n_steps = _train_epoch(epoch, model, dataset, logger,
                               number_of_steps_previous_epochs,
                               max_num_steps_epoch, input_image_size)
        number_of_steps_previous_epochs.assign_add(n_steps)
        save_path = manager.save()
        logging.info('Saved checkpoint for epoch %d: %s',
                     int(latest_epoch.numpy()), save_path)