context_harness/
embed_cmd.rs

1//! Embedding CLI commands: `ctx embed pending` and `ctx embed rebuild`.
2//!
3//! Manages the embedding lifecycle:
4//!
5//! - **[`run_embed_pending`]** — backfill missing or stale embeddings
6//! - **[`run_embed_rebuild`]** — delete and regenerate all embeddings
7//! - **[`embed_chunks_inline`]** — embed chunks during sync (non-fatal)
8//!
9//! # Staleness Detection
10//!
11//! Each chunk's text is hashed (SHA-256). When the hash in the `embeddings`
12//! table differs from the chunk's current hash, the embedding is considered
13//! stale and will be re-generated by `embed pending`.
14//!
15//! # Batching
16//!
17//! Embeddings are generated in batches (configurable via `embedding.batch_size`
18//! or `--batch-size` flag). Each batch is a single API call to the embedding
19//! provider. Failed batches are logged but don't abort the entire operation.
20
21use anyhow::{bail, Result};
22use sha2::{Digest, Sha256};
23use sqlx::{Row, SqlitePool};
24
25use crate::config::Config;
26use crate::db;
27use crate::embedding;
28
29/// Backfill embeddings for chunks that are missing or have stale hashes.
30///
31/// Finds all chunks where either:
32/// 1. No embedding exists for the current model, or
33/// 2. The embedding's stored hash doesn't match the chunk's current text hash.
34///
35/// # Arguments
36///
37/// * `config` — Application configuration.
38/// * `limit` — Optional cap on the number of chunks to process.
39/// * `batch_size_override` — Override the config's `embedding.batch_size`.
40/// * `dry_run` — If `true`, report counts without writing anything.
41///
42/// # Errors
43///
44/// Returns an error if the embedding provider is disabled.
45pub async fn run_embed_pending(
46    config: &Config,
47    limit: Option<usize>,
48    batch_size_override: Option<usize>,
49    dry_run: bool,
50) -> Result<()> {
51    if !config.embedding.is_enabled() {
52        bail!("Embedding provider is disabled. Set [embedding] provider in config.");
53    }
54
55    let provider = embedding::create_provider(&config.embedding)?;
56    let model_name = provider.model_name().to_string();
57    let pool = db::connect(config).await?;
58    let batch_size = batch_size_override.unwrap_or(config.embedding.batch_size);
59
60    // Find chunks missing embeddings or with stale hashes
61    let pending = find_pending_chunks(&pool, &model_name, limit).await?;
62
63    if dry_run {
64        println!("embed pending (dry-run)");
65        println!("  chunks needing embeddings: {}", pending.len());
66        return Ok(());
67    }
68
69    if pending.is_empty() {
70        println!("embed pending");
71        println!("  all chunks up to date");
72        return Ok(());
73    }
74
75    let total = pending.len();
76    let mut embedded = 0u64;
77    let mut failed = 0u64;
78
79    for batch in pending.chunks(batch_size) {
80        let texts: Vec<String> = batch.iter().map(|p| p.text.clone()).collect();
81
82        match embedding::embed_texts(provider.as_ref(), &config.embedding, &texts).await {
83            Ok(vectors) => {
84                for (item, vec) in batch.iter().zip(vectors.iter()) {
85                    let blob = embedding::vec_to_blob(vec);
86                    upsert_embedding(
87                        &pool,
88                        &item.chunk_id,
89                        &item.document_id,
90                        &model_name,
91                        provider.dims(),
92                        &item.text_hash,
93                        &blob,
94                    )
95                    .await?;
96                    embedded += 1;
97                }
98            }
99            Err(e) => {
100                eprintln!("Warning: embedding batch failed: {}", e);
101                failed += batch.len() as u64;
102            }
103        }
104    }
105
106    println!("embed pending");
107    println!("  total pending: {}", total);
108    println!("  embedded: {}", embedded);
109    println!("  failed: {}", failed);
110
111    pool.close().await;
112    Ok(())
113}
114
115/// Delete all embeddings and regenerate for all chunks.
116///
117/// Clears both the `embeddings` metadata table and the `chunk_vectors`
118/// blob table, then re-embeds every chunk in the database.
119///
120/// # Arguments
121///
122/// * `config` — Application configuration.
123/// * `batch_size_override` — Override the config's `embedding.batch_size`.
124///
125/// # Errors
126///
127/// Returns an error if the embedding provider is disabled.
128pub async fn run_embed_rebuild(config: &Config, batch_size_override: Option<usize>) -> Result<()> {
129    if !config.embedding.is_enabled() {
130        bail!("Embedding provider is disabled. Set [embedding] provider in config.");
131    }
132
133    let provider = embedding::create_provider(&config.embedding)?;
134    let model_name = provider.model_name().to_string();
135    let pool = db::connect(config).await?;
136    let batch_size = batch_size_override.unwrap_or(config.embedding.batch_size);
137
138    // Delete all existing embeddings
139    sqlx::query("DELETE FROM chunk_vectors")
140        .execute(&pool)
141        .await?;
142    sqlx::query("DELETE FROM embeddings").execute(&pool).await?;
143
144    println!("embed rebuild — cleared existing embeddings");
145
146    // Get all chunks
147    let all_chunks = find_pending_chunks(&pool, &model_name, None).await?;
148
149    if all_chunks.is_empty() {
150        println!("  no chunks to embed");
151        pool.close().await;
152        return Ok(());
153    }
154
155    let total = all_chunks.len();
156    let mut embedded = 0u64;
157    let mut failed = 0u64;
158
159    for batch in all_chunks.chunks(batch_size) {
160        let texts: Vec<String> = batch.iter().map(|p| p.text.clone()).collect();
161
162        match embedding::embed_texts(provider.as_ref(), &config.embedding, &texts).await {
163            Ok(vectors) => {
164                for (item, vec) in batch.iter().zip(vectors.iter()) {
165                    let blob = embedding::vec_to_blob(vec);
166                    upsert_embedding(
167                        &pool,
168                        &item.chunk_id,
169                        &item.document_id,
170                        &model_name,
171                        provider.dims(),
172                        &item.text_hash,
173                        &blob,
174                    )
175                    .await?;
176                    embedded += 1;
177                }
178            }
179            Err(e) => {
180                eprintln!("Warning: embedding batch failed: {}", e);
181                failed += batch.len() as u64;
182            }
183        }
184    }
185
186    println!("embed rebuild");
187    println!("  total chunks: {}", total);
188    println!("  embedded: {}", embedded);
189    println!("  failed: {}", failed);
190
191    pool.close().await;
192    Ok(())
193}
194
195/// Embed chunks during sync (inline). Non-fatal on failure.
196///
197/// Called by [`crate::ingest::run_sync`] after chunking each document.
198/// Checks each chunk for existing, up-to-date embeddings before
199/// calling the provider, avoiding redundant API calls.
200///
201/// # Returns
202///
203/// A tuple `(embedded, pending)`:
204/// - `embedded` — number of chunks successfully embedded (or already up-to-date)
205/// - `pending` — number of chunks that failed to embed
206pub async fn embed_chunks_inline(
207    config: &Config,
208    pool: &SqlitePool,
209    chunks: &[crate::models::Chunk],
210) -> (u64, u64) {
211    if !config.embedding.is_enabled() {
212        return (0, 0);
213    }
214
215    let provider = match embedding::create_provider(&config.embedding) {
216        Ok(p) => p,
217        Err(e) => {
218            eprintln!("Warning: could not create embedding provider: {}", e);
219            return (0, chunks.len() as u64);
220        }
221    };
222
223    let model_name = provider.model_name().to_string();
224    let mut embedded = 0u64;
225    let mut pending = 0u64;
226
227    for batch in chunks.chunks(config.embedding.batch_size) {
228        // Check which chunks need embedding
229        let mut need_embedding = Vec::new();
230        for chunk in batch {
231            let text_hash = hash_text(&chunk.text);
232            let existing: Option<String> =
233                sqlx::query_scalar("SELECT hash FROM embeddings WHERE chunk_id = ? AND model = ?")
234                    .bind(&chunk.id)
235                    .bind(&model_name)
236                    .fetch_optional(pool)
237                    .await
238                    .unwrap_or(None);
239
240            if existing.as_deref() == Some(&text_hash) {
241                // Already up to date
242                embedded += 1;
243                continue;
244            }
245
246            need_embedding.push((chunk, text_hash));
247        }
248
249        if need_embedding.is_empty() {
250            continue;
251        }
252
253        let texts: Vec<String> = need_embedding.iter().map(|(c, _)| c.text.clone()).collect();
254
255        match embedding::embed_texts(provider.as_ref(), &config.embedding, &texts).await {
256            Ok(vectors) => {
257                for ((chunk, text_hash), vec) in need_embedding.iter().zip(vectors.iter()) {
258                    let blob = embedding::vec_to_blob(vec);
259                    if let Err(e) = upsert_embedding(
260                        pool,
261                        &chunk.id,
262                        &chunk.document_id,
263                        &model_name,
264                        provider.dims(),
265                        text_hash,
266                        &blob,
267                    )
268                    .await
269                    {
270                        eprintln!("Warning: failed to store embedding for {}: {}", chunk.id, e);
271                        pending += 1;
272                    } else {
273                        embedded += 1;
274                    }
275                }
276            }
277            Err(e) => {
278                eprintln!("Warning: embedding batch failed: {}", e);
279                pending += need_embedding.len() as u64;
280            }
281        }
282    }
283
284    (embedded, pending)
285}
286
287/// A chunk that needs embedding (missing or stale).
288struct PendingChunk {
289    chunk_id: String,
290    document_id: String,
291    text: String,
292    text_hash: String,
293}
294
295/// Find chunks that are missing embeddings or have stale hashes.
296///
297/// A chunk is "pending" if:
298/// 1. No row exists in `embeddings` for this chunk+model, or
299/// 2. The stored `hash` differs from the chunk's current `hash`.
300async fn find_pending_chunks(
301    pool: &SqlitePool,
302    model: &str,
303    limit: Option<usize>,
304) -> Result<Vec<PendingChunk>> {
305    let limit_val = limit.unwrap_or(usize::MAX) as i64;
306
307    // Chunks that either have no embedding or have a stale hash
308    let rows = sqlx::query(
309        r#"
310        SELECT c.id AS chunk_id, c.document_id, c.text, c.hash AS chunk_hash
311        FROM chunks c
312        LEFT JOIN embeddings e ON e.chunk_id = c.id AND e.model = ?
313        WHERE e.chunk_id IS NULL OR e.hash != c.hash
314        ORDER BY c.document_id, c.chunk_index
315        LIMIT ?
316        "#,
317    )
318    .bind(model)
319    .bind(limit_val)
320    .fetch_all(pool)
321    .await?;
322
323    let results: Vec<PendingChunk> = rows
324        .iter()
325        .map(|row| {
326            let text: String = row.get("text");
327            let text_hash = hash_text(&text);
328            PendingChunk {
329                chunk_id: row.get("chunk_id"),
330                document_id: row.get("document_id"),
331                text,
332                text_hash,
333            }
334        })
335        .collect();
336
337    Ok(results)
338}
339
340/// Upsert an embedding into both `embeddings` (metadata) and `chunk_vectors` (blob).
341///
342/// Uses `INSERT ... ON CONFLICT DO UPDATE` for idempotent writes.
343async fn upsert_embedding(
344    pool: &SqlitePool,
345    chunk_id: &str,
346    document_id: &str,
347    model: &str,
348    dims: usize,
349    text_hash: &str,
350    blob: &[u8],
351) -> Result<()> {
352    let now = chrono::Utc::now().timestamp();
353
354    sqlx::query(
355        r#"
356        INSERT INTO embeddings (chunk_id, model, dims, created_at, hash)
357        VALUES (?, ?, ?, ?, ?)
358        ON CONFLICT(chunk_id) DO UPDATE SET
359            model = excluded.model,
360            dims = excluded.dims,
361            created_at = excluded.created_at,
362            hash = excluded.hash
363        "#,
364    )
365    .bind(chunk_id)
366    .bind(model)
367    .bind(dims as i64)
368    .bind(now)
369    .bind(text_hash)
370    .execute(pool)
371    .await?;
372
373    sqlx::query(
374        r#"
375        INSERT INTO chunk_vectors (chunk_id, document_id, embedding)
376        VALUES (?, ?, ?)
377        ON CONFLICT(chunk_id) DO UPDATE SET
378            document_id = excluded.document_id,
379            embedding = excluded.embedding
380        "#,
381    )
382    .bind(chunk_id)
383    .bind(document_id)
384    .bind(blob)
385    .execute(pool)
386    .await?;
387
388    Ok(())
389}
390
391/// Compute SHA-256 hash of text content (hex-encoded).
392///
393/// Used for embedding staleness detection: if the hash of the current
394/// chunk text differs from the hash stored in the `embeddings` table,
395/// the embedding is stale and needs to be regenerated.
396fn hash_text(text: &str) -> String {
397    let mut hasher = Sha256::new();
398    hasher.update(text.as_bytes());
399    format!("{:x}", hasher.finalize())
400}