in shapenet/data/mesh_vox.py [0:0]
def __getitem__(self, idx):
sid = self.synset_ids[idx]
mid = self.model_ids[idx]
iid = self.image_ids[idx]
# Always read metadata for this model; TODO cache in __init__?
metadata_path = os.path.join(self.data_dir, sid, mid, "metadata.pt")
with PathManager.open(metadata_path, "rb") as f:
metadata = torch.load(f)
K = metadata["intrinsic"]
RT = metadata["extrinsics"][iid]
img_path = metadata["image_list"][iid]
img_path = os.path.join(self.data_dir, sid, mid, "images", img_path)
# Load the image
with PathManager.open(img_path, "rb") as f:
img = Image.open(f).convert("RGB")
img = self.transform(img)
# Maybe read mesh
verts, faces = None, None
if self.return_mesh:
mesh_path = os.path.join(self.data_dir, sid, mid, "mesh.pt")
with PathManager.open(mesh_path, "rb") as f:
mesh_data = torch.load(f)
verts, faces = mesh_data["verts"], mesh_data["faces"]
verts = project_verts(verts, RT)
# Maybe use cached samples
points, normals = None, None
if not self.sample_online:
samples = self.mid_to_samples.get(mid, None)
if samples is None:
# They were not cached in memory, so read off disk
samples_path = os.path.join(self.data_dir, sid, mid, "samples.pt")
with PathManager.open(samples_path, "rb") as f:
samples = torch.load(f)
points = samples["points_sampled"]
normals = samples["normals_sampled"]
idx = torch.randperm(points.shape[0])[: self.num_samples]
points, normals = points[idx], normals[idx]
points = project_verts(points, RT)
normals = normals.mm(RT[:3, :3].t()) # Only rotate, don't translate
voxels, P = None, None
if self.voxel_size > 0:
# Use precomputed voxels if we have them, otherwise return voxel_coords
# and we will compute voxels in postprocess
voxel_file = "vox%d/%03d.pt" % (self.voxel_size, iid)
voxel_file = os.path.join(self.data_dir, sid, mid, voxel_file)
if PathManager.isfile(voxel_file):
with PathManager.open(voxel_file, "rb") as f:
voxels = torch.load(f)
else:
voxel_path = os.path.join(self.data_dir, sid, mid, "voxels.pt")
with PathManager.open(voxel_path, "rb") as f:
voxel_data = torch.load(f)
voxels = voxel_data["voxel_coords"]
P = K.mm(RT)
id_str = "%s-%s-%02d" % (sid, mid, iid)
return img, verts, faces, points, normals, voxels, P, id_str