Skip to main content

context_harness/
sqlite_store.rs

1//! SQLite-backed [`Store`] implementation.
2//!
3//! Maps each [`Store`] operation to the existing SQLite queries used by
4//! the ingestion, search, and retrieval modules.
5
6use anyhow::Result;
7use async_trait::async_trait;
8use sqlx::{Row, SqlitePool};
9
10use context_harness_core::embedding::{blob_to_vec, cosine_similarity, vec_to_blob};
11use context_harness_core::models::{Chunk, Document};
12use context_harness_core::store::{
13    ChunkCandidate, ChunkResponse, DocumentMetadata, DocumentResponse, Store,
14};
15
16/// SQLite implementation of the [`Store`] trait.
17///
18/// Wraps a [`SqlitePool`] and translates every `Store` method into one
19/// or more SQL statements against the existing schema (documents, chunks,
20/// chunks_fts, chunk_vectors, embeddings).
21pub struct SqliteStore {
22    pool: SqlitePool,
23}
24
25impl SqliteStore {
26    pub fn new(pool: SqlitePool) -> Self {
27        Self { pool }
28    }
29
30    #[allow(dead_code)]
31    pub fn pool(&self) -> &SqlitePool {
32        &self.pool
33    }
34}
35
36fn fts_query_from_user_text(query: &str) -> String {
37    query
38        .split(|c: char| !(c.is_alphanumeric() || c == '_'))
39        .filter(|term| !term.is_empty())
40        .collect::<Vec<_>>()
41        .join(" ")
42}
43
44fn format_ts_iso(ts: i64) -> String {
45    chrono::DateTime::from_timestamp(ts, 0)
46        .map(|dt| dt.format("%Y-%m-%dT%H:%M:%SZ").to_string())
47        .unwrap_or_else(|| ts.to_string())
48}
49
50#[async_trait]
51impl Store for SqliteStore {
52    async fn upsert_document(&self, doc: &Document) -> Result<String> {
53        sqlx::query(
54            r#"
55            INSERT INTO documents (id, source, source_id, source_url, title, author,
56                                   created_at, updated_at, content_type, body,
57                                   metadata_json, raw_json, dedup_hash)
58            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
59            ON CONFLICT(source, source_id) DO UPDATE SET
60                source_url = excluded.source_url,
61                title = excluded.title,
62                author = excluded.author,
63                updated_at = excluded.updated_at,
64                content_type = excluded.content_type,
65                body = excluded.body,
66                metadata_json = excluded.metadata_json,
67                raw_json = excluded.raw_json,
68                dedup_hash = excluded.dedup_hash
69            "#,
70        )
71        .bind(&doc.id)
72        .bind(&doc.source)
73        .bind(&doc.source_id)
74        .bind(&doc.source_url)
75        .bind(&doc.title)
76        .bind(&doc.author)
77        .bind(doc.created_at)
78        .bind(doc.updated_at)
79        .bind(&doc.content_type)
80        .bind(&doc.body)
81        .bind(&doc.metadata_json)
82        .bind(&doc.raw_json)
83        .bind(&doc.dedup_hash)
84        .execute(&self.pool)
85        .await?;
86
87        Ok(doc.id.clone())
88    }
89
90    async fn replace_chunks(
91        &self,
92        doc_id: &str,
93        chunks: &[Chunk],
94        vectors: Option<&[Vec<f32>]>,
95    ) -> Result<()> {
96        let mut tx = self.pool.begin().await?;
97
98        sqlx::query(
99            "DELETE FROM chunk_vectors WHERE chunk_id IN (SELECT id FROM chunks WHERE document_id = ?)",
100        )
101        .bind(doc_id)
102        .execute(&mut *tx)
103        .await?;
104
105        sqlx::query(
106            "DELETE FROM embeddings WHERE chunk_id IN (SELECT id FROM chunks WHERE document_id = ?)",
107        )
108        .bind(doc_id)
109        .execute(&mut *tx)
110        .await?;
111
112        sqlx::query("DELETE FROM chunks_fts WHERE document_id = ?")
113            .bind(doc_id)
114            .execute(&mut *tx)
115            .await?;
116
117        sqlx::query("DELETE FROM chunks WHERE document_id = ?")
118            .bind(doc_id)
119            .execute(&mut *tx)
120            .await?;
121
122        for (i, chunk) in chunks.iter().enumerate() {
123            sqlx::query(
124                "INSERT INTO chunks (id, document_id, chunk_index, text, hash) VALUES (?, ?, ?, ?, ?)",
125            )
126            .bind(&chunk.id)
127            .bind(&chunk.document_id)
128            .bind(chunk.chunk_index)
129            .bind(&chunk.text)
130            .bind(&chunk.hash)
131            .execute(&mut *tx)
132            .await?;
133
134            sqlx::query("INSERT INTO chunks_fts (chunk_id, document_id, text) VALUES (?, ?, ?)")
135                .bind(&chunk.id)
136                .bind(&chunk.document_id)
137                .bind(&chunk.text)
138                .execute(&mut *tx)
139                .await?;
140
141            if let Some(vecs) = vectors {
142                if let Some(vec) = vecs.get(i) {
143                    let blob = vec_to_blob(vec);
144                    sqlx::query(
145                        r#"
146                        INSERT INTO chunk_vectors (chunk_id, document_id, embedding)
147                        VALUES (?, ?, ?)
148                        ON CONFLICT(chunk_id) DO UPDATE SET
149                            document_id = excluded.document_id,
150                            embedding = excluded.embedding
151                        "#,
152                    )
153                    .bind(&chunk.id)
154                    .bind(doc_id)
155                    .bind(&blob)
156                    .execute(&mut *tx)
157                    .await?;
158                }
159            }
160        }
161
162        tx.commit().await?;
163        Ok(())
164    }
165
166    async fn upsert_embedding(
167        &self,
168        chunk_id: &str,
169        doc_id: &str,
170        vector: &[f32],
171        model: &str,
172        dims: usize,
173        content_hash: &str,
174    ) -> Result<()> {
175        let now = chrono::Utc::now().timestamp();
176        let blob = vec_to_blob(vector);
177
178        sqlx::query(
179            r#"
180            INSERT INTO embeddings (chunk_id, model, dims, created_at, hash)
181            VALUES (?, ?, ?, ?, ?)
182            ON CONFLICT(chunk_id) DO UPDATE SET
183                model = excluded.model,
184                dims = excluded.dims,
185                created_at = excluded.created_at,
186                hash = excluded.hash
187            "#,
188        )
189        .bind(chunk_id)
190        .bind(model)
191        .bind(dims as i64)
192        .bind(now)
193        .bind(content_hash)
194        .execute(&self.pool)
195        .await?;
196
197        sqlx::query(
198            r#"
199            INSERT INTO chunk_vectors (chunk_id, document_id, embedding)
200            VALUES (?, ?, ?)
201            ON CONFLICT(chunk_id) DO UPDATE SET
202                document_id = excluded.document_id,
203                embedding = excluded.embedding
204            "#,
205        )
206        .bind(chunk_id)
207        .bind(doc_id)
208        .bind(&blob)
209        .execute(&self.pool)
210        .await?;
211
212        Ok(())
213    }
214
215    async fn get_document(&self, id: &str) -> Result<Option<DocumentResponse>> {
216        let doc_row = sqlx::query(
217            "SELECT id, source, source_id, source_url, title, author, created_at, updated_at, content_type, body, metadata_json FROM documents WHERE id = ?",
218        )
219        .bind(id)
220        .fetch_optional(&self.pool)
221        .await?;
222
223        let doc_row = match doc_row {
224            Some(row) => row,
225            None => return Ok(None),
226        };
227
228        let created_at: i64 = doc_row.get("created_at");
229        let updated_at: i64 = doc_row.get("updated_at");
230        let metadata_json: String = doc_row.get("metadata_json");
231
232        let metadata: serde_json::Value =
233            serde_json::from_str(&metadata_json).unwrap_or(serde_json::json!({}));
234
235        let chunk_rows = sqlx::query(
236            "SELECT chunk_index, text FROM chunks WHERE document_id = ? ORDER BY chunk_index ASC",
237        )
238        .bind(id)
239        .fetch_all(&self.pool)
240        .await?;
241
242        let chunks: Vec<ChunkResponse> = chunk_rows
243            .iter()
244            .map(|row| ChunkResponse {
245                index: row.get("chunk_index"),
246                text: row.get("text"),
247            })
248            .collect();
249
250        Ok(Some(DocumentResponse {
251            id: doc_row.get("id"),
252            source: doc_row.get("source"),
253            source_id: doc_row.get("source_id"),
254            source_url: doc_row.get("source_url"),
255            title: doc_row.get("title"),
256            author: doc_row.get("author"),
257            created_at: format_ts_iso(created_at),
258            updated_at: format_ts_iso(updated_at),
259            content_type: doc_row.get("content_type"),
260            body: doc_row.get("body"),
261            metadata,
262            chunks,
263        }))
264    }
265
266    async fn get_document_metadata(&self, id: &str) -> Result<Option<DocumentMetadata>> {
267        let row = sqlx::query(
268            "SELECT id, title, source, source_id, updated_at, source_url FROM documents WHERE id = ?",
269        )
270        .bind(id)
271        .fetch_optional(&self.pool)
272        .await?;
273
274        Ok(row.map(|r| DocumentMetadata {
275            id: r.get("id"),
276            title: r.get("title"),
277            source: r.get("source"),
278            source_id: r.get("source_id"),
279            source_url: r.get("source_url"),
280            updated_at: r.get("updated_at"),
281        }))
282    }
283
284    async fn keyword_search(
285        &self,
286        query: &str,
287        limit: i64,
288        _source: Option<&str>,
289        _since: Option<&str>,
290    ) -> Result<Vec<ChunkCandidate>> {
291        let fts_query = fts_query_from_user_text(query);
292        if fts_query.is_empty() {
293            return Ok(Vec::new());
294        }
295
296        let rows = sqlx::query(
297            r#"
298            SELECT chunk_id, document_id, rank,
299                   snippet(chunks_fts, 2, '>>>', '<<<', '...', 48) AS snippet
300            FROM chunks_fts
301            WHERE chunks_fts MATCH ?
302            ORDER BY rank
303            LIMIT ?
304            "#,
305        )
306        .bind(fts_query)
307        .bind(limit)
308        .fetch_all(&self.pool)
309        .await?;
310
311        let candidates: Vec<ChunkCandidate> = rows
312            .iter()
313            .map(|row| {
314                let rank: f64 = row.get("rank");
315                ChunkCandidate {
316                    chunk_id: row.get("chunk_id"),
317                    document_id: row.get("document_id"),
318                    raw_score: -rank,
319                    snippet: row.get("snippet"),
320                }
321            })
322            .collect();
323
324        Ok(candidates)
325    }
326
327    async fn vector_search(
328        &self,
329        query_vec: &[f32],
330        limit: i64,
331        _source: Option<&str>,
332        _since: Option<&str>,
333    ) -> Result<Vec<ChunkCandidate>> {
334        let rows = sqlx::query(
335            r#"
336            SELECT cv.chunk_id, cv.document_id, cv.embedding,
337                   COALESCE(substr(c.text, 1, 240), '') AS snippet
338            FROM chunk_vectors cv
339            JOIN chunks c ON c.id = cv.chunk_id
340            "#,
341        )
342        .fetch_all(&self.pool)
343        .await?;
344
345        let mut candidates: Vec<ChunkCandidate> = rows
346            .iter()
347            .map(|row| {
348                let blob: Vec<u8> = row.get("embedding");
349                let vec = blob_to_vec(&blob);
350                let similarity = cosine_similarity(query_vec, &vec) as f64;
351                ChunkCandidate {
352                    chunk_id: row.get("chunk_id"),
353                    document_id: row.get("document_id"),
354                    raw_score: similarity,
355                    snippet: row.get("snippet"),
356                }
357            })
358            .collect();
359
360        candidates.sort_by(|a, b| {
361            b.raw_score
362                .partial_cmp(&a.raw_score)
363                .unwrap_or(std::cmp::Ordering::Equal)
364        });
365        candidates.truncate(limit as usize);
366
367        Ok(candidates)
368    }
369}