in threestudio/models/isosurface.py [0:0]
def _forward(self, pos_nx3, sdf_n, tet_fx4):
with torch.no_grad():
occ_n = sdf_n > 0
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
occ_sum = torch.sum(occ_fx4, -1)
valid_tets = (occ_sum > 0) & (occ_sum < 4)
occ_sum = occ_sum[valid_tets]
# find all vertices
all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2)
all_edges = self.sort_edges(all_edges)
unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
unique_edges = unique_edges.long()
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
mapping = (
torch.ones(
(unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device
)
* -1
)
mapping[mask_edges] = torch.arange(
mask_edges.sum(), dtype=torch.long, device=pos_nx3.device
)
idx_map = mapping[idx_map] # map edges to verts
interp_v = unique_edges[mask_edges]
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
edges_to_interp_sdf[:, -1] *= -1
denominator = edges_to_interp_sdf.sum(1, keepdim=True)
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
idx_map = idx_map.reshape(-1, 6)
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device))
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
num_triangles = self.num_triangles_table[tetindex]
# Generate triangle indices
faces = torch.cat(
(
torch.gather(
input=idx_map[num_triangles == 1],
dim=1,
index=self.triangle_table[tetindex[num_triangles == 1]][:, :3],
).reshape(-1, 3),
torch.gather(
input=idx_map[num_triangles == 2],
dim=1,
index=self.triangle_table[tetindex[num_triangles == 2]][:, :6],
).reshape(-1, 3),
),
dim=0,
)
return verts, faces