context_harness_core/store/
memory.rs1use std::collections::HashMap;
8use std::sync::RwLock;
9
10use anyhow::Result;
11use async_trait::async_trait;
12
13use crate::models::{Chunk, Document};
14
15use super::{ChunkCandidate, ChunkResponse, DocumentMetadata, DocumentResponse, Store};
16
17struct StoredDoc {
18 doc: Document,
19 metadata_json_parsed: serde_json::Value,
20}
21
22struct StoredChunk {
23 chunk: Chunk,
24 document_id: String,
25}
26
27struct StoredVector {
28 chunk_id: String,
29 document_id: String,
30 vector: Vec<f32>,
31 _model: String,
32 _dims: usize,
33 _content_hash: String,
34}
35
36pub struct InMemoryStore {
38 docs: RwLock<HashMap<String, StoredDoc>>,
39 chunks: RwLock<Vec<StoredChunk>>,
40 vectors: RwLock<Vec<StoredVector>>,
41}
42
43impl InMemoryStore {
44 pub fn new() -> Self {
45 Self {
46 docs: RwLock::new(HashMap::new()),
47 chunks: RwLock::new(Vec::new()),
48 vectors: RwLock::new(Vec::new()),
49 }
50 }
51}
52
53impl Default for InMemoryStore {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59fn format_ts_iso(ts: i64) -> String {
60 chrono::DateTime::from_timestamp(ts, 0)
61 .map(|dt| dt.format("%Y-%m-%dT%H:%M:%SZ").to_string())
62 .unwrap_or_else(|| ts.to_string())
63}
64
65fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
66 if a.len() != b.len() || a.is_empty() {
67 return 0.0;
68 }
69 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
70 let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
71 let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
72 if mag_a < f32::EPSILON || mag_b < f32::EPSILON {
73 0.0
74 } else {
75 dot / (mag_a * mag_b)
76 }
77}
78
79#[async_trait]
80impl Store for InMemoryStore {
81 async fn upsert_document(&self, doc: &Document) -> Result<String> {
82 let parsed = serde_json::from_str(&doc.metadata_json).unwrap_or(serde_json::json!({}));
83 let mut docs = self.docs.write().unwrap();
84 docs.insert(
85 doc.id.clone(),
86 StoredDoc {
87 doc: doc.clone(),
88 metadata_json_parsed: parsed,
89 },
90 );
91 Ok(doc.id.clone())
92 }
93
94 async fn replace_chunks(
95 &self,
96 doc_id: &str,
97 chunks: &[Chunk],
98 vectors: Option<&[Vec<f32>]>,
99 ) -> Result<()> {
100 {
101 let mut stored = self.chunks.write().unwrap();
102 stored.retain(|sc| sc.document_id != doc_id);
103 for c in chunks {
104 stored.push(StoredChunk {
105 chunk: c.clone(),
106 document_id: doc_id.to_string(),
107 });
108 }
109 }
110 if let Some(vecs) = vectors {
111 let mut stored_vecs = self.vectors.write().unwrap();
112 stored_vecs.retain(|sv| sv.document_id != doc_id);
113 for (c, v) in chunks.iter().zip(vecs.iter()) {
114 stored_vecs.push(StoredVector {
115 chunk_id: c.id.clone(),
116 document_id: doc_id.to_string(),
117 vector: v.clone(),
118 _model: String::new(),
119 _dims: v.len(),
120 _content_hash: c.hash.clone(),
121 });
122 }
123 }
124 Ok(())
125 }
126
127 async fn upsert_embedding(
128 &self,
129 chunk_id: &str,
130 doc_id: &str,
131 vector: &[f32],
132 model: &str,
133 dims: usize,
134 content_hash: &str,
135 ) -> Result<()> {
136 let mut vecs = self.vectors.write().unwrap();
137 vecs.retain(|sv| sv.chunk_id != chunk_id);
138 vecs.push(StoredVector {
139 chunk_id: chunk_id.to_string(),
140 document_id: doc_id.to_string(),
141 vector: vector.to_vec(),
142 _model: model.to_string(),
143 _dims: dims,
144 _content_hash: content_hash.to_string(),
145 });
146 Ok(())
147 }
148
149 async fn get_document(&self, id: &str) -> Result<Option<DocumentResponse>> {
150 let docs = self.docs.read().unwrap();
151 let stored = match docs.get(id) {
152 Some(s) => s,
153 None => return Ok(None),
154 };
155 let chunks_guard = self.chunks.read().unwrap();
156 let mut chunk_responses: Vec<ChunkResponse> = chunks_guard
157 .iter()
158 .filter(|sc| sc.document_id == id)
159 .map(|sc| ChunkResponse {
160 index: sc.chunk.chunk_index,
161 text: sc.chunk.text.clone(),
162 })
163 .collect();
164 chunk_responses.sort_by_key(|c| c.index);
165
166 Ok(Some(DocumentResponse {
167 id: stored.doc.id.clone(),
168 source: stored.doc.source.clone(),
169 source_id: stored.doc.source_id.clone(),
170 source_url: stored.doc.source_url.clone(),
171 title: stored.doc.title.clone(),
172 author: stored.doc.author.clone(),
173 created_at: format_ts_iso(stored.doc.created_at),
174 updated_at: format_ts_iso(stored.doc.updated_at),
175 content_type: stored.doc.content_type.clone(),
176 body: stored.doc.body.clone(),
177 metadata: stored.metadata_json_parsed.clone(),
178 chunks: chunk_responses,
179 }))
180 }
181
182 async fn get_document_metadata(&self, id: &str) -> Result<Option<DocumentMetadata>> {
183 let docs = self.docs.read().unwrap();
184 Ok(docs.get(id).map(|s| DocumentMetadata {
185 id: s.doc.id.clone(),
186 title: s.doc.title.clone(),
187 source: s.doc.source.clone(),
188 source_id: s.doc.source_id.clone(),
189 source_url: s.doc.source_url.clone(),
190 updated_at: s.doc.updated_at,
191 }))
192 }
193
194 async fn keyword_search(
195 &self,
196 query: &str,
197 limit: i64,
198 _source: Option<&str>,
199 _since: Option<&str>,
200 ) -> Result<Vec<ChunkCandidate>> {
201 let query_lower = query.to_lowercase();
202 let terms: Vec<&str> = query_lower.split_whitespace().collect();
203 if terms.is_empty() {
204 return Ok(Vec::new());
205 }
206 let chunks_guard = self.chunks.read().unwrap();
207 let mut candidates: Vec<ChunkCandidate> = chunks_guard
208 .iter()
209 .filter_map(|sc| {
210 let text_lower = sc.chunk.text.to_lowercase();
211 let matches: usize = terms.iter().filter(|t| text_lower.contains(*t)).count();
212 if matches > 0 {
213 let snippet = sc.chunk.text.chars().take(240).collect::<String>();
214 Some(ChunkCandidate {
215 chunk_id: sc.chunk.id.clone(),
216 document_id: sc.document_id.clone(),
217 raw_score: matches as f64,
218 snippet,
219 })
220 } else {
221 None
222 }
223 })
224 .collect();
225 candidates.sort_by(|a, b| {
226 b.raw_score
227 .partial_cmp(&a.raw_score)
228 .unwrap_or(std::cmp::Ordering::Equal)
229 });
230 candidates.truncate(limit as usize);
231 Ok(candidates)
232 }
233
234 async fn vector_search(
235 &self,
236 query_vec: &[f32],
237 limit: i64,
238 _source: Option<&str>,
239 _since: Option<&str>,
240 ) -> Result<Vec<ChunkCandidate>> {
241 let vecs = self.vectors.read().unwrap();
242 let chunks_guard = self.chunks.read().unwrap();
243 let mut candidates: Vec<ChunkCandidate> = vecs
244 .iter()
245 .map(|sv| {
246 let sim = cosine_sim(query_vec, &sv.vector) as f64;
247 let snippet = chunks_guard
248 .iter()
249 .find(|sc| sc.chunk.id == sv.chunk_id)
250 .map(|sc| sc.chunk.text.chars().take(240).collect::<String>())
251 .unwrap_or_default();
252 ChunkCandidate {
253 chunk_id: sv.chunk_id.clone(),
254 document_id: sv.document_id.clone(),
255 raw_score: sim,
256 snippet,
257 }
258 })
259 .collect();
260 candidates.sort_by(|a, b| {
261 b.raw_score
262 .partial_cmp(&a.raw_score)
263 .unwrap_or(std::cmp::Ordering::Equal)
264 });
265 candidates.truncate(limit as usize);
266 Ok(candidates)
267 }
268}