1use anyhow::{bail, Result};
19use chrono::NaiveDate;
20use serde::Serialize;
21use std::collections::HashMap;
22
23use crate::store::{ChunkCandidate, DocumentMetadata, Store};
24
25#[derive(Debug, Clone)]
27pub struct SearchParams {
28 pub hybrid_alpha: f64,
30 pub candidate_k_keyword: i64,
32 pub candidate_k_vector: i64,
34 pub final_limit: i64,
36}
37
38#[derive(Debug, Clone)]
40pub struct SearchRequest<'a> {
41 pub query: &'a str,
43 pub query_vec: Option<&'a [f32]>,
45 pub mode: &'a str,
47 pub source_filter: Option<&'a str>,
49 pub since: Option<&'a str>,
51 pub params: SearchParams,
53 pub explain: bool,
55}
56
57#[derive(Debug, Clone, Serialize)]
59pub struct SearchResultItem {
60 pub id: String,
62 pub score: f64,
64 pub title: Option<String>,
66 pub source: String,
68 pub source_id: String,
70 pub updated_at: String,
72 pub snippet: String,
74 pub source_url: Option<String>,
76 #[serde(skip_serializing_if = "Option::is_none")]
78 pub explain: Option<ScoreExplanation>,
79}
80
81#[derive(Debug, Clone, Serialize)]
83pub struct ScoreExplanation {
84 pub keyword_score: f64,
86 pub semantic_score: f64,
88 pub alpha: f64,
90 pub keyword_candidates: usize,
92 pub vector_candidates: usize,
94}
95
96pub async fn search<S: Store>(store: &S, req: &SearchRequest<'_>) -> Result<Vec<SearchResultItem>> {
102 if req.query.trim().is_empty() {
103 return Ok(Vec::new());
104 }
105
106 match req.mode {
107 "keyword" | "semantic" | "hybrid" => {}
108 _ => bail!(
109 "Unknown search mode: {}. Use keyword, semantic, or hybrid.",
110 req.mode
111 ),
112 }
113
114 let keyword_candidates = if req.mode == "keyword" || req.mode == "hybrid" {
115 store
116 .keyword_search(
117 req.query,
118 req.params.candidate_k_keyword,
119 req.source_filter,
120 req.since,
121 )
122 .await?
123 } else {
124 Vec::new()
125 };
126
127 let vector_candidates = if req.mode == "semantic" || req.mode == "hybrid" {
128 match req.query_vec {
129 Some(qv) => {
130 store
131 .vector_search(
132 qv,
133 req.params.candidate_k_vector,
134 req.source_filter,
135 req.since,
136 )
137 .await?
138 }
139 None => bail!("query_vec is required for semantic/hybrid mode"),
140 }
141 } else {
142 Vec::new()
143 };
144
145 if keyword_candidates.is_empty() && vector_candidates.is_empty() {
146 return Ok(Vec::new());
147 }
148
149 let norm_keyword = normalize_scores(&keyword_candidates);
150 let norm_vector = normalize_scores(&vector_candidates);
151
152 let kw_map: HashMap<&str, f64> = norm_keyword
153 .iter()
154 .map(|(c, s)| (c.chunk_id.as_str(), *s))
155 .collect();
156 let vec_map: HashMap<&str, f64> = norm_vector
157 .iter()
158 .map(|(c, s)| (c.chunk_id.as_str(), *s))
159 .collect();
160
161 let mut all_chunks: HashMap<String, &ChunkCandidate> = HashMap::new();
162 for c in &keyword_candidates {
163 all_chunks.entry(c.chunk_id.clone()).or_insert(c);
164 }
165 for c in &vector_candidates {
166 all_chunks.entry(c.chunk_id.clone()).or_insert(c);
167 }
168
169 let effective_alpha = match req.mode {
170 "keyword" => 0.0,
171 "semantic" => 1.0,
172 _ => req.params.hybrid_alpha,
173 };
174
175 struct ScoredChunk {
176 document_id: String,
177 hybrid_score: f64,
178 keyword_score: f64,
179 semantic_score: f64,
180 snippet: String,
181 }
182
183 let kw_count = keyword_candidates.len();
184 let vec_count = vector_candidates.len();
185
186 let mut scored_chunks: Vec<ScoredChunk> = all_chunks
187 .iter()
188 .map(|(chunk_id, cand)| {
189 let k = kw_map.get(chunk_id.as_str()).copied().unwrap_or(0.0);
190 let v = vec_map.get(chunk_id.as_str()).copied().unwrap_or(0.0);
191 let hybrid = (1.0 - effective_alpha) * k + effective_alpha * v;
192 ScoredChunk {
193 document_id: cand.document_id.clone(),
194 hybrid_score: hybrid,
195 keyword_score: k,
196 semantic_score: v,
197 snippet: cand.snippet.clone(),
198 }
199 })
200 .collect();
201
202 struct DocResult {
203 doc_id: String,
204 doc_score: f64,
205 keyword_score: f64,
206 semantic_score: f64,
207 best_snippet: String,
208 }
209
210 let mut doc_map: HashMap<String, DocResult> = HashMap::new();
211
212 scored_chunks.sort_by(|a, b| {
213 b.hybrid_score
214 .partial_cmp(&a.hybrid_score)
215 .unwrap_or(std::cmp::Ordering::Equal)
216 });
217
218 for sc in &scored_chunks {
219 let entry = doc_map
220 .entry(sc.document_id.clone())
221 .or_insert_with(|| DocResult {
222 doc_id: sc.document_id.clone(),
223 doc_score: sc.hybrid_score,
224 keyword_score: sc.keyword_score,
225 semantic_score: sc.semantic_score,
226 best_snippet: sc.snippet.clone(),
227 });
228 if sc.hybrid_score > entry.doc_score {
229 entry.doc_score = sc.hybrid_score;
230 entry.keyword_score = sc.keyword_score;
231 entry.semantic_score = sc.semantic_score;
232 entry.best_snippet = sc.snippet.clone();
233 }
234 }
235
236 let mut results: Vec<SearchResultItem> = Vec::new();
237
238 for doc_result in doc_map.values() {
239 let meta: Option<DocumentMetadata> =
240 store.get_document_metadata(&doc_result.doc_id).await?;
241
242 if let Some(meta) = meta {
243 if let Some(src) = req.source_filter {
244 if meta.source != src {
245 continue;
246 }
247 }
248
249 if let Some(since_str) = req.since {
250 let since_date = NaiveDate::parse_from_str(since_str, "%Y-%m-%d")?;
251 let since_ts = since_date
252 .and_hms_opt(0, 0, 0)
253 .unwrap()
254 .and_utc()
255 .timestamp();
256 if meta.updated_at < since_ts {
257 continue;
258 }
259 }
260
261 let updated_at_iso = format_ts_iso(meta.updated_at);
262
263 let explanation = if req.explain {
264 Some(ScoreExplanation {
265 keyword_score: doc_result.keyword_score,
266 semantic_score: doc_result.semantic_score,
267 alpha: effective_alpha,
268 keyword_candidates: kw_count,
269 vector_candidates: vec_count,
270 })
271 } else {
272 None
273 };
274
275 results.push(SearchResultItem {
276 id: meta.id,
277 score: doc_result.doc_score,
278 title: meta.title,
279 source: meta.source,
280 source_id: meta.source_id,
281 updated_at: updated_at_iso,
282 snippet: doc_result.best_snippet.clone(),
283 source_url: meta.source_url,
284 explain: explanation,
285 });
286 }
287 }
288
289 results.sort_by(|a, b| {
290 b.score
291 .partial_cmp(&a.score)
292 .unwrap_or(std::cmp::Ordering::Equal)
293 .then(b.updated_at.cmp(&a.updated_at))
294 .then(a.id.cmp(&b.id))
295 });
296
297 results.truncate(req.params.final_limit as usize);
298
299 Ok(results)
300}
301
302pub fn format_ts_iso(ts: i64) -> String {
304 chrono::DateTime::from_timestamp(ts, 0)
305 .map(|dt| dt.format("%Y-%m-%dT%H:%M:%SZ").to_string())
306 .unwrap_or_else(|| ts.to_string())
307}
308
309pub fn normalize_scores(candidates: &[ChunkCandidate]) -> Vec<(&ChunkCandidate, f64)> {
313 if candidates.is_empty() {
314 return Vec::new();
315 }
316
317 let s_min = candidates
318 .iter()
319 .map(|c| c.raw_score)
320 .fold(f64::INFINITY, f64::min);
321 let s_max = candidates
322 .iter()
323 .map(|c| c.raw_score)
324 .fold(f64::NEG_INFINITY, f64::max);
325
326 candidates
327 .iter()
328 .map(|c| {
329 let norm = if (s_max - s_min).abs() < f64::EPSILON {
330 1.0
331 } else {
332 (c.raw_score - s_min) / (s_max - s_min)
333 };
334 (c, norm)
335 })
336 .collect()
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342
343 fn make_candidate(chunk_id: &str, doc_id: &str, score: f64) -> ChunkCandidate {
344 ChunkCandidate {
345 chunk_id: chunk_id.to_string(),
346 document_id: doc_id.to_string(),
347 raw_score: score,
348 snippet: String::new(),
349 }
350 }
351
352 #[test]
353 fn test_normalize_empty() {
354 let result = normalize_scores(&[]);
355 assert!(result.is_empty());
356 }
357
358 #[test]
359 fn test_normalize_single() {
360 let candidates = vec![make_candidate("c1", "d1", 5.0)];
361 let result = normalize_scores(&candidates);
362 assert_eq!(result.len(), 1);
363 assert!((result[0].1 - 1.0).abs() < 1e-9);
364 }
365
366 #[test]
367 fn test_normalize_range() {
368 let candidates = vec![
369 make_candidate("c1", "d1", 10.0),
370 make_candidate("c2", "d2", 5.0),
371 make_candidate("c3", "d3", 0.0),
372 ];
373 let result = normalize_scores(&candidates);
374 assert!((result[0].1 - 1.0).abs() < 1e-9);
375 assert!((result[1].1 - 0.5).abs() < 1e-9);
376 assert!((result[2].1 - 0.0).abs() < 1e-9);
377 }
378
379 #[test]
380 fn test_normalize_all_equal() {
381 let candidates = vec![
382 make_candidate("c1", "d1", 3.0),
383 make_candidate("c2", "d2", 3.0),
384 ];
385 let result = normalize_scores(&candidates);
386 for (_, score) in &result {
387 assert!((*score - 1.0).abs() < 1e-9);
388 }
389 }
390
391 #[test]
392 fn test_scores_always_in_unit() {
393 let candidates = vec![
394 make_candidate("c1", "d1", -5.0),
395 make_candidate("c2", "d2", 100.0),
396 make_candidate("c3", "d3", 42.0),
397 ];
398 let result = normalize_scores(&candidates);
399 for (_, score) in &result {
400 assert!(
401 *score >= 0.0 && *score <= 1.0,
402 "Score out of range: {}",
403 score
404 );
405 }
406 }
407
408 #[test]
409 fn test_hybrid_alpha_zero_equals_keyword() {
410 let kw = vec![
411 make_candidate("c1", "d1", 10.0),
412 make_candidate("c2", "d2", 5.0),
413 make_candidate("c3", "d3", 1.0),
414 ];
415 let vec_cands = vec![
416 make_candidate("c1", "d1", 0.1),
417 make_candidate("c2", "d2", 0.9),
418 ];
419
420 let norm_k = normalize_scores(&kw);
421 let norm_v = normalize_scores(&vec_cands);
422
423 let kw_map: HashMap<&str, f64> = norm_k
424 .iter()
425 .map(|(c, s)| (c.chunk_id.as_str(), *s))
426 .collect();
427 let vec_map: HashMap<&str, f64> = norm_v
428 .iter()
429 .map(|(c, s)| (c.chunk_id.as_str(), *s))
430 .collect();
431
432 let alpha = 0.0;
433 let mut hybrid_scores: Vec<(&str, f64)> = Vec::new();
434 let mut kw_only: Vec<(&str, f64)> = Vec::new();
435
436 for c in &kw {
437 let k = kw_map.get(c.chunk_id.as_str()).copied().unwrap_or(0.0);
438 let v = vec_map.get(c.chunk_id.as_str()).copied().unwrap_or(0.0);
439 let h = (1.0 - alpha) * k + alpha * v;
440 hybrid_scores.push((c.chunk_id.as_str(), h));
441 kw_only.push((c.chunk_id.as_str(), k));
442 }
443
444 hybrid_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
445 kw_only.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
446
447 let h_order: Vec<&str> = hybrid_scores.iter().map(|(id, _)| *id).collect();
448 let k_order: Vec<&str> = kw_only.iter().map(|(id, _)| *id).collect();
449 assert_eq!(h_order, k_order, "alpha=0 should produce keyword ordering");
450 }
451
452 #[test]
453 fn test_hybrid_alpha_one_equals_vector() {
454 let kw = vec![
455 make_candidate("c1", "d1", 10.0),
456 make_candidate("c2", "d2", 5.0),
457 ];
458 let vec_cands = vec![
459 make_candidate("c1", "d1", 0.1),
460 make_candidate("c2", "d2", 0.9),
461 make_candidate("c3", "d3", 0.5),
462 ];
463
464 let norm_k = normalize_scores(&kw);
465 let norm_v = normalize_scores(&vec_cands);
466
467 let kw_map: HashMap<&str, f64> = norm_k
468 .iter()
469 .map(|(c, s)| (c.chunk_id.as_str(), *s))
470 .collect();
471 let vec_map: HashMap<&str, f64> = norm_v
472 .iter()
473 .map(|(c, s)| (c.chunk_id.as_str(), *s))
474 .collect();
475
476 let alpha = 1.0;
477 let mut hybrid_scores: Vec<(&str, f64)> = Vec::new();
478 let mut vec_only: Vec<(&str, f64)> = Vec::new();
479
480 for c in &vec_cands {
481 let k = kw_map.get(c.chunk_id.as_str()).copied().unwrap_or(0.0);
482 let v = vec_map.get(c.chunk_id.as_str()).copied().unwrap_or(0.0);
483 let h = (1.0 - alpha) * k + alpha * v;
484 hybrid_scores.push((c.chunk_id.as_str(), h));
485 vec_only.push((c.chunk_id.as_str(), v));
486 }
487
488 hybrid_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
489 vec_only.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
490
491 let h_order: Vec<&str> = hybrid_scores.iter().map(|(id, _)| *id).collect();
492 let v_order: Vec<&str> = vec_only.iter().map(|(id, _)| *id).collect();
493 assert_eq!(h_order, v_order, "alpha=1 should produce vector ordering");
494 }
495}