context_harness/embedding/
mod.rs

1//! Embedding provider abstraction and implementations.
2//!
3//! Defines the [`EmbeddingProvider`] trait and concrete implementations:
4//! - **[`DisabledProvider`]** — returns errors; used when embeddings are not configured.
5//! - **[`OpenAIProvider`]** — calls the OpenAI embeddings API with batching, retry, and backoff.
6//! - **[`OllamaProvider`]** — calls a local Ollama instance's `/api/embed` endpoint.
7//! - **[`LocalProvider`]** — runs models locally via fastembed (primary) or tract (musl/Intel Mac); no network calls after model download.
8//!
9//! Also provides vector utilities for working with sqlite-vec:
10//! - [`cosine_similarity`] — compute similarity between two embedding vectors
11//! - [`vec_to_blob`] — encode a `Vec<f32>` as little-endian bytes for SQLite BLOB storage
12//! - [`blob_to_vec`] — decode a SQLite BLOB back into a `Vec<f32>`
13//!
14//! # Provider Selection
15//!
16//! Use [`create_provider`] to instantiate the appropriate provider based
17//! on the configuration:
18//!
19//! ```rust,no_run
20//! # use context_harness::config::EmbeddingConfig;
21//! # use context_harness::embedding::create_provider;
22//! let config = EmbeddingConfig::default(); // provider = "disabled"
23//! let provider = create_provider(&config).unwrap();
24//! assert_eq!(provider.model_name(), "disabled");
25//! ```
26//!
27//! # Retry Strategy
28//!
29//! The OpenAI and Ollama providers use exponential backoff for transient errors:
30//! - HTTP 429 (rate limited) and 5xx (server error) → retry
31//! - HTTP 4xx (client error, not 429) → fail immediately
32//! - Network errors → retry
33//! - Backoff: 1s, 2s, 4s, 8s, 16s, 32s (capped at 2^5)
34
35#[cfg(feature = "local-embeddings-tract")]
36mod local_tract;
37
38use anyhow::{bail, Result};
39use std::time::Duration;
40
41use crate::config::EmbeddingConfig;
42
43#[allow(unused_imports)]
44pub use context_harness_core::embedding::{
45    blob_to_vec, cosine_similarity, vec_to_blob, EmbeddingProvider,
46};
47
48/// Embed a batch of texts using the configured provider.
49///
50/// This is the main entry point for generating embeddings. It dispatches
51/// to the appropriate backend based on the config's `provider` field.
52///
53/// # Arguments
54///
55/// * `_provider` — Provider instance (used for metadata; dispatch is config-based).
56/// * `config` — Embedding configuration with provider, model, and retry settings.
57/// * `texts` — Batch of text strings to embed.
58///
59/// # Returns
60///
61/// A vector of embedding vectors, one per input text, in the same order.
62///
63/// # Errors
64///
65/// - `"disabled"` provider: always returns an error.
66/// - `"openai"` provider: returns an error if the API key is missing,
67///   the API returns a non-retryable error, or all retries are exhausted.
68pub async fn embed_texts(
69    _provider: &dyn EmbeddingProvider,
70    config: &EmbeddingConfig,
71    texts: &[String],
72) -> Result<Vec<Vec<f32>>> {
73    match config.provider.as_str() {
74        "openai" => embed_openai(config, texts).await,
75        "ollama" => embed_ollama(config, texts).await,
76        #[cfg(feature = "local-embeddings-fastembed")]
77        "local" => embed_local_fastembed(config, texts).await,
78        #[cfg(all(feature = "local-embeddings-tract", not(feature = "local-embeddings-fastembed")))]
79        "local" => embed_local_tract(config, texts).await,
80        #[cfg(not(any(feature = "local-embeddings-fastembed", feature = "local-embeddings-tract")))]
81        "local" => bail!(
82            "Local embedding provider requires one of: --features local-embeddings-fastembed, --features local-embeddings-tract"
83        ),
84        "disabled" => bail!("Embedding provider is disabled"),
85        other => bail!("Unknown embedding provider: {}", other),
86    }
87}
88
89/// Embed a single query text.
90///
91/// Convenience wrapper around [`embed_texts`] for single-text use cases
92/// (e.g. embedding a search query for semantic search).
93pub async fn embed_query(
94    provider: &dyn EmbeddingProvider,
95    config: &EmbeddingConfig,
96    text: &str,
97) -> Result<Vec<f32>> {
98    let results = embed_texts(provider, config, &[text.to_string()]).await?;
99    results
100        .into_iter()
101        .next()
102        .ok_or_else(|| anyhow::anyhow!("Empty embedding response"))
103}
104
105// ============ Disabled Provider ============
106
107/// A no-op embedding provider that always returns errors.
108///
109/// Used when `embedding.provider = "disabled"` in the configuration.
110/// Any attempt to embed text will fail with a descriptive error message.
111pub struct DisabledProvider;
112
113impl EmbeddingProvider for DisabledProvider {
114    fn model_name(&self) -> &str {
115        "disabled"
116    }
117    fn dims(&self) -> usize {
118        0
119    }
120}
121
122// ============ OpenAI Provider ============
123
124/// Embedding provider using the OpenAI API.
125///
126/// Calls the `POST /v1/embeddings` endpoint with the configured model.
127/// Requires the `OPENAI_API_KEY` environment variable to be set.
128///
129/// # Features
130///
131/// - Batched embedding (multiple texts per API call)
132/// - Exponential backoff retry for rate limits and server errors
133/// - Configurable timeout and max retries
134pub struct OpenAIProvider {
135    /// Model name (e.g. `"text-embedding-3-small"`).
136    model: String,
137    /// Vector dimensionality (e.g. `1536`).
138    dims: usize,
139}
140
141impl OpenAIProvider {
142    /// Create a new OpenAI provider from configuration.
143    ///
144    /// # Errors
145    ///
146    /// Returns an error if `model` or `dims` is not set in config,
147    /// or if `OPENAI_API_KEY` is not in the environment.
148    pub fn new(config: &EmbeddingConfig) -> Result<Self> {
149        let model = config
150            .model
151            .clone()
152            .ok_or_else(|| anyhow::anyhow!("embedding.model required for OpenAI provider"))?;
153        let dims = config
154            .dims
155            .ok_or_else(|| anyhow::anyhow!("embedding.dims required for OpenAI provider"))?;
156
157        // Verify API key is available
158        if std::env::var("OPENAI_API_KEY").is_err() {
159            bail!("OPENAI_API_KEY environment variable not set");
160        }
161
162        Ok(Self { model, dims })
163    }
164}
165
166impl EmbeddingProvider for OpenAIProvider {
167    fn model_name(&self) -> &str {
168        &self.model
169    }
170    fn dims(&self) -> usize {
171        self.dims
172    }
173}
174
175/// Call the OpenAI embeddings API with retry/backoff.
176///
177/// Sends a batch of texts to `POST https://api.openai.com/v1/embeddings`
178/// and returns the embedding vectors in input order.
179///
180/// Retry strategy:
181/// - HTTP 429 or 5xx → retry with exponential backoff
182/// - HTTP 4xx (not 429) → fail immediately
183/// - Network error → retry
184async fn embed_openai(config: &EmbeddingConfig, texts: &[String]) -> Result<Vec<Vec<f32>>> {
185    let api_key =
186        std::env::var("OPENAI_API_KEY").map_err(|_| anyhow::anyhow!("OPENAI_API_KEY not set"))?;
187
188    let model = config
189        .model
190        .as_ref()
191        .ok_or_else(|| anyhow::anyhow!("embedding.model required"))?;
192
193    let client = reqwest::Client::builder()
194        .timeout(Duration::from_secs(config.timeout_secs))
195        .build()?;
196
197    let body = serde_json::json!({
198        "model": model,
199        "input": texts,
200    });
201
202    let mut last_err = None;
203
204    for attempt in 0..=config.max_retries {
205        if attempt > 0 {
206            // Exponential backoff: 1s, 2s, 4s, 8s, ...
207            let delay = Duration::from_secs(1 << (attempt - 1).min(5));
208            tokio::time::sleep(delay).await;
209        }
210
211        let resp = client
212            .post("https://api.openai.com/v1/embeddings")
213            .header("Authorization", format!("Bearer {}", api_key))
214            .header("Content-Type", "application/json")
215            .json(&body)
216            .send()
217            .await;
218
219        match resp {
220            Ok(response) => {
221                let status = response.status();
222
223                if status.is_success() {
224                    let json: serde_json::Value = response.json().await?;
225                    return parse_openai_response(&json);
226                }
227
228                // Rate limited or server error — retry
229                if status.as_u16() == 429 || status.is_server_error() {
230                    let body_text = response.text().await.unwrap_or_default();
231                    last_err = Some(anyhow::anyhow!(
232                        "OpenAI API error {}: {}",
233                        status,
234                        body_text
235                    ));
236                    continue;
237                }
238
239                // Client error (not 429) — don't retry
240                let body_text = response.text().await.unwrap_or_default();
241                bail!("OpenAI API error {}: {}", status, body_text);
242            }
243            Err(e) => {
244                last_err = Some(e.into());
245                continue;
246            }
247        }
248    }
249
250    Err(last_err.unwrap_or_else(|| anyhow::anyhow!("Embedding failed after retries")))
251}
252
253/// Parse the OpenAI embeddings API response JSON.
254///
255/// Extracts the `data[].embedding` arrays and returns them in order.
256fn parse_openai_response(json: &serde_json::Value) -> Result<Vec<Vec<f32>>> {
257    let data = json
258        .get("data")
259        .and_then(|d| d.as_array())
260        .ok_or_else(|| anyhow::anyhow!("Invalid OpenAI response: missing data array"))?;
261
262    let mut embeddings = Vec::with_capacity(data.len());
263
264    for item in data {
265        let embedding = item
266            .get("embedding")
267            .and_then(|e| e.as_array())
268            .ok_or_else(|| anyhow::anyhow!("Invalid OpenAI response: missing embedding"))?;
269
270        let vec: Vec<f32> = embedding
271            .iter()
272            .map(|v| v.as_f64().unwrap_or(0.0) as f32)
273            .collect();
274
275        embeddings.push(vec);
276    }
277
278    // Sort by index to ensure order matches input
279    Ok(embeddings)
280}
281
282// ============ Ollama Provider ============
283
284/// Embedding provider using a local Ollama instance.
285///
286/// Calls `POST /api/embed` on the configured Ollama URL (default: `http://localhost:11434`).
287/// Requires Ollama to be running with an embedding model pulled (e.g. `ollama pull nomic-embed-text`).
288pub struct OllamaProvider {
289    model: String,
290    dims: usize,
291    #[allow(dead_code)]
292    url: String,
293}
294
295impl OllamaProvider {
296    pub fn new(config: &EmbeddingConfig) -> Result<Self> {
297        let model = config
298            .model
299            .clone()
300            .ok_or_else(|| anyhow::anyhow!("embedding.model required for Ollama provider"))?;
301        let dims = config
302            .dims
303            .ok_or_else(|| anyhow::anyhow!("embedding.dims required for Ollama provider"))?;
304        let url = config
305            .url
306            .clone()
307            .unwrap_or_else(|| "http://localhost:11434".to_string());
308
309        Ok(Self { model, dims, url })
310    }
311}
312
313impl EmbeddingProvider for OllamaProvider {
314    fn model_name(&self) -> &str {
315        &self.model
316    }
317    fn dims(&self) -> usize {
318        self.dims
319    }
320}
321
322async fn embed_ollama(config: &EmbeddingConfig, texts: &[String]) -> Result<Vec<Vec<f32>>> {
323    let model = config
324        .model
325        .as_ref()
326        .ok_or_else(|| anyhow::anyhow!("embedding.model required"))?;
327
328    let url = config.url.as_deref().unwrap_or("http://localhost:11434");
329
330    let client = reqwest::Client::builder()
331        .timeout(Duration::from_secs(config.timeout_secs))
332        .build()?;
333
334    let body = serde_json::json!({
335        "model": model,
336        "input": texts,
337    });
338
339    let mut last_err = None;
340
341    for attempt in 0..=config.max_retries {
342        if attempt > 0 {
343            let delay = Duration::from_secs(1 << (attempt - 1).min(5));
344            tokio::time::sleep(delay).await;
345        }
346
347        let resp = client
348            .post(format!("{}/api/embed", url))
349            .header("Content-Type", "application/json")
350            .json(&body)
351            .send()
352            .await;
353
354        match resp {
355            Ok(response) => {
356                let status = response.status();
357
358                if status.is_success() {
359                    let json: serde_json::Value = response.json().await?;
360                    return parse_ollama_response(&json);
361                }
362
363                if status.as_u16() == 429 || status.is_server_error() {
364                    let body_text = response.text().await.unwrap_or_default();
365                    last_err = Some(anyhow::anyhow!(
366                        "Ollama API error {}: {}",
367                        status,
368                        body_text
369                    ));
370                    continue;
371                }
372
373                let body_text = response.text().await.unwrap_or_default();
374                bail!("Ollama API error {}: {}", status, body_text);
375            }
376            Err(e) => {
377                last_err = Some(anyhow::anyhow!(
378                    "Ollama connection error (is Ollama running at {}?): {}",
379                    url,
380                    e
381                ));
382                continue;
383            }
384        }
385    }
386
387    Err(last_err.unwrap_or_else(|| anyhow::anyhow!("Ollama embedding failed after retries")))
388}
389
390fn parse_ollama_response(json: &serde_json::Value) -> Result<Vec<Vec<f32>>> {
391    let embeddings = json
392        .get("embeddings")
393        .and_then(|e| e.as_array())
394        .ok_or_else(|| anyhow::anyhow!("Invalid Ollama response: missing embeddings array"))?;
395
396    let mut result = Vec::with_capacity(embeddings.len());
397
398    for embedding in embeddings {
399        let vec: Vec<f32> = embedding
400            .as_array()
401            .ok_or_else(|| anyhow::anyhow!("Invalid Ollama response: embedding is not an array"))?
402            .iter()
403            .map(|v| v.as_f64().unwrap_or(0.0) as f32)
404            .collect();
405        result.push(vec);
406    }
407
408    Ok(result)
409}
410
411// ============ Local Provider (fastembed or tract) ============
412
413/// Embedding provider for local inference (fastembed on primary platforms, tract on musl/Intel Mac).
414///
415/// Models are downloaded on first use from Hugging Face and cached.
416/// After initial download, no network calls are needed — embeddings run entirely offline.
417/// No system dependencies: ORT is bundled (fastembed) or pure Rust (tract).
418#[cfg(any(
419    feature = "local-embeddings-fastembed",
420    feature = "local-embeddings-tract"
421))]
422pub struct LocalProvider {
423    model_name: String,
424    dims: usize,
425}
426
427#[cfg(any(
428    feature = "local-embeddings-fastembed",
429    feature = "local-embeddings-tract"
430))]
431impl LocalProvider {
432    pub fn new(config: &EmbeddingConfig) -> Result<Self> {
433        let (model_name, dims) = resolve_local_model(config)?;
434        Ok(Self { model_name, dims })
435    }
436}
437
438#[cfg(any(
439    feature = "local-embeddings-fastembed",
440    feature = "local-embeddings-tract"
441))]
442impl EmbeddingProvider for LocalProvider {
443    fn model_name(&self) -> &str {
444        &self.model_name
445    }
446    fn dims(&self) -> usize {
447        self.dims
448    }
449}
450
451#[cfg(any(
452    feature = "local-embeddings-fastembed",
453    feature = "local-embeddings-tract"
454))]
455fn resolve_local_model(config: &EmbeddingConfig) -> Result<(String, usize)> {
456    let model_name = config
457        .model
458        .clone()
459        .unwrap_or_else(|| "all-minilm-l6-v2".to_string());
460
461    let dims = config.dims.unwrap_or(match model_name.as_str() {
462        "all-minilm-l6-v2" => 384,
463        "bge-small-en-v1.5" => 384,
464        "bge-base-en-v1.5" => 768,
465        "bge-large-en-v1.5" => 1024,
466        "nomic-embed-text-v1" | "nomic-embed-text-v1.5" => 768,
467        "multilingual-e5-small" => 384,
468        "multilingual-e5-base" => 768,
469        "multilingual-e5-large" => 1024,
470        _ => 384,
471    });
472
473    Ok((model_name, dims))
474}
475
476#[cfg(feature = "local-embeddings-fastembed")]
477fn config_to_fastembed_model(name: &str) -> Result<fastembed::EmbeddingModel> {
478    match name {
479        "all-minilm-l6-v2" => Ok(fastembed::EmbeddingModel::AllMiniLML6V2),
480        "bge-small-en-v1.5" => Ok(fastembed::EmbeddingModel::BGESmallENV15),
481        "bge-base-en-v1.5" => Ok(fastembed::EmbeddingModel::BGEBaseENV15),
482        "bge-large-en-v1.5" => Ok(fastembed::EmbeddingModel::BGELargeENV15),
483        "nomic-embed-text-v1" => Ok(fastembed::EmbeddingModel::NomicEmbedTextV1),
484        "nomic-embed-text-v1.5" => Ok(fastembed::EmbeddingModel::NomicEmbedTextV15),
485        "multilingual-e5-small" => Ok(fastembed::EmbeddingModel::MultilingualE5Small),
486        "multilingual-e5-base" => Ok(fastembed::EmbeddingModel::MultilingualE5Base),
487        "multilingual-e5-large" => Ok(fastembed::EmbeddingModel::MultilingualE5Large),
488        other => bail!(
489            "Unknown local embedding model: '{}'. Supported models: \
490             all-minilm-l6-v2, bge-small-en-v1.5, bge-base-en-v1.5, bge-large-en-v1.5, \
491             nomic-embed-text-v1, nomic-embed-text-v1.5, \
492             multilingual-e5-small, multilingual-e5-base, multilingual-e5-large",
493            other
494        ),
495    }
496}
497
498#[cfg(feature = "local-embeddings-fastembed")]
499async fn embed_local_fastembed(
500    config: &EmbeddingConfig,
501    texts: &[String],
502) -> Result<Vec<Vec<f32>>> {
503    let model_name = config
504        .model
505        .clone()
506        .unwrap_or_else(|| "all-minilm-l6-v2".to_string());
507
508    let fastembed_model = config_to_fastembed_model(&model_name)?;
509    let batch_size = config.batch_size;
510    let texts = texts.to_vec();
511
512    tokio::task::spawn_blocking(move || {
513        let mut model = fastembed::TextEmbedding::try_new(
514            fastembed::InitOptions::new(fastembed_model).with_show_download_progress(true),
515        )
516        .map_err(|e| anyhow::anyhow!("Failed to initialize local embedding model: {}", e))?;
517
518        let embeddings = model
519            .embed(texts, Some(batch_size))
520            .map_err(|e| anyhow::anyhow!("Local embedding failed: {}", e))?;
521
522        Ok(embeddings)
523    })
524    .await?
525}
526
527#[cfg(feature = "local-embeddings-tract")]
528#[cfg_attr(
529    all(
530        feature = "local-embeddings-fastembed",
531        feature = "local-embeddings-tract"
532    ),
533    allow(dead_code)
534)]
535async fn embed_local_tract(config: &EmbeddingConfig, texts: &[String]) -> Result<Vec<Vec<f32>>> {
536    local_tract::embed_local_tract(config, texts).await
537}
538
539/// Create the appropriate [`EmbeddingProvider`] based on configuration.
540///
541/// # Supported Providers
542///
543/// | Config Value | Provider |
544/// |-------------|----------|
545/// | `"disabled"` | [`DisabledProvider`] |
546/// | `"openai"` | [`OpenAIProvider`] |
547/// | `"ollama"` | [`OllamaProvider`] |
548/// | `"local"` | `LocalProvider` (fastembed or tract, see features) |
549///
550/// # Errors
551///
552/// Returns an error for unknown provider names or if the provider
553/// cannot be initialized (missing config, API key, or feature flag).
554pub fn create_provider(config: &EmbeddingConfig) -> Result<Box<dyn EmbeddingProvider>> {
555    match config.provider.as_str() {
556        "disabled" => Ok(Box::new(DisabledProvider)),
557        "openai" => Ok(Box::new(OpenAIProvider::new(config)?)),
558        "ollama" => Ok(Box::new(OllamaProvider::new(config)?)),
559        #[cfg(any(feature = "local-embeddings-fastembed", feature = "local-embeddings-tract"))]
560        "local" => Ok(Box::new(LocalProvider::new(config)?)),
561        #[cfg(not(any(feature = "local-embeddings-fastembed", feature = "local-embeddings-tract")))]
562        "local" => bail!(
563            "Local embedding provider requires one of: --features local-embeddings-fastembed, --features local-embeddings-tract"
564        ),
565        other => bail!("Unknown embedding provider: {}", other),
566    }
567}