def visualize_optimal_poses()

in phosa/pose_optimization.py [0:0]


def visualize_optimal_poses(model, image_crop, mask, score=0):
    """
    Visualizes the 8 best-scoring object poses.

    Args:
        model (PoseOptimizer).
        image_crop (H x H x 3).
        mask (M x M x 3).
        score (float): Mask confidence score (optional).
    """
    num_vis = 8
    rotations = model.rotations
    translations = model.translations
    verts = model.vertices[0]
    faces = model.faces[0]
    loss_dict, sil = model()
    losses = sum(loss_dict.values())
    K_roi = model.renderer.K
    inds = torch.argsort(losses)[:num_vis]
    obj_renderer = PerspectiveRenderer()

    fig = plt.figure(figsize=((10, 4)))
    ax1 = fig.add_subplot(2, 5, 1)
    ax1.imshow(image_crop)
    ax1.axis("off")
    ax1.set_title("Cropped Image")

    ax2 = fig.add_subplot(2, 5, 2)
    ax2.imshow(mask)
    ax2.axis("off")
    if score > 0:
        ax2.set_title(f"Mask Conf: {score:.2f}")
    else:
        ax2.set_title("Mask")

    for i, ind in enumerate(inds.cpu().numpy()):
        ax = fig.add_subplot(2, 5, i + 3)
        ax.imshow(
            obj_renderer(
                vertices=verts,
                faces=faces,
                image=image_crop,
                translation=translations[ind],
                rotation=rot6d_to_matrix(rotations)[ind],
                color_name="red",
                K=K_roi,
            )
        )
        ax.set_title(f"Rank {i}: {losses[ind]:.1f}")
        ax.axis("off")
    plt.show()