issue-bot/src/embeddings/inference_endpoints.rs (52 lines of code) (raw):
use reqwest::{
header::{HeaderMap, HeaderValue, AUTHORIZATION},
Client,
};
use serde::Serialize;
use crate::{config::EmbeddingApiConfig, APP_USER_AGENT};
use super::EmbeddingError;
#[derive(Serialize)]
enum TruncateDirection {
#[allow(unused)]
Left,
Right,
}
#[derive(Serialize)]
struct EmbedRequest {
inputs: String,
truncate: bool,
truncate_direction: TruncateDirection,
}
#[derive(Clone)]
pub struct EmbeddingApi {
cfg: EmbeddingApiConfig,
client: Client,
}
impl EmbeddingApi {
pub fn new(cfg: EmbeddingApiConfig) -> Result<Self, EmbeddingError> {
let mut headers = HeaderMap::new();
let mut auth_value = HeaderValue::from_str(&format!("Bearer {}", cfg.auth_token))?;
auth_value.set_sensitive(true);
headers.insert(AUTHORIZATION, auth_value);
let client = Client::builder()
.user_agent(APP_USER_AGENT)
.default_headers(headers)
.build()?;
Ok(Self { cfg, client })
}
// TODO: handle API errors gracefully
pub async fn generate_embedding(&self, text: String) -> Result<Vec<f32>, EmbeddingError> {
self.client
.post(&self.cfg.url)
.json(&EmbedRequest {
inputs: text,
truncate: true,
truncate_direction: TruncateDirection::Right,
})
.send()
.await?
.json::<Vec<Vec<f32>>>()
.await?
.pop()
.ok_or(EmbeddingError::MissingEmbedding)
}
}