context_harness/embedding/
mod.rs1#[cfg(feature = "local-embeddings-tract")]
36mod local_tract;
37
38use anyhow::{bail, Result};
39use std::time::Duration;
40
41use crate::config::EmbeddingConfig;
42
43#[allow(unused_imports)]
44pub use context_harness_core::embedding::{
45 blob_to_vec, cosine_similarity, vec_to_blob, EmbeddingProvider,
46};
47
48pub async fn embed_texts(
69 _provider: &dyn EmbeddingProvider,
70 config: &EmbeddingConfig,
71 texts: &[String],
72) -> Result<Vec<Vec<f32>>> {
73 match config.provider.as_str() {
74 "openai" => embed_openai(config, texts).await,
75 "ollama" => embed_ollama(config, texts).await,
76 #[cfg(feature = "local-embeddings-fastembed")]
77 "local" => embed_local_fastembed(config, texts).await,
78 #[cfg(all(feature = "local-embeddings-tract", not(feature = "local-embeddings-fastembed")))]
79 "local" => embed_local_tract(config, texts).await,
80 #[cfg(not(any(feature = "local-embeddings-fastembed", feature = "local-embeddings-tract")))]
81 "local" => bail!(
82 "Local embedding provider requires one of: --features local-embeddings-fastembed, --features local-embeddings-tract"
83 ),
84 "disabled" => bail!("Embedding provider is disabled"),
85 other => bail!("Unknown embedding provider: {}", other),
86 }
87}
88
89pub async fn embed_query(
94 provider: &dyn EmbeddingProvider,
95 config: &EmbeddingConfig,
96 text: &str,
97) -> Result<Vec<f32>> {
98 let results = embed_texts(provider, config, &[text.to_string()]).await?;
99 results
100 .into_iter()
101 .next()
102 .ok_or_else(|| anyhow::anyhow!("Empty embedding response"))
103}
104
105pub struct DisabledProvider;
112
113impl EmbeddingProvider for DisabledProvider {
114 fn model_name(&self) -> &str {
115 "disabled"
116 }
117 fn dims(&self) -> usize {
118 0
119 }
120}
121
122pub struct OpenAIProvider {
135 model: String,
137 dims: usize,
139}
140
141impl OpenAIProvider {
142 pub fn new(config: &EmbeddingConfig) -> Result<Self> {
149 let model = config
150 .model
151 .clone()
152 .ok_or_else(|| anyhow::anyhow!("embedding.model required for OpenAI provider"))?;
153 let dims = config
154 .dims
155 .ok_or_else(|| anyhow::anyhow!("embedding.dims required for OpenAI provider"))?;
156
157 if std::env::var("OPENAI_API_KEY").is_err() {
159 bail!("OPENAI_API_KEY environment variable not set");
160 }
161
162 Ok(Self { model, dims })
163 }
164}
165
166impl EmbeddingProvider for OpenAIProvider {
167 fn model_name(&self) -> &str {
168 &self.model
169 }
170 fn dims(&self) -> usize {
171 self.dims
172 }
173}
174
175async fn embed_openai(config: &EmbeddingConfig, texts: &[String]) -> Result<Vec<Vec<f32>>> {
185 let api_key =
186 std::env::var("OPENAI_API_KEY").map_err(|_| anyhow::anyhow!("OPENAI_API_KEY not set"))?;
187
188 let model = config
189 .model
190 .as_ref()
191 .ok_or_else(|| anyhow::anyhow!("embedding.model required"))?;
192
193 let client = reqwest::Client::builder()
194 .timeout(Duration::from_secs(config.timeout_secs))
195 .build()?;
196
197 let body = serde_json::json!({
198 "model": model,
199 "input": texts,
200 });
201
202 let mut last_err = None;
203
204 for attempt in 0..=config.max_retries {
205 if attempt > 0 {
206 let delay = Duration::from_secs(1 << (attempt - 1).min(5));
208 tokio::time::sleep(delay).await;
209 }
210
211 let resp = client
212 .post("https://api.openai.com/v1/embeddings")
213 .header("Authorization", format!("Bearer {}", api_key))
214 .header("Content-Type", "application/json")
215 .json(&body)
216 .send()
217 .await;
218
219 match resp {
220 Ok(response) => {
221 let status = response.status();
222
223 if status.is_success() {
224 let json: serde_json::Value = response.json().await?;
225 return parse_openai_response(&json);
226 }
227
228 if status.as_u16() == 429 || status.is_server_error() {
230 let body_text = response.text().await.unwrap_or_default();
231 last_err = Some(anyhow::anyhow!(
232 "OpenAI API error {}: {}",
233 status,
234 body_text
235 ));
236 continue;
237 }
238
239 let body_text = response.text().await.unwrap_or_default();
241 bail!("OpenAI API error {}: {}", status, body_text);
242 }
243 Err(e) => {
244 last_err = Some(e.into());
245 continue;
246 }
247 }
248 }
249
250 Err(last_err.unwrap_or_else(|| anyhow::anyhow!("Embedding failed after retries")))
251}
252
253fn parse_openai_response(json: &serde_json::Value) -> Result<Vec<Vec<f32>>> {
257 let data = json
258 .get("data")
259 .and_then(|d| d.as_array())
260 .ok_or_else(|| anyhow::anyhow!("Invalid OpenAI response: missing data array"))?;
261
262 let mut embeddings = Vec::with_capacity(data.len());
263
264 for item in data {
265 let embedding = item
266 .get("embedding")
267 .and_then(|e| e.as_array())
268 .ok_or_else(|| anyhow::anyhow!("Invalid OpenAI response: missing embedding"))?;
269
270 let vec: Vec<f32> = embedding
271 .iter()
272 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
273 .collect();
274
275 embeddings.push(vec);
276 }
277
278 Ok(embeddings)
280}
281
282pub struct OllamaProvider {
289 model: String,
290 dims: usize,
291 #[allow(dead_code)]
292 url: String,
293}
294
295impl OllamaProvider {
296 pub fn new(config: &EmbeddingConfig) -> Result<Self> {
297 let model = config
298 .model
299 .clone()
300 .ok_or_else(|| anyhow::anyhow!("embedding.model required for Ollama provider"))?;
301 let dims = config
302 .dims
303 .ok_or_else(|| anyhow::anyhow!("embedding.dims required for Ollama provider"))?;
304 let url = config
305 .url
306 .clone()
307 .unwrap_or_else(|| "http://localhost:11434".to_string());
308
309 Ok(Self { model, dims, url })
310 }
311}
312
313impl EmbeddingProvider for OllamaProvider {
314 fn model_name(&self) -> &str {
315 &self.model
316 }
317 fn dims(&self) -> usize {
318 self.dims
319 }
320}
321
322async fn embed_ollama(config: &EmbeddingConfig, texts: &[String]) -> Result<Vec<Vec<f32>>> {
323 let model = config
324 .model
325 .as_ref()
326 .ok_or_else(|| anyhow::anyhow!("embedding.model required"))?;
327
328 let url = config.url.as_deref().unwrap_or("http://localhost:11434");
329
330 let client = reqwest::Client::builder()
331 .timeout(Duration::from_secs(config.timeout_secs))
332 .build()?;
333
334 let body = serde_json::json!({
335 "model": model,
336 "input": texts,
337 });
338
339 let mut last_err = None;
340
341 for attempt in 0..=config.max_retries {
342 if attempt > 0 {
343 let delay = Duration::from_secs(1 << (attempt - 1).min(5));
344 tokio::time::sleep(delay).await;
345 }
346
347 let resp = client
348 .post(format!("{}/api/embed", url))
349 .header("Content-Type", "application/json")
350 .json(&body)
351 .send()
352 .await;
353
354 match resp {
355 Ok(response) => {
356 let status = response.status();
357
358 if status.is_success() {
359 let json: serde_json::Value = response.json().await?;
360 return parse_ollama_response(&json);
361 }
362
363 if status.as_u16() == 429 || status.is_server_error() {
364 let body_text = response.text().await.unwrap_or_default();
365 last_err = Some(anyhow::anyhow!(
366 "Ollama API error {}: {}",
367 status,
368 body_text
369 ));
370 continue;
371 }
372
373 let body_text = response.text().await.unwrap_or_default();
374 bail!("Ollama API error {}: {}", status, body_text);
375 }
376 Err(e) => {
377 last_err = Some(anyhow::anyhow!(
378 "Ollama connection error (is Ollama running at {}?): {}",
379 url,
380 e
381 ));
382 continue;
383 }
384 }
385 }
386
387 Err(last_err.unwrap_or_else(|| anyhow::anyhow!("Ollama embedding failed after retries")))
388}
389
390fn parse_ollama_response(json: &serde_json::Value) -> Result<Vec<Vec<f32>>> {
391 let embeddings = json
392 .get("embeddings")
393 .and_then(|e| e.as_array())
394 .ok_or_else(|| anyhow::anyhow!("Invalid Ollama response: missing embeddings array"))?;
395
396 let mut result = Vec::with_capacity(embeddings.len());
397
398 for embedding in embeddings {
399 let vec: Vec<f32> = embedding
400 .as_array()
401 .ok_or_else(|| anyhow::anyhow!("Invalid Ollama response: embedding is not an array"))?
402 .iter()
403 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
404 .collect();
405 result.push(vec);
406 }
407
408 Ok(result)
409}
410
411#[cfg(any(
419 feature = "local-embeddings-fastembed",
420 feature = "local-embeddings-tract"
421))]
422pub struct LocalProvider {
423 model_name: String,
424 dims: usize,
425}
426
427#[cfg(any(
428 feature = "local-embeddings-fastembed",
429 feature = "local-embeddings-tract"
430))]
431impl LocalProvider {
432 pub fn new(config: &EmbeddingConfig) -> Result<Self> {
433 let (model_name, dims) = resolve_local_model(config)?;
434 Ok(Self { model_name, dims })
435 }
436}
437
438#[cfg(any(
439 feature = "local-embeddings-fastembed",
440 feature = "local-embeddings-tract"
441))]
442impl EmbeddingProvider for LocalProvider {
443 fn model_name(&self) -> &str {
444 &self.model_name
445 }
446 fn dims(&self) -> usize {
447 self.dims
448 }
449}
450
451#[cfg(any(
452 feature = "local-embeddings-fastembed",
453 feature = "local-embeddings-tract"
454))]
455fn resolve_local_model(config: &EmbeddingConfig) -> Result<(String, usize)> {
456 let model_name = config
457 .model
458 .clone()
459 .unwrap_or_else(|| "all-minilm-l6-v2".to_string());
460
461 let dims = config.dims.unwrap_or(match model_name.as_str() {
462 "all-minilm-l6-v2" => 384,
463 "bge-small-en-v1.5" => 384,
464 "bge-base-en-v1.5" => 768,
465 "bge-large-en-v1.5" => 1024,
466 "nomic-embed-text-v1" | "nomic-embed-text-v1.5" => 768,
467 "multilingual-e5-small" => 384,
468 "multilingual-e5-base" => 768,
469 "multilingual-e5-large" => 1024,
470 _ => 384,
471 });
472
473 Ok((model_name, dims))
474}
475
476#[cfg(feature = "local-embeddings-fastembed")]
477fn config_to_fastembed_model(name: &str) -> Result<fastembed::EmbeddingModel> {
478 match name {
479 "all-minilm-l6-v2" => Ok(fastembed::EmbeddingModel::AllMiniLML6V2),
480 "bge-small-en-v1.5" => Ok(fastembed::EmbeddingModel::BGESmallENV15),
481 "bge-base-en-v1.5" => Ok(fastembed::EmbeddingModel::BGEBaseENV15),
482 "bge-large-en-v1.5" => Ok(fastembed::EmbeddingModel::BGELargeENV15),
483 "nomic-embed-text-v1" => Ok(fastembed::EmbeddingModel::NomicEmbedTextV1),
484 "nomic-embed-text-v1.5" => Ok(fastembed::EmbeddingModel::NomicEmbedTextV15),
485 "multilingual-e5-small" => Ok(fastembed::EmbeddingModel::MultilingualE5Small),
486 "multilingual-e5-base" => Ok(fastembed::EmbeddingModel::MultilingualE5Base),
487 "multilingual-e5-large" => Ok(fastembed::EmbeddingModel::MultilingualE5Large),
488 other => bail!(
489 "Unknown local embedding model: '{}'. Supported models: \
490 all-minilm-l6-v2, bge-small-en-v1.5, bge-base-en-v1.5, bge-large-en-v1.5, \
491 nomic-embed-text-v1, nomic-embed-text-v1.5, \
492 multilingual-e5-small, multilingual-e5-base, multilingual-e5-large",
493 other
494 ),
495 }
496}
497
498#[cfg(feature = "local-embeddings-fastembed")]
499async fn embed_local_fastembed(
500 config: &EmbeddingConfig,
501 texts: &[String],
502) -> Result<Vec<Vec<f32>>> {
503 let model_name = config
504 .model
505 .clone()
506 .unwrap_or_else(|| "all-minilm-l6-v2".to_string());
507
508 let fastembed_model = config_to_fastembed_model(&model_name)?;
509 let batch_size = config.batch_size;
510 let texts = texts.to_vec();
511
512 tokio::task::spawn_blocking(move || {
513 let mut model = fastembed::TextEmbedding::try_new(
514 fastembed::InitOptions::new(fastembed_model).with_show_download_progress(true),
515 )
516 .map_err(|e| anyhow::anyhow!("Failed to initialize local embedding model: {}", e))?;
517
518 let embeddings = model
519 .embed(texts, Some(batch_size))
520 .map_err(|e| anyhow::anyhow!("Local embedding failed: {}", e))?;
521
522 Ok(embeddings)
523 })
524 .await?
525}
526
527#[cfg(feature = "local-embeddings-tract")]
528#[cfg_attr(
529 all(
530 feature = "local-embeddings-fastembed",
531 feature = "local-embeddings-tract"
532 ),
533 allow(dead_code)
534)]
535async fn embed_local_tract(config: &EmbeddingConfig, texts: &[String]) -> Result<Vec<Vec<f32>>> {
536 local_tract::embed_local_tract(config, texts).await
537}
538
539pub fn create_provider(config: &EmbeddingConfig) -> Result<Box<dyn EmbeddingProvider>> {
555 match config.provider.as_str() {
556 "disabled" => Ok(Box::new(DisabledProvider)),
557 "openai" => Ok(Box::new(OpenAIProvider::new(config)?)),
558 "ollama" => Ok(Box::new(OllamaProvider::new(config)?)),
559 #[cfg(any(feature = "local-embeddings-fastembed", feature = "local-embeddings-tract"))]
560 "local" => Ok(Box::new(LocalProvider::new(config)?)),
561 #[cfg(not(any(feature = "local-embeddings-fastembed", feature = "local-embeddings-tract")))]
562 "local" => bail!(
563 "Local embedding provider requires one of: --features local-embeddings-fastembed, --features local-embeddings-tract"
564 ),
565 other => bail!("Unknown embedding provider: {}", other),
566 }
567}