def update()

in tensorflow_graphics/projects/points_to_3Dobjects/utils/evaluator.py [0:0]


  def update(self, labeled_sdfs, labeled_classes, labeled_poses,
             predicted_sdfs, predicted_classes, predicted_poses):
    """Update."""
    labeled_rotations = labeled_poses[0]
    labeled_translations = labeled_poses[1]
    labeled_sizes = labeled_poses[2]

    status = True
    if status:
      box_limits_x = [100, -100]
      # box_limits_y = [100, -100]
      box_limits_z = [100, -100]
      for i in range(labeled_translations.shape[0]):
        rot = tf.reshape(tf.gather(labeled_rotations[i], [0, 2, 6, 8]), [2, 2])

        min_x = tf.cast(0.0 - labeled_sizes[i][0] / 2.0, dtype=tf.float32)
        max_x = tf.cast(0.0 + labeled_sizes[i][0] / 2.0, dtype=tf.float32)
        # min_y = tf.cast(0.0 - labeled_sizes[i][1] / 2.0, dtype=tf.float32)
        # max_y = tf.cast(0.0 + labeled_sizes[i][1] / 2.0, dtype=tf.float32)
        min_z = tf.cast(0.0 - labeled_sizes[i][2] / 2.0, dtype=tf.float32)
        max_z = tf.cast(0.0 + labeled_sizes[i][2] / 2.0, dtype=tf.float32)

        translation = tf.reshape([labeled_translations[i][0],
                                  labeled_translations[i][2]], [2, 1])

        pt_0 = rot @ tf.reshape([min_x, min_z], [2, 1]) + translation
        pt_1 = rot @ tf.reshape([min_x, max_z], [2, 1]) + translation
        pt_2 = rot @ tf.reshape([max_x, min_z], [2, 1]) + translation
        pt_3 = rot @ tf.reshape([max_x, max_z], [2, 1]) + translation

        for pt in [pt_0, pt_1, pt_2, pt_3]:
          if pt[0] < box_limits_x[0]:
            box_limits_x[0] = pt[0]

          if pt[0] > box_limits_x[1]:
            box_limits_x[1] = pt[0]

          if pt[1] < box_limits_z[0]:
            box_limits_z[0] = pt[1]

          if pt[1] > box_limits_z[1]:
            box_limits_z[1] = pt[1]
      mean_x = tf.reduce_mean(box_limits_x)
      mean_z = tf.reduce_mean(box_limits_z)
    else:
      mean_x = tf.reduce_mean(labeled_translations[:, 0])
      mean_z = tf.reduce_mean(labeled_translations[:, 2])
    samples_world = grid.generate(
        (mean_x - 0.5, 0.0, mean_z - 0.5), (mean_x + 0.5, 1.0, mean_z + 0.5),
        [self.resolution, self.resolution, self.resolution])
    # samples_world = grid.generate(
    #     (box_limits_x[0][0], box_limits_y[0], box_limits_z[0][0]),
    #     (box_limits_x[1][0], box_limits_y[1], box_limits_z[1][0]),
    #     [self.resolution, self.resolution, self.resolution])
    # samples_world = grid.generate(
    #     (-5.0, -5.0, -5.0),
    #     (5.0, 5.0, 5.0),
    #     [self.resolution, self.resolution, self.resolution])
    samples_world = tf.reshape(samples_world, [-1, 3])
    ious = []

    status = False
    if status:
      _, axs = plt.subplots(labeled_translations.shape[0], 5)
      fig_obj_count = 0
    for class_id in range(self.max_num_classes):
      # Do the same for the ground truth and predictions
      sdf_values = tf.zeros_like(samples_world)[:, 0:1]
      for mtype, (classes, sdfs, poses) in enumerate([
          (labeled_classes, labeled_sdfs, labeled_poses),
          (predicted_classes, predicted_sdfs, predicted_poses)]):
        for i in range(classes.shape[0]):
          if class_id == classes[i]:
            sdf = tf.expand_dims(sdfs[i], -1)
            sdf = sdf * -1.0  # inside positive, outside zero
            samples_object = centernet_utils.transform_pointcloud(
                tf.reshape(samples_world, [1, 1, -1, 3]),
                tf.reshape(poses[2][i], [1, 1, 3]),
                tf.reshape(poses[0][i], [1, 1, 3, 3]),
                tf.reshape(poses[1][i], [1, 1, 3]), inverse=True) * 2.0
            samples_object = \
                (samples_object * (29.0/32.0) / 2.0 + 0.5) * 32.0 - 0.5
            samples = tf.squeeze(samples_object)
            interpolated = trilinear.interpolate(sdf, samples)

            sdf_values += tf.math.sign(tf.nn.relu(interpolated + self.tol))
            status2 = False
            if status2:
              a = 2
              values = interpolated
              inter = tf.reshape(values, [self.resolution,
                                          self.resolution,
                                          self.resolution])
              inter = tf.transpose(tf.reduce_max(inter, axis=a))
              im = axs[fig_obj_count, mtype * 2 + 0].matshow(inter.numpy())
              plt.colorbar(im, ax=axs[fig_obj_count, mtype * 2 + 0])
              print(mtype, fig_obj_count, 0)

              values = tf.math.sign(tf.nn.relu(interpolated + self.tol))
              inter = tf.reshape(values, [self.resolution,
                                          self.resolution,
                                          self.resolution])
              inter = tf.transpose(tf.reduce_max(inter, axis=a))
              im = axs[fig_obj_count, mtype * 2 + 1].matshow(inter.numpy())
              plt.colorbar(im, ax=axs[fig_obj_count, mtype * 2 + 1])
              print(mtype, fig_obj_count, 1)

              if mtype == 1:
                values = sdf_values
                inter = tf.reshape(values, [self.resolution,
                                            self.resolution,
                                            self.resolution])
                inter = tf.transpose(tf.reduce_max(inter, axis=a))
                im = axs[fig_obj_count, 4].matshow(inter.numpy())
                plt.colorbar(im, ax=axs[fig_obj_count, 4])
                print(mtype, fig_obj_count, 2)
                fig_obj_count += 1

      intersection = tf.reduce_sum(tf.math.sign(tf.nn.relu(sdf_values - 1)))
      union = tf.reduce_sum(tf.math.sign(sdf_values))
      iou = intersection / union
      if not tf.math.is_nan(iou):
        ious.append(iou)
      status3 = False
      if status3:
        _ = plt.figure(figsize=(5, 5))
        plt.clf()
        # mask = (sdf_values.numpy() > 0)[:, 0]
        # plt.scatter(samples_world.numpy()[mask, 0],
        #             samples_world.numpy()[mask, 1],
        #             marker='.', c=sdf_values.numpy()[mask, 0])

        plt.scatter(samples_world.numpy()[:, 0],
                    samples_world.numpy()[:, 1],
                    marker='.', c=sdf_values.numpy()[:, 0])
        plt.colorbar()
      if not tf.math.is_nan(iou):
        self.iou_per_class[class_id].append(iou)
    if ious:
      ious = [0]
    return np.mean(ious), np.min(ious)