def batch()

in src/features.py [0:0]


    def batch(self, data, **kwargs) -> (torch.tensor, torch.tensor):
        poses = data[DatasetKeyConstants.image_pose]
        rotations = data[DatasetKeyConstants.image_rotation]
        directions = data[DatasetKeyConstants.ray_directions_samples]

        depth_image = None
        if DatasetKeyConstants.depth_image_samples in data:
            depth_image = data[DatasetKeyConstants.depth_image_samples]

        n_images = poses.shape[0]
        n_samples = directions.shape[1]

        prev_outs = kwargs.get('prev_outs', None)
        is_inference = kwargs.get('is_inference', False)

        depth = None
        # If we have prev_outs, use the depth from that ray
        # Otherwise (default) use the GT depth
        if prev_outs is not None and len(prev_outs) > 0 and (not self.train_with_gt_depth or is_inference is True):
            depth = prev_outs[-1][FeatureSetKeyConstants.network_output]
        else:
            if (not is_inference or len(prev_outs) == 0) and depth_image is not None:
                depth = depth_image

        z_vals = self.z_sampler.generate(n_images * n_samples, poses.device, depth=depth, depth_range=self.depth_range,
                                         depth_transform=self.depth_transform,
                                         det=self.deterministic_sampling or is_inference)

        if self.perturb and is_inference is False:
            # get intervals between samples
            mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
            upper = torch.cat([mids, z_vals[..., -1:]], -1)
            lower = torch.cat([z_vals[..., :1], mids], -1)
            # stratified samples in those intervals
            t_rand = torch.rand(z_vals.shape, device=z_vals.device)
            z_vals = lower + (upper - lower) * t_rand

        ray_origins = None
        ray_directions = None
        if prev_outs is not None and len(prev_outs) > 0:
            p_out = prev_outs[-1]
            if FeatureSetKeyConstants.input_feature_ray_origins in p_out:
                ray_origins = p_out[FeatureSetKeyConstants.input_feature_ray_origins]
            if FeatureSetKeyConstants.input_feature_ray_directions in p_out:
                ray_directions = p_out[FeatureSetKeyConstants.input_feature_ray_directions]

        if ray_directions is None:
            ray_directions = nerf_get_ray_dirs(rotations, directions)

        if ray_origins is None:
            # Now that we have the directions for the chosen images, we need the
            # corresponding poses to then generate all samples
            ray_origins = tile(poses, dim=0, n_tile=n_samples).reshape(n_images * n_samples, -1)

        ray_sample_positions = (ray_origins[..., None, :] + ray_directions[..., None, :] * z_vals[..., :, None])

        if len(self.rayMarchNormalizationCenter) == 3:
            ray_sample_positions = self.normalizationFunction(ray_sample_positions,
                                                              torch.tensor(self.rayMarchNormalizationCenter,
                                                                           device=ray_sample_positions.device),
                                                              self.max_depth)
        else:
            ray_sample_positions = self.normalizationFunction(ray_sample_positions, self.view_cell_center,
                                                              self.max_depth)

        # Reshape to positions only and do positional encoding
        inputs = ray_sample_positions
        inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])

        embedded = self.pos_enc.encode(inputs_flat)

        input_dirs = ray_directions[:, None].expand(inputs.shape)
        input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
        embedded_dirs = self.dir_enc.encode(input_dirs_flat)

        embedded = torch.cat([embedded, embedded_dirs], -1)

        embedded = torch.reshape(embedded, [-1, self.n_ray_samples, embedded.shape[-1]])

        ret_dict = {FeatureSetKeyConstants.input_feature_batch: embedded,
                    FeatureSetKeyConstants.nerf_input_feature_z_vals: z_vals,
                    FeatureSetKeyConstants.nerf_input_feature_ray_directions: ray_directions,
                    FeatureSetKeyConstants.nerf_input_feature_ray_origins: ray_origins}

        if not is_inference and depth_image is not None:
            ret_dict[FeatureSetKeyConstants.input_depth_groundtruth] = depth_image
            ret_dict[FeatureSetKeyConstants.input_depth_groundtruth_world] = self.depth_transform.to_world(
                depth_image, self.depth_range)

        ret_dict[FeatureSetKeyConstants.input_depth_range] = torch.tensor(self.depth_range)
        ret_dict[FeatureSetKeyConstants.input_depth] = depth

        return ret_dict