fn embed()

in backends/ort/src/lib.rs [92:294]


    fn embed(&self, batch: Batch) -> Result<Embeddings, BackendError> {
        let batch_size = batch.len();
        let max_length = batch.max_length as usize;

        // Whether a least one of the request in the batch is padded
        let mut masking = false;

        let (input_ids, type_ids, input_lengths, attention_mask) = {
            let elems = batch_size * max_length;

            if batch_size > 1 {
                // Prepare padded batch
                let mut input_ids = Vec::with_capacity(elems);
                let mut type_ids = Vec::with_capacity(elems);
                let mut attention_mask = Vec::with_capacity(elems);
                let mut input_lengths = Vec::with_capacity(batch_size);

                for i in 0..batch_size {
                    let start = batch.cumulative_seq_lengths[i] as usize;
                    let end = batch.cumulative_seq_lengths[i + 1] as usize;
                    let seq_length = (end - start) as u32;
                    input_lengths.push(seq_length as f32);

                    // Copy values
                    for j in start..end {
                        input_ids.push(batch.input_ids[j] as i64);
                        type_ids.push(batch.token_type_ids[j] as i64);
                        attention_mask.push(1_i64);
                    }

                    // Add padding if needed
                    let padding = batch.max_length - seq_length;
                    if padding > 0 {
                        // Set bool to use attention mask
                        masking = true;
                        for _ in 0..padding {
                            input_ids.push(0);
                            type_ids.push(0);
                            attention_mask.push(0_i64);
                        }
                    }
                }
                (input_ids, type_ids, input_lengths, attention_mask)
            } else {
                let attention_mask = vec![1_i64; elems];

                (
                    batch.input_ids.into_iter().map(|v| v as i64).collect(),
                    batch.token_type_ids.into_iter().map(|v| v as i64).collect(),
                    vec![batch.max_length as f32],
                    attention_mask,
                )
            }
        };

        // Create ndarrays
        let input_ids = ndarray::Array2::from_shape_vec((batch_size, max_length), input_ids).e()?;
        let attention_mask =
            ndarray::Array2::from_shape_vec((batch_size, max_length), attention_mask).e()?;
        let input_lengths = ndarray::Array1::from_vec(input_lengths);

        // Create onnx inputs
        let inputs = match self.type_id_name.as_ref() {
            Some(type_id_name) => {
                // Add type ids to inputs
                let type_ids =
                    ndarray::Array2::from_shape_vec((batch_size, max_length), type_ids).e()?;
                ort::inputs!["input_ids" => input_ids, "attention_mask" => attention_mask.clone(), type_id_name => type_ids].e()?
            }
            None => {
                ort::inputs!["input_ids" => input_ids, "attention_mask" => attention_mask.clone()]
                    .e()?
            }
        };

        // Run model
        let outputs = self.session.run(inputs).e()?;

        // Get last_hidden_state ndarray
        let outputs = outputs
            .get("last_hidden_state")
            .or(outputs.get("token_embeddings"))
            .ok_or(BackendError::Inference(format!(
                "Unknown output keys: {:?}",
                self.session.outputs
            )))?
            .try_extract_tensor::<f32>()
            .e()?
            .to_owned();

        // Final embeddings struct
        let mut embeddings =
            HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());

        let has_pooling_requests = !batch.pooled_indices.is_empty();
        let has_raw_requests = !batch.raw_indices.is_empty();

        if has_pooling_requests {
            let mut outputs = outputs.clone();

            // Only use pooled_indices if at least one member of the batch ask for raw embeddings
            let indices = if has_raw_requests {
                let indices: Vec<usize> =
                    batch.pooled_indices.iter().map(|v| *v as usize).collect();

                // Select values in the batch
                outputs = outputs.select(Axis(0), &indices);
                Some(indices)
            } else {
                None
            };

            let pooled_embeddings = match self.pool {
                // CLS pooling
                Pool::Cls => outputs.slice(s![.., 0, ..]).into_owned().into_dyn(),
                // Last token pooling is not supported for this model
                Pool::LastToken => unreachable!(),
                // Mean pooling
                Pool::Mean => {
                    if masking {
                        let mut attention_mask = attention_mask;
                        let mut input_lengths = input_lengths;

                        if let Some(indices) = indices {
                            // Select values in the batch
                            attention_mask = attention_mask.select(Axis(0), &indices);
                            input_lengths = input_lengths.select(Axis(0), &indices);
                        };

                        // Cast and reshape
                        let attention_mask = attention_mask.mapv(|x| x as f32).insert_axis(Axis(2));

                        // Mask padded values
                        outputs = outputs.mul(attention_mask);
                        outputs
                            .sum_axis(Axis(1))
                            .div(input_lengths.insert_axis(Axis(1)))
                    } else {
                        outputs.mean_axis(Axis(1)).unwrap()
                    }
                }
                Pool::Splade => unreachable!(),
            };

            for (i, e) in batch
                .pooled_indices
                .into_iter()
                .zip(pooled_embeddings.rows())
            {
                embeddings.insert(i as usize, Embedding::Pooled(e.to_vec()));
            }
        };

        if has_raw_requests {
            // Reshape outputs
            let s = outputs.shape().to_vec();
            #[allow(deprecated)]
            let outputs = outputs.into_shape((s[0] * s[1], s[2])).e()?;

            // We need to remove the padding tokens only if batch_size > 1 and there are some
            // member of the batch that require pooling
            // or if batch_size > 1 and the members of the batch have different lengths
            let raw_embeddings = if (masking || has_pooling_requests) && batch_size > 1 {
                let mut final_indices: Vec<usize> = Vec::with_capacity(batch_size * max_length);

                for i in batch.raw_indices.iter() {
                    let start = i * batch.max_length;
                    let i = *i as usize;
                    let length =
                        batch.cumulative_seq_lengths[i + 1] - batch.cumulative_seq_lengths[i];

                    for j in start..start + length {
                        // Add indices for the tokens of this specific member of the batch
                        final_indices.push(j as usize);
                    }
                }

                // Select the tokens with final indices
                outputs.select(Axis(0), &final_indices)
            } else {
                outputs
            };

            // Used for indexing in the raw_embeddings tensor
            let input_lengths: Vec<usize> = (0..batch_size)
                .map(|i| {
                    (batch.cumulative_seq_lengths[i + 1] - batch.cumulative_seq_lengths[i]) as usize
                })
                .collect();

            let mut cumulative_length = 0;
            for i in batch.raw_indices.into_iter() {
                let length = input_lengths[i as usize];
                let e = raw_embeddings.slice(s![cumulative_length..cumulative_length + length, ..]);
                let e = e.rows().into_iter().map(|v| v.to_vec()).collect();

                embeddings.insert(i as usize, Embedding::All(e));
                cumulative_length += length;
            }
        }

        Ok(embeddings)
    }