diff --git a/crates/rmcp/src/service/server.rs b/crates/rmcp/src/service/server.rs index 41749adac..9270488ba 100644 --- a/crates/rmcp/src/service/server.rs +++ b/crates/rmcp/src/service/server.rs @@ -1,14 +1,15 @@ use futures::{SinkExt, StreamExt}; +use thiserror::Error; use super::*; use crate::model::{ - CancelledNotification, CancelledNotificationParam, ClientInfo, ClientNotification, - ClientRequest, ClientResult, CreateMessageRequest, CreateMessageRequestParam, - CreateMessageResult, ListRootsRequest, ListRootsResult, LoggingMessageNotification, - LoggingMessageNotificationParam, ProgressNotification, ProgressNotificationParam, - PromptListChangedNotification, ResourceListChangedNotification, ResourceUpdatedNotification, - ResourceUpdatedNotificationParam, ServerInfo, ServerMessage, ServerNotification, ServerRequest, - ServerResult, ToolListChangedNotification, + CancelledNotification, CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage, + ClientMessage, ClientNotification, ClientRequest, ClientResult, CreateMessageRequest, + CreateMessageRequestParam, CreateMessageResult, ListRootsRequest, ListRootsResult, + LoggingMessageNotification, LoggingMessageNotificationParam, ProgressNotification, + ProgressNotificationParam, PromptListChangedNotification, ResourceListChangedNotification, + ResourceUpdatedNotification, ResourceUpdatedNotificationParam, ServerInfo, ServerMessage, + ServerNotification, ServerRequest, ServerResult, ToolListChangedNotification, }; #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] @@ -26,6 +27,24 @@ impl ServiceRole for RoleServer { const IS_CLIENT: bool = false; } +/// It represents the error that may occur when serving the server. +/// +/// if you want to handle the error, you can use `serve_server_with_ct` or `serve_server` with `Result, ServerError>` +#[derive(Error, Debug)] +pub enum ServerError { + #[error("expect initialized request, but received: {0:?}")] + ExpectedInitRequest(Option), + + #[error("expect initialized notification, but received: {0:?}")] + ExpectedInitNotification(Option), + + #[error("connection closed: {0}")] + ConnectionClosed(String), + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), +} + pub type ClientSink = Peer; impl> ServiceExt for S { @@ -55,6 +74,46 @@ where serve_server_with_ct(service, transport, CancellationToken::new()).await } +/// Helper function to get the next message from the stream +async fn expect_next_message(stream: &mut S, context: &str) -> Result +where + S: StreamExt + Unpin, +{ + Ok(stream + .next() + .await + .ok_or_else(|| ServerError::ConnectionClosed(context.to_string()))? + .into_message()) +} + +/// Helper function to expect a request from the stream +async fn expect_request( + stream: &mut S, + context: &str, +) -> Result<(ClientRequest, RequestId), ServerError> +where + S: StreamExt + Unpin, +{ + let msg = expect_next_message(stream, context).await?; + let msg_clone = msg.clone(); + msg.into_request() + .ok_or(ServerError::ExpectedInitRequest(Some(msg_clone))) +} + +/// Helper function to expect a notification from the stream +async fn expect_notification( + stream: &mut S, + context: &str, +) -> Result +where + S: StreamExt + Unpin, +{ + let msg = expect_next_message(stream, context).await?; + let msg_clone = msg.clone(); + msg.into_notification() + .ok_or(ServerError::ExpectedInitNotification(Some(msg_clone))) +} + pub async fn serve_server_with_ct( service: S, transport: T, @@ -70,54 +129,45 @@ where let mut stream = Box::pin(stream); let id_provider = >::default(); - // service - let (request, id) = stream - .next() + // Convert ServerError to std::io::Error, then to E + let handle_server_error = |e: ServerError| -> E { + match e { + ServerError::Io(io_err) => io_err.into(), + other => std::io::Error::new(std::io::ErrorKind::Other, format!("{}", other)).into(), + } + }; + + // Get initialize request + let (request, id) = expect_request(&mut stream, "initialized request") .await - .ok_or(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "expect initialize request", - ))? - .into_message() - .into_request() - .ok_or(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "expect initialize request", - ))?; + .map_err(handle_server_error)?; + let ClientRequest::InitializeRequest(peer_info) = request else { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "expect initialize request", - ) - .into()); + return Err(handle_server_error(ServerError::ExpectedInitRequest(Some( + ClientMessage::Request(request, id), + )))); }; + + // Send initialize response let init_response = service.get_info(); sink.send( ServerMessage::Response(ServerResult::InitializeResult(init_response), id) .into_json_rpc_message(), ) .await?; - // waiting for notification - let notification = stream - .next() + + // Wait for initialize notification + let notification = expect_notification(&mut stream, "initialize notification") .await - .ok_or(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "expect initialize notification", - ))? - .into_message() - .into_notification() - .ok_or(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "expect initialize notification", - ))?; + .map_err(handle_server_error)?; + let ClientNotification::InitializedNotification(_) = notification else { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "expect initialize notification", - ) - .into()); + return Err(handle_server_error(ServerError::ExpectedInitNotification( + Some(ClientMessage::Notification(notification)), + ))); }; + + // Continue processing service serve_inner(service, (sink, stream), peer_info.params, id_provider, ct).await }