context_harness_core/store/
memory.rs

1//! In-memory [`Store`] implementation for testing and WASM targets.
2//!
3//! Uses `HashMap` and `Vec` behind `std::sync::RwLock` for thread safety.
4//! Vector search is brute-force cosine similarity over all stored vectors.
5//! Keyword search returns an empty result set (no FTS index).
6
7use std::collections::HashMap;
8use std::sync::RwLock;
9
10use anyhow::Result;
11use async_trait::async_trait;
12
13use crate::models::{Chunk, Document};
14
15use super::{ChunkCandidate, ChunkResponse, DocumentMetadata, DocumentResponse, Store};
16
17struct StoredDoc {
18    doc: Document,
19    metadata_json_parsed: serde_json::Value,
20}
21
22struct StoredChunk {
23    chunk: Chunk,
24    document_id: String,
25}
26
27struct StoredVector {
28    chunk_id: String,
29    document_id: String,
30    vector: Vec<f32>,
31    _model: String,
32    _dims: usize,
33    _content_hash: String,
34}
35
36/// In-memory store for testing and WASM environments.
37pub struct InMemoryStore {
38    docs: RwLock<HashMap<String, StoredDoc>>,
39    chunks: RwLock<Vec<StoredChunk>>,
40    vectors: RwLock<Vec<StoredVector>>,
41}
42
43impl InMemoryStore {
44    pub fn new() -> Self {
45        Self {
46            docs: RwLock::new(HashMap::new()),
47            chunks: RwLock::new(Vec::new()),
48            vectors: RwLock::new(Vec::new()),
49        }
50    }
51}
52
53impl Default for InMemoryStore {
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59fn format_ts_iso(ts: i64) -> String {
60    chrono::DateTime::from_timestamp(ts, 0)
61        .map(|dt| dt.format("%Y-%m-%dT%H:%M:%SZ").to_string())
62        .unwrap_or_else(|| ts.to_string())
63}
64
65fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
66    if a.len() != b.len() || a.is_empty() {
67        return 0.0;
68    }
69    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
70    let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
71    let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
72    if mag_a < f32::EPSILON || mag_b < f32::EPSILON {
73        0.0
74    } else {
75        dot / (mag_a * mag_b)
76    }
77}
78
79#[async_trait]
80impl Store for InMemoryStore {
81    async fn upsert_document(&self, doc: &Document) -> Result<String> {
82        let parsed = serde_json::from_str(&doc.metadata_json).unwrap_or(serde_json::json!({}));
83        let mut docs = self.docs.write().unwrap();
84        docs.insert(
85            doc.id.clone(),
86            StoredDoc {
87                doc: doc.clone(),
88                metadata_json_parsed: parsed,
89            },
90        );
91        Ok(doc.id.clone())
92    }
93
94    async fn replace_chunks(
95        &self,
96        doc_id: &str,
97        chunks: &[Chunk],
98        vectors: Option<&[Vec<f32>]>,
99    ) -> Result<()> {
100        {
101            let mut stored = self.chunks.write().unwrap();
102            stored.retain(|sc| sc.document_id != doc_id);
103            for c in chunks {
104                stored.push(StoredChunk {
105                    chunk: c.clone(),
106                    document_id: doc_id.to_string(),
107                });
108            }
109        }
110        if let Some(vecs) = vectors {
111            let mut stored_vecs = self.vectors.write().unwrap();
112            stored_vecs.retain(|sv| sv.document_id != doc_id);
113            for (c, v) in chunks.iter().zip(vecs.iter()) {
114                stored_vecs.push(StoredVector {
115                    chunk_id: c.id.clone(),
116                    document_id: doc_id.to_string(),
117                    vector: v.clone(),
118                    _model: String::new(),
119                    _dims: v.len(),
120                    _content_hash: c.hash.clone(),
121                });
122            }
123        }
124        Ok(())
125    }
126
127    async fn upsert_embedding(
128        &self,
129        chunk_id: &str,
130        doc_id: &str,
131        vector: &[f32],
132        model: &str,
133        dims: usize,
134        content_hash: &str,
135    ) -> Result<()> {
136        let mut vecs = self.vectors.write().unwrap();
137        vecs.retain(|sv| sv.chunk_id != chunk_id);
138        vecs.push(StoredVector {
139            chunk_id: chunk_id.to_string(),
140            document_id: doc_id.to_string(),
141            vector: vector.to_vec(),
142            _model: model.to_string(),
143            _dims: dims,
144            _content_hash: content_hash.to_string(),
145        });
146        Ok(())
147    }
148
149    async fn get_document(&self, id: &str) -> Result<Option<DocumentResponse>> {
150        let docs = self.docs.read().unwrap();
151        let stored = match docs.get(id) {
152            Some(s) => s,
153            None => return Ok(None),
154        };
155        let chunks_guard = self.chunks.read().unwrap();
156        let mut chunk_responses: Vec<ChunkResponse> = chunks_guard
157            .iter()
158            .filter(|sc| sc.document_id == id)
159            .map(|sc| ChunkResponse {
160                index: sc.chunk.chunk_index,
161                text: sc.chunk.text.clone(),
162            })
163            .collect();
164        chunk_responses.sort_by_key(|c| c.index);
165
166        Ok(Some(DocumentResponse {
167            id: stored.doc.id.clone(),
168            source: stored.doc.source.clone(),
169            source_id: stored.doc.source_id.clone(),
170            source_url: stored.doc.source_url.clone(),
171            title: stored.doc.title.clone(),
172            author: stored.doc.author.clone(),
173            created_at: format_ts_iso(stored.doc.created_at),
174            updated_at: format_ts_iso(stored.doc.updated_at),
175            content_type: stored.doc.content_type.clone(),
176            body: stored.doc.body.clone(),
177            metadata: stored.metadata_json_parsed.clone(),
178            chunks: chunk_responses,
179        }))
180    }
181
182    async fn get_document_metadata(&self, id: &str) -> Result<Option<DocumentMetadata>> {
183        let docs = self.docs.read().unwrap();
184        Ok(docs.get(id).map(|s| DocumentMetadata {
185            id: s.doc.id.clone(),
186            title: s.doc.title.clone(),
187            source: s.doc.source.clone(),
188            source_id: s.doc.source_id.clone(),
189            source_url: s.doc.source_url.clone(),
190            updated_at: s.doc.updated_at,
191        }))
192    }
193
194    async fn keyword_search(
195        &self,
196        query: &str,
197        limit: i64,
198        _source: Option<&str>,
199        _since: Option<&str>,
200    ) -> Result<Vec<ChunkCandidate>> {
201        let query_lower = query.to_lowercase();
202        let terms: Vec<&str> = query_lower.split_whitespace().collect();
203        if terms.is_empty() {
204            return Ok(Vec::new());
205        }
206        let chunks_guard = self.chunks.read().unwrap();
207        let mut candidates: Vec<ChunkCandidate> = chunks_guard
208            .iter()
209            .filter_map(|sc| {
210                let text_lower = sc.chunk.text.to_lowercase();
211                let matches: usize = terms.iter().filter(|t| text_lower.contains(*t)).count();
212                if matches > 0 {
213                    let snippet = sc.chunk.text.chars().take(240).collect::<String>();
214                    Some(ChunkCandidate {
215                        chunk_id: sc.chunk.id.clone(),
216                        document_id: sc.document_id.clone(),
217                        raw_score: matches as f64,
218                        snippet,
219                    })
220                } else {
221                    None
222                }
223            })
224            .collect();
225        candidates.sort_by(|a, b| {
226            b.raw_score
227                .partial_cmp(&a.raw_score)
228                .unwrap_or(std::cmp::Ordering::Equal)
229        });
230        candidates.truncate(limit as usize);
231        Ok(candidates)
232    }
233
234    async fn vector_search(
235        &self,
236        query_vec: &[f32],
237        limit: i64,
238        _source: Option<&str>,
239        _since: Option<&str>,
240    ) -> Result<Vec<ChunkCandidate>> {
241        let vecs = self.vectors.read().unwrap();
242        let chunks_guard = self.chunks.read().unwrap();
243        let mut candidates: Vec<ChunkCandidate> = vecs
244            .iter()
245            .map(|sv| {
246                let sim = cosine_sim(query_vec, &sv.vector) as f64;
247                let snippet = chunks_guard
248                    .iter()
249                    .find(|sc| sc.chunk.id == sv.chunk_id)
250                    .map(|sc| sc.chunk.text.chars().take(240).collect::<String>())
251                    .unwrap_or_default();
252                ChunkCandidate {
253                    chunk_id: sv.chunk_id.clone(),
254                    document_id: sv.document_id.clone(),
255                    raw_score: sim,
256                    snippet,
257                }
258            })
259            .collect();
260        candidates.sort_by(|a, b| {
261            b.raw_score
262                .partial_cmp(&a.raw_score)
263                .unwrap_or(std::cmp::Ordering::Equal)
264        });
265        candidates.truncate(limit as usize);
266        Ok(candidates)
267    }
268}