1use anyhow::Result;
7use async_trait::async_trait;
8use sqlx::{Row, SqlitePool};
9
10use context_harness_core::embedding::{blob_to_vec, cosine_similarity, vec_to_blob};
11use context_harness_core::models::{Chunk, Document};
12use context_harness_core::store::{
13 ChunkCandidate, ChunkResponse, DocumentMetadata, DocumentResponse, Store,
14};
15
16pub struct SqliteStore {
22 pool: SqlitePool,
23}
24
25impl SqliteStore {
26 pub fn new(pool: SqlitePool) -> Self {
27 Self { pool }
28 }
29
30 #[allow(dead_code)]
31 pub fn pool(&self) -> &SqlitePool {
32 &self.pool
33 }
34}
35
36fn fts_query_from_user_text(query: &str) -> String {
37 query
38 .split(|c: char| !(c.is_alphanumeric() || c == '_'))
39 .filter(|term| !term.is_empty())
40 .collect::<Vec<_>>()
41 .join(" ")
42}
43
44fn format_ts_iso(ts: i64) -> String {
45 chrono::DateTime::from_timestamp(ts, 0)
46 .map(|dt| dt.format("%Y-%m-%dT%H:%M:%SZ").to_string())
47 .unwrap_or_else(|| ts.to_string())
48}
49
50#[async_trait]
51impl Store for SqliteStore {
52 async fn upsert_document(&self, doc: &Document) -> Result<String> {
53 sqlx::query(
54 r#"
55 INSERT INTO documents (id, source, source_id, source_url, title, author,
56 created_at, updated_at, content_type, body,
57 metadata_json, raw_json, dedup_hash)
58 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
59 ON CONFLICT(source, source_id) DO UPDATE SET
60 source_url = excluded.source_url,
61 title = excluded.title,
62 author = excluded.author,
63 updated_at = excluded.updated_at,
64 content_type = excluded.content_type,
65 body = excluded.body,
66 metadata_json = excluded.metadata_json,
67 raw_json = excluded.raw_json,
68 dedup_hash = excluded.dedup_hash
69 "#,
70 )
71 .bind(&doc.id)
72 .bind(&doc.source)
73 .bind(&doc.source_id)
74 .bind(&doc.source_url)
75 .bind(&doc.title)
76 .bind(&doc.author)
77 .bind(doc.created_at)
78 .bind(doc.updated_at)
79 .bind(&doc.content_type)
80 .bind(&doc.body)
81 .bind(&doc.metadata_json)
82 .bind(&doc.raw_json)
83 .bind(&doc.dedup_hash)
84 .execute(&self.pool)
85 .await?;
86
87 Ok(doc.id.clone())
88 }
89
90 async fn replace_chunks(
91 &self,
92 doc_id: &str,
93 chunks: &[Chunk],
94 vectors: Option<&[Vec<f32>]>,
95 ) -> Result<()> {
96 let mut tx = self.pool.begin().await?;
97
98 sqlx::query(
99 "DELETE FROM chunk_vectors WHERE chunk_id IN (SELECT id FROM chunks WHERE document_id = ?)",
100 )
101 .bind(doc_id)
102 .execute(&mut *tx)
103 .await?;
104
105 sqlx::query(
106 "DELETE FROM embeddings WHERE chunk_id IN (SELECT id FROM chunks WHERE document_id = ?)",
107 )
108 .bind(doc_id)
109 .execute(&mut *tx)
110 .await?;
111
112 sqlx::query("DELETE FROM chunks_fts WHERE document_id = ?")
113 .bind(doc_id)
114 .execute(&mut *tx)
115 .await?;
116
117 sqlx::query("DELETE FROM chunks WHERE document_id = ?")
118 .bind(doc_id)
119 .execute(&mut *tx)
120 .await?;
121
122 for (i, chunk) in chunks.iter().enumerate() {
123 sqlx::query(
124 "INSERT INTO chunks (id, document_id, chunk_index, text, hash) VALUES (?, ?, ?, ?, ?)",
125 )
126 .bind(&chunk.id)
127 .bind(&chunk.document_id)
128 .bind(chunk.chunk_index)
129 .bind(&chunk.text)
130 .bind(&chunk.hash)
131 .execute(&mut *tx)
132 .await?;
133
134 sqlx::query("INSERT INTO chunks_fts (chunk_id, document_id, text) VALUES (?, ?, ?)")
135 .bind(&chunk.id)
136 .bind(&chunk.document_id)
137 .bind(&chunk.text)
138 .execute(&mut *tx)
139 .await?;
140
141 if let Some(vecs) = vectors {
142 if let Some(vec) = vecs.get(i) {
143 let blob = vec_to_blob(vec);
144 sqlx::query(
145 r#"
146 INSERT INTO chunk_vectors (chunk_id, document_id, embedding)
147 VALUES (?, ?, ?)
148 ON CONFLICT(chunk_id) DO UPDATE SET
149 document_id = excluded.document_id,
150 embedding = excluded.embedding
151 "#,
152 )
153 .bind(&chunk.id)
154 .bind(doc_id)
155 .bind(&blob)
156 .execute(&mut *tx)
157 .await?;
158 }
159 }
160 }
161
162 tx.commit().await?;
163 Ok(())
164 }
165
166 async fn upsert_embedding(
167 &self,
168 chunk_id: &str,
169 doc_id: &str,
170 vector: &[f32],
171 model: &str,
172 dims: usize,
173 content_hash: &str,
174 ) -> Result<()> {
175 let now = chrono::Utc::now().timestamp();
176 let blob = vec_to_blob(vector);
177
178 sqlx::query(
179 r#"
180 INSERT INTO embeddings (chunk_id, model, dims, created_at, hash)
181 VALUES (?, ?, ?, ?, ?)
182 ON CONFLICT(chunk_id) DO UPDATE SET
183 model = excluded.model,
184 dims = excluded.dims,
185 created_at = excluded.created_at,
186 hash = excluded.hash
187 "#,
188 )
189 .bind(chunk_id)
190 .bind(model)
191 .bind(dims as i64)
192 .bind(now)
193 .bind(content_hash)
194 .execute(&self.pool)
195 .await?;
196
197 sqlx::query(
198 r#"
199 INSERT INTO chunk_vectors (chunk_id, document_id, embedding)
200 VALUES (?, ?, ?)
201 ON CONFLICT(chunk_id) DO UPDATE SET
202 document_id = excluded.document_id,
203 embedding = excluded.embedding
204 "#,
205 )
206 .bind(chunk_id)
207 .bind(doc_id)
208 .bind(&blob)
209 .execute(&self.pool)
210 .await?;
211
212 Ok(())
213 }
214
215 async fn get_document(&self, id: &str) -> Result<Option<DocumentResponse>> {
216 let doc_row = sqlx::query(
217 "SELECT id, source, source_id, source_url, title, author, created_at, updated_at, content_type, body, metadata_json FROM documents WHERE id = ?",
218 )
219 .bind(id)
220 .fetch_optional(&self.pool)
221 .await?;
222
223 let doc_row = match doc_row {
224 Some(row) => row,
225 None => return Ok(None),
226 };
227
228 let created_at: i64 = doc_row.get("created_at");
229 let updated_at: i64 = doc_row.get("updated_at");
230 let metadata_json: String = doc_row.get("metadata_json");
231
232 let metadata: serde_json::Value =
233 serde_json::from_str(&metadata_json).unwrap_or(serde_json::json!({}));
234
235 let chunk_rows = sqlx::query(
236 "SELECT chunk_index, text FROM chunks WHERE document_id = ? ORDER BY chunk_index ASC",
237 )
238 .bind(id)
239 .fetch_all(&self.pool)
240 .await?;
241
242 let chunks: Vec<ChunkResponse> = chunk_rows
243 .iter()
244 .map(|row| ChunkResponse {
245 index: row.get("chunk_index"),
246 text: row.get("text"),
247 })
248 .collect();
249
250 Ok(Some(DocumentResponse {
251 id: doc_row.get("id"),
252 source: doc_row.get("source"),
253 source_id: doc_row.get("source_id"),
254 source_url: doc_row.get("source_url"),
255 title: doc_row.get("title"),
256 author: doc_row.get("author"),
257 created_at: format_ts_iso(created_at),
258 updated_at: format_ts_iso(updated_at),
259 content_type: doc_row.get("content_type"),
260 body: doc_row.get("body"),
261 metadata,
262 chunks,
263 }))
264 }
265
266 async fn get_document_metadata(&self, id: &str) -> Result<Option<DocumentMetadata>> {
267 let row = sqlx::query(
268 "SELECT id, title, source, source_id, updated_at, source_url FROM documents WHERE id = ?",
269 )
270 .bind(id)
271 .fetch_optional(&self.pool)
272 .await?;
273
274 Ok(row.map(|r| DocumentMetadata {
275 id: r.get("id"),
276 title: r.get("title"),
277 source: r.get("source"),
278 source_id: r.get("source_id"),
279 source_url: r.get("source_url"),
280 updated_at: r.get("updated_at"),
281 }))
282 }
283
284 async fn keyword_search(
285 &self,
286 query: &str,
287 limit: i64,
288 _source: Option<&str>,
289 _since: Option<&str>,
290 ) -> Result<Vec<ChunkCandidate>> {
291 let fts_query = fts_query_from_user_text(query);
292 if fts_query.is_empty() {
293 return Ok(Vec::new());
294 }
295
296 let rows = sqlx::query(
297 r#"
298 SELECT chunk_id, document_id, rank,
299 snippet(chunks_fts, 2, '>>>', '<<<', '...', 48) AS snippet
300 FROM chunks_fts
301 WHERE chunks_fts MATCH ?
302 ORDER BY rank
303 LIMIT ?
304 "#,
305 )
306 .bind(fts_query)
307 .bind(limit)
308 .fetch_all(&self.pool)
309 .await?;
310
311 let candidates: Vec<ChunkCandidate> = rows
312 .iter()
313 .map(|row| {
314 let rank: f64 = row.get("rank");
315 ChunkCandidate {
316 chunk_id: row.get("chunk_id"),
317 document_id: row.get("document_id"),
318 raw_score: -rank,
319 snippet: row.get("snippet"),
320 }
321 })
322 .collect();
323
324 Ok(candidates)
325 }
326
327 async fn vector_search(
328 &self,
329 query_vec: &[f32],
330 limit: i64,
331 _source: Option<&str>,
332 _since: Option<&str>,
333 ) -> Result<Vec<ChunkCandidate>> {
334 let rows = sqlx::query(
335 r#"
336 SELECT cv.chunk_id, cv.document_id, cv.embedding,
337 COALESCE(substr(c.text, 1, 240), '') AS snippet
338 FROM chunk_vectors cv
339 JOIN chunks c ON c.id = cv.chunk_id
340 "#,
341 )
342 .fetch_all(&self.pool)
343 .await?;
344
345 let mut candidates: Vec<ChunkCandidate> = rows
346 .iter()
347 .map(|row| {
348 let blob: Vec<u8> = row.get("embedding");
349 let vec = blob_to_vec(&blob);
350 let similarity = cosine_similarity(query_vec, &vec) as f64;
351 ChunkCandidate {
352 chunk_id: row.get("chunk_id"),
353 document_id: row.get("document_id"),
354 raw_score: similarity,
355 snippet: row.get("snippet"),
356 }
357 })
358 .collect();
359
360 candidates.sort_by(|a, b| {
361 b.raw_score
362 .partial_cmp(&a.raw_score)
363 .unwrap_or(std::cmp::Ordering::Equal)
364 });
365 candidates.truncate(limit as usize);
366
367 Ok(candidates)
368 }
369}