def _val_epoch()

in tensorflow_graphics/projects/points_to_3Dobjects/train_multi_objects/train.py [0:0]


def _val_epoch(
    name,
    model,
    dataset,
    input_image_size,
    epoch,
    logger,
    number_of_steps_previous_epochs,
    evaluator: evaluator_util.Evaluator,
    record_loss=False):
  """Validation epoch."""

  if name:
    print(name)

  if FLAGS.part_id > -2:
    record_loss = False
  strategy = tf.distribute.get_strategy()

  def distributed_step(sample):
    training = False
    output, loss = strategy.run(model.test_sample,
                                args=(sample, record_loss, training))
    losses_value = {}
    if record_loss:
      for key, value in loss.items():
        losses_value[key] = strategy.reduce(
            tf.distribute.ReduceOp.SUM, value, axis=None)
    return output, losses_value

  if FLAGS.run_graph:
    distributed_step = tf.function(distributed_step)

  logger.reset_losses()
  evaluator.reset_metrics()

  dataset_iterator = iter(dataset)
  n_steps = tf.constant(0, dtype=tf.int64)

  while True:
    logging.info('val %d', int(n_steps.numpy()))
    start_time = time.time()
    sample = tf_utils.get_next_sample_dataset(dataset_iterator)
    if sample is None or tf_utils.compute_batch_size(sample) == 0:
      break
    n_steps += tf.cast(tf_utils.compute_batch_size(sample), tf.int64)
    logger.record_scalar('meta/time_read',
                         time.time() - start_time,
                         n_steps + number_of_steps_previous_epochs)

    start_time = time.time()
    outputs, losses = distributed_step(sample)
    logger.record_scalar('meta/forward_pass', time.time() - start_time,
                         n_steps + number_of_steps_previous_epochs)
    status = False
    if status:
      model_path = '/usr/local/google/home/engelmann/saved_model'
      model.network.save(model_path, save_format='tf')
      new_model = tf.keras.models.load_model(model_path)
      new_model.summary()

    start_time = time.time()
    if record_loss:
      logger.record_losses('iterations/', losses,
                           n_steps + number_of_steps_previous_epochs)
    outputs = outputs[-1]  # only take outputs from last hourglass
    batch_id = 0

    # We assume batch_size=1 here.
    detections = model.postprocess_sample2(input_image_size, sample, outputs)

    logger.record_scalar('meta/post_processing',
                         time.time() - start_time,
                         n_steps + number_of_steps_previous_epochs)
    tmp_sample = {k: v[0] for k, v in sample.items()}
    result_dict = evaluator.add_detections(tmp_sample, detections)
    iou_mean, iou_min = result_dict['iou_mean'], result_dict['iou_min']

    if (FLAGS.master == 'local' or FLAGS.plot) and not FLAGS.francis and \
      n_steps < tf.constant(13, dtype=tf.int64) and FLAGS.part_id < -1:

      # Plot 3D
      if FLAGS.local_plot_3d:
        logdir = os.path.join(
            '..', os.path.join(*(FLAGS.logdir.split(os.path.sep)[5:])),
            'plots3d', str(sample['scene_filename'][batch_id].numpy())[2:-1])
        logging.info(logdir)
        plot.plot_detections_3d(detections, sample, logdir, model.dict_clusters)

      # Plot 2D
      image = tf.io.decode_image(sample['image_data'][batch_id]).numpy()
      figure_heatmaps = plot.plot_to_image(plot.plot_heatmaps(
          image, detections))
      figure_boxes_2d = plot.plot_to_image(plot.plot_boxes_2d(
          image, sample, detections))
      figure_boxes_3d = plot.plot_to_image(plot.plot_boxes_3d(
          image, sample, detections))

      total_steps = n_steps + number_of_steps_previous_epochs
      tf.summary.image('Heatmaps', figure_heatmaps, total_steps)
      tf.summary.image('Boxes 2D', figure_boxes_2d, total_steps)
      tf.summary.image('Boxes 3D', figure_boxes_3d, total_steps)

    if (FLAGS.part_id > -1 and FLAGS.qualitative) or FLAGS.francis or True:
      logdir = FLAGS.logdir

      if FLAGS.francis:
        logdir = os.path.join(FLAGS.qualidir, 'francis')

      path_input = os.path.join(logdir, 'qualitative', 'img')
      path_blender = os.path.join(logdir, 'qualitative', 'blender2')
      path_2d_min = os.path.join(logdir, 'qualitative', 'img_2d_min')
      path_2d_mean = os.path.join(logdir, 'qualitative', 'img_2d_mean')
      path_3d_min = os.path.join(logdir, 'qualitative', 'img_3d_min')
      path_3d_mean = os.path.join(logdir, 'qualitative', 'img_3d_mean')

      tf.io.gfile.makedirs(path_input)
      tf.io.gfile.makedirs(path_blender)
      tf.io.gfile.makedirs(path_2d_min)
      tf.io.gfile.makedirs(path_2d_mean)
      tf.io.gfile.makedirs(path_3d_min)
      tf.io.gfile.makedirs(path_3d_mean)

      scene_name = \
          str(sample['scene_filename'][0].numpy(), 'utf-8').split('.')[0]
      iou_min_str = f'{iou_min:.5f}' if iou_min >= 0 else '0'
      iou_mean_str = f'{iou_mean:.5f}' if iou_mean >= 0 else '0'
      image = tf.io.decode_image(sample['image_data'][batch_id]).numpy()

      # Plot original image
      _ = plt.figure(figsize=(5, 5))
      plt.clf()
      plt.imshow(image)
      filepath_input = os.path.join(path_input, scene_name+'.png')
      with tf.io.gfile.GFile(filepath_input, 'wb') as f:
        plt.savefig(f)

      # Plot image 2D bounding boxes
      plot.plot_boxes_2d(image, sample, detections,
                         groundtruth=(not FLAGS.francis))
      filepath_2d_min = \
          os.path.join(path_2d_min, iou_min_str+'_'+scene_name+'.png')
      filepath_2d_mean = \
          os.path.join(path_2d_mean, iou_mean_str+'_'+scene_name+'.png')
      for path in [filepath_2d_min, filepath_2d_mean]:
        with tf.io.gfile.GFile(path, 'wb') as f:
          plt.savefig(f)

      # Plot image 3D bounding boxes
      plot.plot_boxes_3d(image,
                         sample,
                         detections,
                         groundtruth=(not FLAGS.francis))
      filepath_3d_min = \
          os.path.join(path_3d_min, iou_min_str+'_'+scene_name+'.png')
      filepath_3d_mean = \
          os.path.join(path_3d_mean, iou_mean_str+'_'+scene_name+'.png')
      for path in [filepath_3d_min, filepath_3d_mean]:
        with tf.io.gfile.GFile(path, 'wb') as f:
          plt.savefig(f)

      if FLAGS.local_plot_3d:
        # Plot 3D visualizer
        path = os.path.join(
            '..', os.path.join(*(logdir.split(os.path.sep)[6:])),
            'qualitative', 'web_3d_min', iou_min_str+'_'+scene_name)
        plot.plot_detections_3d(detections,
                                sample,
                                path,
                                model.dict_clusters,
                                local=FLAGS.francis)
        path = os.path.join(
            '..', os.path.join(*(logdir.split(os.path.sep)[6:])),
            'qualitative', 'web_3d_mean', iou_mean_str+'_'+scene_name)
        plot.plot_detections_3d(detections,
                                sample,
                                path,
                                model.dict_clusters,
                                local=FLAGS.francis)

      # Save pickels for plotting in blender
      path_blender_file = os.path.join(path_blender, scene_name)
      plot.save_for_blender(detections, sample, path_blender_file,
                            model.dict_clusters, model.shape_pointclouds,
                            local=FLAGS.francis)

  if record_loss:
    logger.record_losses_epoch('epoch/', epoch)

  metrics = evaluator.evaluate()
  if record_loss:
    logger.record_dictionary_scalars('metrics/', metrics, epoch)
  # mAP3Ds = ['3D_mAP_50', '3D_mAP_60', '3D_mAP_70', '3D_mAP_80', '3D_mAP_90']
  # mAP3D = np.mean(np.array([metrics[v] for v in mAP3Ds]))
  # logger.record_scalar('metrics/3D_mAP', mAP3D, epoch)
  # mAP2Ds = ['2D_mAP_50', '2D_mAP_60', '2D_mAP_70', '2D_mAP_80', '2D_mAP_90']
  # mAP2D = np.mean(np.array([metrics[v] for v in mAP2Ds]))
  # logger.record_scalar('metrics/2D_mAP', mAP2D, epoch)
  # else:
  #   stats = dataset.evaluate_evaluator()
  #   logger.record_dictionary_scalars(f'{name}_', stats, epoch)
  return n_steps