in tensorflow_graphics/projects/points_to_3Dobjects/utils/evaluator.py [0:0]
def update(self, labeled_sdfs, labeled_classes, labeled_poses,
predicted_sdfs, predicted_classes, predicted_poses):
"""Update."""
labeled_rotations = labeled_poses[0]
labeled_translations = labeled_poses[1]
labeled_sizes = labeled_poses[2]
status = True
if status:
box_limits_x = [100, -100]
# box_limits_y = [100, -100]
box_limits_z = [100, -100]
for i in range(labeled_translations.shape[0]):
rot = tf.reshape(tf.gather(labeled_rotations[i], [0, 2, 6, 8]), [2, 2])
min_x = tf.cast(0.0 - labeled_sizes[i][0] / 2.0, dtype=tf.float32)
max_x = tf.cast(0.0 + labeled_sizes[i][0] / 2.0, dtype=tf.float32)
# min_y = tf.cast(0.0 - labeled_sizes[i][1] / 2.0, dtype=tf.float32)
# max_y = tf.cast(0.0 + labeled_sizes[i][1] / 2.0, dtype=tf.float32)
min_z = tf.cast(0.0 - labeled_sizes[i][2] / 2.0, dtype=tf.float32)
max_z = tf.cast(0.0 + labeled_sizes[i][2] / 2.0, dtype=tf.float32)
translation = tf.reshape([labeled_translations[i][0],
labeled_translations[i][2]], [2, 1])
pt_0 = rot @ tf.reshape([min_x, min_z], [2, 1]) + translation
pt_1 = rot @ tf.reshape([min_x, max_z], [2, 1]) + translation
pt_2 = rot @ tf.reshape([max_x, min_z], [2, 1]) + translation
pt_3 = rot @ tf.reshape([max_x, max_z], [2, 1]) + translation
for pt in [pt_0, pt_1, pt_2, pt_3]:
if pt[0] < box_limits_x[0]:
box_limits_x[0] = pt[0]
if pt[0] > box_limits_x[1]:
box_limits_x[1] = pt[0]
if pt[1] < box_limits_z[0]:
box_limits_z[0] = pt[1]
if pt[1] > box_limits_z[1]:
box_limits_z[1] = pt[1]
mean_x = tf.reduce_mean(box_limits_x)
mean_z = tf.reduce_mean(box_limits_z)
else:
mean_x = tf.reduce_mean(labeled_translations[:, 0])
mean_z = tf.reduce_mean(labeled_translations[:, 2])
samples_world = grid.generate(
(mean_x - 0.5, 0.0, mean_z - 0.5), (mean_x + 0.5, 1.0, mean_z + 0.5),
[self.resolution, self.resolution, self.resolution])
# samples_world = grid.generate(
# (box_limits_x[0][0], box_limits_y[0], box_limits_z[0][0]),
# (box_limits_x[1][0], box_limits_y[1], box_limits_z[1][0]),
# [self.resolution, self.resolution, self.resolution])
# samples_world = grid.generate(
# (-5.0, -5.0, -5.0),
# (5.0, 5.0, 5.0),
# [self.resolution, self.resolution, self.resolution])
samples_world = tf.reshape(samples_world, [-1, 3])
ious = []
status = False
if status:
_, axs = plt.subplots(labeled_translations.shape[0], 5)
fig_obj_count = 0
for class_id in range(self.max_num_classes):
# Do the same for the ground truth and predictions
sdf_values = tf.zeros_like(samples_world)[:, 0:1]
for mtype, (classes, sdfs, poses) in enumerate([
(labeled_classes, labeled_sdfs, labeled_poses),
(predicted_classes, predicted_sdfs, predicted_poses)]):
for i in range(classes.shape[0]):
if class_id == classes[i]:
sdf = tf.expand_dims(sdfs[i], -1)
sdf = sdf * -1.0 # inside positive, outside zero
samples_object = centernet_utils.transform_pointcloud(
tf.reshape(samples_world, [1, 1, -1, 3]),
tf.reshape(poses[2][i], [1, 1, 3]),
tf.reshape(poses[0][i], [1, 1, 3, 3]),
tf.reshape(poses[1][i], [1, 1, 3]), inverse=True) * 2.0
samples_object = \
(samples_object * (29.0/32.0) / 2.0 + 0.5) * 32.0 - 0.5
samples = tf.squeeze(samples_object)
interpolated = trilinear.interpolate(sdf, samples)
sdf_values += tf.math.sign(tf.nn.relu(interpolated + self.tol))
status2 = False
if status2:
a = 2
values = interpolated
inter = tf.reshape(values, [self.resolution,
self.resolution,
self.resolution])
inter = tf.transpose(tf.reduce_max(inter, axis=a))
im = axs[fig_obj_count, mtype * 2 + 0].matshow(inter.numpy())
plt.colorbar(im, ax=axs[fig_obj_count, mtype * 2 + 0])
print(mtype, fig_obj_count, 0)
values = tf.math.sign(tf.nn.relu(interpolated + self.tol))
inter = tf.reshape(values, [self.resolution,
self.resolution,
self.resolution])
inter = tf.transpose(tf.reduce_max(inter, axis=a))
im = axs[fig_obj_count, mtype * 2 + 1].matshow(inter.numpy())
plt.colorbar(im, ax=axs[fig_obj_count, mtype * 2 + 1])
print(mtype, fig_obj_count, 1)
if mtype == 1:
values = sdf_values
inter = tf.reshape(values, [self.resolution,
self.resolution,
self.resolution])
inter = tf.transpose(tf.reduce_max(inter, axis=a))
im = axs[fig_obj_count, 4].matshow(inter.numpy())
plt.colorbar(im, ax=axs[fig_obj_count, 4])
print(mtype, fig_obj_count, 2)
fig_obj_count += 1
intersection = tf.reduce_sum(tf.math.sign(tf.nn.relu(sdf_values - 1)))
union = tf.reduce_sum(tf.math.sign(sdf_values))
iou = intersection / union
if not tf.math.is_nan(iou):
ious.append(iou)
status3 = False
if status3:
_ = plt.figure(figsize=(5, 5))
plt.clf()
# mask = (sdf_values.numpy() > 0)[:, 0]
# plt.scatter(samples_world.numpy()[mask, 0],
# samples_world.numpy()[mask, 1],
# marker='.', c=sdf_values.numpy()[mask, 0])
plt.scatter(samples_world.numpy()[:, 0],
samples_world.numpy()[:, 1],
marker='.', c=sdf_values.numpy()[:, 0])
plt.colorbar()
if not tf.math.is_nan(iou):
self.iou_per_class[class_id].append(iou)
if ious:
ious = [0]
return np.mean(ious), np.min(ious)