1use anyhow::Result;
7use async_trait::async_trait;
8use sqlx::{Row, SqlitePool};
9
10use context_harness_core::embedding::{blob_to_vec, cosine_similarity, vec_to_blob};
11use context_harness_core::models::{Chunk, Document};
12use context_harness_core::store::{
13 ChunkCandidate, ChunkResponse, DocumentMetadata, DocumentResponse, Store,
14};
15
16pub struct SqliteStore {
22 pool: SqlitePool,
23}
24
25impl SqliteStore {
26 pub fn new(pool: SqlitePool) -> Self {
27 Self { pool }
28 }
29
30 #[allow(dead_code)]
31 pub fn pool(&self) -> &SqlitePool {
32 &self.pool
33 }
34}
35
36fn format_ts_iso(ts: i64) -> String {
37 chrono::DateTime::from_timestamp(ts, 0)
38 .map(|dt| dt.format("%Y-%m-%dT%H:%M:%SZ").to_string())
39 .unwrap_or_else(|| ts.to_string())
40}
41
42#[async_trait]
43impl Store for SqliteStore {
44 async fn upsert_document(&self, doc: &Document) -> Result<String> {
45 sqlx::query(
46 r#"
47 INSERT INTO documents (id, source, source_id, source_url, title, author,
48 created_at, updated_at, content_type, body,
49 metadata_json, raw_json, dedup_hash)
50 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
51 ON CONFLICT(source, source_id) DO UPDATE SET
52 source_url = excluded.source_url,
53 title = excluded.title,
54 author = excluded.author,
55 updated_at = excluded.updated_at,
56 content_type = excluded.content_type,
57 body = excluded.body,
58 metadata_json = excluded.metadata_json,
59 raw_json = excluded.raw_json,
60 dedup_hash = excluded.dedup_hash
61 "#,
62 )
63 .bind(&doc.id)
64 .bind(&doc.source)
65 .bind(&doc.source_id)
66 .bind(&doc.source_url)
67 .bind(&doc.title)
68 .bind(&doc.author)
69 .bind(doc.created_at)
70 .bind(doc.updated_at)
71 .bind(&doc.content_type)
72 .bind(&doc.body)
73 .bind(&doc.metadata_json)
74 .bind(&doc.raw_json)
75 .bind(&doc.dedup_hash)
76 .execute(&self.pool)
77 .await?;
78
79 Ok(doc.id.clone())
80 }
81
82 async fn replace_chunks(
83 &self,
84 doc_id: &str,
85 chunks: &[Chunk],
86 vectors: Option<&[Vec<f32>]>,
87 ) -> Result<()> {
88 let mut tx = self.pool.begin().await?;
89
90 sqlx::query(
91 "DELETE FROM chunk_vectors WHERE chunk_id IN (SELECT id FROM chunks WHERE document_id = ?)",
92 )
93 .bind(doc_id)
94 .execute(&mut *tx)
95 .await?;
96
97 sqlx::query(
98 "DELETE FROM embeddings WHERE chunk_id IN (SELECT id FROM chunks WHERE document_id = ?)",
99 )
100 .bind(doc_id)
101 .execute(&mut *tx)
102 .await?;
103
104 sqlx::query("DELETE FROM chunks_fts WHERE document_id = ?")
105 .bind(doc_id)
106 .execute(&mut *tx)
107 .await?;
108
109 sqlx::query("DELETE FROM chunks WHERE document_id = ?")
110 .bind(doc_id)
111 .execute(&mut *tx)
112 .await?;
113
114 for (i, chunk) in chunks.iter().enumerate() {
115 sqlx::query(
116 "INSERT INTO chunks (id, document_id, chunk_index, text, hash) VALUES (?, ?, ?, ?, ?)",
117 )
118 .bind(&chunk.id)
119 .bind(&chunk.document_id)
120 .bind(chunk.chunk_index)
121 .bind(&chunk.text)
122 .bind(&chunk.hash)
123 .execute(&mut *tx)
124 .await?;
125
126 sqlx::query("INSERT INTO chunks_fts (chunk_id, document_id, text) VALUES (?, ?, ?)")
127 .bind(&chunk.id)
128 .bind(&chunk.document_id)
129 .bind(&chunk.text)
130 .execute(&mut *tx)
131 .await?;
132
133 if let Some(vecs) = vectors {
134 if let Some(vec) = vecs.get(i) {
135 let blob = vec_to_blob(vec);
136 sqlx::query(
137 r#"
138 INSERT INTO chunk_vectors (chunk_id, document_id, embedding)
139 VALUES (?, ?, ?)
140 ON CONFLICT(chunk_id) DO UPDATE SET
141 document_id = excluded.document_id,
142 embedding = excluded.embedding
143 "#,
144 )
145 .bind(&chunk.id)
146 .bind(doc_id)
147 .bind(&blob)
148 .execute(&mut *tx)
149 .await?;
150 }
151 }
152 }
153
154 tx.commit().await?;
155 Ok(())
156 }
157
158 async fn upsert_embedding(
159 &self,
160 chunk_id: &str,
161 doc_id: &str,
162 vector: &[f32],
163 model: &str,
164 dims: usize,
165 content_hash: &str,
166 ) -> Result<()> {
167 let now = chrono::Utc::now().timestamp();
168 let blob = vec_to_blob(vector);
169
170 sqlx::query(
171 r#"
172 INSERT INTO embeddings (chunk_id, model, dims, created_at, hash)
173 VALUES (?, ?, ?, ?, ?)
174 ON CONFLICT(chunk_id) DO UPDATE SET
175 model = excluded.model,
176 dims = excluded.dims,
177 created_at = excluded.created_at,
178 hash = excluded.hash
179 "#,
180 )
181 .bind(chunk_id)
182 .bind(model)
183 .bind(dims as i64)
184 .bind(now)
185 .bind(content_hash)
186 .execute(&self.pool)
187 .await?;
188
189 sqlx::query(
190 r#"
191 INSERT INTO chunk_vectors (chunk_id, document_id, embedding)
192 VALUES (?, ?, ?)
193 ON CONFLICT(chunk_id) DO UPDATE SET
194 document_id = excluded.document_id,
195 embedding = excluded.embedding
196 "#,
197 )
198 .bind(chunk_id)
199 .bind(doc_id)
200 .bind(&blob)
201 .execute(&self.pool)
202 .await?;
203
204 Ok(())
205 }
206
207 async fn get_document(&self, id: &str) -> Result<Option<DocumentResponse>> {
208 let doc_row = sqlx::query(
209 "SELECT id, source, source_id, source_url, title, author, created_at, updated_at, content_type, body, metadata_json FROM documents WHERE id = ?",
210 )
211 .bind(id)
212 .fetch_optional(&self.pool)
213 .await?;
214
215 let doc_row = match doc_row {
216 Some(row) => row,
217 None => return Ok(None),
218 };
219
220 let created_at: i64 = doc_row.get("created_at");
221 let updated_at: i64 = doc_row.get("updated_at");
222 let metadata_json: String = doc_row.get("metadata_json");
223
224 let metadata: serde_json::Value =
225 serde_json::from_str(&metadata_json).unwrap_or(serde_json::json!({}));
226
227 let chunk_rows = sqlx::query(
228 "SELECT chunk_index, text FROM chunks WHERE document_id = ? ORDER BY chunk_index ASC",
229 )
230 .bind(id)
231 .fetch_all(&self.pool)
232 .await?;
233
234 let chunks: Vec<ChunkResponse> = chunk_rows
235 .iter()
236 .map(|row| ChunkResponse {
237 index: row.get("chunk_index"),
238 text: row.get("text"),
239 })
240 .collect();
241
242 Ok(Some(DocumentResponse {
243 id: doc_row.get("id"),
244 source: doc_row.get("source"),
245 source_id: doc_row.get("source_id"),
246 source_url: doc_row.get("source_url"),
247 title: doc_row.get("title"),
248 author: doc_row.get("author"),
249 created_at: format_ts_iso(created_at),
250 updated_at: format_ts_iso(updated_at),
251 content_type: doc_row.get("content_type"),
252 body: doc_row.get("body"),
253 metadata,
254 chunks,
255 }))
256 }
257
258 async fn get_document_metadata(&self, id: &str) -> Result<Option<DocumentMetadata>> {
259 let row = sqlx::query(
260 "SELECT id, title, source, source_id, updated_at, source_url FROM documents WHERE id = ?",
261 )
262 .bind(id)
263 .fetch_optional(&self.pool)
264 .await?;
265
266 Ok(row.map(|r| DocumentMetadata {
267 id: r.get("id"),
268 title: r.get("title"),
269 source: r.get("source"),
270 source_id: r.get("source_id"),
271 source_url: r.get("source_url"),
272 updated_at: r.get("updated_at"),
273 }))
274 }
275
276 async fn keyword_search(
277 &self,
278 query: &str,
279 limit: i64,
280 _source: Option<&str>,
281 _since: Option<&str>,
282 ) -> Result<Vec<ChunkCandidate>> {
283 let rows = sqlx::query(
284 r#"
285 SELECT chunk_id, document_id, rank,
286 snippet(chunks_fts, 2, '>>>', '<<<', '...', 48) AS snippet
287 FROM chunks_fts
288 WHERE chunks_fts MATCH ?
289 ORDER BY rank
290 LIMIT ?
291 "#,
292 )
293 .bind(query)
294 .bind(limit)
295 .fetch_all(&self.pool)
296 .await?;
297
298 let candidates: Vec<ChunkCandidate> = rows
299 .iter()
300 .map(|row| {
301 let rank: f64 = row.get("rank");
302 ChunkCandidate {
303 chunk_id: row.get("chunk_id"),
304 document_id: row.get("document_id"),
305 raw_score: -rank,
306 snippet: row.get("snippet"),
307 }
308 })
309 .collect();
310
311 Ok(candidates)
312 }
313
314 async fn vector_search(
315 &self,
316 query_vec: &[f32],
317 limit: i64,
318 _source: Option<&str>,
319 _since: Option<&str>,
320 ) -> Result<Vec<ChunkCandidate>> {
321 let rows = sqlx::query(
322 r#"
323 SELECT cv.chunk_id, cv.document_id, cv.embedding,
324 COALESCE(substr(c.text, 1, 240), '') AS snippet
325 FROM chunk_vectors cv
326 JOIN chunks c ON c.id = cv.chunk_id
327 "#,
328 )
329 .fetch_all(&self.pool)
330 .await?;
331
332 let mut candidates: Vec<ChunkCandidate> = rows
333 .iter()
334 .map(|row| {
335 let blob: Vec<u8> = row.get("embedding");
336 let vec = blob_to_vec(&blob);
337 let similarity = cosine_similarity(query_vec, &vec) as f64;
338 ChunkCandidate {
339 chunk_id: row.get("chunk_id"),
340 document_id: row.get("document_id"),
341 raw_score: similarity,
342 snippet: row.get("snippet"),
343 }
344 })
345 .collect();
346
347 candidates.sort_by(|a, b| {
348 b.raw_score
349 .partial_cmp(&a.raw_score)
350 .unwrap_or(std::cmp::Ordering::Equal)
351 });
352 candidates.truncate(limit as usize);
353
354 Ok(candidates)
355 }
356}