def features_and_preds()

in point_e/evals/feature_extractor.py [0:0]


    def features_and_preds(self, streamer: NpzStreamer) -> Tuple[np.ndarray, np.ndarray]:
        batch_size = self.device_batch_size * len(self.devices)
        point_clouds = (x["arr_0"] for x in streamer.stream(batch_size, ["arr_0"]))

        output_features = []
        output_predictions = []

        with ThreadPool(len(self.devices)) as pool:
            for batch in point_clouds:
                batch = normalize_point_clouds(batch)
                batches = []
                for i, device in zip(range(0, len(batch), self.device_batch_size), self.devices):
                    batches.append(
                        torch.from_numpy(batch[i : i + self.device_batch_size])
                        .permute(0, 2, 1)
                        .to(dtype=torch.float32, device=device)
                    )

                def compute_features(i_batch):
                    i, batch = i_batch
                    with torch.no_grad():
                        return self.models[i](batch, features=True)

                for logits, _, features in pool.imap(compute_features, enumerate(batches)):
                    output_features.append(features.cpu().numpy())
                    output_predictions.append(logits.exp().cpu().numpy())

        return np.concatenate(output_features, axis=0), np.concatenate(output_predictions, axis=0)