diff --git a/Cargo.toml b/Cargo.toml index 317e479bc..a817ee181 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [workspace] -members = ["crates/rmcp", "crates/rmcp-macros", "examples/*"] +members = ["crates/rmcp", "crates/rmcp-macros"] resolver = "2" [workspace.dependencies] diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 738f24462..9b5895038 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -13,6 +13,7 @@ documentation = "https://docs.rs/rmcp" all-features = true [dependencies] +async-stream = "0.3" serde = { version = "1.0", features = ["derive", "rc"] } serde_json = "1.0" thiserror = "2" diff --git a/crates/rmcp/src/.DS_Store b/crates/rmcp/src/.DS_Store new file mode 100644 index 000000000..f787591ae Binary files /dev/null and b/crates/rmcp/src/.DS_Store differ diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index ea05edd61..7695eb3f0 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -11,6 +11,19 @@ use crate::{ }, transport::IntoTransport, }; + +#[cfg(feature = "transport-sse-server")] +use axum::http::Extensions as AxumExtensions; + +#[cfg(not(feature = "transport-sse-server"))] +#[derive(Debug, Clone, Default)] +pub struct AxumExtensions; + +pub trait ProvidesAxiumExtensions { + fn get_extensions(&self) -> &AxumExtensions; + fn get_workspace_id(&self) -> String; +} + #[cfg(feature = "client")] mod client; #[cfg(feature = "client")] @@ -109,7 +122,7 @@ pub trait ServiceExt: Service + Sized { transport: T, ) -> impl Future, E>> + Send where - T: IntoTransport, + T: IntoTransport + ProvidesAxiumExtensions, E: std::error::Error + From + Send + Sync + 'static, Self: Sized, { @@ -121,7 +134,7 @@ pub trait ServiceExt: Service + Sized { ct: CancellationToken, ) -> impl Future, E>> + Send where - T: IntoTransport, + T: IntoTransport + ProvidesAxiumExtensions, E: std::error::Error + From + Send + Sync + 'static, Self: Sized; } @@ -474,6 +487,8 @@ pub struct RequestContext { pub extensions: Extensions, /// An interface to fetch the remote client or server pub peer: Peer, + pub req_extensions: AxumExtensions, + pub workspace_id: String, } /// Use this function to skip initialization process @@ -485,7 +500,7 @@ pub async fn serve_directly( where R: ServiceRole, S: Service, - T: IntoTransport, + T: IntoTransport + ProvidesAxiumExtensions, E: std::error::Error + Send + Sync + 'static, { serve_directly_with_ct(service, transport, peer_info, Default::default()).await @@ -501,11 +516,22 @@ pub async fn serve_directly_with_ct( where R: ServiceRole, S: Service, - T: IntoTransport, + T: IntoTransport + ProvidesAxiumExtensions, E: std::error::Error + Send + Sync + 'static, { let (peer, peer_rx) = Peer::new(Arc::new(AtomicU32RequestIdProvider::default()), peer_info); - serve_inner(service, transport, peer, peer_rx, ct).await + let req_extensions = transport.get_extensions().clone(); + let workspace_id = transport.get_workspace_id(); + serve_inner( + service, + transport, + peer, + peer_rx, + ct, + req_extensions, + workspace_id, + ) + .await } #[instrument(skip_all)] @@ -515,6 +541,8 @@ async fn serve_inner( peer: Peer, mut peer_rx: tokio::sync::mpsc::Receiver>, ct: CancellationToken, + req_extensions: AxumExtensions, + workspace_id: String, ) -> Result, E> where R: ServiceRole, @@ -669,6 +697,8 @@ where peer: peer.clone(), meta: request.get_meta().clone(), extensions: request.extensions().clone(), + req_extensions: req_extensions.clone(), + workspace_id: workspace_id.clone(), }; tokio::spawn(async move { let result = service.handle_request(request, context).await; diff --git a/crates/rmcp/src/service/client.rs b/crates/rmcp/src/service/client.rs index 75d317f99..782280a96 100644 --- a/crates/rmcp/src/service/client.rs +++ b/crates/rmcp/src/service/client.rs @@ -93,7 +93,7 @@ impl> ServiceExt for S { ct: CancellationToken, ) -> impl Future, E>> + Send where - T: IntoTransport, + T: IntoTransport + ProvidesAxiumExtensions, E: std::error::Error + From + Send + Sync + 'static, Self: Sized, { @@ -107,7 +107,7 @@ pub async fn serve_client( ) -> Result, E> where S: Service, - T: IntoTransport, + T: IntoTransport + ProvidesAxiumExtensions, E: std::error::Error + From + Send + Sync + 'static, { serve_client_with_ct(service, transport, Default::default()).await @@ -120,9 +120,11 @@ pub async fn serve_client_with_ct( ) -> Result, E> where S: Service, - T: IntoTransport, + T: IntoTransport + ProvidesAxiumExtensions, E: std::error::Error + From + Send + Sync + 'static, { + let req_extensions = transport.get_extensions().clone(); + let workspace_id = transport.get_workspace_id(); let (sink, stream) = transport.into_transport(); let mut sink = Box::pin(sink); let mut stream = Box::pin(stream); @@ -175,7 +177,17 @@ where ); sink.send(notification).await?; let (peer, peer_rx) = Peer::new(id_provider, initialize_result); - serve_inner(service, (sink, stream), peer, peer_rx, ct).await + + serve_inner( + service, + (sink, stream), + peer, + peer_rx, + ct, + req_extensions, + workspace_id, + ) + .await } macro_rules! method { diff --git a/crates/rmcp/src/service/server.rs b/crates/rmcp/src/service/server.rs index ab17c84f9..e427dfc62 100644 --- a/crates/rmcp/src/service/server.rs +++ b/crates/rmcp/src/service/server.rs @@ -60,7 +60,7 @@ impl> ServiceExt for S { ct: CancellationToken, ) -> impl Future, E>> + Send where - T: IntoTransport, + T: IntoTransport + ProvidesAxiumExtensions, E: std::error::Error + From + Send + Sync + 'static, Self: Sized, { @@ -74,7 +74,7 @@ pub async fn serve_server( ) -> Result, E> where S: Service, - T: IntoTransport, + T: IntoTransport + ProvidesAxiumExtensions, E: std::error::Error + From + Send + Sync + 'static, { serve_server_with_ct(service, transport, CancellationToken::new()).await @@ -129,9 +129,11 @@ pub async fn serve_server_with_ct( ) -> Result, E> where S: Service, - T: IntoTransport, + T: IntoTransport + ProvidesAxiumExtensions, E: std::error::Error + From + Send + Sync + 'static, { + let req_extensions = transport.get_extensions().clone(); + let workspace_id = transport.get_workspace_id(); let (sink, stream) = transport.into_transport(); let mut sink = Box::pin(sink); let mut stream = Box::pin(stream); @@ -162,6 +164,8 @@ where meta: request.get_meta().clone(), extensions: request.extensions().clone(), peer: peer.clone(), + req_extensions: req_extensions.clone(), + workspace_id: workspace_id.clone(), }; // Send initialize response let init_response = service.handle_request(request.clone(), context).await; @@ -207,7 +211,16 @@ where }; let _ = service.handle_notification(notification).await; // Continue processing service - serve_inner(service, (sink, stream), peer, peer_rx, ct).await + serve_inner( + service, + (sink, stream), + peer, + peer_rx, + ct, + req_extensions, + workspace_id, + ) + .await } macro_rules! method { diff --git a/crates/rmcp/src/transport/sse_server.rs b/crates/rmcp/src/transport/sse_server.rs index ed04b3d0a..98e2b36b2 100644 --- a/crates/rmcp/src/transport/sse_server.rs +++ b/crates/rmcp/src/transport/sse_server.rs @@ -2,8 +2,8 @@ use std::{collections::HashMap, io, net::SocketAddr, sync::Arc, time::Duration}; use axum::{ Json, Router, - extract::{Query, State}, - http::StatusCode, + extract::{Path, Query, Request, State}, + http::{Extensions, StatusCode}, response::{ Response, sse::{Event, KeepAlive, Sse}, @@ -32,13 +32,17 @@ struct App { txs: TxStore, transport_tx: tokio::sync::mpsc::UnboundedSender, post_path: Arc, + full_message_path: Arc, sse_ping_interval: Duration, + ct: CancellationToken, } impl App { pub fn new( post_path: String, + full_message_path: String, sse_ping_interval: Duration, + ct: CancellationToken, ) -> ( Self, tokio::sync::mpsc::UnboundedReceiver, @@ -49,7 +53,9 @@ impl App { txs: Default::default(), transport_tx, post_path: post_path.into(), + full_message_path: full_message_path.into(), sse_ping_interval, + ct, }, transport_rx, ) @@ -88,6 +94,8 @@ async fn post_event_handler( async fn sse_handler( State(app): State, + Path(workspace_id): Path, + request: Request, ) -> Result>>, Response> { let session = session_id(); tracing::info!(%session, "sse connection"); @@ -107,6 +115,8 @@ async fn sse_handler( sink, session_id: session.clone(), tx_store: app.txs.clone(), + req_extensions: request.extensions().clone(), + workspace_id: workspace_id.clone(), }; let transport_send_result = app.transport_tx.send(transport); if transport_send_result.is_err() { @@ -116,20 +126,65 @@ async fn sse_handler( *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; return Err(response); } - let post_path = app.post_path.as_ref(); let ping_interval = app.sse_ping_interval; - let stream = futures::stream::once(futures::future::ok( - Event::default() - .event("endpoint") - .data(format!("{post_path}?sessionId={session}")), - )) - .chain(ReceiverStream::new(to_client_rx).map(|message| { - match serde_json::to_string(&message) { - Ok(bytes) => Ok(Event::default().event("message").data(&bytes)), - Err(e) => Err(io::Error::new(io::ErrorKind::InvalidData, e)), + let post_path = app.post_path.clone(); + let full_endpoint_path = app + .full_message_path + .replace(":workspace_id", &workspace_id); + let server_ct = app.ct.clone(); + + // Clone variables needed for the cleanup task *before* they are moved by async_stream + let session_for_cleanup = session.clone(); + let server_ct_for_cleanup = server_ct.clone(); + let tx_store_for_cleanup = app.txs.clone(); + + let mut message_stream = ReceiverStream::new(to_client_rx); + let client_stream = async_stream::stream! { + yield Ok(Event::default() + .event("endpoint") + .data(format!("{full_endpoint_path}{post_path}?sessionId={session}"))); + + loop { + tokio::select! { + biased; + _ = server_ct.cancelled() => { + tracing::info!(%session, "SSE connection cancelled via token."); + break; + } + maybe_message = message_stream.next() => { + match maybe_message { + Some(message) => { + match serde_json::to_string(&message) { + Ok(bytes) => yield Ok(Event::default().event("message").data(bytes)), + Err(e) => { + tracing::error!(%session, "Failed to serialize message: {}", e); + yield Err(io::Error::new(io::ErrorKind::InvalidData, e)); + break; + } + } + } + None => { + tracing::info!(%session, "Message channel closed, ending SSE stream."); + break; + } + } + } + } } - })); - Ok(Sse::new(stream).keep_alive(KeepAlive::new().interval(ping_interval))) + tracing::debug!(%session, "SSE client_stream finished."); + }; + + // Clean up the tx entry when the SSE connection handler finishes (either normally or cancelled) + tokio::spawn(async move { + server_ct_for_cleanup.cancelled().await; + tracing::debug!(session=%session_for_cleanup, "Removing session from tx store due to cancellation or handler exit."); + tx_store_for_cleanup + .write() + .await + .remove(&session_for_cleanup); + }); + + Ok(Sse::new(client_stream).keep_alive(KeepAlive::new().interval(ping_interval))) } pub struct SseServerTransport { @@ -137,6 +192,18 @@ pub struct SseServerTransport { sink: PollSender>, session_id: SessionId, tx_store: TxStore, + req_extensions: Extensions, + workspace_id: String, +} + +impl crate::service::ProvidesAxiumExtensions for SseServerTransport { + fn get_extensions(&self) -> &Extensions { + &self.req_extensions + } + + fn get_workspace_id(&self) -> String { + self.workspace_id.clone() + } } impl Sink> for SseServerTransport { @@ -207,6 +274,7 @@ pub struct SseServerConfig { pub post_path: String, pub ct: CancellationToken, pub sse_keep_alive: Option, + pub full_message_path: String, } #[derive(Debug)] @@ -223,6 +291,7 @@ impl SseServer { post_path: "/message".to_string(), ct: CancellationToken::new(), sse_keep_alive: None, + full_message_path: "/message".to_string(), }) .await } @@ -250,7 +319,9 @@ impl SseServer { pub fn new(config: SseServerConfig) -> (SseServer, Router) { let (app, transport_rx) = App::new( config.post_path.clone(), + config.full_message_path.clone(), config.sse_keep_alive.unwrap_or(DEFAULT_AUTO_PING_INTERVAL), + config.ct.clone(), ); let router = Router::new() .route(&config.sse_path, get(sse_handler)) diff --git a/crates/rmcp/tests/common/calculator.rs b/crates/rmcp/tests/common/calculator.rs deleted file mode 100644 index e179f2583..000000000 --- a/crates/rmcp/tests/common/calculator.rs +++ /dev/null @@ -1,44 +0,0 @@ -use rmcp::{ - ServerHandler, - model::{ServerCapabilities, ServerInfo}, - schemars, tool, -}; -#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] -pub struct SumRequest { - #[schemars(description = "the left hand side number")] - pub a: i32, - pub b: i32, -} -#[derive(Debug, Clone, Default)] -pub struct Calculator; -#[tool(tool_box)] -impl Calculator { - #[tool(description = "Calculate the sum of two numbers")] - fn sum(&self, #[tool(aggr)] SumRequest { a, b }: SumRequest) -> String { - (a + b).to_string() - } - - #[tool(description = "Calculate the sub of two numbers")] - fn sub( - &self, - #[tool(param)] - #[schemars(description = "the left hand side number")] - a: i32, - #[tool(param)] - #[schemars(description = "the right hand side number")] - b: i32, - ) -> String { - (a - b).to_string() - } -} - -#[tool(tool_box)] -impl ServerHandler for Calculator { - fn get_info(&self) -> ServerInfo { - ServerInfo { - instructions: Some("A simple calculator".into()), - capabilities: ServerCapabilities::builder().enable_tools().build(), - ..Default::default() - } - } -} diff --git a/crates/rmcp/tests/common/handlers.rs b/crates/rmcp/tests/common/handlers.rs deleted file mode 100644 index d2212b632..000000000 --- a/crates/rmcp/tests/common/handlers.rs +++ /dev/null @@ -1,193 +0,0 @@ -use std::{ - future::Future, - sync::{Arc, Mutex}, -}; - -use rmcp::{ - ClientHandler, Error as McpError, RoleClient, RoleServer, ServerHandler, - model::*, - service::{Peer, RequestContext}, -}; -use serde_json::json; -use tokio::sync::Notify; - -#[derive(Clone)] -pub struct TestClientHandler { - pub peer: Option>, - pub honor_this_server: bool, - pub honor_all_servers: bool, - pub receive_signal: Arc, - pub received_messages: Arc>>, -} - -impl TestClientHandler { - #[allow(dead_code)] - pub fn new(honor_this_server: bool, honor_all_servers: bool) -> Self { - Self { - peer: None, - honor_this_server, - honor_all_servers, - receive_signal: Arc::new(Notify::new()), - received_messages: Arc::new(Mutex::new(Vec::new())), - } - } - - #[allow(dead_code)] - pub fn with_notification( - honor_this_server: bool, - honor_all_servers: bool, - receive_signal: Arc, - received_messages: Arc>>, - ) -> Self { - Self { - peer: None, - honor_this_server, - honor_all_servers, - receive_signal, - received_messages, - } - } -} - -impl ClientHandler for TestClientHandler { - fn get_peer(&self) -> Option> { - self.peer.clone() - } - - fn set_peer(&mut self, peer: Peer) { - self.peer = Some(peer); - } - - async fn create_message( - &self, - params: CreateMessageRequestParam, - _context: RequestContext, - ) -> Result { - // First validate that there's at least one User message - if !params.messages.iter().any(|msg| msg.role == Role::User) { - return Err(McpError::invalid_request( - "Message sequence must contain at least one user message", - Some(json!({"messages": params.messages})), - )); - } - - // Create response based on context inclusion - let response = match params.include_context { - Some(ContextInclusion::ThisServer) if self.honor_this_server => { - "Test response with context: test context" - } - Some(ContextInclusion::AllServers) if self.honor_all_servers => { - "Test response with context: test context" - } - _ => "Test response without context", - }; - - Ok(CreateMessageResult { - message: SamplingMessage { - role: Role::Assistant, - content: Content::text(response.to_string()), - }, - model: "test-model".to_string(), - stop_reason: Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()), - }) - } - - fn on_logging_message( - &self, - params: LoggingMessageNotificationParam, - ) -> impl Future + Send + '_ { - let receive_signal = self.receive_signal.clone(); - let received_messages = self.received_messages.clone(); - - async move { - println!("Client: Received log message: {:?}", params); - let mut messages = received_messages.lock().unwrap(); - messages.push(params); - receive_signal.notify_one(); - } - } -} - -pub struct TestServer {} - -impl TestServer { - #[allow(dead_code)] - pub fn new() -> Self { - Self {} - } -} - -impl ServerHandler for TestServer { - fn get_info(&self) -> ServerInfo { - ServerInfo { - capabilities: ServerCapabilities::builder().enable_logging().build(), - ..Default::default() - } - } - - fn set_level( - &self, - request: SetLevelRequestParam, - context: RequestContext, - ) -> impl Future> + Send + '_ { - let peer = context.peer; - async move { - let (data, logger) = match request.level { - LoggingLevel::Error => ( - serde_json::json!({ - "message": "Failed to process request", - "error_code": "E1001", - "error_details": "Connection timeout", - "timestamp": chrono::Utc::now().to_rfc3339(), - }), - Some("error_handler".to_string()), - ), - LoggingLevel::Debug => ( - serde_json::json!({ - "message": "Processing request", - "function": "handle_request", - "line": 42, - "context": { - "request_id": "req-123", - "user_id": "user-456" - }, - "timestamp": chrono::Utc::now().to_rfc3339(), - }), - Some("debug_logger".to_string()), - ), - LoggingLevel::Info => ( - serde_json::json!({ - "message": "System status update", - "status": "healthy", - "metrics": { - "requests_per_second": 150, - "average_latency_ms": 45, - "error_rate": 0.01 - }, - "timestamp": chrono::Utc::now().to_rfc3339(), - }), - Some("monitoring".to_string()), - ), - _ => ( - serde_json::json!({ - "message": format!("Message at level {:?}", request.level), - "timestamp": chrono::Utc::now().to_rfc3339(), - }), - None, - ), - }; - - if let Err(e) = peer - .notify_logging_message(LoggingMessageNotificationParam { - level: request.level, - data, - logger, - }) - .await - { - panic!("Failed to send notification: {}", e); - } - Ok(()) - } - } -} diff --git a/crates/rmcp/tests/common/mod.rs b/crates/rmcp/tests/common/mod.rs deleted file mode 100644 index 491960651..000000000 --- a/crates/rmcp/tests/common/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod calculator; -pub mod handlers; diff --git a/crates/rmcp/tests/test_complex_schema.rs b/crates/rmcp/tests/test_complex_schema.rs deleted file mode 100644 index b9370fcec..000000000 --- a/crates/rmcp/tests/test_complex_schema.rs +++ /dev/null @@ -1,63 +0,0 @@ -use rmcp::{Error as McpError, model::*, schemars, tool}; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)] -pub enum ChatRole { - System, - User, - Assistant, - Tool, -} - -#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)] -pub struct ChatMessage { - pub role: ChatRole, - pub content: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)] -pub struct ChatRequest { - pub system: Option, - pub messages: Vec, -} - -#[derive(Clone, Default)] -pub struct Demo; - -#[tool(tool_box)] -impl Demo { - pub fn new() -> Self { - Self - } - - #[tool(description = "LLM")] - async fn chat( - &self, - #[tool(aggr)] chat_request: ChatRequest, - ) -> Result { - let content = Content::json(chat_request)?; - Ok(CallToolResult::success(vec![content])) - } -} - -#[test] -fn test_complex_schema() { - let attr = Demo::chat_tool_attr(); - let input_schema = attr.input_schema; - let enum_number = input_schema - .get("definitions") - .unwrap() - .as_object() - .unwrap() - .get("ChatRole") - .unwrap() - .as_object() - .unwrap() - .get("enum") - .unwrap() - .as_array() - .unwrap() - .len(); - assert_eq!(enum_number, 4); - println!("{}", serde_json::to_string_pretty(&input_schema).unwrap()); -} diff --git a/crates/rmcp/tests/test_deserialization.rs b/crates/rmcp/tests/test_deserialization.rs deleted file mode 100644 index 73621f487..000000000 --- a/crates/rmcp/tests/test_deserialization.rs +++ /dev/null @@ -1,15 +0,0 @@ -use rmcp::model::{JsonRpcResponse, ServerJsonRpcMessage, ServerResult}; -#[test] -fn test_tool_list_result() { - let json = std::fs::read("tests/test_deserialization/tool_list_result.json").unwrap(); - let result: ServerJsonRpcMessage = serde_json::from_slice(&json).unwrap(); - println!("{result:#?}"); - - assert!(matches!( - result, - ServerJsonRpcMessage::Response(JsonRpcResponse { - result: ServerResult::ListToolsResult(_), - .. - }) - )); -} diff --git a/crates/rmcp/tests/test_deserialization/tool_list_result.json b/crates/rmcp/tests/test_deserialization/tool_list_result.json deleted file mode 100644 index 674fdc058..000000000 --- a/crates/rmcp/tests/test_deserialization/tool_list_result.json +++ /dev/null @@ -1,28 +0,0 @@ -{ - "result": { - "tools": [ - { - "name": "add", - "inputSchema": { - "type": "object", - "properties": { - "a": { - "type": "number" - }, - "b": { - "type": "number" - } - }, - "required": [ - "a", - "b" - ], - "additionalProperties": false, - "$schema": "http://json-schema.org/draft-07/schema#" - } - } - ] - }, - "jsonrpc": "2.0", - "id": 2 -} \ No newline at end of file diff --git a/crates/rmcp/tests/test_logging.rs b/crates/rmcp/tests/test_logging.rs deleted file mode 100644 index eb63773fc..000000000 --- a/crates/rmcp/tests/test_logging.rs +++ /dev/null @@ -1,351 +0,0 @@ -// cargo test --features "server client" --package rmcp test_logging -mod common; - -use std::sync::{Arc, Mutex}; - -use common::handlers::{TestClientHandler, TestServer}; -use rmcp::{ - ServiceExt, - model::{LoggingLevel, LoggingMessageNotificationParam, SetLevelRequestParam}, -}; -use serde_json::json; -use tokio::sync::Notify; - -#[tokio::test] -async fn test_logging_spec_compliance() -> anyhow::Result<()> { - let (server_transport, client_transport) = tokio::io::duplex(4096); - let receive_signal = Arc::new(Notify::new()); - let received_messages = Arc::new(Mutex::new(Vec::::new())); - - // Start server in a separate task - let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; - - // Test server can send messages before level is set - server - .peer() - .notify_logging_message(LoggingMessageNotificationParam { - level: LoggingLevel::Info, - data: serde_json::json!({ - "message": "Server initiated message", - "timestamp": chrono::Utc::now().to_rfc3339(), - }), - logger: Some("test_server".to_string()), - }) - .await?; - - server.waiting().await?; - anyhow::Ok(()) - }); - - let client = TestClientHandler::with_notification( - true, - true, - receive_signal.clone(), - received_messages.clone(), - ) - .serve(client_transport) - .await?; - - // Wait for the initial server message - receive_signal.notified().await; - { - let mut messages = received_messages.lock().unwrap(); - assert_eq!(messages.len(), 1, "Should receive server-initiated message"); - messages.clear(); - } - - // Test level filtering and message format - for level in [ - LoggingLevel::Emergency, - LoggingLevel::Warning, - LoggingLevel::Debug, - ] { - client - .peer() - .set_level(SetLevelRequestParam { level }) - .await?; - - // Wait for each message response - receive_signal.notified().await; - - let mut messages = received_messages.lock().unwrap(); - let msg = messages.last().unwrap(); - - // Verify required fields - assert_eq!(msg.level, level); - assert!(msg.data.is_object()); - - // Verify data format - let data = msg.data.as_object().unwrap(); - assert!(data.contains_key("message")); - assert!(data.contains_key("timestamp")); - - // Verify timestamp - let timestamp = data["timestamp"].as_str().unwrap(); - chrono::DateTime::parse_from_rfc3339(timestamp).expect("RFC3339 timestamp"); - - messages.clear(); - } - - // Important: Cancel the client before ending the test - client.cancel().await?; - - // Wait for server to complete - server_handle.await??; - - Ok(()) -} - -#[tokio::test] -async fn test_logging_user_scenarios() -> anyhow::Result<()> { - let (server_transport, client_transport) = tokio::io::duplex(4096); - let receive_signal = Arc::new(Notify::new()); - let received_messages = Arc::new(Mutex::new(Vec::::new())); - - let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; - server.waiting().await?; - anyhow::Ok(()) - }); - - let client = TestClientHandler::with_notification( - true, - true, - receive_signal.clone(), - received_messages.clone(), - ) - .serve(client_transport) - .await?; - - // Test 1: Error reporting scenario - client - .peer() - .set_level(SetLevelRequestParam { - level: LoggingLevel::Error, - }) - .await?; - receive_signal.notified().await; // Wait for response - { - let messages = received_messages.lock().unwrap(); - let msg = &messages[0]; - let data = msg.data.as_object().unwrap(); - assert!( - data.contains_key("error_code"), - "Error should have an error code" - ); - assert!( - data.contains_key("error_details"), - "Error should have details" - ); - assert!( - data.contains_key("timestamp"), - "Should know when error occurred" - ); - } - - // Test 2: Debug scenario - client - .peer() - .set_level(SetLevelRequestParam { - level: LoggingLevel::Debug, - }) - .await?; - receive_signal.notified().await; // Wait for response - { - let messages = received_messages.lock().unwrap(); - let msg = messages.last().unwrap(); - let data = msg.data.as_object().unwrap(); - assert!( - data.contains_key("function"), - "Debug should show function name" - ); - assert!(data.contains_key("line"), "Debug should show line number"); - assert!( - data.contains_key("context"), - "Debug should show execution context" - ); - } - - // Test 3: Production monitoring scenario - client - .peer() - .set_level(SetLevelRequestParam { - level: LoggingLevel::Info, - }) - .await?; - receive_signal.notified().await; // Wait for response - { - let messages = received_messages.lock().unwrap(); - let msg = messages.last().unwrap(); - let data = msg.data.as_object().unwrap(); - assert!(data.contains_key("status"), "Should show system status"); - assert!(data.contains_key("metrics"), "Should include metrics"); - } - - // Important: Cancel client and wait for server before ending - client.cancel().await?; - server_handle.await??; - - Ok(()) -} - -#[test] -fn test_logging_level_serialization() { - // Test all levels match spec exactly - let test_cases = [ - (LoggingLevel::Alert, "alert"), - (LoggingLevel::Critical, "critical"), - (LoggingLevel::Debug, "debug"), - (LoggingLevel::Emergency, "emergency"), - (LoggingLevel::Error, "error"), - (LoggingLevel::Info, "info"), - (LoggingLevel::Notice, "notice"), - (LoggingLevel::Warning, "warning"), - ]; - - for (level, expected) in test_cases { - let serialized = serde_json::to_string(&level).unwrap(); - // Remove quotes from serialized string - let serialized = serialized.trim_matches('"'); - assert_eq!( - serialized, expected, - "LoggingLevel::{:?} should serialize to \"{}\"", - level, expected - ); - } - - // Test deserialization from spec strings - for (level, spec_string) in test_cases { - let deserialized: LoggingLevel = - serde_json::from_str(&format!("\"{}\"", spec_string)).unwrap(); - assert_eq!( - deserialized, level, - "\"{}\" should deserialize to LoggingLevel::{:?}", - spec_string, level - ); - } -} - -#[tokio::test] -async fn test_logging_edge_cases() -> anyhow::Result<()> { - let (server_transport, client_transport) = tokio::io::duplex(4096); - let receive_signal = Arc::new(Notify::new()); - let received_messages = Arc::new(Mutex::new(Vec::::new())); - - let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; - server.waiting().await?; - anyhow::Ok(()) - }); - - let client = TestClientHandler::with_notification( - true, - true, - receive_signal.clone(), - received_messages.clone(), - ) - .serve(client_transport) - .await?; - - // Test all logging levels from spec - for level in [ - LoggingLevel::Alert, - LoggingLevel::Critical, - LoggingLevel::Notice, // These weren't tested before - ] { - client - .peer() - .set_level(SetLevelRequestParam { level }) - .await?; - receive_signal.notified().await; - - let messages = received_messages.lock().unwrap(); - let msg = messages.last().unwrap(); - assert_eq!(msg.level, level); - } - - client.cancel().await?; - server_handle.await??; - Ok(()) -} - -#[tokio::test] -async fn test_logging_optional_fields() -> anyhow::Result<()> { - let (server_transport, client_transport) = tokio::io::duplex(4096); - let receive_signal = Arc::new(Notify::new()); - let received_messages = Arc::new(Mutex::new(Vec::::new())); - - let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; - - // Test message with and without optional logger field - for (level, has_logger) in [(LoggingLevel::Info, true), (LoggingLevel::Debug, false)] { - server - .peer() - .notify_logging_message(LoggingMessageNotificationParam { - level, - data: json!({"test": "data"}), - logger: has_logger.then(|| "test_logger".to_string()), - }) - .await?; - } - - server.waiting().await?; - anyhow::Ok(()) - }); - - let client = TestClientHandler::with_notification( - true, - true, - receive_signal.clone(), - received_messages.clone(), - ) - .serve(client_transport) - .await?; - - // Wait for the initial server message - receive_signal.notified().await; - { - let mut messages = received_messages.lock().unwrap(); - assert_eq!(messages.len(), 2, "Should receive two messages"); - messages.clear(); - } - - // Test level filtering and message format - for level in [LoggingLevel::Info, LoggingLevel::Debug] { - client - .peer() - .set_level(SetLevelRequestParam { level }) - .await?; - - // Wait for each message response - receive_signal.notified().await; - - let mut messages = received_messages.lock().unwrap(); - let msg = messages.last().unwrap(); - - // Verify required fields - assert_eq!(msg.level, level); - assert!(msg.data.is_object()); - - // Verify data format - let data = msg.data.as_object().unwrap(); - assert!(data.contains_key("message")); - assert!(data.contains_key("timestamp")); - - // Verify timestamp - let timestamp = data["timestamp"].as_str().unwrap(); - chrono::DateTime::parse_from_rfc3339(timestamp).expect("RFC3339 timestamp"); - - messages.clear(); - } - - // Important: Cancel the client before ending the test - client.cancel().await?; - - // Wait for server to complete - server_handle.await??; - - Ok(()) -} diff --git a/crates/rmcp/tests/test_message_protocol.rs b/crates/rmcp/tests/test_message_protocol.rs deleted file mode 100644 index 602f93dab..000000000 --- a/crates/rmcp/tests/test_message_protocol.rs +++ /dev/null @@ -1,559 +0,0 @@ -//cargo test --test test_message_protocol --features "client server" - -mod common; -use common::handlers::{TestClientHandler, TestServer}; -use rmcp::{ - ServiceExt, - model::*, - service::{RequestContext, Service}, -}; -use tokio_util::sync::CancellationToken; - -// Tests start here -#[tokio::test] -async fn test_message_roles() { - let messages = vec![ - SamplingMessage { - role: Role::User, - content: Content::text("user message"), - }, - SamplingMessage { - role: Role::Assistant, - content: Content::text("assistant message"), - }, - ]; - - // Verify all roles can be serialized/deserialized correctly - let json = serde_json::to_string(&messages).unwrap(); - let deserialized: Vec = serde_json::from_str(&json).unwrap(); - assert_eq!(messages, deserialized); -} - -#[tokio::test] -async fn test_context_inclusion_integration() -> anyhow::Result<()> { - let (server_transport, client_transport) = tokio::io::duplex(4096); - - // Start server - let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; - server.waiting().await?; - anyhow::Ok(()) - }); - - // Start client that honors context requests - let handler = TestClientHandler::new(true, true); - let client = handler.clone().serve(client_transport).await?; - - // Test ThisServer context inclusion - let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { - method: Default::default(), - params: CreateMessageRequestParam { - messages: vec![SamplingMessage { - role: Role::User, - content: Content::text("test message"), - }], - include_context: Some(ContextInclusion::ThisServer), - model_preferences: None, - system_prompt: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - }, - extensions: Default::default(), - }); - - let result = handler - .handle_request( - request.clone(), - RequestContext { - peer: client.peer().clone(), - ct: CancellationToken::new(), - id: NumberOrString::Number(1), - meta: Default::default(), - extensions: Default::default(), - }, - ) - .await?; - - if let ClientResult::CreateMessageResult(result) = result { - let text = result.message.content.as_text().unwrap().text.as_str(); - assert!( - text.contains("test context"), - "Response should include context for ThisServer" - ); - } else { - panic!("Expected CreateMessageResult"); - } - - // Test AllServers context inclusion - let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { - method: Default::default(), - params: CreateMessageRequestParam { - messages: vec![SamplingMessage { - role: Role::User, - content: Content::text("test message"), - }], - include_context: Some(ContextInclusion::AllServers), - model_preferences: None, - system_prompt: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - }, - extensions: Default::default(), - }); - - let result = handler - .handle_request( - request.clone(), - RequestContext { - peer: client.peer().clone(), - ct: CancellationToken::new(), - id: NumberOrString::Number(2), - meta: Default::default(), - extensions: Default::default(), - }, - ) - .await?; - - if let ClientResult::CreateMessageResult(result) = result { - let text = result.message.content.as_text().unwrap().text.as_str(); - assert!( - text.contains("test context"), - "Response should include context for AllServers" - ); - } else { - panic!("Expected CreateMessageResult"); - } - - // Test No context inclusion - let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { - method: Default::default(), - params: CreateMessageRequestParam { - messages: vec![SamplingMessage { - role: Role::User, - content: Content::text("test message"), - }], - include_context: Some(ContextInclusion::None), - model_preferences: None, - system_prompt: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - }, - extensions: Default::default(), - }); - - let result = handler - .handle_request( - request.clone(), - RequestContext { - peer: client.peer().clone(), - ct: CancellationToken::new(), - id: NumberOrString::Number(3), - meta: Default::default(), - extensions: Default::default(), - }, - ) - .await?; - - if let ClientResult::CreateMessageResult(result) = result { - let text = result.message.content.as_text().unwrap().text.as_str(); - assert!( - !text.contains("test context"), - "Response should not include context for None" - ); - } else { - panic!("Expected CreateMessageResult"); - } - - client.cancel().await?; - server_handle.await??; - Ok(()) -} - -#[tokio::test] -async fn test_context_inclusion_ignored_integration() -> anyhow::Result<()> { - let (server_transport, client_transport) = tokio::io::duplex(4096); - - // Start server - let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; - server.waiting().await?; - anyhow::Ok(()) - }); - - // Start client that ignores context requests - let handler = TestClientHandler::new(false, false); - let client = handler.clone().serve(client_transport).await?; - - // Test that context requests are ignored - let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { - method: Default::default(), - params: CreateMessageRequestParam { - messages: vec![SamplingMessage { - role: Role::User, - content: Content::text("test message"), - }], - include_context: Some(ContextInclusion::ThisServer), - model_preferences: None, - system_prompt: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - }, - extensions: Default::default(), - }); - - let result = handler - .handle_request( - request.clone(), - RequestContext { - peer: client.peer().clone(), - ct: CancellationToken::new(), - id: NumberOrString::Number(1), - meta: Meta::default(), - extensions: Default::default(), - }, - ) - .await?; - - if let ClientResult::CreateMessageResult(result) = result { - let text = result.message.content.as_text().unwrap().text.as_str(); - assert!( - !text.contains("test context"), - "Context should be ignored when client chooses not to honor requests" - ); - } else { - panic!("Expected CreateMessageResult"); - } - - client.cancel().await?; - server_handle.await??; - Ok(()) -} - -#[tokio::test] -async fn test_message_sequence_integration() -> anyhow::Result<()> { - let (server_transport, client_transport) = tokio::io::duplex(4096); - - // Start server - let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; - server.waiting().await?; - anyhow::Ok(()) - }); - - // Start client - let handler = TestClientHandler::new(true, true); - let client = handler.clone().serve(client_transport).await?; - - let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { - method: Default::default(), - params: CreateMessageRequestParam { - messages: vec![ - SamplingMessage { - role: Role::User, - content: Content::text("first message"), - }, - SamplingMessage { - role: Role::Assistant, - content: Content::text("second message"), - }, - ], - include_context: Some(ContextInclusion::ThisServer), - model_preferences: None, - system_prompt: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - }, - extensions: Default::default(), - }); - - let result = handler - .handle_request( - request.clone(), - RequestContext { - peer: client.peer().clone(), - ct: CancellationToken::new(), - id: NumberOrString::Number(1), - meta: Meta::default(), - extensions: Default::default(), - }, - ) - .await?; - - if let ClientResult::CreateMessageResult(result) = result { - let text = result.message.content.as_text().unwrap().text.as_str(); - assert!( - text.contains("test context"), - "Response should include context when ThisServer is specified" - ); - assert_eq!(result.model, "test-model"); - assert_eq!( - result.stop_reason, - Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()) - ); - } else { - panic!("Expected CreateMessageResult"); - } - - client.cancel().await?; - server_handle.await??; - Ok(()) -} - -#[tokio::test] -async fn test_message_sequence_validation_integration() -> anyhow::Result<()> { - let (server_transport, client_transport) = tokio::io::duplex(4096); - - let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; - server.waiting().await?; - anyhow::Ok(()) - }); - - let handler = TestClientHandler::new(true, true); - let client = handler.clone().serve(client_transport).await?; - - // Test valid sequence: User -> Assistant -> User - let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { - method: Default::default(), - params: CreateMessageRequestParam { - messages: vec![ - SamplingMessage { - role: Role::User, - content: Content::text("first user message"), - }, - SamplingMessage { - role: Role::Assistant, - content: Content::text("first assistant response"), - }, - SamplingMessage { - role: Role::User, - content: Content::text("second user message"), - }, - ], - include_context: None, - model_preferences: None, - system_prompt: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - }, - extensions: Default::default(), - }); - - let result = handler - .handle_request( - request.clone(), - RequestContext { - peer: client.peer().clone(), - ct: CancellationToken::new(), - id: NumberOrString::Number(1), - meta: Meta::default(), - extensions: Default::default(), - }, - ) - .await?; - - assert!(matches!(result, ClientResult::CreateMessageResult(_))); - - // Test invalid: No user message - let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { - method: Default::default(), - params: CreateMessageRequestParam { - messages: vec![SamplingMessage { - role: Role::Assistant, - content: Content::text("assistant message"), - }], - include_context: None, - model_preferences: None, - system_prompt: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - }, - extensions: Default::default(), - }); - - let result = handler - .handle_request( - request.clone(), - RequestContext { - peer: client.peer().clone(), - ct: CancellationToken::new(), - id: NumberOrString::Number(2), - meta: Meta::default(), - extensions: Default::default(), - }, - ) - .await; - - assert!(result.is_err()); - - client.cancel().await?; - server_handle.await??; - Ok(()) -} - -#[tokio::test] -async fn test_selective_context_handling_integration() -> anyhow::Result<()> { - let (server_transport, client_transport) = tokio::io::duplex(4096); - - let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; - server.waiting().await?; - anyhow::Ok(()) - }); - - // Client that only honors ThisServer but ignores AllServers - let handler = TestClientHandler::new(true, false); - let client = handler.clone().serve(client_transport).await?; - - // Test ThisServer is honored - let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { - method: Default::default(), - params: CreateMessageRequestParam { - messages: vec![SamplingMessage { - role: Role::User, - content: Content::text("test message"), - }], - include_context: Some(ContextInclusion::ThisServer), - model_preferences: None, - system_prompt: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - }, - extensions: Default::default(), - }); - - let result = handler - .handle_request( - request.clone(), - RequestContext { - peer: client.peer().clone(), - ct: CancellationToken::new(), - id: NumberOrString::Number(1), - meta: Meta::default(), - extensions: Default::default(), - }, - ) - .await?; - - if let ClientResult::CreateMessageResult(result) = result { - let text = result.message.content.as_text().unwrap().text.as_str(); - assert!( - text.contains("test context"), - "ThisServer context request should be honored" - ); - } - - // Test AllServers is ignored - let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { - method: Default::default(), - params: CreateMessageRequestParam { - messages: vec![SamplingMessage { - role: Role::User, - content: Content::text("test message"), - }], - include_context: Some(ContextInclusion::AllServers), - model_preferences: None, - system_prompt: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - }, - extensions: Default::default(), - }); - - let result = handler - .handle_request( - request.clone(), - RequestContext { - peer: client.peer().clone(), - ct: CancellationToken::new(), - id: NumberOrString::Number(2), - meta: Meta::default(), - extensions: Default::default(), - }, - ) - .await?; - - if let ClientResult::CreateMessageResult(result) = result { - let text = result.message.content.as_text().unwrap().text.as_str(); - assert!( - !text.contains("test context"), - "AllServers context request should be ignored" - ); - } - - client.cancel().await?; - server_handle.await??; - Ok(()) -} - -#[tokio::test] -async fn test_context_inclusion() -> anyhow::Result<()> { - let (server_transport, client_transport) = tokio::io::duplex(4096); - let server_handle = tokio::spawn(async move { - let server = TestServer::new().serve(server_transport).await?; - server.waiting().await?; - anyhow::Ok(()) - }); - - let handler = TestClientHandler::new(true, true); - let client = handler.clone().serve(client_transport).await?; - - // Test context handling - let request = ServerRequest::CreateMessageRequest(CreateMessageRequest { - method: Default::default(), - params: CreateMessageRequestParam { - messages: vec![SamplingMessage { - role: Role::User, - content: Content::text("test"), - }], - include_context: Some(ContextInclusion::ThisServer), - model_preferences: None, - system_prompt: None, - temperature: None, - max_tokens: 100, - stop_sequences: None, - metadata: None, - }, - extensions: Default::default(), - }); - - let result = handler - .handle_request( - request.clone(), - RequestContext { - peer: client.peer().clone(), - ct: CancellationToken::new(), - id: NumberOrString::Number(1), - meta: Meta::default(), - extensions: Default::default(), - }, - ) - .await?; - - if let ClientResult::CreateMessageResult(result) = result { - let text = result.message.content.as_text().unwrap().text.as_str(); - assert!(text.contains("test context")); - } - - client.cancel().await?; - server_handle.await??; - Ok(()) -} diff --git a/crates/rmcp/tests/test_notification.rs b/crates/rmcp/tests/test_notification.rs deleted file mode 100644 index 4d4c0f6e3..000000000 --- a/crates/rmcp/tests/test_notification.rs +++ /dev/null @@ -1,101 +0,0 @@ -use std::sync::Arc; - -use rmcp::{ - ClientHandler, Peer, RoleClient, ServerHandler, ServiceExt, - model::{ - ResourceUpdatedNotificationParam, ServerCapabilities, ServerInfo, SubscribeRequestParam, - }, -}; -use tokio::sync::Notify; -use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; - -pub struct Server {} - -impl ServerHandler for Server { - fn get_info(&self) -> ServerInfo { - ServerInfo { - capabilities: ServerCapabilities::builder() - .enable_resources() - .enable_resources_subscribe() - .enable_resources_list_changed() - .build(), - ..Default::default() - } - } - - async fn subscribe( - &self, - request: rmcp::model::SubscribeRequestParam, - context: rmcp::service::RequestContext, - ) -> Result<(), rmcp::Error> { - let uri = request.uri; - let peer = context.peer; - - tokio::spawn(async move { - let span = tracing::info_span!("subscribe", uri = %uri); - let _enter = span.enter(); - - if let Err(e) = peer - .notify_resource_updated(ResourceUpdatedNotificationParam { uri: uri.clone() }) - .await - { - panic!("Failed to send notification: {}", e); - } - }); - - Ok(()) - } -} - -pub struct Client { - receive_signal: Arc, - peer: Option>, -} - -impl ClientHandler for Client { - async fn on_resource_updated(&self, params: rmcp::model::ResourceUpdatedNotificationParam) { - let uri = params.uri; - tracing::info!("Resource updated: {}", uri); - self.receive_signal.notify_one(); - } - - fn set_peer(&mut self, peer: Peer) { - self.peer.replace(peer); - } - - fn get_peer(&self) -> Option> { - self.peer.clone() - } -} - -#[tokio::test] -async fn test_server_notification() -> anyhow::Result<()> { - let _ = tracing_subscriber::registry() - .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "debug".to_string().into()), - ) - .with(tracing_subscriber::fmt::layer()) - .try_init(); - let (server_transport, client_transport) = tokio::io::duplex(4096); - tokio::spawn(async move { - let server = Server {}.serve(server_transport).await?; - server.waiting().await?; - anyhow::Ok(()) - }); - let receive_signal = Arc::new(Notify::new()); - let client = Client { - peer: Default::default(), - receive_signal: receive_signal.clone(), - } - .serve(client_transport) - .await?; - client - .subscribe(SubscribeRequestParam { - uri: "test://test-resource".to_owned(), - }) - .await?; - receive_signal.notified().await; - client.cancel().await?; - Ok(()) -} diff --git a/crates/rmcp/tests/test_tool_handler.rs b/crates/rmcp/tests/test_tool_handler.rs deleted file mode 100644 index 8b1378917..000000000 --- a/crates/rmcp/tests/test_tool_handler.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/crates/rmcp/tests/test_tool_macros.rs b/crates/rmcp/tests/test_tool_macros.rs deleted file mode 100644 index daa5ee3d4..000000000 --- a/crates/rmcp/tests/test_tool_macros.rs +++ /dev/null @@ -1,102 +0,0 @@ -use std::sync::Arc; - -use rmcp::{ServerHandler, handler::server::tool::ToolCallContext, tool}; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; - -#[derive(Serialize, Deserialize, JsonSchema)] -pub struct GetWeatherRequest { - pub city: String, - pub date: String, -} - -impl ServerHandler for Server { - async fn call_tool( - &self, - request: rmcp::model::CallToolRequestParam, - context: rmcp::service::RequestContext, - ) -> Result { - let tcc = ToolCallContext::new(self, request, context); - match tcc.name() { - "get-weather" => Self::get_weather_tool_call(tcc).await, - _ => Err(rmcp::Error::invalid_params("method not found", None)), - } - } -} - -#[derive(Debug, Clone, Default)] -pub struct Server {} - -impl Server { - /// This tool is used to get the weather of a city. - #[tool(name = "get-weather", description = "Get the weather of a city.", vis = )] - pub async fn get_weather(&self, #[tool(param)] city: String) -> String { - drop(city); - "rain".to_string() - } - #[tool(description = "Empty Parameter")] - async fn empty_param(&self) {} -} - -// define generic service trait -pub trait DataService: Send + Sync + 'static { - fn get_data(&self) -> String; -} - -// mock service for test -#[derive(Clone)] -struct MockDataService; -impl DataService for MockDataService { - fn get_data(&self) -> String { - "mock data".to_string() - } -} - -// define generic server -#[derive(Debug, Clone)] -pub struct GenericServer { - data_service: Arc, -} - -#[tool(tool_box)] -impl GenericServer { - pub fn new(data_service: DS) -> Self { - Self { - data_service: Arc::new(data_service), - } - } - - #[tool(description = "Get data from the service")] - async fn get_data(&self) -> String { - self.data_service.get_data() - } -} - -#[tokio::test] -async fn test_tool_macros() { - let server = Server::default(); - let _attr = Server::get_weather_tool_attr(); - let _get_weather_call_fn = Server::get_weather_tool_call; - let _get_weather_fn = Server::get_weather; - server.get_weather("harbin".into()).await; -} - -#[tokio::test] -async fn test_tool_macros_with_empty_param() { - let _attr = Server::empty_param_tool_attr(); - println!("{_attr:?}"); - assert_eq!(_attr.input_schema.get("type").unwrap(), "object"); - assert!(_attr.input_schema.get("properties").is_none()); -} - -#[tokio::test] -async fn test_tool_macros_with_generics() { - let mock_service = MockDataService; - let server = GenericServer::new(mock_service); - let _attr = GenericServer::::get_data_tool_attr(); - let _get_data_call_fn = GenericServer::::get_data_tool_call; - let _get_data_fn = GenericServer::::get_data; - assert_eq!(server.get_data().await, "mock data"); -} - -impl GetWeatherRequest {} diff --git a/crates/rmcp/tests/test_with_js.rs b/crates/rmcp/tests/test_with_js.rs deleted file mode 100644 index f5431beef..000000000 --- a/crates/rmcp/tests/test_with_js.rs +++ /dev/null @@ -1,68 +0,0 @@ -use rmcp::{ - ServiceExt, - transport::{SseServer, TokioChildProcess}, -}; -use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; -mod common; -use common::calculator::Calculator; - -const BIND_ADDRESS: &str = "127.0.0.1:8000"; - -#[tokio::test] -async fn test_with_js_client() -> anyhow::Result<()> { - let _ = tracing_subscriber::registry() - .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "debug".to_string().into()), - ) - .with(tracing_subscriber::fmt::layer()) - .try_init(); - tokio::process::Command::new("npm") - .arg("install") - .current_dir("tests/test_with_js") - .spawn()? - .wait() - .await?; - - let ct = SseServer::serve(BIND_ADDRESS.parse()?) - .await? - .with_service(Calculator::default); - - let exit_status = tokio::process::Command::new("node") - .arg("tests/test_with_js/client.js") - .spawn()? - .wait() - .await?; - assert!(exit_status.success()); - ct.cancel(); - Ok(()) -} - -#[tokio::test] -async fn test_with_js_server() -> anyhow::Result<()> { - let _ = tracing_subscriber::registry() - .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "debug".to_string().into()), - ) - .with(tracing_subscriber::fmt::layer()) - .try_init(); - tokio::process::Command::new("npm") - .arg("install") - .current_dir("tests/test_with_js") - .spawn()? - .wait() - .await?; - let transport = TokioChildProcess::new( - tokio::process::Command::new("node").arg("tests/test_with_js/server.js"), - )?; - - let client = ().serve(transport).await?; - let resources = client.list_all_resources().await?; - tracing::info!("{:#?}", resources); - let tools = client.list_all_tools().await?; - tracing::info!("{:#?}", tools); - - client.cancel().await?; - Ok(()) -} diff --git a/crates/rmcp/tests/test_with_js/.gitignore b/crates/rmcp/tests/test_with_js/.gitignore deleted file mode 100644 index 572406bfd..000000000 --- a/crates/rmcp/tests/test_with_js/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -/node_modules -package-lock.json \ No newline at end of file diff --git a/crates/rmcp/tests/test_with_js/client.js b/crates/rmcp/tests/test_with_js/client.js deleted file mode 100644 index 54b7cad78..000000000 --- a/crates/rmcp/tests/test_with_js/client.js +++ /dev/null @@ -1,28 +0,0 @@ -import { Client } from "@modelcontextprotocol/sdk/client/index.js"; -import { SSEClientTransport } from "@modelcontextprotocol/sdk/client/sse.js"; - -const transport = new SSEClientTransport( new URL(`http://127.0.0.1:8000/sse`)); - -const client = new Client( - { - name: "example-client", - version: "1.0.0" - }, - { - capabilities: { - prompts: {}, - resources: {}, - tools: {} - } - } -); -await client.connect(transport); -const tools = await client.listTools(); -console.log(tools); -const resources = await client.listResources(); -console.log(resources); -const templates = await client.listResourceTemplates(); -console.log(templates); -const prompts = await client.listPrompts(); -console.log(prompts); -await client.close(); diff --git a/crates/rmcp/tests/test_with_js/package.json b/crates/rmcp/tests/test_with_js/package.json deleted file mode 100644 index 6dee815cb..000000000 --- a/crates/rmcp/tests/test_with_js/package.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "dependencies": { - "@modelcontextprotocol/sdk": "^1.7.0" - }, - "type": "module", - "name": "test_with_ts", - "version": "1.0.0", - "main": "index.js", - "scripts": { - "test": "echo \"Error: no test specified\" && exit 1" - }, - "author": "", - "license": "ISC", - "description": "" -} diff --git a/crates/rmcp/tests/test_with_js/server.js b/crates/rmcp/tests/test_with_js/server.js deleted file mode 100644 index c128340f6..000000000 --- a/crates/rmcp/tests/test_with_js/server.js +++ /dev/null @@ -1,35 +0,0 @@ -import { McpServer, ResourceTemplate } from "@modelcontextprotocol/sdk/server/mcp.js"; -import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; -import { z } from "zod"; - -const server = new McpServer({ - name: "Demo", - version: "1.0.0" -}); - -server.resource( - "greeting", - new ResourceTemplate("greeting://{name}", { list: undefined }), - async (uri, { name }) => ({ - contents: [{ - uri: uri.href, - text: `Hello, ${name}` - }] - }) -); - -server.tool( - "add", - { a: z.number(), b: z.number() }, - async ({ a, b }) => ({ - "content": [ - { - "type": "text", - "text": `${a + b}` - } - ] - }) -); - -const transport = new StdioServerTransport(); -await server.connect(transport); diff --git a/crates/rmcp/tests/test_with_python.rs b/crates/rmcp/tests/test_with_python.rs deleted file mode 100644 index 7e0fbd713..000000000 --- a/crates/rmcp/tests/test_with_python.rs +++ /dev/null @@ -1,70 +0,0 @@ -use rmcp::{ - ServiceExt, - transport::{SseServer, TokioChildProcess}, -}; -use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; -mod common; -use common::calculator::Calculator; - -const BIND_ADDRESS: &str = "127.0.0.1:8000"; - -#[tokio::test] -async fn test_with_python_client() -> anyhow::Result<()> { - let _ = tracing_subscriber::registry() - .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "debug".to_string().into()), - ) - .with(tracing_subscriber::fmt::layer()) - .try_init(); - tokio::process::Command::new("uv") - .args(["pip", "install", "-r", "pyproject.toml"]) - .current_dir("tests/test_with_python") - .spawn()? - .wait() - .await?; - - let ct = SseServer::serve(BIND_ADDRESS.parse()?) - .await? - .with_service(Calculator::default); - - let status = tokio::process::Command::new("uv") - .arg("run") - .arg("tests/test_with_python/client.py") - .spawn()? - .wait() - .await?; - assert!(status.success()); - ct.cancel(); - Ok(()) -} - -#[tokio::test] -async fn test_with_python_server() -> anyhow::Result<()> { - let _ = tracing_subscriber::registry() - .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "debug".to_string().into()), - ) - .with(tracing_subscriber::fmt::layer()) - .try_init(); - tokio::process::Command::new("uv") - .args(["pip", "install", "-r", "pyproject.toml"]) - .current_dir("tests/test_with_python") - .spawn()? - .wait() - .await?; - let transport = TokioChildProcess::new( - tokio::process::Command::new("uv") - .arg("run") - .arg("tests/test_with_python/server.py"), - )?; - - let client = ().serve(transport).await?; - let resources = client.list_all_resources().await?; - tracing::info!("{:#?}", resources); - let tools = client.list_all_tools().await?; - tracing::info!("{:#?}", tools); - client.cancel().await?; - Ok(()) -} diff --git a/crates/rmcp/tests/test_with_python/.gitignore b/crates/rmcp/tests/test_with_python/.gitignore deleted file mode 100644 index e69de29bb..000000000 diff --git a/crates/rmcp/tests/test_with_python/client.py b/crates/rmcp/tests/test_with_python/client.py deleted file mode 100644 index efa6b50d6..000000000 --- a/crates/rmcp/tests/test_with_python/client.py +++ /dev/null @@ -1,28 +0,0 @@ -from mcp import ClientSession, StdioServerParameters, types -from mcp.client.sse import sse_client - - - -async def run(): - async with sse_client("http://localhost:8000/sse") as (read, write): - async with ClientSession( - read, write - ) as session: - # Initialize the connection - await session.initialize() - - # List available prompts - prompts = await session.list_prompts() - print(prompts) - # List available resources - resources = await session.list_resources() - print(resources) - - # List available tools - tools = await session.list_tools() - print(tools) - -if __name__ == "__main__": - import asyncio - - asyncio.run(run()) \ No newline at end of file diff --git a/crates/rmcp/tests/test_with_python/pyproject.toml b/crates/rmcp/tests/test_with_python/pyproject.toml deleted file mode 100644 index f4b5a700e..000000000 --- a/crates/rmcp/tests/test_with_python/pyproject.toml +++ /dev/null @@ -1,14 +0,0 @@ -[build-system] -requires = ["setuptools>=42", "wheel"] -build-backend = "setuptools.build_meta" - -[project] -name = "test_with_python" -version = "0.1.0" -description = "Test Python client for RMCP" -dependencies = [ - "mcp", -] - -[tool.setuptools] -py-modules = ["client", "server"] \ No newline at end of file diff --git a/crates/rmcp/tests/test_with_python/server.py b/crates/rmcp/tests/test_with_python/server.py deleted file mode 100644 index 1e33efa42..000000000 --- a/crates/rmcp/tests/test_with_python/server.py +++ /dev/null @@ -1,20 +0,0 @@ -from mcp.server.fastmcp import FastMCP - -mcp = FastMCP("Demo") - -@mcp.tool() -def add(a: int, b: int) -> int: - """Add two numbers""" - return a + b - - -# Add a dynamic greeting resource -@mcp.resource("greeting://{name}") -def get_greeting(name: str) -> str: - """Get a personalized greeting""" - return f"Hello, {name}!" - - - -if __name__ == "__main__": - mcp.run() \ No newline at end of file diff --git a/examples/README.md b/examples/README.md deleted file mode 100644 index e4fe12da7..000000000 --- a/examples/README.md +++ /dev/null @@ -1,88 +0,0 @@ -# Quick Start With Claude Desktop - -1. **Build the Server (Counter Example)** - - ```sh - cargo build --release --example servers_std_io - ``` - - This builds a standard input/output MCP server binary. - -2. **Add or update this section in your** `PATH-TO/claude_desktop_config.json` - - Windows - - ```json - { - "mcpServers": { - "counter": { - "command": "PATH-TO/rust-sdk/target/release/examples/servers_std_io.exe", - "args": [] - } - } - } - ``` - - MacOS/Linux - - ```json - { - "mcpServers": { - "counter": { - "command": "PATH-TO/rust-sdk/target/release/examples/servers_std_io", - "args": [] - } - } - } - ``` - -3. **Ensure that the MCP UI elements appear in Claude Desktop** - The MCP UI elements will only show up in Claude for Desktop if at least one server is properly configured. It may require to restart Claude for Desktop. - -4. **Once Claude Desktop is running, try chatting:** - - ```text - counter.say_hello - ``` - - Or test other tools like: - - ```texts - counter.increment - counter.get_value - counter.sum {"a": 3, "b": 4} - ``` - -# Client Examples - -- [Client SSE](clients/src/sse.rs), using reqwest and eventsource-client. -- [Client stdio](clients/src/std_io.rs), using tokio to spawn child process. -- [Everything](clients/src/everything_stdio.rs), test with `@modelcontextprotocol/server-everything` -- [Collection](clients/src/collection.rs), How to transpose service into dynamic object, so they will have a same type. - -# Server Examples - -- [Server SSE](servers/src/axum.rs), using axum as web server. -- [Server stdio](servers/src/std_io.rs), using tokio async io. - -# Transport Examples - -- [Tcp](transport/src/tcp.rs) -- [Transport on http upgrade](transport/src/http_upgrade.rs) -- [Unix Socket](transport/src/unix_socket.rs) -- [Websocket](transport/src/websocket.rs) - -# Integration - -- [Rig](examples/rig-integration) A stream chatbot with rig -- [Simple Chat Client](examples/simple-chat-client) A simple chat client implementation using the Model Context Protocol (MCP) SDK. - -# WASI - -- [WASI-P2 runtime](wasi) How it works with wasip2 - -## Use Mcp Inspector - -```sh -npx @modelcontextprotocol/inspector -``` diff --git a/examples/clients/Cargo.toml b/examples/clients/Cargo.toml deleted file mode 100644 index 52226ed47..000000000 --- a/examples/clients/Cargo.toml +++ /dev/null @@ -1,42 +0,0 @@ - - -[package] -name = "mcp-client-examples" -version = "0.1.5" -edition = "2024" -publish = false - -[dependencies] -rmcp = { path = "../../crates/rmcp", features = [ - "client", - "transport-sse", - "transport-child-process", - "tower" -] } -tokio = { version = "1", features = ["full"] } -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" -tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter"] } -rand = "0.8" -futures = "0.3" -anyhow = "1.0" - -tower = "0.5" - -[[example]] -name = "clients_sse" -path = "src/sse.rs" - -[[example]] -name = "clients_std_io" -path = "src/std_io.rs" - -[[example]] -name = "clients_everything_stdio" -path = "src/everything_stdio.rs" - -[[example]] -name = "clients_collection" -path = "src/collection.rs" - diff --git a/examples/clients/src/collection.rs b/examples/clients/src/collection.rs deleted file mode 100644 index f474ea85f..000000000 --- a/examples/clients/src/collection.rs +++ /dev/null @@ -1,49 +0,0 @@ -use std::collections::HashMap; - -use anyhow::Result; -use rmcp::{model::CallToolRequestParam, service::ServiceExt, transport::TokioChildProcess}; -use tokio::process::Command; -use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; - -#[tokio::main] -async fn main() -> Result<()> { - // Initialize logging - tracing_subscriber::registry() - .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| format!("info,{}=debug", env!("CARGO_CRATE_NAME")).into()), - ) - .with(tracing_subscriber::fmt::layer()) - .init(); - - let mut client_list = HashMap::new(); - for idx in 0..10 { - let service = () - .into_dyn() - .serve(TokioChildProcess::new( - Command::new("uvx").arg("mcp-server-git"), - )?) - .await?; - client_list.insert(idx, service); - } - - for (_, service) in client_list.iter() { - // Initialize - let _server_info = service.peer_info(); - - // List tools - let _tools = service.list_tools(Default::default()).await?; - - // Call tool 'git_status' with arguments = {"repo_path": "."} - let _tool_result = service - .call_tool(CallToolRequestParam { - name: "git_status".into(), - arguments: serde_json::json!({ "repo_path": "." }).as_object().cloned(), - }) - .await?; - } - for (_, service) in client_list { - service.cancel().await?; - } - Ok(()) -} diff --git a/examples/clients/src/everything_stdio.rs b/examples/clients/src/everything_stdio.rs deleted file mode 100644 index 091e9053e..000000000 --- a/examples/clients/src/everything_stdio.rs +++ /dev/null @@ -1,98 +0,0 @@ -use anyhow::Result; -use rmcp::{ - ServiceExt, - model::{CallToolRequestParam, GetPromptRequestParam, ReadResourceRequestParam}, - object, - transport::TokioChildProcess, -}; -use tokio::process::Command; -use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; - -#[tokio::main] -async fn main() -> Result<()> { - // Initialize logging - tracing_subscriber::registry() - .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| format!("info,{}=debug", env!("CARGO_CRATE_NAME")).into()), - ) - .with(tracing_subscriber::fmt::layer()) - .init(); - - // Start server - let service = () - .serve(TokioChildProcess::new( - Command::new("npx") - .arg("-y") - .arg("@modelcontextprotocol/server-everything"), - )?) - .await?; - - // Initialize - let server_info = service.peer_info(); - tracing::info!("Connected to server: {server_info:#?}"); - - // List tools - let tools = service.list_all_tools().await?; - tracing::info!("Available tools: {tools:#?}"); - - // Call tool echo - let tool_result = service - .call_tool(CallToolRequestParam { - name: "echo".into(), - arguments: Some(object!({ "message": "hi from rmcp" })), - }) - .await?; - tracing::info!("Tool result for echo: {tool_result:#?}"); - - // Call tool longRunningOperation - let tool_result = service - .call_tool(CallToolRequestParam { - name: "longRunningOperation".into(), - arguments: Some(object!({ "duration": 3, "steps": 1 })), - }) - .await?; - tracing::info!("Tool result for longRunningOperation: {tool_result:#?}"); - - // List resources - let resources = service.list_all_resources().await?; - tracing::info!("Available resources: {resources:#?}"); - - // Read resource - let resource = service - .read_resource(ReadResourceRequestParam { - uri: "test://static/resource/3".into(), - }) - .await?; - tracing::info!("Resource: {resource:#?}"); - - // List prompts - let prompts = service.list_all_prompts().await?; - tracing::info!("Available prompts: {prompts:#?}"); - - // Get simple prompt - let prompt = service - .get_prompt(GetPromptRequestParam { - name: "simple_prompt".into(), - arguments: None, - }) - .await?; - tracing::info!("Prompt - simple: {prompt:#?}"); - - // Get complex prompt (returns text & image) - let prompt = service - .get_prompt(GetPromptRequestParam { - name: "complex_prompt".into(), - arguments: Some(object!({ "temperature": "0.5", "style": "formal" })), - }) - .await?; - tracing::info!("Prompt - complex: {prompt:#?}"); - - // List resource templates - let resource_templates = service.list_all_resource_templates().await?; - tracing::info!("Available resource templates: {resource_templates:#?}"); - - service.cancel().await?; - - Ok(()) -} diff --git a/examples/clients/src/sse.rs b/examples/clients/src/sse.rs deleted file mode 100644 index a4f3df391..000000000 --- a/examples/clients/src/sse.rs +++ /dev/null @@ -1,49 +0,0 @@ -use anyhow::Result; -use rmcp::{ - ServiceExt, - model::{CallToolRequestParam, ClientCapabilities, ClientInfo, Implementation}, - transport::SseTransport, -}; -use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; - -#[tokio::main] -async fn main() -> Result<()> { - // Initialize logging - tracing_subscriber::registry() - .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| format!("info,{}=debug", env!("CARGO_CRATE_NAME")).into()), - ) - .with(tracing_subscriber::fmt::layer()) - .init(); - let transport = SseTransport::start("http://localhost:8000/sse").await?; - let client_info = ClientInfo { - protocol_version: Default::default(), - capabilities: ClientCapabilities::default(), - client_info: Implementation { - name: "test sse client".to_string(), - version: "0.0.1".to_string(), - }, - }; - let client = client_info.serve(transport).await.inspect_err(|e| { - tracing::error!("client error: {:?}", e); - })?; - - // Initialize - let server_info = client.peer_info(); - tracing::info!("Connected to server: {server_info:#?}"); - - // List tools - let tools = client.list_tools(Default::default()).await?; - tracing::info!("Available tools: {tools:#?}"); - - let tool_result = client - .call_tool(CallToolRequestParam { - name: "increment".into(), - arguments: serde_json::json!({}).as_object().cloned(), - }) - .await?; - tracing::info!("Tool result: {tool_result:#?}"); - client.cancel().await?; - Ok(()) -} diff --git a/examples/clients/src/std_io.rs b/examples/clients/src/std_io.rs deleted file mode 100644 index 987418915..000000000 --- a/examples/clients/src/std_io.rs +++ /dev/null @@ -1,47 +0,0 @@ -use anyhow::Result; -use rmcp::{model::CallToolRequestParam, service::ServiceExt, transport::TokioChildProcess}; -use tokio::process::Command; -use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; - -#[tokio::main] -async fn main() -> Result<()> { - // Initialize logging - tracing_subscriber::registry() - .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| format!("info,{}=debug", env!("CARGO_CRATE_NAME")).into()), - ) - .with(tracing_subscriber::fmt::layer()) - .init(); - let service = () - .serve(TokioChildProcess::new( - Command::new("uvx").arg("mcp-server-git"), - )?) - .await?; - - // or - // serve_client( - // (), - // TokioChildProcess::new(Command::new("uvx").arg("mcp-server-git"))?, - // ) - // .await?; - - // Initialize - let server_info = service.peer_info(); - tracing::info!("Connected to server: {server_info:#?}"); - - // List tools - let tools = service.list_tools(Default::default()).await?; - tracing::info!("Available tools: {tools:#?}"); - - // Call tool 'git_status' with arguments = {"repo_path": "."} - let tool_result = service - .call_tool(CallToolRequestParam { - name: "git_status".into(), - arguments: serde_json::json!({ "repo_path": "." }).as_object().cloned(), - }) - .await?; - tracing::info!("Tool result: {tool_result:#?}"); - service.cancel().await?; - Ok(()) -} diff --git a/examples/rig-integration/Cargo.toml b/examples/rig-integration/Cargo.toml deleted file mode 100644 index 8fd14eb0e..000000000 --- a/examples/rig-integration/Cargo.toml +++ /dev/null @@ -1,33 +0,0 @@ -[package] -name = "rig-integration" -edition = { workspace = true } -version = { workspace = true } -authors = { workspace = true } -license = { workspace = true } -repository = { workspace = true } -description = { workspace = true } -keywords = { workspace = true } -homepage = { workspace = true } -categories = { workspace = true } -readme = { workspace = true } - -[dependencies] -rig-core = "0.10.0" -tokio = { version = "1", features = ["full"] } -rmcp = { path = "../../crates/rmcp", features = [ - "client", - "transport-child-process", - "transport-sse", -] } -anyhow = "1.0" -serde_json = "1" -serde = { version = "1", features = ["derive"] } -toml = "0.8" -futures = "0.3" -tracing = "0.1" -tracing-subscriber = { version = "0.3", features = [ - "env-filter", - "std", - "fmt", -] } -tracing-appender = "0.2" diff --git a/examples/rig-integration/config.toml b/examples/rig-integration/config.toml deleted file mode 100644 index 1affe7aaa..000000000 --- a/examples/rig-integration/config.toml +++ /dev/null @@ -1,10 +0,0 @@ -deepseek_key = "" -cohere_key = "" - -[mcp] - -[[mcp.server]] -name = "git" -protocol = "stdio" -command = "uvx" -args = ["mcp-server-git"] diff --git a/examples/rig-integration/src/chat.rs b/examples/rig-integration/src/chat.rs deleted file mode 100644 index 7997d7684..000000000 --- a/examples/rig-integration/src/chat.rs +++ /dev/null @@ -1,124 +0,0 @@ -use futures::StreamExt; -use rig::{ - agent::Agent, - message::Message, - streaming::{StreamingChat, StreamingChoice, StreamingCompletionModel}, -}; -use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter}; - -pub async fn cli_chatbot(chatbot: Agent) -> anyhow::Result<()> { - let mut chat_log = vec![]; - - let mut output = BufWriter::new(tokio::io::stdout()); - let mut input = BufReader::new(tokio::io::stdin()); - output.write_all(b"Enter :q to quit\n").await?; - loop { - output.write_all(b"\x1b[32muser>\x1b[0m ").await?; - // Flush stdout to ensure the prompt appears before input - output.flush().await?; - let mut input_buf = String::new(); - input.read_line(&mut input_buf).await?; - // Remove the newline character from the input - let input = input_buf.trim(); - // Check for a command to exit - if input == ":q" { - break; - } - match chatbot.stream_chat(input, chat_log.clone()).await { - Ok(mut response) => { - tracing::info!(%input); - chat_log.push(Message::user(input)); - stream_output_agent_start(&mut output).await?; - let mut message_buf = String::new(); - while let Some(message) = response.next().await { - match message { - Ok(StreamingChoice::Message(text)) => { - message_buf.push_str(&text); - output_agent(text, &mut output).await?; - } - Ok(StreamingChoice::ToolCall(name, _, param)) => { - chat_log.push(Message::assistant(format!( - "Calling tool: {name} with args: {param}" - ))); - let result = chatbot.tools.call(&name, param.to_string()).await; - match result { - Ok(tool_call_result) => { - stream_output_agent_finished(&mut output).await?; - stream_output_toolcall(&tool_call_result, &mut output).await?; - stream_output_agent_start(&mut output).await?; - chat_log.push(Message::user(tool_call_result)); - } - Err(e) => { - output_error(e, &mut output).await?; - } - } - } - Err(error) => { - output_error(error, &mut output).await?; - } - } - } - chat_log.push(Message::assistant(message_buf)); - stream_output_agent_finished(&mut output).await?; - } - Err(error) => { - output_error(error, &mut output).await?; - } - } - } - - Ok(()) -} - -pub async fn output_error( - e: impl std::fmt::Display, - output: &mut BufWriter, -) -> std::io::Result<()> { - output - .write_all(b"\x1b[1;31m\xE2\x9D\x8C ERROR: \x1b[0m") - .await?; - output.write_all(e.to_string().as_bytes()).await?; - output.write_all(b"\n").await?; - output.flush().await?; - Ok(()) -} - -pub async fn output_agent( - content: impl std::fmt::Display, - output: &mut BufWriter, -) -> std::io::Result<()> { - output.write_all(content.to_string().as_bytes()).await?; - output.flush().await?; - Ok(()) -} - -pub async fn stream_output_toolcall( - content: impl std::fmt::Display, - output: &mut BufWriter, -) -> std::io::Result<()> { - output - .write_all(b"\x1b[1;33m\xF0\x9F\x9B\xA0 Tool Call: \x1b[0m") - .await?; - output.write_all(content.to_string().as_bytes()).await?; - output.write_all(b"\n").await?; - output.flush().await?; - Ok(()) -} - -pub async fn stream_output_agent_start( - output: &mut BufWriter, -) -> std::io::Result<()> { - output - .write_all(b"\x1b[1;34m\xF0\x9F\xA4\x96 Agent: \x1b[0m") - .await?; - output.flush().await?; - Ok(()) -} - -pub async fn stream_output_agent_finished( - output: &mut BufWriter, -) -> std::io::Result<()> { - output.write_all(b"\n").await?; - output.flush().await?; - Ok(()) -} diff --git a/examples/rig-integration/src/config.rs b/examples/rig-integration/src/config.rs deleted file mode 100644 index 387a4f686..000000000 --- a/examples/rig-integration/src/config.rs +++ /dev/null @@ -1,20 +0,0 @@ -use std::path::Path; - -use serde::{Deserialize, Serialize}; - -pub mod mcp; - -#[derive(Debug, Deserialize, Serialize)] -pub struct Config { - pub mcp: mcp::McpConfig, - pub deepseek_key: Option, - pub cohere_key: Option, -} - -impl Config { - pub async fn retrieve(path: impl AsRef) -> anyhow::Result { - let content = tokio::fs::read_to_string(path).await?; - let config: Self = toml::from_str(&content)?; - Ok(config) - } -} diff --git a/examples/rig-integration/src/config/mcp.rs b/examples/rig-integration/src/config/mcp.rs deleted file mode 100644 index ba42f2bc0..000000000 --- a/examples/rig-integration/src/config/mcp.rs +++ /dev/null @@ -1,83 +0,0 @@ -use std::{collections::HashMap, process::Stdio}; - -use rmcp::{RoleClient, ServiceExt, service::RunningService}; -use serde::{Deserialize, Serialize}; - -use crate::mcp_adaptor::McpManager; -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct McpServerConfig { - name: String, - #[serde(flatten)] - transport: McpServerTransportConfig, -} - -#[derive(Debug, Serialize, Deserialize, Clone)] -#[serde(tag = "protocol", rename_all = "lowercase")] -pub enum McpServerTransportConfig { - Sse { - url: String, - }, - Stdio { - command: String, - #[serde(default)] - args: Vec, - #[serde(default)] - envs: HashMap, - }, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct McpConfig { - server: Vec, -} - -impl McpConfig { - pub async fn create_manager(&self) -> anyhow::Result { - let mut clients = HashMap::new(); - let mut task_set = tokio::task::JoinSet::>::new(); - for server in &self.server { - let server = server.clone(); - task_set.spawn(async move { - let client = server.transport.start().await?; - anyhow::Result::Ok((server.name.clone(), client)) - }); - } - let start_up_result = task_set.join_all().await; - for result in start_up_result { - match result { - Ok((name, client)) => { - clients.insert(name, client); - } - Err(e) => { - eprintln!("Failed to start server: {:?}", e); - } - } - } - Ok(McpManager { clients }) - } -} - -impl McpServerTransportConfig { - pub async fn start(&self) -> anyhow::Result> { - let client = match self { - McpServerTransportConfig::Sse { url } => { - let transport = rmcp::transport::SseTransport::start(url).await?; - ().serve(transport).await? - } - McpServerTransportConfig::Stdio { - command, - args, - envs, - } => { - let transport = rmcp::transport::TokioChildProcess::new( - tokio::process::Command::new(command) - .args(args) - .envs(envs) - .stderr(Stdio::null()), - )?; - ().serve(transport).await? - } - }; - Ok(client) - } -} diff --git a/examples/rig-integration/src/main.rs b/examples/rig-integration/src/main.rs deleted file mode 100644 index ddf220f83..000000000 --- a/examples/rig-integration/src/main.rs +++ /dev/null @@ -1,68 +0,0 @@ -use rig::{ - embeddings::EmbeddingsBuilder, - providers::{cohere, deepseek}, - vector_store::in_memory_store::InMemoryVectorStore, -}; -use tracing_appender::rolling::{RollingFileAppender, Rotation}; -pub mod chat; -pub mod config; -pub mod mcp_adaptor; - -#[tokio::main] -async fn main() -> anyhow::Result<()> { - let file_appender = RollingFileAppender::new( - Rotation::DAILY, - "logs", - format!("{}.log", env!("CARGO_CRATE_NAME")), - ); - tracing_subscriber::fmt() - .with_env_filter( - tracing_subscriber::EnvFilter::from_default_env() - .add_directive(tracing::Level::INFO.into()), - ) - .with_writer(file_appender) - .with_file(false) - .with_ansi(false) - .init(); - - let config = config::Config::retrieve("config.toml").await?; - let openai_client = { - if let Some(key) = config.deepseek_key { - deepseek::Client::new(&key) - } else { - deepseek::Client::from_env() - } - }; - let cohere_client = { - if let Some(key) = config.cohere_key { - cohere::Client::new(&key) - } else { - cohere::Client::from_env() - } - }; - let mcp_manager = config.mcp.create_manager().await?; - tracing::info!( - "MCP Manager created, {} servers started", - mcp_manager.clients.len() - ); - let tool_set = mcp_manager.get_tool_set().await?; - let embedding_model = - cohere_client.embedding_model(cohere::EMBED_MULTILINGUAL_V3, "search_document"); - let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .documents(tool_set.schemas()?)? - .build() - .await?; - let store = InMemoryVectorStore::from_documents_with_id_f(embeddings, |f| { - tracing::info!("store tool {}", f.name); - f.name.clone() - }); - let index = store.index(embedding_model); - let dpsk = openai_client - .agent(deepseek::DEEPSEEK_CHAT) - .dynamic_tools(4, index, tool_set) - .build(); - - chat::cli_chatbot(dpsk).await?; - - Ok(()) -} diff --git a/examples/rig-integration/src/mcp_adaptor.rs b/examples/rig-integration/src/mcp_adaptor.rs deleted file mode 100644 index 483c6e026..000000000 --- a/examples/rig-integration/src/mcp_adaptor.rs +++ /dev/null @@ -1,121 +0,0 @@ -use std::collections::HashMap; - -use rig::tool::{ToolDyn as RigTool, ToolEmbeddingDyn, ToolSet}; -use rmcp::{ - RoleClient, - model::{CallToolRequestParam, CallToolResult, Tool as McpTool}, - service::{RunningService, ServerSink}, -}; - -pub struct McpToolAdaptor { - tool: McpTool, - server: ServerSink, -} - -impl RigTool for McpToolAdaptor { - fn name(&self) -> String { - self.tool.name.to_string() - } - - fn definition( - &self, - _prompt: String, - ) -> std::pin::Pin + Send + Sync + '_>> - { - Box::pin(std::future::ready(rig::completion::ToolDefinition { - name: self.name(), - description: self - .tool - .description - .as_deref() - .unwrap_or_default() - .to_string(), - parameters: self.tool.schema_as_json_value(), - })) - } - - fn call( - &self, - args: String, - ) -> std::pin::Pin< - Box> + Send + Sync + '_>, - > { - let server = self.server.clone(); - Box::pin(async move { - let call_mcp_tool_result = server - .call_tool(CallToolRequestParam { - name: self.tool.name.clone(), - arguments: serde_json::from_str(&args) - .map_err(rig::tool::ToolError::JsonError)?, - }) - .await - .inspect(|result| tracing::info!(?result)) - .inspect_err(|error| tracing::error!(%error)) - .map_err(|e| rig::tool::ToolError::ToolCallError(Box::new(e)))?; - - Ok(convert_mcp_call_tool_result_to_string(call_mcp_tool_result)) - }) - } -} - -impl ToolEmbeddingDyn for McpToolAdaptor { - fn context(&self) -> serde_json::Result { - serde_json::to_value(self.tool.clone()) - } - - fn embedding_docs(&self) -> Vec { - vec![ - self.tool - .description - .as_deref() - .unwrap_or_default() - .to_string(), - ] - } -} - -pub struct McpManager { - pub clients: HashMap>, -} - -impl McpManager { - pub async fn get_tool_set(&self) -> anyhow::Result { - let mut tool_set = ToolSet::default(); - let mut task = tokio::task::JoinSet::>::new(); - for client in self.clients.values() { - let server = client.peer().clone(); - task.spawn(get_tool_set(server)); - } - let results = task.join_all().await; - for result in results { - match result { - Err(e) => { - tracing::error!(error = %e, "Failed to get tool set"); - } - Ok(tools) => { - tool_set.add_tools(tools); - } - } - } - Ok(tool_set) - } -} - -pub fn convert_mcp_call_tool_result_to_string(result: CallToolResult) -> String { - serde_json::to_string(&result).unwrap() -} - -pub async fn get_tool_set(server: ServerSink) -> anyhow::Result { - let tools = server.list_all_tools().await?; - let mut tool_builder = ToolSet::builder(); - for tool in tools { - tracing::info!("get tool: {}", tool.name); - let adaptor = McpToolAdaptor { - tool: tool.clone(), - server: server.clone(), - }; - tool_builder = tool_builder.dynamic_tool(adaptor); - } - let tool_set = tool_builder.build(); - Ok(tool_set) -} diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml deleted file mode 100644 index 63d9d6d24..000000000 --- a/examples/servers/Cargo.toml +++ /dev/null @@ -1,46 +0,0 @@ - - -[package] -name = "mcp-server-examples" -version = "0.1.5" -edition = "2024" -publish = false - -[dependencies] -rmcp= { path = "../../crates/rmcp", features = ["server", "transport-sse-server", "transport-io"] } -tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread", "io-std", "signal"] } -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" -anyhow = "1.0" -tracing = "0.1" -tracing-subscriber = { version = "0.3", features = [ - "env-filter", - "std", - "fmt", -] } -futures = "0.3" -rand = { version = "0.9" } -axum = { version = "0.8", features = ["macros"] } -schemars = { version = "0.8", optional = true } -# [dev-dependencies.'cfg(target_arch="linux")'.dependencies] - -[dev-dependencies] -tokio-stream = { version = "0.1" } -# tokio-util = { version = "0.7", features = ["io", "codec"] } -tokio-util = { version = "0.7", features = ["codec"] } - -[[example]] -name = "servers_std_io" -path = "src/std_io.rs" - -[[example]] -name = "servers_axum" -path = "src/axum.rs" - -[[example]] -name = "servers_axum_router" -path = "src/axum_router.rs" - -[[example]] -name = "servers_generic_server" -path = "src/generic_service.rs" \ No newline at end of file diff --git a/examples/servers/src/axum.rs b/examples/servers/src/axum.rs deleted file mode 100644 index e11462651..000000000 --- a/examples/servers/src/axum.rs +++ /dev/null @@ -1,29 +0,0 @@ -use rmcp::transport::sse_server::SseServer; -use tracing_subscriber::{ - layer::SubscriberExt, - util::SubscriberInitExt, - {self}, -}; -mod common; -use common::counter::Counter; - -const BIND_ADDRESS: &str = "127.0.0.1:8000"; - -#[tokio::main] -async fn main() -> anyhow::Result<()> { - tracing_subscriber::registry() - .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "debug".to_string().into()), - ) - .with(tracing_subscriber::fmt::layer()) - .init(); - - let ct = SseServer::serve(BIND_ADDRESS.parse()?) - .await? - .with_service(Counter::new); - - tokio::signal::ctrl_c().await?; - ct.cancel(); - Ok(()) -} diff --git a/examples/servers/src/axum_router.rs b/examples/servers/src/axum_router.rs deleted file mode 100644 index 373a8a6d7..000000000 --- a/examples/servers/src/axum_router.rs +++ /dev/null @@ -1,54 +0,0 @@ -use rmcp::transport::sse_server::{SseServer, SseServerConfig}; -use tracing_subscriber::{ - layer::SubscriberExt, - util::SubscriberInitExt, - {self}, -}; -mod common; -use common::counter::Counter; - -const BIND_ADDRESS: &str = "127.0.0.1:8000"; - -#[tokio::main] -async fn main() -> anyhow::Result<()> { - tracing_subscriber::registry() - .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "debug".to_string().into()), - ) - .with(tracing_subscriber::fmt::layer()) - .init(); - - let config = SseServerConfig { - bind: BIND_ADDRESS.parse()?, - sse_path: "/sse".to_string(), - post_path: "/message".to_string(), - ct: tokio_util::sync::CancellationToken::new(), - sse_keep_alive: None, - }; - - let (sse_server, router) = SseServer::new(config); - - // Do something with the router, e.g., add routes or middleware - - let listener = tokio::net::TcpListener::bind(sse_server.config.bind).await?; - - let ct = sse_server.config.ct.child_token(); - - let server = axum::serve(listener, router).with_graceful_shutdown(async move { - ct.cancelled().await; - tracing::info!("sse server cancelled"); - }); - - tokio::spawn(async move { - if let Err(e) = server.await { - tracing::error!(error = %e, "sse server shutdown with error"); - } - }); - - let ct = sse_server.with_service(Counter::new); - - tokio::signal::ctrl_c().await?; - ct.cancel(); - Ok(()) -} diff --git a/examples/servers/src/common/calculator.rs b/examples/servers/src/common/calculator.rs deleted file mode 100644 index 68beecc0c..000000000 --- a/examples/servers/src/common/calculator.rs +++ /dev/null @@ -1,46 +0,0 @@ -use rmcp::{ - ServerHandler, - handler::server::wrapper::Json, - model::{ServerCapabilities, ServerInfo}, - schemars, tool, -}; - -#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] -pub struct SumRequest { - #[schemars(description = "the left hand side number")] - pub a: i32, - pub b: i32, -} -#[derive(Debug, Clone)] -pub struct Calculator; -#[tool(tool_box)] -impl Calculator { - #[tool(description = "Calculate the sum of two numbers")] - fn sum(&self, #[tool(aggr)] SumRequest { a, b }: SumRequest) -> String { - (a + b).to_string() - } - - #[tool(description = "Calculate the difference of two numbers")] - fn sub( - &self, - #[tool(param)] - #[schemars(description = "the left hand side number")] - a: i32, - #[tool(param)] - #[schemars(description = "the right hand side number")] - b: i32, - ) -> Json { - Json(a - b) - } -} - -#[tool(tool_box)] -impl ServerHandler for Calculator { - fn get_info(&self) -> ServerInfo { - ServerInfo { - instructions: Some("A simple calculator".into()), - capabilities: ServerCapabilities::builder().enable_tools().build(), - ..Default::default() - } - } -} diff --git a/examples/servers/src/common/counter.rs b/examples/servers/src/common/counter.rs deleted file mode 100644 index 7bed523ab..000000000 --- a/examples/servers/src/common/counter.rs +++ /dev/null @@ -1,197 +0,0 @@ -use std::sync::Arc; - -use rmcp::{ - Error as McpError, RoleServer, ServerHandler, const_string, model::*, schemars, - service::RequestContext, tool, -}; -use serde_json::json; -use tokio::sync::Mutex; - -#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] -pub struct StructRequest { - pub a: i32, - pub b: i32, -} - -#[derive(Clone)] -pub struct Counter { - counter: Arc>, -} -#[tool(tool_box)] -impl Counter { - #[allow(dead_code)] - pub fn new() -> Self { - Self { - counter: Arc::new(Mutex::new(0)), - } - } - - fn _create_resource_text(&self, uri: &str, name: &str) -> Resource { - RawResource::new(uri, name.to_string()).no_annotation() - } - - #[tool(description = "Increment the counter by 1")] - async fn increment(&self) -> Result { - let mut counter = self.counter.lock().await; - *counter += 1; - Ok(CallToolResult::success(vec![Content::text( - counter.to_string(), - )])) - } - - #[tool(description = "Decrement the counter by 1")] - async fn decrement(&self) -> Result { - let mut counter = self.counter.lock().await; - *counter -= 1; - Ok(CallToolResult::success(vec![Content::text( - counter.to_string(), - )])) - } - - #[tool(description = "Get the current counter value")] - async fn get_value(&self) -> Result { - let counter = self.counter.lock().await; - Ok(CallToolResult::success(vec![Content::text( - counter.to_string(), - )])) - } - - #[tool(description = "Say hello to the client")] - fn say_hello(&self) -> Result { - Ok(CallToolResult::success(vec![Content::text("hello")])) - } - - #[tool(description = "Repeat what you say")] - fn echo( - &self, - #[tool(param)] - #[schemars(description = "Repeat what you say")] - saying: String, - ) -> Result { - Ok(CallToolResult::success(vec![Content::text(saying)])) - } - - #[tool(description = "Calculate the sum of two numbers")] - fn sum( - &self, - #[tool(aggr)] StructRequest { a, b }: StructRequest, - ) -> Result { - Ok(CallToolResult::success(vec![Content::text( - (a + b).to_string(), - )])) - } -} -const_string!(Echo = "echo"); -#[tool(tool_box)] -impl ServerHandler for Counter { - fn get_info(&self) -> ServerInfo { - ServerInfo { - protocol_version: ProtocolVersion::V_2024_11_05, - capabilities: ServerCapabilities::builder() - .enable_prompts() - .enable_resources() - .enable_tools() - .build(), - server_info: Implementation::from_build_env(), - instructions: Some("This server provides a counter tool that can increment and decrement values. The counter starts at 0 and can be modified using the 'increment' and 'decrement' tools. Use 'get_value' to check the current count.".to_string()), - } - } - - async fn list_resources( - &self, - _request: Option, - _: RequestContext, - ) -> Result { - Ok(ListResourcesResult { - resources: vec![ - self._create_resource_text("str:////Users/to/some/path/", "cwd"), - self._create_resource_text("memo://insights", "memo-name"), - ], - next_cursor: None, - }) - } - - async fn read_resource( - &self, - ReadResourceRequestParam { uri }: ReadResourceRequestParam, - _: RequestContext, - ) -> Result { - match uri.as_str() { - "str:////Users/to/some/path/" => { - let cwd = "/Users/to/some/path/"; - Ok(ReadResourceResult { - contents: vec![ResourceContents::text(cwd, uri)], - }) - } - "memo://insights" => { - let memo = "Business Intelligence Memo\n\nAnalysis has revealed 5 key insights ..."; - Ok(ReadResourceResult { - contents: vec![ResourceContents::text(memo, uri)], - }) - } - _ => Err(McpError::resource_not_found( - "resource_not_found", - Some(json!({ - "uri": uri - })), - )), - } - } - - async fn list_prompts( - &self, - _request: Option, - _: RequestContext, - ) -> Result { - Ok(ListPromptsResult { - next_cursor: None, - prompts: vec![Prompt::new( - "example_prompt", - Some("This is an example prompt that takes one required argument, message"), - Some(vec![PromptArgument { - name: "message".to_string(), - description: Some("A message to put in the prompt".to_string()), - required: Some(true), - }]), - )], - }) - } - - async fn get_prompt( - &self, - GetPromptRequestParam { name, arguments }: GetPromptRequestParam, - _: RequestContext, - ) -> Result { - match name.as_str() { - "example_prompt" => { - let message = arguments - .and_then(|json| json.get("message")?.as_str().map(|s| s.to_string())) - .ok_or_else(|| { - McpError::invalid_params("No message provided to example_prompt", None) - })?; - - let prompt = - format!("This is an example prompt with your message here: '{message}'"); - Ok(GetPromptResult { - description: None, - messages: vec![PromptMessage { - role: PromptMessageRole::User, - content: PromptMessageContent::text(prompt), - }], - }) - } - _ => Err(McpError::invalid_params("prompt not found", None)), - } - } - - async fn list_resource_templates( - &self, - _request: Option, - _: RequestContext, - ) -> Result { - Ok(ListResourceTemplatesResult { - next_cursor: None, - resource_templates: Vec::new(), - }) - } -} diff --git a/examples/servers/src/common/generic_service.rs b/examples/servers/src/common/generic_service.rs deleted file mode 100644 index 433a4308f..000000000 --- a/examples/servers/src/common/generic_service.rs +++ /dev/null @@ -1,73 +0,0 @@ -use std::sync::Arc; - -use rmcp::{ - ServerHandler, - model::{ServerCapabilities, ServerInfo}, - schemars, tool, -}; - -#[allow(dead_code)] -pub trait DataService: Send + Sync + 'static { - fn get_data(&self) -> String; - fn set_data(&mut self, data: String); -} - -#[derive(Debug, Clone)] -pub struct MemoryDataService { - data: String, -} - -impl MemoryDataService { - #[allow(dead_code)] - pub fn new(initial_data: impl Into) -> Self { - Self { - data: initial_data.into(), - } - } -} - -impl DataService for MemoryDataService { - fn get_data(&self) -> String { - self.data.clone() - } - - fn set_data(&mut self, data: String) { - self.data = data; - } -} - -#[derive(Debug, Clone)] -pub struct GenericService { - #[allow(dead_code)] - data_service: Arc, -} - -#[tool(tool_box)] -impl GenericService { - pub fn new(data_service: DS) -> Self { - Self { - data_service: Arc::new(data_service), - } - } - - #[tool(description = "get memory from service")] - pub async fn get_data(&self) -> String { - self.data_service.get_data() - } - - #[tool(description = "set memory to service")] - pub async fn set_data(&self, #[tool(param)] data: String) -> String { - let new_data = data.clone(); - format!("Current memory: {}", new_data) - } -} - -impl ServerHandler for GenericService { - fn get_info(&self) -> ServerInfo { - ServerInfo { - instructions: Some("generic data service".into()), - capabilities: ServerCapabilities::builder().enable_tools().build(), - ..Default::default() - } - } -} diff --git a/examples/servers/src/common/mod.rs b/examples/servers/src/common/mod.rs deleted file mode 100644 index 5919bccdc..000000000 --- a/examples/servers/src/common/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod calculator; -pub mod counter; -pub mod generic_service; diff --git a/examples/servers/src/generic_service.rs b/examples/servers/src/generic_service.rs deleted file mode 100644 index 546621e34..000000000 --- a/examples/servers/src/generic_service.rs +++ /dev/null @@ -1,18 +0,0 @@ -use std::error::Error; -mod common; -use common::generic_service::{GenericService, MemoryDataService}; -use rmcp::serve_server; - -#[tokio::main] -async fn main() -> Result<(), Box> { - let memory_service = MemoryDataService::new("initial data"); - - let generic_service = GenericService::new(memory_service); - - println!("start server, connect to standard input/output"); - - let io = (tokio::io::stdin(), tokio::io::stdout()); - - serve_server(generic_service, io).await?; - Ok(()) -} diff --git a/examples/servers/src/std_io.rs b/examples/servers/src/std_io.rs deleted file mode 100644 index 9339ab866..000000000 --- a/examples/servers/src/std_io.rs +++ /dev/null @@ -1,25 +0,0 @@ -use anyhow::Result; -use common::counter::Counter; -use rmcp::{ServiceExt, transport::stdio}; -use tracing_subscriber::{self, EnvFilter}; -mod common; -/// npx @modelcontextprotocol/inspector cargo run -p mcp-server-examples --example std_io -#[tokio::main] -async fn main() -> Result<()> { - // Initialize the tracing subscriber with file and stdout logging - tracing_subscriber::fmt() - .with_env_filter(EnvFilter::from_default_env().add_directive(tracing::Level::DEBUG.into())) - .with_writer(std::io::stderr) - .with_ansi(false) - .init(); - - tracing::info!("Starting MCP server"); - - // Create an instance of our counter router - let service = Counter::new().serve(stdio()).await.inspect_err(|e| { - tracing::error!("serving error: {:?}", e); - })?; - - service.waiting().await?; - Ok(()) -} diff --git a/examples/simple-chat-client/Cargo.toml b/examples/simple-chat-client/Cargo.toml deleted file mode 100644 index 24bcb78bd..000000000 --- a/examples/simple-chat-client/Cargo.toml +++ /dev/null @@ -1,20 +0,0 @@ -[package] -name = "simple-chat-client" -version = "0.1.0" -edition = "2021" - -[dependencies] -tokio = { version = "1", features = ["full"] } -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" -reqwest = { version = "0.11", features = ["json"] } -anyhow = "1.0" -thiserror = "1.0" -async-trait = "0.1" -futures = "0.3" -toml = "0.8" -rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", features = [ - "client", - "transport-child-process", - "transport-sse", -], no-default-features = true } \ No newline at end of file diff --git a/examples/simple-chat-client/README.md b/examples/simple-chat-client/README.md deleted file mode 100644 index 82478f056..000000000 --- a/examples/simple-chat-client/README.md +++ /dev/null @@ -1,15 +0,0 @@ -# Simple Chat Client - -A simple chat client implementation using the Model Context Protocol (MCP) SDK. It just a example for developers to understand how to use the MCP SDK. This example use the easiest way to start a MCP server, and call the tool directly. No need embedding or complex third library or function call(because some models can't support function call).Just add tool in system prompt, and the client will call the tool automatically. - - -## Config -the config file is in `src/config.toml`. you can change the config to your own.Move the config file to `/etc/simple-chat-client/config.toml` for system-wide configuration. - -## Usage - -After configuring the config file, you can run the example: -```bash -cargo run --bin simple_chat -``` - diff --git a/examples/simple-chat-client/src/bin/simple_chat.rs b/examples/simple-chat-client/src/bin/simple_chat.rs deleted file mode 100644 index 48ff17259..000000000 --- a/examples/simple-chat-client/src/bin/simple_chat.rs +++ /dev/null @@ -1,84 +0,0 @@ -use std::sync::Arc; - -use anyhow::Result; -use simple_chat_client::{ - chat::ChatSession, - client::OpenAIClient, - config::Config, - tool::{Tool, ToolSet, get_mcp_tools}, -}; - -//default config path -const DEFAULT_CONFIG_PATH: &str = "/etc/simple-chat-client/config.toml"; - -#[tokio::main] -async fn main() -> Result<()> { - // load config - let config = Config::load(DEFAULT_CONFIG_PATH).await?; - - // create openai client - let api_key = config - .openai_key - .clone() - .unwrap_or_else(|| std::env::var("OPENAI_API_KEY").expect("need set api key")); - let url = config.chat_url.clone(); - println!("url is {:?}", url); - let openai_client = Arc::new(OpenAIClient::new(api_key, url)); - - // create tool set - let mut tool_set = ToolSet::default(); - - // load mcp - if config.mcp.is_some() { - let mcp_clients = config.create_mcp_clients().await?; - - for (name, client) in mcp_clients { - println!("loading mcp tools: {}", name); - let server = client.peer().clone(); - let tools = get_mcp_tools(server).await?; - - for tool in tools { - println!("adding tool: {}", tool.name()); - tool_set.add_tool(tool); - } - } - } - - // create chat session - let mut session = ChatSession::new( - openai_client, - tool_set, - config - .model_name - .unwrap_or_else(|| "gpt-4o-mini".to_string()), - ); - - // build system prompt with tool info - let mut system_prompt = - "you are a assistant, you can help user to complete various tasks. you have the following tools to use:\n".to_string(); - - // add tool info to system prompt - for tool in session.get_tools() { - system_prompt.push_str(&format!( - "\ntool name: {}\ndescription: {}\nparameters: {}\n", - tool.name(), - tool.description(), - serde_json::to_string_pretty(&tool.parameters()).unwrap_or_default() - )); - } - - // add tool call format guidance - system_prompt.push_str( - "\nif you need to call tool, please use the following format:\n\ - Tool: \n\ - Inputs: \n", - ); - - // add system prompt - session.add_system_prompt(system_prompt); - - // start chat - session.chat().await?; - - Ok(()) -} diff --git a/examples/simple-chat-client/src/chat.rs b/examples/simple-chat-client/src/chat.rs deleted file mode 100644 index 34386cdce..000000000 --- a/examples/simple-chat-client/src/chat.rs +++ /dev/null @@ -1,170 +0,0 @@ -use std::{ - io::{self, Write}, - sync::Arc, -}; - -use anyhow::Result; -use serde_json::Value; - -use crate::{ - client::ChatClient, - model::{CompletionRequest, Message, Tool as ModelTool}, - tool::{Tool as ToolTrait, ToolSet}, -}; - -pub struct ChatSession { - client: Arc, - tool_set: ToolSet, - model: String, - messages: Vec, -} - -impl ChatSession { - pub fn new(client: Arc, tool_set: ToolSet, model: String) -> Self { - Self { - client, - tool_set, - model, - messages: Vec::new(), - } - } - - pub fn add_system_prompt(&mut self, prompt: impl ToString) { - self.messages.push(Message::system(prompt)); - } - - pub fn get_tools(&self) -> Vec> { - self.tool_set.tools() - } - - pub async fn chat(&mut self) -> Result<()> { - println!("welcome to use simple chat client, use 'exit' to quit"); - - loop { - print!("> "); - io::stdout().flush()?; - - let mut input = String::new(); - io::stdin().read_line(&mut input)?; - input = input.trim().to_string(); - - if input.is_empty() { - continue; - } - - if input == "exit" { - break; - } - - self.messages.push(Message::user(&input)); - - // prepare tool list - let tools = self.tool_set.tools(); - let tool_definitions = if !tools.is_empty() { - Some( - tools - .iter() - .map(|tool| crate::model::Tool { - name: tool.name(), - description: tool.description(), - parameters: tool.parameters(), - }) - .collect(), - ) - } else { - None - }; - - // create request - let request = CompletionRequest { - model: self.model.clone(), - messages: self.messages.clone(), - temperature: Some(0.7), - tools: tool_definitions, - }; - - // send request - let response = self.client.complete(request).await?; - - if let Some(choice) = response.choices.first() { - println!("AI: {}", choice.message.content); - self.messages.push(choice.message.clone()); - - // check if message contains tool call - if choice.message.content.contains("Tool:") { - let lines: Vec<&str> = choice.message.content.split('\n').collect(); - - // simple parse tool call - let mut tool_name = None; - let mut args_text = Vec::new(); - let mut parsing_args = false; - - for line in lines { - if line.starts_with("Tool:") { - tool_name = line.strip_prefix("Tool:").map(|s| s.trim().to_string()); - parsing_args = false; - } else if line.starts_with("Inputs:") { - parsing_args = true; - } else if parsing_args { - args_text.push(line.trim()); - } - } - - if let Some(name) = tool_name { - if let Some(tool) = self.tool_set.get_tool(&name) { - println!("calling tool: {}", name); - - // simple handle args - let args_str = args_text.join("\n"); - let args = match serde_json::from_str(&args_str) { - Ok(v) => v, - Err(_) => { - // try to handle args as string - serde_json::Value::String(args_str) - } - }; - - // call tool - match tool.call(args).await { - Ok(result) => { - println!("tool result: {}", result); - - // add tool result to dialog - self.messages.push(Message::user(result)); - } - Err(e) => { - println!("tool call failed: {}", e); - self.messages - .push(Message::user(format!("tool call failed: {}", e))); - } - } - } else { - println!("tool not found: {}", name); - } - } - } - } - } - - Ok(()) - } -} - -#[async_trait::async_trait] -impl ToolTrait for ModelTool { - fn name(&self) -> String { - self.name.clone() - } - - fn description(&self) -> String { - self.description.clone() - } - - fn parameters(&self) -> Value { - self.parameters.clone() - } - - async fn call(&self, _args: Value) -> Result { - unimplemented!("ModelTool can't be called directly, only for tool definition") - } -} diff --git a/examples/simple-chat-client/src/client.rs b/examples/simple-chat-client/src/client.rs deleted file mode 100644 index c2c292204..000000000 --- a/examples/simple-chat-client/src/client.rs +++ /dev/null @@ -1,68 +0,0 @@ -use anyhow::Result; -use async_trait::async_trait; -use reqwest::Client as HttpClient; - -use crate::model::{CompletionRequest, CompletionResponse}; - -#[async_trait] -pub trait ChatClient: Send + Sync { - async fn complete(&self, request: CompletionRequest) -> Result; -} - -pub struct OpenAIClient { - api_key: String, - client: HttpClient, - base_url: String, -} - -impl OpenAIClient { - pub fn new(api_key: String, url: Option) -> Self { - let base_url = url.unwrap_or("https://api.openai.com/v1/chat/completions".to_string()); - - // create http client without proxy - let client = HttpClient::builder() - .no_proxy() - .build() - .unwrap_or_else(|_| HttpClient::new()); - - Self { - api_key, - client, - base_url, - } - } - - pub fn with_base_url(mut self, base_url: impl Into) -> Self { - self.base_url = base_url.into(); - self - } -} - -#[async_trait] -impl ChatClient for OpenAIClient { - async fn complete(&self, request: CompletionRequest) -> Result { - println!("sending request to {}", self.base_url); - println!("using api key: {}", self.api_key); - let request_json = serde_json::to_string(&request)?; - println!("request content: {}", request_json); - // no proxy - - let response = self - .client - .post(&self.base_url) - .header("Authorization", format!("Bearer {}", self.api_key)) - .header("Content-Type", "application/json") - .json(&request) - .send() - .await?; - - if !response.status().is_success() { - let error_text = response.text().await?; - println!("API error: {}", error_text); - return Err(anyhow::anyhow!("API Error: {}", error_text)); - } - - let completion: CompletionResponse = response.json().await?; - Ok(completion) - } -} diff --git a/examples/simple-chat-client/src/config.rs b/examples/simple-chat-client/src/config.rs deleted file mode 100644 index 9431e7b82..000000000 --- a/examples/simple-chat-client/src/config.rs +++ /dev/null @@ -1,89 +0,0 @@ -use std::{collections::HashMap, path::Path, process::Stdio}; - -use anyhow::Result; -use rmcp::{RoleClient, ServiceExt, service::RunningService}; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Deserialize)] -pub struct Config { - pub openai_key: Option, - pub chat_url: Option, - pub mcp: Option, - pub model_name: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct McpConfig { - pub server: Vec, -} - -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct McpServerConfig { - pub name: String, - #[serde(flatten)] - pub transport: McpServerTransportConfig, -} - -#[derive(Debug, Serialize, Deserialize, Clone)] -#[serde(tag = "protocol", rename_all = "lowercase")] -pub enum McpServerTransportConfig { - Sse { - url: String, - }, - Stdio { - command: String, - #[serde(default)] - args: Vec, - #[serde(default)] - envs: HashMap, - }, -} - -impl McpServerTransportConfig { - pub async fn start(&self) -> Result> { - let client = match self { - McpServerTransportConfig::Sse { url } => { - let transport = rmcp::transport::sse::SseTransport::start(url).await?; - ().serve(transport).await? - } - McpServerTransportConfig::Stdio { - command, - args, - envs, - } => { - let transport = rmcp::transport::child_process::TokioChildProcess::new( - tokio::process::Command::new(command) - .args(args) - .envs(envs) - .stderr(Stdio::inherit()) - .stdout(Stdio::inherit()), - )?; - ().serve(transport).await? - } - }; - Ok(client) - } -} - -impl Config { - pub async fn load(path: impl AsRef) -> Result { - let content = tokio::fs::read_to_string(path).await?; - let config: Self = toml::from_str(&content)?; - Ok(config) - } - - pub async fn create_mcp_clients( - &self, - ) -> Result>> { - let mut clients = HashMap::new(); - - if let Some(mcp_config) = &self.mcp { - for server in &mcp_config.server { - let client = server.transport.start().await?; - clients.insert(server.name.clone(), client); - } - } - - Ok(clients) - } -} diff --git a/examples/simple-chat-client/src/config.toml b/examples/simple-chat-client/src/config.toml deleted file mode 100644 index 42e59f986..000000000 --- a/examples/simple-chat-client/src/config.toml +++ /dev/null @@ -1,10 +0,0 @@ -openai_key = "key" -chat_url = "url" -model_name = "model_name" - -[mcp] -[[mcp.server]] -name = "MCP server name" -protocol = "stdio" -command = "MCP server path" -args = [" "] \ No newline at end of file diff --git a/examples/simple-chat-client/src/error.rs b/examples/simple-chat-client/src/error.rs deleted file mode 100644 index 92c86643f..000000000 --- a/examples/simple-chat-client/src/error.rs +++ /dev/null @@ -1,24 +0,0 @@ -use std::fmt; - -use serde::Serialize; - -#[derive(Debug, Serialize)] -pub struct McpError { - pub message: String, -} - -impl fmt::Display for McpError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.message) - } -} - -impl std::error::Error for McpError {} - -impl McpError { - pub fn new(message: impl ToString) -> Self { - Self { - message: message.to_string(), - } - } -} diff --git a/examples/simple-chat-client/src/lib.rs b/examples/simple-chat-client/src/lib.rs deleted file mode 100644 index 6b2bd9d39..000000000 --- a/examples/simple-chat-client/src/lib.rs +++ /dev/null @@ -1,6 +0,0 @@ -pub mod chat; -pub mod client; -pub mod config; -pub mod error; -pub mod model; -pub mod tool; diff --git a/examples/simple-chat-client/src/model.rs b/examples/simple-chat-client/src/model.rs deleted file mode 100644 index 4e9aeb674..000000000 --- a/examples/simple-chat-client/src/model.rs +++ /dev/null @@ -1,90 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct Message { - pub role: String, - pub content: String, -} - -impl Message { - pub fn system(content: impl ToString) -> Self { - Self { - role: "system".to_string(), - content: content.to_string(), - } - } - - pub fn user(content: impl ToString) -> Self { - Self { - role: "user".to_string(), - content: content.to_string(), - } - } - - pub fn assistant(content: impl ToString) -> Self { - Self { - role: "assistant".to_string(), - content: content.to_string(), - } - } -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct CompletionRequest { - pub model: String, - pub messages: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option>, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct Tool { - pub name: String, - pub description: String, - pub parameters: serde_json::Value, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct CompletionResponse { - pub id: String, - pub object: String, - pub created: u64, - pub model: String, - pub choices: Vec, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct Choice { - pub index: u32, - pub message: Message, - pub finish_reason: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ToolCall { - pub name: String, - pub arguments: serde_json::Value, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ToolResult { - pub success: bool, - pub contents: Vec, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct Content { - pub content_type: String, - pub body: String, -} - -impl Content { - pub fn text(content: impl ToString) -> Self { - Self { - content_type: "text/plain".to_string(), - body: content.to_string(), - } - } -} diff --git a/examples/simple-chat-client/src/tool.rs b/examples/simple-chat-client/src/tool.rs deleted file mode 100644 index 6f9993669..000000000 --- a/examples/simple-chat-client/src/tool.rs +++ /dev/null @@ -1,130 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use anyhow::Result; -use async_trait::async_trait; -use rmcp::{ - model::{CallToolRequestParam, Tool as McpTool}, - service::ServerSink, -}; -use serde_json::Value; - -use crate::{ - error::McpError, - model::{Content, ToolResult}, -}; - -#[async_trait] -pub trait Tool: Send + Sync { - fn name(&self) -> String; - fn description(&self) -> String; - fn parameters(&self) -> Value; - async fn call(&self, args: Value) -> Result; -} - -pub struct McpToolAdapter { - tool: McpTool, - server: ServerSink, -} - -impl McpToolAdapter { - pub fn new(tool: McpTool, server: ServerSink) -> Self { - Self { tool, server } - } -} - -#[async_trait] -impl Tool for McpToolAdapter { - fn name(&self) -> String { - self.tool.name.clone().to_string() - } - - fn description(&self) -> String { - self.tool - .description - .clone() - .unwrap_or_default() - .to_string() - } - - fn parameters(&self) -> Value { - serde_json::to_value(&self.tool.input_schema).unwrap_or(serde_json::json!({})) - } - - async fn call(&self, args: Value) -> Result { - let arguments = match args { - Value::Object(map) => Some(map), - _ => None, - }; - - let call_result = self - .server - .call_tool(CallToolRequestParam { - name: self.tool.name.clone(), - arguments, - }) - .await?; - let result = serde_json::to_string(&call_result).unwrap(); - - Ok(result) - } -} -#[derive(Default)] -pub struct ToolSet { - tools: HashMap>, -} - -impl ToolSet { - pub fn add_tool(&mut self, tool: T) { - self.tools.insert(tool.name(), Arc::new(tool)); - } - - pub fn get_tool(&self, name: &str) -> Option> { - self.tools.get(name).cloned() - } - - pub fn tools(&self) -> Vec> { - self.tools.values().cloned().collect() - } -} - -pub async fn get_mcp_tools(server: ServerSink) -> Result> { - let tools = server.list_all_tools().await?; - Ok(tools - .into_iter() - .map(|tool| McpToolAdapter::new(tool, server.clone())) - .collect()) -} - -pub trait IntoCallToolResult { - fn into_call_tool_result(self) -> Result; -} - -impl IntoCallToolResult for Result -where - T: serde::Serialize, -{ - fn into_call_tool_result(self) -> Result { - match self { - Ok(response) => { - let content = Content { - content_type: "application/json".to_string(), - body: serde_json::to_string(&response).unwrap_or_default(), - }; - Ok(ToolResult { - success: true, - contents: vec![content], - }) - } - Err(error) => { - let content = Content { - content_type: "application/json".to_string(), - body: serde_json::to_string(&error).unwrap_or_default(), - }; - Ok(ToolResult { - success: false, - contents: vec![content], - }) - } - } - } -} diff --git a/examples/transport/Cargo.toml b/examples/transport/Cargo.toml deleted file mode 100644 index 06708bb9c..000000000 --- a/examples/transport/Cargo.toml +++ /dev/null @@ -1,61 +0,0 @@ -[package] -name = "transport" -edition = { workspace = true } -version = { workspace = true } -authors = { workspace = true } -license = { workspace = true } -repository = { workspace = true } -description = { workspace = true } -keywords = { workspace = true } -homepage = { workspace = true } -categories = { workspace = true } -readme = { workspace = true } - -[package.metadata.docs.rs] -all-features = true - -[dependencies] -rmcp = { path = "../../crates/rmcp", features = ["server", "client"] } -tokio = { version = "1", features = [ - "macros", - "rt", - "rt-multi-thread", - "io-std", - "net", - "fs", - "time", -] } -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" -anyhow = "1.0" -tracing = "0.1" -tracing-subscriber = { version = "0.3", features = [ - "env-filter", - "std", - "fmt", -] } -futures = "0.3" -rand = { version = "0.8" } -schemars = { version = "0.8", optional = true } -hyper = { version = "1", features = ["client", "server", "http1"] } -hyper-util = { version = "0.1", features = ["tokio"] } -tokio-tungstenite = "0.26.2" -reqwest = { version = "0.12" } -pin-project-lite = "0.2" - -[[example]] -name = "tcp" -path = "src/tcp.rs" - - -[[example]] -name = "http_upgrade" -path = "src/http_upgrade.rs" - -[[example]] -name = "unix_socket" -path = "src/unix_socket.rs" - -[[example]] -name = "websocket" -path = "src/websocket.rs" diff --git a/examples/transport/src/common/calculator.rs b/examples/transport/src/common/calculator.rs deleted file mode 100644 index 99b7314a1..000000000 --- a/examples/transport/src/common/calculator.rs +++ /dev/null @@ -1,41 +0,0 @@ -use rmcp::{ServerHandler, model::ServerInfo, schemars, tool, tool_box}; - -#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] -pub struct SumRequest { - #[schemars(description = "the left hand side number")] - pub a: i32, - pub b: i32, -} -#[derive(Debug, Clone)] -pub struct Calculator; -impl Calculator { - #[tool(description = "Calculate the sum of two numbers")] - fn sum(&self, #[tool(aggr)] SumRequest { a, b }: SumRequest) -> String { - (a + b).to_string() - } - - #[tool(description = "Calculate the sub of two numbers")] - fn sub( - &self, - #[tool(param)] - #[schemars(description = "the left hand side number")] - a: i32, - #[tool(param)] - #[schemars(description = "the right hand side number")] - b: i32, - ) -> String { - (a - b).to_string() - } - - tool_box!(Calculator { sum, sub }); -} - -impl ServerHandler for Calculator { - tool_box!(@derive); - fn get_info(&self) -> ServerInfo { - ServerInfo { - instructions: Some("A simple calculator".into()), - ..Default::default() - } - } -} diff --git a/examples/transport/src/common/mod.rs b/examples/transport/src/common/mod.rs deleted file mode 100644 index 09bb58d2c..000000000 --- a/examples/transport/src/common/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod calculator; diff --git a/examples/transport/src/http_upgrade.rs b/examples/transport/src/http_upgrade.rs deleted file mode 100644 index 1c0cf6ad2..000000000 --- a/examples/transport/src/http_upgrade.rs +++ /dev/null @@ -1,67 +0,0 @@ -use common::calculator::Calculator; -use hyper::{ - Request, StatusCode, - body::Incoming, - header::{HeaderValue, UPGRADE}, -}; -use hyper_util::rt::TokioIo; -use rmcp::{RoleClient, ServiceExt, service::RunningService}; -use tracing_subscriber::EnvFilter; -mod common; -#[tokio::main] -async fn main() -> anyhow::Result<()> { - tracing_subscriber::fmt() - .with_env_filter(EnvFilter::from_default_env().add_directive(tracing::Level::INFO.into())) - .init(); - start_server().await?; - let client = http_client("127.0.0.1:8001").await?; - let tools = client.list_all_tools().await?; - client.cancel().await?; - tracing::info!("{:#?}", tools); - Ok(()) -} - -async fn http_server(req: Request) -> Result, hyper::Error> { - tokio::spawn(async move { - let upgraded = hyper::upgrade::on(req).await?; - let service = Calculator.serve(TokioIo::new(upgraded)).await?; - service.waiting().await?; - anyhow::Result::<()>::Ok(()) - }); - let mut response = hyper::Response::new(String::new()); - *response.status_mut() = StatusCode::SWITCHING_PROTOCOLS; - response - .headers_mut() - .insert(UPGRADE, HeaderValue::from_static("mcp")); - Ok(response) -} - -async fn http_client(uri: &str) -> anyhow::Result> { - let tcp_stream = tokio::net::TcpStream::connect(uri).await?; - let (mut s, c) = - hyper::client::conn::http1::handshake::<_, String>(TokioIo::new(tcp_stream)).await?; - tokio::spawn(c.with_upgrades()); - let mut req = Request::new(String::new()); - req.headers_mut() - .insert(UPGRADE, HeaderValue::from_static("mcp")); - let response = s.send_request(req).await?; - let upgraded = hyper::upgrade::on(response).await?; - let client = ().serve(TokioIo::new(upgraded)).await?; - Ok(client) -} - -async fn start_server() -> anyhow::Result<()> { - let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:8001").await?; - let service = hyper::service::service_fn(http_server); - tokio::spawn(async move { - while let Ok((stream, addr)) = tcp_listener.accept().await { - tracing::info!("accepted connection from: {}", addr); - let conn = hyper::server::conn::http1::Builder::new() - .serve_connection(TokioIo::new(stream), service) - .with_upgrades(); - tokio::spawn(conn); - } - }); - - Ok(()) -} diff --git a/examples/transport/src/tcp.rs b/examples/transport/src/tcp.rs deleted file mode 100644 index 72428fe65..000000000 --- a/examples/transport/src/tcp.rs +++ /dev/null @@ -1,32 +0,0 @@ -use common::calculator::Calculator; -use rmcp::{serve_client, serve_server}; - -mod common; -#[tokio::main] -async fn main() -> anyhow::Result<()> { - tokio::spawn(server()); - client().await?; - Ok(()) -} - -async fn server() -> anyhow::Result<()> { - let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:8001").await?; - while let Ok((stream, _)) = tcp_listener.accept().await { - tokio::spawn(async move { - let server = serve_server(Calculator, stream).await?; - server.waiting().await?; - anyhow::Ok(()) - }); - } - Ok(()) -} - -async fn client() -> anyhow::Result<()> { - let stream = tokio::net::TcpSocket::new_v4()? - .connect("127.0.0.1:8001".parse()?) - .await?; - let client = serve_client((), stream).await?; - let tools = client.peer().list_tools(Default::default()).await?; - println!("{:?}", tools); - Ok(()) -} diff --git a/examples/transport/src/unix_socket.rs b/examples/transport/src/unix_socket.rs deleted file mode 100644 index 875ed9bb0..000000000 --- a/examples/transport/src/unix_socket.rs +++ /dev/null @@ -1,82 +0,0 @@ -use std::fs; - -use common::calculator::Calculator; -use rmcp::{serve_client, serve_server}; -use tokio::net::{UnixListener, UnixStream}; - -mod common; - -const SOCKET_PATH: &str = "/tmp/rmcp_example.sock"; - -#[tokio::main] -async fn main() -> anyhow::Result<()> { - // Remove any existing socket file - let _ = fs::remove_file(SOCKET_PATH); - match UnixListener::bind(SOCKET_PATH) { - Ok(unix_listener) => { - println!("Server successfully listening on {}", SOCKET_PATH); - tokio::spawn(server(unix_listener)); - } - Err(e) => { - println!("Unable to bind to {}: {}", SOCKET_PATH, e); - } - } - - client().await?; - - // Clean up socket file - let _ = fs::remove_file(SOCKET_PATH); - - Ok(()) -} - -async fn server(unix_listener: UnixListener) -> anyhow::Result<()> { - while let Ok((stream, addr)) = unix_listener.accept().await { - println!("Client connected: {:?}", addr); - tokio::spawn(async move { - match serve_server(Calculator, stream).await { - Ok(server) => { - println!("Server initialized successfully"); - if let Err(e) = server.waiting().await { - println!("Error while server waiting: {}", e); - } - } - Err(e) => println!("Server initialization failed: {}", e), - } - - anyhow::Ok(()) - }); - } - Ok(()) -} - -async fn client() -> anyhow::Result<()> { - println!("Client connecting to {}", SOCKET_PATH); - let stream = UnixStream::connect(SOCKET_PATH).await?; - - let client = serve_client((), stream).await?; - println!("Client connected and initialized successfully"); - - // List available tools - let tools = client.peer().list_tools(Default::default()).await?; - println!("Available tools: {:?}", tools); - - // Call the sum tool - if let Some(sum_tool) = tools.tools.iter().find(|t| t.name.contains("sum")) { - println!("Calling sum tool: {}", sum_tool.name); - let result = client - .peer() - .call_tool(rmcp::model::CallToolRequestParam { - name: sum_tool.name.clone(), - arguments: Some(rmcp::object!({ - "a": 10, - "b": 20 - })), - }) - .await?; - - println!("Result: {:?}", result); - } - - Ok(()) -} diff --git a/examples/transport/src/websocket.rs b/examples/transport/src/websocket.rs deleted file mode 100644 index 0d0fec729..000000000 --- a/examples/transport/src/websocket.rs +++ /dev/null @@ -1,168 +0,0 @@ -use std::marker::PhantomData; - -use common::calculator::Calculator; -use futures::{Sink, Stream}; -use rmcp::{ - RoleClient, RoleServer, ServiceExt, - service::{RunningService, RxJsonRpcMessage, ServiceRole, TxJsonRpcMessage}, -}; -use tokio_tungstenite::tungstenite; -use tracing_subscriber::EnvFilter; -mod common; -#[tokio::main] -async fn main() -> anyhow::Result<()> { - tracing_subscriber::fmt() - .with_env_filter(EnvFilter::from_default_env().add_directive(tracing::Level::INFO.into())) - .init(); - start_server().await?; - let client = http_client("ws://127.0.0.1:8001").await?; - let tools = client.list_all_tools().await?; - client.cancel().await?; - tracing::info!("{:#?}", tools); - Ok(()) -} - -async fn http_client(uri: &str) -> anyhow::Result> { - let (stream, response) = tokio_tungstenite::connect_async(uri).await?; - if response.status() != tungstenite::http::StatusCode::SWITCHING_PROTOCOLS { - return Err(anyhow::anyhow!("failed to upgrade connection")); - } - let transport = WebsocketTransport::new_client(stream); - let client = ().serve(transport).await?; - Ok(client) -} - -async fn start_server() -> anyhow::Result<()> { - let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:8001").await?; - tokio::spawn(async move { - while let Ok((stream, addr)) = tcp_listener.accept().await { - tracing::info!("accepted connection from: {}", addr); - tokio::spawn(async move { - let ws_stream = tokio_tungstenite::accept_async(stream).await?; - let transport = WebsocketTransport::new_server(ws_stream); - let server = Calculator.serve(transport).await?; - server.waiting().await?; - Ok::<(), anyhow::Error>(()) - }); - } - }); - Ok(()) -} - -pin_project_lite::pin_project! { - pub struct WebsocketTransport { - #[pin] - stream: S, - marker: PhantomData<(fn() -> E, fn() -> R)> - } -} - -impl WebsocketTransport { - pub fn new(stream: S) -> Self { - Self { - stream, - marker: PhantomData, - } - } -} - -impl WebsocketTransport { - pub fn new_client(stream: S) -> Self { - Self { - stream, - marker: PhantomData, - } - } -} - -impl WebsocketTransport { - pub fn new_server(stream: S) -> Self { - Self { - stream, - marker: PhantomData, - } - } -} - -impl Stream for WebsocketTransport -where - S: Stream>, - R: ServiceRole, - E: std::error::Error, -{ - type Item = RxJsonRpcMessage; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.as_mut().project(); - match this.stream.poll_next(cx) { - std::task::Poll::Ready(Some(Ok(message))) => { - let message = match message { - tungstenite::Message::Text(json) => json, - _ => return self.poll_next(cx), - }; - let message = match serde_json::from_str::>(&message) { - Ok(message) => message, - Err(e) => { - tracing::warn!(error = %e, "serde_json parse error"); - return self.poll_next(cx); - } - }; - std::task::Poll::Ready(Some(message)) - } - std::task::Poll::Ready(Some(Err(e))) => { - tracing::warn!(error = %e, "websocket error"); - self.poll_next(cx) - } - std::task::Poll::Ready(None) => std::task::Poll::Ready(None), - std::task::Poll::Pending => std::task::Poll::Pending, - } - } -} - -impl Sink> for WebsocketTransport -where - S: Sink, - R: ServiceRole, -{ - type Error = E; - - fn poll_ready( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.project(); - this.stream.poll_ready(cx) - } - - fn start_send( - self: std::pin::Pin<&mut Self>, - item: TxJsonRpcMessage, - ) -> Result<(), Self::Error> { - let this = self.project(); - let message = tungstenite::Message::Text( - serde_json::to_string(&item) - .expect("jsonrpc should be valid json") - .into(), - ); - this.stream.start_send(message) - } - - fn poll_flush( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.project(); - this.stream.poll_flush(cx) - } - - fn poll_close( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let this = self.project(); - this.stream.poll_close(cx) - } -} diff --git a/examples/wasi/Cargo.toml b/examples/wasi/Cargo.toml deleted file mode 100644 index 509bfde46..000000000 --- a/examples/wasi/Cargo.toml +++ /dev/null @@ -1,29 +0,0 @@ - - -[package] -name = "wasi" -edition = { workspace = true } -version = { workspace = true } -authors = { workspace = true } -license = { workspace = true } -repository = { workspace = true } -description = { workspace = true } -keywords = { workspace = true } -homepage = { workspace = true } -categories = { workspace = true } -readme = { workspace = true } - -[lib] -crate-type = ["cdylib"] - -[dependencies] -wasi = { version = "0.14.2"} -tokio = { version = "1", features = ["rt", "io-util", "sync", "macros", "time"] } -rmcp= { path = "../../crates/rmcp", features = ["server", "macros"] } -serde = { version = "1", features = ["derive"]} -tracing-subscriber = { version = "0.3", features = [ - "env-filter", - "std", - "fmt", -] } -tracing = "0.1" \ No newline at end of file diff --git a/examples/wasi/README.md b/examples/wasi/README.md deleted file mode 100644 index ca5e661dc..000000000 --- a/examples/wasi/README.md +++ /dev/null @@ -1,4 +0,0 @@ -```sh -cargo build -p wasi --target wasm32-wasip2 -npx @modelcontextprotocol/inspector wasmtime target/wasm32-wasip2/debug/wasi.wasm -``` diff --git a/examples/wasi/config.toml b/examples/wasi/config.toml deleted file mode 100644 index dac24597a..000000000 --- a/examples/wasi/config.toml +++ /dev/null @@ -1,2 +0,0 @@ -[build] -target = "wasm32-wasip2" \ No newline at end of file diff --git a/examples/wasi/src/calculator.rs b/examples/wasi/src/calculator.rs deleted file mode 100644 index f1c35eeac..000000000 --- a/examples/wasi/src/calculator.rs +++ /dev/null @@ -1,46 +0,0 @@ -use rmcp::{ - ServerHandler, - model::{ServerCapabilities, ServerInfo}, - schemars, tool, tool_box, -}; - -#[derive(Debug, rmcp::serde::Deserialize, schemars::JsonSchema)] -pub struct SumRequest { - #[schemars(description = "the left hand side number")] - pub a: i32, - pub b: i32, -} -#[derive(Debug, Clone)] -pub struct Calculator; -impl Calculator { - #[tool(description = "Calculate the sum of two numbers")] - fn sum(&self, #[tool(aggr)] SumRequest { a, b }: SumRequest) -> String { - (a + b).to_string() - } - - #[tool(description = "Calculate the sub of two numbers")] - fn sub( - &self, - #[tool(param)] - #[schemars(description = "the left hand side number")] - a: i32, - #[tool(param)] - #[schemars(description = "the right hand side number")] - b: i32, - ) -> String { - (a - b).to_string() - } - - tool_box!(Calculator { sum, sub }); -} - -impl ServerHandler for Calculator { - tool_box!(@derive); - fn get_info(&self) -> ServerInfo { - ServerInfo { - instructions: Some("A simple calculator".into()), - capabilities: ServerCapabilities::builder().enable_tools().build(), - ..Default::default() - } - } -} diff --git a/examples/wasi/src/lib.rs b/examples/wasi/src/lib.rs deleted file mode 100644 index 3b2904a81..000000000 --- a/examples/wasi/src/lib.rs +++ /dev/null @@ -1,121 +0,0 @@ -pub mod calculator; -use std::task::{Poll, Waker}; - -use rmcp::ServiceExt; -use tokio::io::{AsyncRead, AsyncWrite}; -use tracing_subscriber::EnvFilter; -use wasi::{ - cli::{ - stdin::{InputStream, get_stdin}, - stdout::{OutputStream, get_stdout}, - }, - io::streams::Pollable, -}; - -pub fn wasi_io() -> (AsyncInputStream, AsyncOutputStream) { - let input = AsyncInputStream { inner: get_stdin() }; - let output = AsyncOutputStream { - inner: get_stdout(), - }; - (input, output) -} - -pub struct AsyncInputStream { - inner: InputStream, -} - -impl AsyncRead for AsyncInputStream { - fn poll_read( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> std::task::Poll> { - let bytes = self - .inner - .read(buf.remaining() as u64) - .map_err(std::io::Error::other)?; - if bytes.is_empty() { - let pollable = self.inner.subscribe(); - let waker = cx.waker().clone(); - runtime_poll(waker, pollable); - return Poll::Pending; - } - buf.put_slice(&bytes); - std::task::Poll::Ready(Ok(())) - } -} - -pub struct AsyncOutputStream { - inner: OutputStream, -} -fn runtime_poll(waker: Waker, pollable: Pollable) { - tokio::task::spawn(async move { - loop { - if pollable.ready() { - waker.wake(); - break; - } else { - tokio::task::yield_now().await; - } - } - }); -} -impl AsyncWrite for AsyncOutputStream { - fn poll_write( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> Poll> { - let writable_len = self.inner.check_write().map_err(std::io::Error::other)?; - if writable_len == 0 { - let pollable = self.inner.subscribe(); - let waker = cx.waker().clone(); - runtime_poll(waker, pollable); - return Poll::Pending; - } - let bytes_to_write = buf.len().min(writable_len as usize); - self.inner - .write(&buf[0..bytes_to_write]) - .map_err(std::io::Error::other)?; - Poll::Ready(Ok(bytes_to_write)) - } - - fn poll_flush( - self: std::pin::Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> Poll> { - self.inner.flush().map_err(std::io::Error::other)?; - Poll::Ready(Ok(())) - } - - fn poll_shutdown( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - self.poll_flush(cx) - } -} - -struct TokioCliRunner; - -impl wasi::exports::cli::run::Guest for TokioCliRunner { - fn run() -> Result<(), ()> { - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); - rt.block_on(async move { - tracing_subscriber::fmt() - .with_env_filter( - EnvFilter::from_default_env().add_directive(tracing::Level::DEBUG.into()), - ) - .with_writer(std::io::stderr) - .with_ansi(false) - .init(); - let server = calculator::Calculator.serve(wasi_io()).await.unwrap(); - server.waiting().await.unwrap(); - }); - Ok(()) - } -} -wasi::cli::command::export!(TokioCliRunner);