fn predict()

in backends/ort/src/lib.rs [296:378]


    fn predict(&self, batch: Batch) -> Result<Predictions, BackendError> {
        let batch_size = batch.len();
        let max_length = batch.max_length as usize;

        let (input_ids, type_ids, attention_mask) = {
            let elems = batch_size * max_length;

            if batch_size > 1 {
                // Prepare padded batch
                let mut input_ids = Vec::with_capacity(elems);
                let mut type_ids = Vec::with_capacity(elems);
                let mut attention_mask = Vec::with_capacity(elems);

                for i in 0..batch_size {
                    let start = batch.cumulative_seq_lengths[i] as usize;
                    let end = batch.cumulative_seq_lengths[i + 1] as usize;
                    let seq_length = (end - start) as u32;

                    // Copy values
                    for j in start..end {
                        input_ids.push(batch.input_ids[j] as i64);
                        type_ids.push(batch.token_type_ids[j] as i64);
                        attention_mask.push(1_i64);
                    }

                    // Add padding if needed
                    let padding = batch.max_length - seq_length;
                    if padding > 0 {
                        for _ in 0..padding {
                            input_ids.push(0);
                            type_ids.push(0);
                            attention_mask.push(0_i64);
                        }
                    }
                }
                (input_ids, type_ids, attention_mask)
            } else {
                let attention_mask = vec![1_i64; elems];

                (
                    batch.input_ids.into_iter().map(|v| v as i64).collect(),
                    batch.token_type_ids.into_iter().map(|v| v as i64).collect(),
                    attention_mask,
                )
            }
        };

        // Create ndarrays
        let input_ids = ndarray::Array2::from_shape_vec((batch_size, max_length), input_ids).e()?;
        let attention_mask =
            ndarray::Array2::from_shape_vec((batch_size, max_length), attention_mask).e()?;

        // Create onnx inputs
        let inputs = match self.type_id_name.as_ref() {
            Some(type_id_name) => {
                // Add type ids to inputs
                let type_ids =
                    ndarray::Array2::from_shape_vec((batch_size, max_length), type_ids).e()?;
                ort::inputs!["input_ids" => input_ids, "attention_mask" => attention_mask.clone(), type_id_name => type_ids].e()?
            }
            None => {
                ort::inputs!["input_ids" => input_ids, "attention_mask" => attention_mask.clone()]
                    .e()?
            }
        };

        // Run model
        let outputs = self.session.run(inputs).e()?;

        // Get last_hidden_state ndarray
        let outputs = outputs["logits"]
            .try_extract_tensor::<f32>()
            .e()?
            .to_owned();

        let mut predictions =
            HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());
        for (i, r) in outputs.rows().into_iter().enumerate() {
            predictions.insert(i, r.to_vec());
        }

        Ok(predictions)
    }