in contactopt/diffcontact.py [0:0]
def capsule_sdf(mesh_verts, mesh_normals, query_points, query_normals, caps_rad, caps_top, caps_bot, foreach_on_mesh):
"""
Find the SDF of query points to mesh verts
Capsule SDF formulation from https://iquilezles.org/www/articles/distfunctions/distfunctions.htm
:param mesh_verts: (batch, V, 3)
:param mesh_normals: (batch, V, 3)
:param query_points: (batch, Q, 3)
:param caps_rad: scalar, radius of capsules
:param caps_top: scalar, distance from mesh to top of capsule
:param caps_bot: scalar, distance from mesh to bottom of capsule
:param foreach_on_mesh: boolean, foreach point on mesh find closest query (V), or foreach query find closest mesh (Q)
:return: normalized sdf + 1 (batch, V or Q)
"""
# TODO implement normal check?
if foreach_on_mesh: # Foreach mesh vert, find closest query point
knn_dists, nearest_idx, nearest_pos = pytorch3d.ops.knn_points(mesh_verts, query_points, K=1, return_nn=True) # TODO should attract capsule middle?
capsule_tops = mesh_verts + mesh_normals * caps_top
capsule_bots = mesh_verts + mesh_normals * caps_bot
delta_top = nearest_pos[:, :, 0, :] - capsule_tops
normal_dot = torch.sum(mesh_normals * batched_index_select(query_normals, 1, nearest_idx.squeeze(2)), dim=2)
else: # Foreach query vert, find closest mesh point
knn_dists, nearest_idx, nearest_pos = pytorch3d.ops.knn_points(query_points, mesh_verts, K=1, return_nn=True) # TODO should attract capsule middle?
closest_mesh_verts = batched_index_select(mesh_verts, 1, nearest_idx.squeeze(2)) # Shape (batch, V, 3)
closest_mesh_normals = batched_index_select(mesh_normals, 1, nearest_idx.squeeze(2)) # Shape (batch, V, 3)
capsule_tops = closest_mesh_verts + closest_mesh_normals * caps_top # Coordinates of the top focii of the capsules (batch, V, 3)
capsule_bots = closest_mesh_verts + closest_mesh_normals * caps_bot
delta_top = query_points - capsule_tops
normal_dot = torch.sum(query_normals * closest_mesh_normals, dim=2)
bot_to_top = capsule_bots - capsule_tops # Vector from capsule bottom to top
along_axis = torch.sum(delta_top * bot_to_top, dim=2) # Dot product
top_to_bot_square = torch.sum(bot_to_top * bot_to_top, dim=2)
h = torch.clamp(along_axis / top_to_bot_square, 0, 1) # Could avoid NaNs with offset in division here
dist_to_axis = torch.norm(delta_top - bot_to_top * h.unsqueeze(2), dim=2) # Distance to capsule centerline
return dist_to_axis / caps_rad, normal_dot # (Normalized SDF)+1 0 on endpoint, 1 on edge of capsule