context_harness/
sqlite_store.rs

1//! SQLite-backed [`Store`] implementation.
2//!
3//! Maps each [`Store`] operation to the existing SQLite queries used by
4//! the ingestion, search, and retrieval modules.
5
6use anyhow::Result;
7use async_trait::async_trait;
8use sqlx::{Row, SqlitePool};
9
10use context_harness_core::embedding::{blob_to_vec, cosine_similarity, vec_to_blob};
11use context_harness_core::models::{Chunk, Document};
12use context_harness_core::store::{
13    ChunkCandidate, ChunkResponse, DocumentMetadata, DocumentResponse, Store,
14};
15
16/// SQLite implementation of the [`Store`] trait.
17///
18/// Wraps a [`SqlitePool`] and translates every `Store` method into one
19/// or more SQL statements against the existing schema (documents, chunks,
20/// chunks_fts, chunk_vectors, embeddings).
21pub struct SqliteStore {
22    pool: SqlitePool,
23}
24
25impl SqliteStore {
26    pub fn new(pool: SqlitePool) -> Self {
27        Self { pool }
28    }
29
30    #[allow(dead_code)]
31    pub fn pool(&self) -> &SqlitePool {
32        &self.pool
33    }
34}
35
36fn format_ts_iso(ts: i64) -> String {
37    chrono::DateTime::from_timestamp(ts, 0)
38        .map(|dt| dt.format("%Y-%m-%dT%H:%M:%SZ").to_string())
39        .unwrap_or_else(|| ts.to_string())
40}
41
42#[async_trait]
43impl Store for SqliteStore {
44    async fn upsert_document(&self, doc: &Document) -> Result<String> {
45        sqlx::query(
46            r#"
47            INSERT INTO documents (id, source, source_id, source_url, title, author,
48                                   created_at, updated_at, content_type, body,
49                                   metadata_json, raw_json, dedup_hash)
50            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
51            ON CONFLICT(source, source_id) DO UPDATE SET
52                source_url = excluded.source_url,
53                title = excluded.title,
54                author = excluded.author,
55                updated_at = excluded.updated_at,
56                content_type = excluded.content_type,
57                body = excluded.body,
58                metadata_json = excluded.metadata_json,
59                raw_json = excluded.raw_json,
60                dedup_hash = excluded.dedup_hash
61            "#,
62        )
63        .bind(&doc.id)
64        .bind(&doc.source)
65        .bind(&doc.source_id)
66        .bind(&doc.source_url)
67        .bind(&doc.title)
68        .bind(&doc.author)
69        .bind(doc.created_at)
70        .bind(doc.updated_at)
71        .bind(&doc.content_type)
72        .bind(&doc.body)
73        .bind(&doc.metadata_json)
74        .bind(&doc.raw_json)
75        .bind(&doc.dedup_hash)
76        .execute(&self.pool)
77        .await?;
78
79        Ok(doc.id.clone())
80    }
81
82    async fn replace_chunks(
83        &self,
84        doc_id: &str,
85        chunks: &[Chunk],
86        vectors: Option<&[Vec<f32>]>,
87    ) -> Result<()> {
88        let mut tx = self.pool.begin().await?;
89
90        sqlx::query(
91            "DELETE FROM chunk_vectors WHERE chunk_id IN (SELECT id FROM chunks WHERE document_id = ?)",
92        )
93        .bind(doc_id)
94        .execute(&mut *tx)
95        .await?;
96
97        sqlx::query(
98            "DELETE FROM embeddings WHERE chunk_id IN (SELECT id FROM chunks WHERE document_id = ?)",
99        )
100        .bind(doc_id)
101        .execute(&mut *tx)
102        .await?;
103
104        sqlx::query("DELETE FROM chunks_fts WHERE document_id = ?")
105            .bind(doc_id)
106            .execute(&mut *tx)
107            .await?;
108
109        sqlx::query("DELETE FROM chunks WHERE document_id = ?")
110            .bind(doc_id)
111            .execute(&mut *tx)
112            .await?;
113
114        for (i, chunk) in chunks.iter().enumerate() {
115            sqlx::query(
116                "INSERT INTO chunks (id, document_id, chunk_index, text, hash) VALUES (?, ?, ?, ?, ?)",
117            )
118            .bind(&chunk.id)
119            .bind(&chunk.document_id)
120            .bind(chunk.chunk_index)
121            .bind(&chunk.text)
122            .bind(&chunk.hash)
123            .execute(&mut *tx)
124            .await?;
125
126            sqlx::query("INSERT INTO chunks_fts (chunk_id, document_id, text) VALUES (?, ?, ?)")
127                .bind(&chunk.id)
128                .bind(&chunk.document_id)
129                .bind(&chunk.text)
130                .execute(&mut *tx)
131                .await?;
132
133            if let Some(vecs) = vectors {
134                if let Some(vec) = vecs.get(i) {
135                    let blob = vec_to_blob(vec);
136                    sqlx::query(
137                        r#"
138                        INSERT INTO chunk_vectors (chunk_id, document_id, embedding)
139                        VALUES (?, ?, ?)
140                        ON CONFLICT(chunk_id) DO UPDATE SET
141                            document_id = excluded.document_id,
142                            embedding = excluded.embedding
143                        "#,
144                    )
145                    .bind(&chunk.id)
146                    .bind(doc_id)
147                    .bind(&blob)
148                    .execute(&mut *tx)
149                    .await?;
150                }
151            }
152        }
153
154        tx.commit().await?;
155        Ok(())
156    }
157
158    async fn upsert_embedding(
159        &self,
160        chunk_id: &str,
161        doc_id: &str,
162        vector: &[f32],
163        model: &str,
164        dims: usize,
165        content_hash: &str,
166    ) -> Result<()> {
167        let now = chrono::Utc::now().timestamp();
168        let blob = vec_to_blob(vector);
169
170        sqlx::query(
171            r#"
172            INSERT INTO embeddings (chunk_id, model, dims, created_at, hash)
173            VALUES (?, ?, ?, ?, ?)
174            ON CONFLICT(chunk_id) DO UPDATE SET
175                model = excluded.model,
176                dims = excluded.dims,
177                created_at = excluded.created_at,
178                hash = excluded.hash
179            "#,
180        )
181        .bind(chunk_id)
182        .bind(model)
183        .bind(dims as i64)
184        .bind(now)
185        .bind(content_hash)
186        .execute(&self.pool)
187        .await?;
188
189        sqlx::query(
190            r#"
191            INSERT INTO chunk_vectors (chunk_id, document_id, embedding)
192            VALUES (?, ?, ?)
193            ON CONFLICT(chunk_id) DO UPDATE SET
194                document_id = excluded.document_id,
195                embedding = excluded.embedding
196            "#,
197        )
198        .bind(chunk_id)
199        .bind(doc_id)
200        .bind(&blob)
201        .execute(&self.pool)
202        .await?;
203
204        Ok(())
205    }
206
207    async fn get_document(&self, id: &str) -> Result<Option<DocumentResponse>> {
208        let doc_row = sqlx::query(
209            "SELECT id, source, source_id, source_url, title, author, created_at, updated_at, content_type, body, metadata_json FROM documents WHERE id = ?",
210        )
211        .bind(id)
212        .fetch_optional(&self.pool)
213        .await?;
214
215        let doc_row = match doc_row {
216            Some(row) => row,
217            None => return Ok(None),
218        };
219
220        let created_at: i64 = doc_row.get("created_at");
221        let updated_at: i64 = doc_row.get("updated_at");
222        let metadata_json: String = doc_row.get("metadata_json");
223
224        let metadata: serde_json::Value =
225            serde_json::from_str(&metadata_json).unwrap_or(serde_json::json!({}));
226
227        let chunk_rows = sqlx::query(
228            "SELECT chunk_index, text FROM chunks WHERE document_id = ? ORDER BY chunk_index ASC",
229        )
230        .bind(id)
231        .fetch_all(&self.pool)
232        .await?;
233
234        let chunks: Vec<ChunkResponse> = chunk_rows
235            .iter()
236            .map(|row| ChunkResponse {
237                index: row.get("chunk_index"),
238                text: row.get("text"),
239            })
240            .collect();
241
242        Ok(Some(DocumentResponse {
243            id: doc_row.get("id"),
244            source: doc_row.get("source"),
245            source_id: doc_row.get("source_id"),
246            source_url: doc_row.get("source_url"),
247            title: doc_row.get("title"),
248            author: doc_row.get("author"),
249            created_at: format_ts_iso(created_at),
250            updated_at: format_ts_iso(updated_at),
251            content_type: doc_row.get("content_type"),
252            body: doc_row.get("body"),
253            metadata,
254            chunks,
255        }))
256    }
257
258    async fn get_document_metadata(&self, id: &str) -> Result<Option<DocumentMetadata>> {
259        let row = sqlx::query(
260            "SELECT id, title, source, source_id, updated_at, source_url FROM documents WHERE id = ?",
261        )
262        .bind(id)
263        .fetch_optional(&self.pool)
264        .await?;
265
266        Ok(row.map(|r| DocumentMetadata {
267            id: r.get("id"),
268            title: r.get("title"),
269            source: r.get("source"),
270            source_id: r.get("source_id"),
271            source_url: r.get("source_url"),
272            updated_at: r.get("updated_at"),
273        }))
274    }
275
276    async fn keyword_search(
277        &self,
278        query: &str,
279        limit: i64,
280        _source: Option<&str>,
281        _since: Option<&str>,
282    ) -> Result<Vec<ChunkCandidate>> {
283        let rows = sqlx::query(
284            r#"
285            SELECT chunk_id, document_id, rank,
286                   snippet(chunks_fts, 2, '>>>', '<<<', '...', 48) AS snippet
287            FROM chunks_fts
288            WHERE chunks_fts MATCH ?
289            ORDER BY rank
290            LIMIT ?
291            "#,
292        )
293        .bind(query)
294        .bind(limit)
295        .fetch_all(&self.pool)
296        .await?;
297
298        let candidates: Vec<ChunkCandidate> = rows
299            .iter()
300            .map(|row| {
301                let rank: f64 = row.get("rank");
302                ChunkCandidate {
303                    chunk_id: row.get("chunk_id"),
304                    document_id: row.get("document_id"),
305                    raw_score: -rank,
306                    snippet: row.get("snippet"),
307                }
308            })
309            .collect();
310
311        Ok(candidates)
312    }
313
314    async fn vector_search(
315        &self,
316        query_vec: &[f32],
317        limit: i64,
318        _source: Option<&str>,
319        _since: Option<&str>,
320    ) -> Result<Vec<ChunkCandidate>> {
321        let rows = sqlx::query(
322            r#"
323            SELECT cv.chunk_id, cv.document_id, cv.embedding,
324                   COALESCE(substr(c.text, 1, 240), '') AS snippet
325            FROM chunk_vectors cv
326            JOIN chunks c ON c.id = cv.chunk_id
327            "#,
328        )
329        .fetch_all(&self.pool)
330        .await?;
331
332        let mut candidates: Vec<ChunkCandidate> = rows
333            .iter()
334            .map(|row| {
335                let blob: Vec<u8> = row.get("embedding");
336                let vec = blob_to_vec(&blob);
337                let similarity = cosine_similarity(query_vec, &vec) as f64;
338                ChunkCandidate {
339                    chunk_id: row.get("chunk_id"),
340                    document_id: row.get("document_id"),
341                    raw_score: similarity,
342                    snippet: row.get("snippet"),
343                }
344            })
345            .collect();
346
347        candidates.sort_by(|a, b| {
348            b.raw_score
349                .partial_cmp(&a.raw_score)
350                .unwrap_or(std::cmp::Ordering::Equal)
351        });
352        candidates.truncate(limit as usize);
353
354        Ok(candidates)
355    }
356}