in shap_e/rendering/raycast/cast.py [0:0]
def cast_rays(rays: Rays, mesh: TriMesh, checkpoint: bool = False) -> RayCollisions:
"""
Cast a batch of rays onto a mesh.
"""
if checkpoint:
collides, ray_dists, tri_indices, barycentric, normals = RayCollisionFunction.apply(
rays.origins, rays.directions, mesh.faces, mesh.vertices
)
return RayCollisions(
collides=collides,
ray_dists=ray_dists,
tri_indices=tri_indices,
barycentric=barycentric,
normals=normals,
)
# https://github.com/unixpickle/vae-textures/blob/2968549ddd4a3487f9437d4db00793324453cd59/vae_textures/render.py#L98
normals = mesh.normals() # [N x 3]
directions = rays.directions # [M x 3]
collides = (directions @ normals.T).abs() > 1e-8 # [N x M]
tris = mesh.vertices[mesh.faces] # [N x 3 x 3]
v1 = tris[:, 1] - tris[:, 0]
v2 = tris[:, 2] - tris[:, 0]
cross1 = cross_product(directions[:, None], v2[None]) # [N x M x 3]
det = torch.sum(cross1 * v1[None], dim=-1) # [N x M]
collides = torch.logical_and(collides, det.abs() > 1e-8)
invDet = 1 / det # [N x M]
o = rays.origins[:, None] - tris[None, :, 0] # [N x M x 3]
bary1 = invDet * torch.sum(o * cross1, dim=-1) # [N x M]
collides = torch.logical_and(collides, torch.logical_and(bary1 >= 0, bary1 <= 1))
cross2 = cross_product(o, v1[None]) # [N x M x 3]
bary2 = invDet * torch.sum(directions[:, None] * cross2, dim=-1) # [N x M]
collides = torch.logical_and(collides, torch.logical_and(bary2 >= 0, bary2 <= 1))
bary0 = 1 - (bary1 + bary2)
# Make sure this is in the positive part of the ray.
scale = invDet * torch.sum(v2 * cross2, dim=-1)
collides = torch.logical_and(collides, scale > 0)
# Select the nearest collision
ray_dists, tri_indices = torch.min(
torch.where(collides, scale, torch.tensor(torch.inf).to(scale)), dim=-1
) # [N]
nearest_bary = torch.stack(
[
bary0[range(len(tri_indices)), tri_indices],
bary1[range(len(tri_indices)), tri_indices],
bary2[range(len(tri_indices)), tri_indices],
],
dim=-1,
)
return RayCollisions(
collides=torch.any(collides, dim=-1),
ray_dists=ray_dists,
tri_indices=tri_indices,
barycentric=nearest_bary,
normals=normals[tri_indices],
)