in src/features.py [0:0]
def batch(self, data, **kwargs) -> (torch.tensor, torch.tensor):
poses = data[DatasetKeyConstants.image_pose]
rotations = data[DatasetKeyConstants.image_rotation]
directions = data[DatasetKeyConstants.ray_directions_samples]
depth_image = None
if DatasetKeyConstants.depth_image_samples in data:
depth_image = data[DatasetKeyConstants.depth_image_samples]
n_images = poses.shape[0]
n_samples = directions.shape[1]
prev_outs = kwargs.get('prev_outs', None)
is_inference = kwargs.get('is_inference', False)
depth = None
# If we have prev_outs, use the depth from that ray
# Otherwise (default) use the GT depth
if prev_outs is not None and len(prev_outs) > 0 and (not self.train_with_gt_depth or is_inference is True):
depth = prev_outs[-1][FeatureSetKeyConstants.network_output]
else:
if (not is_inference or len(prev_outs) == 0) and depth_image is not None:
depth = depth_image
z_vals = self.z_sampler.generate(n_images * n_samples, poses.device, depth=depth, depth_range=self.depth_range,
depth_transform=self.depth_transform,
det=self.deterministic_sampling or is_inference)
if self.perturb and is_inference is False:
# get intervals between samples
mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
upper = torch.cat([mids, z_vals[..., -1:]], -1)
lower = torch.cat([z_vals[..., :1], mids], -1)
# stratified samples in those intervals
t_rand = torch.rand(z_vals.shape, device=z_vals.device)
z_vals = lower + (upper - lower) * t_rand
ray_origins = None
ray_directions = None
if prev_outs is not None and len(prev_outs) > 0:
p_out = prev_outs[-1]
if FeatureSetKeyConstants.input_feature_ray_origins in p_out:
ray_origins = p_out[FeatureSetKeyConstants.input_feature_ray_origins]
if FeatureSetKeyConstants.input_feature_ray_directions in p_out:
ray_directions = p_out[FeatureSetKeyConstants.input_feature_ray_directions]
if ray_directions is None:
ray_directions = nerf_get_ray_dirs(rotations, directions)
if ray_origins is None:
# Now that we have the directions for the chosen images, we need the
# corresponding poses to then generate all samples
ray_origins = tile(poses, dim=0, n_tile=n_samples).reshape(n_images * n_samples, -1)
ray_sample_positions = (ray_origins[..., None, :] + ray_directions[..., None, :] * z_vals[..., :, None])
if len(self.rayMarchNormalizationCenter) == 3:
ray_sample_positions = self.normalizationFunction(ray_sample_positions,
torch.tensor(self.rayMarchNormalizationCenter,
device=ray_sample_positions.device),
self.max_depth)
else:
ray_sample_positions = self.normalizationFunction(ray_sample_positions, self.view_cell_center,
self.max_depth)
# Reshape to positions only and do positional encoding
inputs = ray_sample_positions
inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
embedded = self.pos_enc.encode(inputs_flat)
input_dirs = ray_directions[:, None].expand(inputs.shape)
input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
embedded_dirs = self.dir_enc.encode(input_dirs_flat)
embedded = torch.cat([embedded, embedded_dirs], -1)
embedded = torch.reshape(embedded, [-1, self.n_ray_samples, embedded.shape[-1]])
ret_dict = {FeatureSetKeyConstants.input_feature_batch: embedded,
FeatureSetKeyConstants.nerf_input_feature_z_vals: z_vals,
FeatureSetKeyConstants.nerf_input_feature_ray_directions: ray_directions,
FeatureSetKeyConstants.nerf_input_feature_ray_origins: ray_origins}
if not is_inference and depth_image is not None:
ret_dict[FeatureSetKeyConstants.input_depth_groundtruth] = depth_image
ret_dict[FeatureSetKeyConstants.input_depth_groundtruth_world] = self.depth_transform.to_world(
depth_image, self.depth_range)
ret_dict[FeatureSetKeyConstants.input_depth_range] = torch.tensor(self.depth_range)
ret_dict[FeatureSetKeyConstants.input_depth] = depth
return ret_dict