in backends/python/src/lib.rs [104:131]
fn predict(&self, batch: Batch) -> Result<Predictions, BackendError> {
if !batch.raw_indices.is_empty() {
return Err(BackendError::Inference(
"raw embeddings are not supported for the Python backend.".to_string(),
));
}
let batch_size = batch.len();
let results = self
.tokio_runtime
.block_on(self.backend_client.clone().predict(
batch.input_ids,
batch.token_type_ids,
batch.position_ids,
batch.cumulative_seq_lengths,
batch.max_length,
))
.map_err(|err| BackendError::Inference(err.to_string()))?;
let raw_results: Vec<Vec<f32>> = results.into_iter().map(|r| r.values).collect();
let mut predictions =
HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());
for (i, r) in raw_results.into_iter().enumerate() {
predictions.insert(i, r);
}
Ok(predictions)
}