in phosa/pose_optimization.py [0:0]
def visualize_optimal_poses(model, image_crop, mask, score=0):
"""
Visualizes the 8 best-scoring object poses.
Args:
model (PoseOptimizer).
image_crop (H x H x 3).
mask (M x M x 3).
score (float): Mask confidence score (optional).
"""
num_vis = 8
rotations = model.rotations
translations = model.translations
verts = model.vertices[0]
faces = model.faces[0]
loss_dict, sil = model()
losses = sum(loss_dict.values())
K_roi = model.renderer.K
inds = torch.argsort(losses)[:num_vis]
obj_renderer = PerspectiveRenderer()
fig = plt.figure(figsize=((10, 4)))
ax1 = fig.add_subplot(2, 5, 1)
ax1.imshow(image_crop)
ax1.axis("off")
ax1.set_title("Cropped Image")
ax2 = fig.add_subplot(2, 5, 2)
ax2.imshow(mask)
ax2.axis("off")
if score > 0:
ax2.set_title(f"Mask Conf: {score:.2f}")
else:
ax2.set_title("Mask")
for i, ind in enumerate(inds.cpu().numpy()):
ax = fig.add_subplot(2, 5, i + 3)
ax.imshow(
obj_renderer(
vertices=verts,
faces=faces,
image=image_crop,
translation=translations[ind],
rotation=rot6d_to_matrix(rotations)[ind],
color_name="red",
K=K_roi,
)
)
ax.set_title(f"Rank {i}: {losses[ind]:.1f}")
ax.axis("off")
plt.show()