1use axum::{
54 extract::{Path, State},
55 http::StatusCode,
56 response::{IntoResponse, Response},
57 routing::{get, post},
58 Json, Router,
59};
60use rmcp::transport::streamable_http_server::{
61 session::local::LocalSessionManager, StreamableHttpService,
62};
63use serde::Serialize;
64use std::sync::Arc;
65use tower_http::cors::{Any, CorsLayer};
66
67use crate::agent_script::{load_agent_definitions, LuaAgentAdapter};
68use crate::agents::{AgentInfo, AgentRegistry};
69use crate::config::Config;
70use crate::mcp::McpBridge;
71use crate::registry::RegistryManager;
72use crate::tool_script::{load_tool_definitions, validate_params, LuaToolAdapter, ToolInfo};
73use crate::traits::{ToolContext, ToolRegistry};
74
75#[derive(Clone)]
77struct AppState {
78 config: Arc<Config>,
80 tools: Arc<ToolRegistry>,
82 agents: Arc<AgentRegistry>,
84}
85
86type ExtState = (Arc<ToolRegistry>, Arc<AgentRegistry>);
88
89pub async fn run_server(config: &Config) -> anyhow::Result<()> {
106 run_server_with_extensions(
107 config,
108 Arc::new(ToolRegistry::new()),
109 Arc::new(AgentRegistry::new()),
110 )
111 .await
112}
113
114pub async fn run_server_with_extensions(
140 config: &Config,
141 extra_tools: Arc<ToolRegistry>,
142 extra_agents: Arc<AgentRegistry>,
143) -> anyhow::Result<()> {
144 let bind_addr = config.server.bind.clone();
145 let config = Arc::new(config.clone());
146
147 let mut tool_registry = ToolRegistry::with_builtins();
149
150 let lua_defs = load_tool_definitions(&config)?;
152 let configured_tool_names: Vec<String> = lua_defs.iter().map(|d| d.name.clone()).collect();
153 for def in lua_defs {
154 tool_registry.register(Box::new(LuaToolAdapter::new(def, config.clone())));
155 }
156
157 let reg_mgr = RegistryManager::from_config(&config);
159 for ext in reg_mgr.list_tools() {
160 if configured_tool_names.iter().any(|n| n == &ext.name) {
161 continue;
162 }
163 if !ext.script_path.exists() {
164 continue;
165 }
166 let tool_cfg = crate::config::ScriptToolConfig {
167 path: ext.script_path.clone(),
168 timeout: 30,
169 extra: toml::Table::new(),
170 };
171 match crate::tool_script::load_single_tool(&ext.name, &tool_cfg) {
172 Ok(def) => {
173 tool_registry.register(Box::new(LuaToolAdapter::new(def, config.clone())));
174 }
175 Err(e) => {
176 eprintln!(
177 "Warning: failed to load registry tool '{}': {}",
178 ext.name, e
179 );
180 }
181 }
182 }
183
184 let tool_count = tool_registry.len() + extra_tools.len();
186 if tool_count > 3 {
187 println!("Registered {} tools:", tool_count);
188 for t in tool_registry.tools() {
189 let tag = if t.is_builtin() { "builtin" } else { "lua" };
190 println!(" POST /tools/{} — {} ({})", t.name(), t.description(), tag);
191 }
192 for t in extra_tools.tools() {
193 println!(" POST /tools/{} — {} (rust)", t.name(), t.description());
194 }
195 }
196
197 let mut agent_registry = AgentRegistry::from_config(&config)?;
199
200 let lua_agents = load_agent_definitions(&config)?;
202 let configured_agent_names: Vec<String> = lua_agents.iter().map(|d| d.name.clone()).collect();
203 for def in lua_agents {
204 agent_registry.register(Box::new(LuaAgentAdapter::new(def, config.clone())));
205 }
206
207 for ext in reg_mgr.list_agents() {
209 if configured_agent_names.iter().any(|n| n == &ext.name) {
210 continue;
211 }
212 if !ext.script_path.exists() {
213 continue;
214 }
215 if ext.script_path.extension().is_some_and(|e| e == "lua") {
216 let agent_cfg = crate::config::ScriptAgentConfig {
217 path: ext.script_path.clone(),
218 timeout: 30,
219 extra: toml::Table::new(),
220 };
221 match crate::agent_script::load_single_agent(&ext.name, &agent_cfg) {
222 Ok(def) => {
223 agent_registry.register(Box::new(LuaAgentAdapter::new(def, config.clone())));
224 }
225 Err(e) => {
226 eprintln!(
227 "Warning: failed to load registry agent '{}': {}",
228 ext.name, e
229 );
230 }
231 }
232 }
233 }
234
235 let agent_count = agent_registry.len() + extra_agents.len();
236 if agent_count > 0 {
237 println!("Registered {} agents:", agent_count);
238 for a in agent_registry.agents() {
239 println!(
240 " POST /agents/{}/prompt — {} ({})",
241 a.name(),
242 a.description(),
243 a.source()
244 );
245 }
246 for a in extra_agents.agents() {
247 println!(
248 " POST /agents/{}/prompt — {} ({})",
249 a.name(),
250 a.description(),
251 a.source()
252 );
253 }
254 }
255
256 let tools = Arc::new(tool_registry);
257 let agents = Arc::new(agent_registry);
258
259 let state = AppState {
260 config: config.clone(),
261 tools: tools.clone(),
262 agents: agents.clone(),
263 };
264
265 let mcp_tools = tools.clone();
267 let mcp_extra = extra_tools.clone();
268 let mcp_agents = agents.clone();
269 let mcp_extra_agents = extra_agents.clone();
270 let mcp_config = config.clone();
271
272 let extra_state = (extra_tools.clone(), extra_agents);
273 let mcp_service = StreamableHttpService::new(
274 move || {
275 Ok(McpBridge::new(
276 mcp_config.clone(),
277 mcp_tools.clone(),
278 mcp_extra.clone(),
279 mcp_agents.clone(),
280 mcp_extra_agents.clone(),
281 ))
282 },
283 Arc::new(LocalSessionManager::default()),
284 Default::default(),
285 );
286
287 let cors = CorsLayer::new()
288 .allow_origin(Any)
289 .allow_methods(Any)
290 .allow_headers(Any);
291
292 let app = Router::new()
293 .route("/tools/list", get(handle_list_tools))
294 .route("/tools/{name}", post(handle_tool_call))
295 .route("/agents/list", get(handle_list_agents))
296 .route("/agents/{name}/prompt", post(handle_resolve_agent))
297 .route("/health", get(handle_health))
298 .with_state((state, extra_state))
299 .nest_service("/mcp", mcp_service)
300 .layer(cors);
301
302 println!("MCP server listening on http://{}", bind_addr);
303 println!(" MCP endpoint: http://{}/mcp", bind_addr);
304
305 let listener = tokio::net::TcpListener::bind(&bind_addr).await?;
306 axum::serve(listener, app).await?;
307
308 Ok(())
309}
310
311#[derive(Serialize)]
315struct ErrorBody {
316 error: ErrorDetail,
317}
318
319#[derive(Serialize)]
321struct ErrorDetail {
322 code: String,
324 message: String,
326}
327
328struct AppError {
330 status: StatusCode,
331 code: String,
332 message: String,
333}
334
335impl IntoResponse for AppError {
336 fn into_response(self) -> Response {
337 let body = ErrorBody {
338 error: ErrorDetail {
339 code: self.code,
340 message: self.message,
341 },
342 };
343 (self.status, Json(body)).into_response()
344 }
345}
346
347fn bad_request(message: impl Into<String>) -> AppError {
349 AppError {
350 status: StatusCode::BAD_REQUEST,
351 code: "bad_request".to_string(),
352 message: message.into(),
353 }
354}
355
356fn not_found(message: impl Into<String>) -> AppError {
358 AppError {
359 status: StatusCode::NOT_FOUND,
360 code: "not_found".to_string(),
361 message: message.into(),
362 }
363}
364
365fn timeout_error(message: impl Into<String>) -> AppError {
367 AppError {
368 status: StatusCode::REQUEST_TIMEOUT,
369 code: "timeout".to_string(),
370 message: message.into(),
371 }
372}
373
374fn tool_error(message: impl Into<String>) -> AppError {
376 AppError {
377 status: StatusCode::INTERNAL_SERVER_ERROR,
378 code: "tool_error".to_string(),
379 message: message.into(),
380 }
381}
382
383fn classify_tool_error(tool_name: &str, err: anyhow::Error) -> AppError {
388 let msg = err.to_string();
389
390 if msg.contains("not found") {
391 not_found(format!("{}: {}", tool_name, msg))
392 } else if msg.contains("must not be empty")
393 || msg.contains("embeddings")
394 || msg.contains("disabled")
395 || msg.contains("invalid")
396 {
397 let mut e = bad_request(format!("{}: {}", tool_name, msg));
399 if msg.contains("embeddings") || msg.contains("disabled") {
401 e.code = "embeddings_disabled".to_string();
402 }
403 e
404 } else if msg.contains("timed out") {
405 timeout_error(format!("{}: {}", tool_name, msg))
406 } else {
407 tool_error(format!("{}: {}", tool_name, msg))
408 }
409}
410
411#[derive(Serialize)]
415struct HealthResponse {
416 status: String,
418 version: String,
420}
421
422async fn handle_health() -> Json<HealthResponse> {
427 Json(HealthResponse {
428 status: "ok".to_string(),
429 version: env!("CARGO_PKG_VERSION").to_string(),
430 })
431}
432
433#[derive(Serialize)]
437struct ToolListResponse {
438 tools: Vec<ToolInfo>,
440}
441
442async fn handle_list_tools(
448 State((state, (extra_tools, _extra_agents))): State<(AppState, ExtState)>,
449) -> Json<ToolListResponse> {
450 let mut tools: Vec<ToolInfo> = state
451 .tools
452 .tools()
453 .iter()
454 .map(|t| ToolInfo {
455 name: t.name().to_string(),
456 description: t.description().to_string(),
457 builtin: t.is_builtin(),
458 parameters: t.parameters_schema(),
459 })
460 .collect();
461
462 for t in extra_tools.tools() {
464 tools.push(ToolInfo {
465 name: t.name().to_string(),
466 description: t.description().to_string(),
467 builtin: false,
468 parameters: t.parameters_schema(),
469 });
470 }
471
472 Json(ToolListResponse { tools })
473}
474
475async fn handle_tool_call(
486 State((state, (extra_tools, _extra_agents))): State<(AppState, ExtState)>,
487 Path(name): Path<String>,
488 Json(params): Json<serde_json::Value>,
489) -> Result<Json<serde_json::Value>, AppError> {
490 let tool = state
492 .tools
493 .find(&name)
494 .or_else(|| extra_tools.find(&name))
495 .ok_or_else(|| not_found(format!("no tool registered with name: {}", name)))?;
496
497 let validated_params = validate_params(&tool.parameters_schema(), ¶ms)
499 .map_err(|e| bad_request(e.to_string()))?;
500
501 let ctx = ToolContext::new(state.config.clone());
503 let result = tool
504 .execute(validated_params, &ctx)
505 .await
506 .map_err(|e| classify_tool_error(&name, e))?;
507
508 Ok(Json(serde_json::json!({ "result": result })))
509}
510
511#[derive(Serialize)]
515struct AgentListResponse {
516 agents: Vec<AgentInfo>,
518}
519
520async fn handle_list_agents(
525 State((state, (_extra_tools, extra_agents))): State<(AppState, ExtState)>,
526) -> Json<AgentListResponse> {
527 let mut agents: Vec<AgentInfo> = state
528 .agents
529 .agents()
530 .iter()
531 .map(|a| AgentInfo {
532 name: a.name().to_string(),
533 description: a.description().to_string(),
534 tools: a.tools(),
535 source: a.source().to_string(),
536 arguments: a.arguments(),
537 })
538 .collect();
539
540 for a in extra_agents.agents() {
542 agents.push(AgentInfo {
543 name: a.name().to_string(),
544 description: a.description().to_string(),
545 tools: a.tools(),
546 source: a.source().to_string(),
547 arguments: a.arguments(),
548 });
549 }
550
551 Json(AgentListResponse { agents })
552}
553
554async fn handle_resolve_agent(
565 State((state, (_extra_tools, extra_agents))): State<(AppState, ExtState)>,
566 Path(name): Path<String>,
567 Json(args): Json<serde_json::Value>,
568) -> Result<Json<serde_json::Value>, AppError> {
569 let agent = state
570 .agents
571 .find(&name)
572 .or_else(|| extra_agents.find(&name))
573 .ok_or_else(|| not_found(format!("no agent registered with name: {}", name)))?;
574
575 let ctx = ToolContext::new(state.config.clone());
576 let prompt = agent
577 .resolve(args, &ctx)
578 .await
579 .map_err(|e| tool_error(format!("agent '{}': {}", name, e)))?;
580
581 Ok(Json(serde_json::to_value(prompt).map_err(|e| {
582 tool_error(format!("failed to serialize agent prompt: {}", e))
583 })?))
584}