in src/datasets.py [0:0]
def __init__(self, config, train_config, dataset_info, set_name="train", num_samples=2048):
super(FullyLoadedViewCellDataset, self).__init__(config, train_config, dataset_info, set_name, num_samples)
with open(os.path.join(self.dataset_path, f"transforms_{set_name}.json"), "r") as f:
json_data = json.load(f)
self.num_items = len(json_data["frames"])
transforms = None
self.color_images = None
tqdm_range = trange(len(json_data["frames"]), desc=f"Loading dataset {set_name:5s}", leave=True)
for frame_idx in tqdm_range:
frame = json_data["frames"][frame_idx]
file_path = os.path.join(self.dataset_path, frame["file_path"])
file_name = file_path + ".png"
color_image = self.load_color_image(file_name)
depth_image = None
pose = np.array(frame["transform_matrix"]).astype(np.float32)
depth_name = file_path + "_depth.npz"
if os.path.exists(depth_name):
depth_image = self.load_depth_image(depth_name)
if self.color_images is None:
self.color_images = np.zeros((len(self), color_image.shape[0], color_image.shape[1],
color_image.shape[2]), dtype=np.float32)
transforms = np.zeros((len(self), pose.shape[0], pose.shape[1]), dtype=np.float32)
if depth_image is not None:
self.depth_images = np.zeros((len(self), depth_image.shape[1], depth_image.shape[2], 1),
dtype=np.float32)
self.color_images[frame_idx] = color_image
transforms[frame_idx] = pose
if depth_image is not None:
self.depth_images[frame_idx] = depth_image[0]
self.preprocess_pos_and_dir(transforms)
self.color_images = torch.from_numpy(self.color_images).to(self.device)
if self.depth_images is not None:
self.depth_images = torch.from_numpy(self.depth_images).to(self.device)
data = {DatasetKeyConstants.color_image_full: self.color_images,
DatasetKeyConstants.depth_image_full: self.depth_images,
DatasetKeyConstants.image_pose: self.poses,
DatasetKeyConstants.image_rotation: self.rotations,
DatasetKeyConstants.ray_directions: self.directions}
# we now call preprocess on all features to perform necessary preprocess steps
for feature_idx in range(len(self.train_config.f_in)):
f_in = self.train_config.f_in[feature_idx]
f_in.preprocess(data, self.device, self.config)
self.depth_images = data[DatasetKeyConstants.depth_image_full]
f_out = self.train_config.f_out[feature_idx]
f_out.preprocess(data, self.device, self.config)