def val()

in tensorflow_graphics/projects/points_to_3Dobjects/train_multi_objects/train.py [0:0]


def val(gpu_ids=None, record_losses=False, split='val', part_id=-2):
  """Val function."""
  FLAGS.batch_size = 1

  strategy = tf.distribute.MirroredStrategy(tf_utils.get_devices(gpu_ids))
  logging.info('Number of devices: %d', strategy.num_replicas_in_sync)
  shape_centers, shape_sdfs, shape_pointclouds, dict_clusters = \
      get_shapes('scannet' in FLAGS.tfrecords_dir)
  soft_shape_labels = get_soft_shape_labels(shape_sdfs)
  part = '*.tfrecord' if part_id == -2 else \
      '-'+str(part_id).zfill(5)+'-of-00100.tfrecord'
  dataset = get_dataset(split+part, soft_shape_labels, shape_pointclouds)

  # for sample in dataset:
  #   plt.imshow(sample['image'])
  #   plt.savefig('/usr/local/google/home/engelmann/res/'+sample['scene_filename'].numpy().decode()+'.png')

  val_evaluator = get_evaluator()

  with strategy.scope():
    name = 'eval_'+str(split)
    work_unit = None
    logging_dir = os.path.join(FLAGS.logdir, 'logging')
    logger = logger_util.Logger(logging_dir, name, work_unit,
                                FLAGS.xmanager_metric,
                                save_loss_tensorboard_frequency=10,
                                print_loss_frequency=1000)
    epoch = tf.Variable(0, trainable=False)
    latest_epoch = tf.Variable(-1, trainable=False)
    num_epochs = tf.Variable(-1, trainable=False)
    number_of_steps_previous_epochs = \
        tf.Variable(0, trainable=False, dtype=tf.int64)

    model = get_model(shape_centers,
                      shape_sdfs,
                      shape_pointclouds,
                      dict_clusters)

    transforms = {'name': 'centernet_preprocessing',
                  'params': {'image_size': (FLAGS.image_height,
                                            FLAGS.image_width),
                             'transform_gt_annotations': True,
                             'random': False}}
    train_targets = {'name': 'centernet_train_targets',
                     'params': {'num_classes': FLAGS.num_classes,
                                'image_size': (FLAGS.image_height,
                                               FLAGS.image_width),
                                'stride': model.output_stride}}
    transform_fn = transforms_factory.TransformsFactory.get_transform_group(
        **transforms)
    train_targets_fn = transforms_factory.TransformsFactory.get_transform_group(
        **train_targets)
    input_image_size = transforms['params']['image_size']

    dataset = dataset.map(transform_fn, num_parallel_calls=FLAGS.num_workers)
    dataset = dataset.batch(FLAGS.batch_size)

    # for k in ['name', 'scene_filename', 'mesh_names', 'classes', 'image',
    #           'image_data', 'original_image_spatial_shape', 'num_boxes',
    #           'center2d', 'groundtruth_boxes', 'dot', 'sizes_3d',
    #           'translations_3d', 'rotations_3d', 'rt', 'k',
    #           'groundtruth_valid_classes', 'shapes']:
    #   print('---', k)
    #   for i, sample in enumerate(dataset.take(7)):
    #     print(sample[k].shape)
    # train_targets_fn(sample)
    # for i, sample in enumerate(dataset):
    #   print(i)
    #   train_targets_fn(sample)

    if train_targets_fn is not None:
      dataset = dataset.map(train_targets_fn,
                            num_parallel_calls=FLAGS.num_workers)

    if FLAGS.debug and False:
      for d in dataset.take(1):
        image = tf.io.decode_image(d['image_data'][0]).numpy()
        heatmaps = d['centers'][0]
        plot.plot_gt_heatmaps(image, heatmaps)

    if tf.distribute.has_strategy():
      strategy = tf.distribute.get_strategy()
      dataset = strategy.experimental_distribute_dataset(dataset)
      if transforms is not None and input_image_size is None:
        if FLAGS.run_graph:
          FLAGS.run_graph = False
          logging.info('Graph mode has been disable because the input does'
                       'not have constant size.')
        if FLAGS.batch_size > strategy.num_replicas_in_sync:
          raise ValueError('Batch size cannot be bigger than the number of GPUs'
                           ' when the input does not have constant size')

    val_checkpoint_dir = os.path.join(FLAGS.logdir, f'{name}_ckpts')
    val_checkpoint = tf.train.Checkpoint(
        epoch=latest_epoch,
        number_of_steps_previous_epochs=number_of_steps_previous_epochs)
    val_manager = tf.train.CheckpointManager(
        val_checkpoint, val_checkpoint_dir, max_to_keep=1)
    if val_manager.latest_checkpoint:
      val_checkpoint.restore(val_manager.latest_checkpoint)

    train_checkpoint_dir = os.path.join(FLAGS.logdir, 'training_ckpts')
    if FLAGS.replication:
      train_checkpoint_dir = os.path.join(train_checkpoint_dir, 'r=30')

    train_checkpoint = tf.train.Checkpoint(epoch=epoch, model=model.network,
                                           num_epochs=num_epochs)
    latest_checkpoint = ''

    if FLAGS.master == 'local' or FLAGS.plot:
      local_dump = os.path.join(FLAGS.logdir, 'images')
      if not tf.io.gfile.exists(local_dump):
        tf.io.gfile.makedirs(local_dump)

    with logger.summary_writer.as_default():
      while True:
        curr_latest_checkpoint = \
            tf.train.latest_checkpoint(train_checkpoint_dir)
        if (curr_latest_checkpoint is not None and
            latest_checkpoint != curr_latest_checkpoint):
          latest_checkpoint = curr_latest_checkpoint
          train_checkpoint.restore(curr_latest_checkpoint)
          if epoch != latest_epoch or FLAGS.eval_only:
            FLAGS.eval_only = False
            logging.info('Evaluating checkpoint in %s: %s.',
                         name, latest_checkpoint)
            n_steps = _val_epoch(name, model, dataset, input_image_size,
                                 epoch.numpy(), logger,
                                 number_of_steps_previous_epochs,
                                 val_evaluator, record_losses)
            number_of_steps_previous_epochs.assign_add(n_steps)

            latest_epoch.assign(epoch.numpy())
            if part_id < -1:
              val_manager.save()
            else:
              return
        if epoch == num_epochs:
          break
        time.sleep(1)