1use anyhow::{bail, Result};
22use sha2::{Digest, Sha256};
23use sqlx::{Row, SqlitePool};
24
25use crate::config::Config;
26use crate::db;
27use crate::embedding;
28
29pub async fn run_embed_pending(
46 config: &Config,
47 limit: Option<usize>,
48 batch_size_override: Option<usize>,
49 dry_run: bool,
50) -> Result<()> {
51 if !config.embedding.is_enabled() {
52 bail!("Embedding provider is disabled. Set [embedding] provider in config.");
53 }
54
55 let provider = embedding::create_provider(&config.embedding)?;
56 let model_name = provider.model_name().to_string();
57 let pool = db::connect(config).await?;
58 let batch_size = batch_size_override.unwrap_or(config.embedding.batch_size);
59
60 let pending = find_pending_chunks(&pool, &model_name, limit).await?;
62
63 if dry_run {
64 println!("embed pending (dry-run)");
65 println!(" chunks needing embeddings: {}", pending.len());
66 return Ok(());
67 }
68
69 if pending.is_empty() {
70 println!("embed pending");
71 println!(" all chunks up to date");
72 return Ok(());
73 }
74
75 let total = pending.len();
76 let mut embedded = 0u64;
77 let mut failed = 0u64;
78
79 for batch in pending.chunks(batch_size) {
80 let texts: Vec<String> = batch.iter().map(|p| p.text.clone()).collect();
81
82 match embedding::embed_texts(provider.as_ref(), &config.embedding, &texts).await {
83 Ok(vectors) => {
84 for (item, vec) in batch.iter().zip(vectors.iter()) {
85 let blob = embedding::vec_to_blob(vec);
86 upsert_embedding(
87 &pool,
88 &item.chunk_id,
89 &item.document_id,
90 &model_name,
91 provider.dims(),
92 &item.text_hash,
93 &blob,
94 )
95 .await?;
96 embedded += 1;
97 }
98 }
99 Err(e) => {
100 eprintln!("Warning: embedding batch failed: {}", e);
101 failed += batch.len() as u64;
102 }
103 }
104 }
105
106 println!("embed pending");
107 println!(" total pending: {}", total);
108 println!(" embedded: {}", embedded);
109 println!(" failed: {}", failed);
110
111 pool.close().await;
112 Ok(())
113}
114
115pub async fn run_embed_rebuild(config: &Config, batch_size_override: Option<usize>) -> Result<()> {
129 if !config.embedding.is_enabled() {
130 bail!("Embedding provider is disabled. Set [embedding] provider in config.");
131 }
132
133 let provider = embedding::create_provider(&config.embedding)?;
134 let model_name = provider.model_name().to_string();
135 let pool = db::connect(config).await?;
136 let batch_size = batch_size_override.unwrap_or(config.embedding.batch_size);
137
138 sqlx::query("DELETE FROM chunk_vectors")
140 .execute(&pool)
141 .await?;
142 sqlx::query("DELETE FROM embeddings").execute(&pool).await?;
143
144 println!("embed rebuild — cleared existing embeddings");
145
146 let all_chunks = find_pending_chunks(&pool, &model_name, None).await?;
148
149 if all_chunks.is_empty() {
150 println!(" no chunks to embed");
151 pool.close().await;
152 return Ok(());
153 }
154
155 let total = all_chunks.len();
156 let mut embedded = 0u64;
157 let mut failed = 0u64;
158
159 for batch in all_chunks.chunks(batch_size) {
160 let texts: Vec<String> = batch.iter().map(|p| p.text.clone()).collect();
161
162 match embedding::embed_texts(provider.as_ref(), &config.embedding, &texts).await {
163 Ok(vectors) => {
164 for (item, vec) in batch.iter().zip(vectors.iter()) {
165 let blob = embedding::vec_to_blob(vec);
166 upsert_embedding(
167 &pool,
168 &item.chunk_id,
169 &item.document_id,
170 &model_name,
171 provider.dims(),
172 &item.text_hash,
173 &blob,
174 )
175 .await?;
176 embedded += 1;
177 }
178 }
179 Err(e) => {
180 eprintln!("Warning: embedding batch failed: {}", e);
181 failed += batch.len() as u64;
182 }
183 }
184 }
185
186 println!("embed rebuild");
187 println!(" total chunks: {}", total);
188 println!(" embedded: {}", embedded);
189 println!(" failed: {}", failed);
190
191 pool.close().await;
192 Ok(())
193}
194
195pub async fn embed_chunks_inline(
207 config: &Config,
208 pool: &SqlitePool,
209 chunks: &[crate::models::Chunk],
210) -> (u64, u64) {
211 if !config.embedding.is_enabled() {
212 return (0, 0);
213 }
214
215 let provider = match embedding::create_provider(&config.embedding) {
216 Ok(p) => p,
217 Err(e) => {
218 eprintln!("Warning: could not create embedding provider: {}", e);
219 return (0, chunks.len() as u64);
220 }
221 };
222
223 let model_name = provider.model_name().to_string();
224 let mut embedded = 0u64;
225 let mut pending = 0u64;
226
227 for batch in chunks.chunks(config.embedding.batch_size) {
228 let mut need_embedding = Vec::new();
230 for chunk in batch {
231 let text_hash = hash_text(&chunk.text);
232 let existing: Option<String> =
233 sqlx::query_scalar("SELECT hash FROM embeddings WHERE chunk_id = ? AND model = ?")
234 .bind(&chunk.id)
235 .bind(&model_name)
236 .fetch_optional(pool)
237 .await
238 .unwrap_or(None);
239
240 if existing.as_deref() == Some(&text_hash) {
241 embedded += 1;
243 continue;
244 }
245
246 need_embedding.push((chunk, text_hash));
247 }
248
249 if need_embedding.is_empty() {
250 continue;
251 }
252
253 let texts: Vec<String> = need_embedding.iter().map(|(c, _)| c.text.clone()).collect();
254
255 match embedding::embed_texts(provider.as_ref(), &config.embedding, &texts).await {
256 Ok(vectors) => {
257 for ((chunk, text_hash), vec) in need_embedding.iter().zip(vectors.iter()) {
258 let blob = embedding::vec_to_blob(vec);
259 if let Err(e) = upsert_embedding(
260 pool,
261 &chunk.id,
262 &chunk.document_id,
263 &model_name,
264 provider.dims(),
265 text_hash,
266 &blob,
267 )
268 .await
269 {
270 eprintln!("Warning: failed to store embedding for {}: {}", chunk.id, e);
271 pending += 1;
272 } else {
273 embedded += 1;
274 }
275 }
276 }
277 Err(e) => {
278 eprintln!("Warning: embedding batch failed: {}", e);
279 pending += need_embedding.len() as u64;
280 }
281 }
282 }
283
284 (embedded, pending)
285}
286
287struct PendingChunk {
289 chunk_id: String,
290 document_id: String,
291 text: String,
292 text_hash: String,
293}
294
295async fn find_pending_chunks(
301 pool: &SqlitePool,
302 model: &str,
303 limit: Option<usize>,
304) -> Result<Vec<PendingChunk>> {
305 let limit_val = limit.unwrap_or(usize::MAX) as i64;
306
307 let rows = sqlx::query(
309 r#"
310 SELECT c.id AS chunk_id, c.document_id, c.text, c.hash AS chunk_hash
311 FROM chunks c
312 LEFT JOIN embeddings e ON e.chunk_id = c.id AND e.model = ?
313 WHERE e.chunk_id IS NULL OR e.hash != c.hash
314 ORDER BY c.document_id, c.chunk_index
315 LIMIT ?
316 "#,
317 )
318 .bind(model)
319 .bind(limit_val)
320 .fetch_all(pool)
321 .await?;
322
323 let results: Vec<PendingChunk> = rows
324 .iter()
325 .map(|row| {
326 let text: String = row.get("text");
327 let text_hash = hash_text(&text);
328 PendingChunk {
329 chunk_id: row.get("chunk_id"),
330 document_id: row.get("document_id"),
331 text,
332 text_hash,
333 }
334 })
335 .collect();
336
337 Ok(results)
338}
339
340async fn upsert_embedding(
344 pool: &SqlitePool,
345 chunk_id: &str,
346 document_id: &str,
347 model: &str,
348 dims: usize,
349 text_hash: &str,
350 blob: &[u8],
351) -> Result<()> {
352 let now = chrono::Utc::now().timestamp();
353
354 sqlx::query(
355 r#"
356 INSERT INTO embeddings (chunk_id, model, dims, created_at, hash)
357 VALUES (?, ?, ?, ?, ?)
358 ON CONFLICT(chunk_id) DO UPDATE SET
359 model = excluded.model,
360 dims = excluded.dims,
361 created_at = excluded.created_at,
362 hash = excluded.hash
363 "#,
364 )
365 .bind(chunk_id)
366 .bind(model)
367 .bind(dims as i64)
368 .bind(now)
369 .bind(text_hash)
370 .execute(pool)
371 .await?;
372
373 sqlx::query(
374 r#"
375 INSERT INTO chunk_vectors (chunk_id, document_id, embedding)
376 VALUES (?, ?, ?)
377 ON CONFLICT(chunk_id) DO UPDATE SET
378 document_id = excluded.document_id,
379 embedding = excluded.embedding
380 "#,
381 )
382 .bind(chunk_id)
383 .bind(document_id)
384 .bind(blob)
385 .execute(pool)
386 .await?;
387
388 Ok(())
389}
390
391fn hash_text(text: &str) -> String {
397 let mut hasher = Sha256::new();
398 hasher.update(text.as_bytes());
399 format!("{:x}", hasher.finalize())
400}