in backends/ort/src/lib.rs [296:378]
fn predict(&self, batch: Batch) -> Result<Predictions, BackendError> {
let batch_size = batch.len();
let max_length = batch.max_length as usize;
let (input_ids, type_ids, attention_mask) = {
let elems = batch_size * max_length;
if batch_size > 1 {
// Prepare padded batch
let mut input_ids = Vec::with_capacity(elems);
let mut type_ids = Vec::with_capacity(elems);
let mut attention_mask = Vec::with_capacity(elems);
for i in 0..batch_size {
let start = batch.cumulative_seq_lengths[i] as usize;
let end = batch.cumulative_seq_lengths[i + 1] as usize;
let seq_length = (end - start) as u32;
// Copy values
for j in start..end {
input_ids.push(batch.input_ids[j] as i64);
type_ids.push(batch.token_type_ids[j] as i64);
attention_mask.push(1_i64);
}
// Add padding if needed
let padding = batch.max_length - seq_length;
if padding > 0 {
for _ in 0..padding {
input_ids.push(0);
type_ids.push(0);
attention_mask.push(0_i64);
}
}
}
(input_ids, type_ids, attention_mask)
} else {
let attention_mask = vec![1_i64; elems];
(
batch.input_ids.into_iter().map(|v| v as i64).collect(),
batch.token_type_ids.into_iter().map(|v| v as i64).collect(),
attention_mask,
)
}
};
// Create ndarrays
let input_ids = ndarray::Array2::from_shape_vec((batch_size, max_length), input_ids).e()?;
let attention_mask =
ndarray::Array2::from_shape_vec((batch_size, max_length), attention_mask).e()?;
// Create onnx inputs
let inputs = match self.type_id_name.as_ref() {
Some(type_id_name) => {
// Add type ids to inputs
let type_ids =
ndarray::Array2::from_shape_vec((batch_size, max_length), type_ids).e()?;
ort::inputs!["input_ids" => input_ids, "attention_mask" => attention_mask.clone(), type_id_name => type_ids].e()?
}
None => {
ort::inputs!["input_ids" => input_ids, "attention_mask" => attention_mask.clone()]
.e()?
}
};
// Run model
let outputs = self.session.run(inputs).e()?;
// Get last_hidden_state ndarray
let outputs = outputs["logits"]
.try_extract_tensor::<f32>()
.e()?
.to_owned();
let mut predictions =
HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());
for (i, r) in outputs.rows().into_iter().enumerate() {
predictions.insert(i, r.to_vec());
}
Ok(predictions)
}