From 5710be172bd06fcf17b0cdedcf93f77ff4911cd4 Mon Sep 17 00:00:00 2001 From: crschnick Date: Thu, 31 Jul 2025 19:27:18 +0000 Subject: [PATCH] Rework --- app/build.gradle | 9 +- .../io/xpipe/app/beacon/AppBeaconServer.java | 10 + .../main/java/io/xpipe/app/core/AppLogs.java | 17 +- .../java/io/xpipe/app/core/mode/BaseMode.java | 2 + ...HttpStreamableServerTransportProvider.java | 753 ++++++++++++++++++ .../java/io/xpipe/app/mcp/McpSchemaFiles.java | 13 + .../main/java/io/xpipe/app/mcp/McpServer.java | 89 +++ .../main/java/io/xpipe/app/mcp/McpTools.java | 184 +++++ app/src/main/java/io/xpipe/app/mcp/find.json | 22 + app/src/main/java/io/xpipe/app/mcp/read.json | 15 + app/src/main/java/module-info.java | 5 + gradle/gradle_scripts/modules.gradle | 43 + 12 files changed, 1153 insertions(+), 9 deletions(-) create mode 100644 app/src/main/java/io/xpipe/app/mcp/HttpStreamableServerTransportProvider.java create mode 100644 app/src/main/java/io/xpipe/app/mcp/McpSchemaFiles.java create mode 100644 app/src/main/java/io/xpipe/app/mcp/McpServer.java create mode 100644 app/src/main/java/io/xpipe/app/mcp/McpTools.java create mode 100644 app/src/main/java/io/xpipe/app/mcp/find.json create mode 100644 app/src/main/java/io/xpipe/app/mcp/read.json diff --git a/app/build.gradle b/app/build.gradle index 246496230..122de1389 100644 --- a/app/build.gradle +++ b/app/build.gradle @@ -48,7 +48,14 @@ dependencies { api 'com.vladsch.flexmark:flexmark-ext-yaml-front-matter:0.64.8' api 'com.vladsch.flexmark:flexmark-ext-toc:0.64.8' - api("com.github.weisj:jsvg:1.7.1") + api 'io.modelcontextprotocol.sdk:mcp:0.11.0' + api "io.projectreactor:reactor-core:3.7.0" + api "io.micrometer:context-propagation:1.1.3" + api "io.projectreactor.tools:blockhound:1.0.13.RELEASE" + api "org.reactivestreams:reactive-streams:1.0.4" + api "com.networknt:json-schema-validator:1.5.8" + + api "com.github.weisj:jsvg:1.7.1" api 'io.xpipe:vernacular:1.15' api 'org.bouncycastle:bcprov-jdk18on:1.81' api 'info.picocli:picocli:4.7.6' diff --git a/app/src/main/java/io/xpipe/app/beacon/AppBeaconServer.java b/app/src/main/java/io/xpipe/app/beacon/AppBeaconServer.java index ad042299b..b4807286c 100644 --- a/app/src/main/java/io/xpipe/app/beacon/AppBeaconServer.java +++ b/app/src/main/java/io/xpipe/app/beacon/AppBeaconServer.java @@ -2,6 +2,7 @@ package io.xpipe.app.beacon; import io.xpipe.app.issue.ErrorEventFactory; import io.xpipe.app.issue.TrackEvent; +import io.xpipe.app.mcp.McpServer; import io.xpipe.app.util.DocumentationLink; import io.xpipe.beacon.BeaconConfig; import io.xpipe.beacon.BeaconInterface; @@ -150,6 +151,15 @@ public class AppBeaconServer { handleCatchAll(exchange); }); + server.createContext("/mcp", exchange -> { + if (exchange.getRequestMethod().equals("GET")) { + McpServer.HANDLER.doGet(exchange); + } else { + McpServer.HANDLER.doPost(exchange); + } + exchange.close(); + }); + server.start(); running = true; } diff --git a/app/src/main/java/io/xpipe/app/core/AppLogs.java b/app/src/main/java/io/xpipe/app/core/AppLogs.java index e659976b1..c7a2625ba 100644 --- a/app/src/main/java/io/xpipe/app/core/AppLogs.java +++ b/app/src/main/java/io/xpipe/app/core/AppLogs.java @@ -6,6 +6,7 @@ import io.xpipe.core.Deobfuscator; import lombok.Getter; import org.apache.commons.io.FileUtils; +import org.apache.commons.io.FilenameUtils; import org.slf4j.ILoggerFactory; import org.slf4j.IMarkerFactory; import org.slf4j.Logger; @@ -315,15 +316,15 @@ public class AppLogs { public Logger getLogger(String name) { // Only change this when debugging the logs of other libraries - return NOPLogger.NOP_LOGGER; + // return NOPLogger.NOP_LOGGER; - // // Don't use fully qualified class names - // var normalizedName = FilenameUtils.getExtension(name); - // if (normalizedName == null || normalizedName.isEmpty()) { - // normalizedName = name; - // } - // - // return loggers.computeIfAbsent(normalizedName, s -> new Slf4jLogger()); + // Don't use fully qualified class names + var normalizedName = FilenameUtils.getExtension(name); + if (normalizedName == null || normalizedName.isEmpty()) { + normalizedName = name; + } + + return loggers.computeIfAbsent(normalizedName, s -> new Slf4jLogger()); } }; diff --git a/app/src/main/java/io/xpipe/app/core/mode/BaseMode.java b/app/src/main/java/io/xpipe/app/core/mode/BaseMode.java index ac527466a..eac386839 100644 --- a/app/src/main/java/io/xpipe/app/core/mode/BaseMode.java +++ b/app/src/main/java/io/xpipe/app/core/mode/BaseMode.java @@ -17,6 +17,7 @@ import io.xpipe.app.ext.ProcessControlProvider; import io.xpipe.app.hub.comp.StoreViewState; import io.xpipe.app.icon.SystemIconManager; import io.xpipe.app.issue.TrackEvent; +import io.xpipe.app.mcp.McpServer; import io.xpipe.app.prefs.AppPrefs; import io.xpipe.app.pwman.KeePassXcPasswordManager; import io.xpipe.app.storage.DataStorage; @@ -64,6 +65,7 @@ public class BaseMode extends OperationMode { AppJavaOptionsCheck.check(); AppSid.init(); AppBeaconServer.init(); + McpServer.init(); AppLayoutModel.init(); if (OperationMode.getStartupMode() == XPipeDaemonMode.GUI) { diff --git a/app/src/main/java/io/xpipe/app/mcp/HttpStreamableServerTransportProvider.java b/app/src/main/java/io/xpipe/app/mcp/HttpStreamableServerTransportProvider.java new file mode 100644 index 000000000..2a79dfa57 --- /dev/null +++ b/app/src/main/java/io/xpipe/app/mcp/HttpStreamableServerTransportProvider.java @@ -0,0 +1,753 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.xpipe.app.mcp; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.sun.net.httpserver.HttpExchange; +import io.modelcontextprotocol.server.DefaultMcpTransportContext; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.spec.*; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.KeepAliveScheduler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.io.*; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.ReentrantLock; + +public class HttpStreamableServerTransportProvider implements McpStreamableServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(HttpStreamableServerTransportProvider.class); + + /** + * Event type for JSON-RPC messages sent through the SSE connection. + */ + public static final String MESSAGE_EVENT_TYPE = "message"; + + /** + * Event type for sending the message endpoint URI to clients. + */ + public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + /** + * Header name for the response media types accepted by the requester. + */ + private static final String ACCEPT = "Accept"; + + public static final String UTF_8 = "UTF-8"; + + public static final String APPLICATION_JSON = "application/json"; + + public static final String TEXT_EVENT_STREAM = "text/event-stream"; + + public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}"; + + /** + * The endpoint URI where clients should send their JSON-RPC messages. Defaults to + * "/mcp". + */ + private final String mcpEndpoint; + + /** + * Flag indicating whether DELETE requests are disallowed on the endpoint. + */ + private final boolean disallowDelete; + + private final ObjectMapper objectMapper; + + private McpStreamableServerSession.Factory sessionFactory; + + /** + * Map of active client sessions, keyed by mcp-session-id. + */ + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + private McpTransportContextExtractor contextExtractor; + + /** + * Flag indicating if the transport is shutting down. + */ + private volatile boolean isClosing = false; + + /** + * Keep-alive scheduler for managing session pings. Activated if keepAliveInterval is + * set. Disabled by default. + */ + private KeepAliveScheduler keepAliveScheduler; + + /** + * Constructs a new HttpServletStreamableServerTransportProvider instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of messages. + * @param mcpEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP. This endpoint will handle GET, POST, and DELETE requests. + * @param disallowDelete Whether to disallow DELETE requests on the endpoint. + * @param contextExtractor The extractor for transport context from the request. + * @throws IllegalArgumentException if any parameter is null + */ + private HttpStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint, + boolean disallowDelete, McpTransportContextExtractor contextExtractor, + Duration keepAliveInterval) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); + Assert.notNull(contextExtractor, "Context extractor must not be null"); + + this.objectMapper = objectMapper; + this.mcpEndpoint = mcpEndpoint; + this.disallowDelete = disallowDelete; + this.contextExtractor = contextExtractor; + + if (keepAliveInterval != null) { + + this.keepAliveScheduler = KeepAliveScheduler + .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(sessions.values())) + .initialDelay(keepAliveInterval) + .interval(keepAliveInterval) + .build(); + + this.keepAliveScheduler.start(); + } + + } + + @Override + public String protocolVersion() { + return "2025-03-26"; + } + + @Override + public void setSessionFactory(McpStreamableServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + /** + * Broadcasts a notification to all connected clients through their SSE connections. + * If any errors occur during sending to a particular client, they are logged but + * don't prevent sending to other clients. + * @param method The method name for the notification + * @param params The parameters for the notification + * @return A Mono that completes when the broadcast attempt is finished + */ + @Override + public Mono notifyClients(String method, Object params) { + if (this.sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", this.sessions.size()); + + return Mono.fromRunnable(() -> { + this.sessions.values().parallelStream().forEach(session -> { + try { + session.sendNotification(method, params).block(); + } + catch (Exception e) { + logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage()); + } + }); + }); + } + + /** + * Initiates a graceful shutdown of the transport. + * @return A Mono that completes when all cleanup operations are finished + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + this.isClosing = true; + logger.debug("Initiating graceful shutdown with {} active sessions", this.sessions.size()); + + this.sessions.values().parallelStream().forEach(session -> { + try { + session.closeGracefully().block(); + } + catch (Exception e) { + logger.error("Failed to close session {}: {}", session.getId(), e.getMessage()); + } + }); + + this.sessions.clear(); + logger.debug("Graceful shutdown completed"); + }).then().doOnSuccess(v -> { + sessions.clear(); + logger.debug("Graceful shutdown completed"); + if (this.keepAliveScheduler != null) { + this.keepAliveScheduler.shutdown(); + } + }); + } + + public void doGet(HttpExchange exchange) + throws IOException { + + String requestURI = exchange.getRequestURI().toString(); + if (!requestURI.endsWith(mcpEndpoint)) { + sendError(exchange, 404, null); + return; + } + + if (this.isClosing) { + sendError(exchange, 503, "Server is shutting down"); + return; + } + + List badRequestErrors = new ArrayList<>(); + + String accept = exchange.getRequestHeaders().getFirst(ACCEPT); + if (accept == null || !accept.contains(TEXT_EVENT_STREAM)) { + badRequestErrors.add("text/event-stream required in Accept header"); + } + + String sessionId = exchange.getRequestHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + + if (sessionId == null || sessionId.isBlank()) { + badRequestErrors.add("Session ID required in mcp-session-id header"); + } + + if (!badRequestErrors.isEmpty()) { + String combinedMessage = String.join("; ", badRequestErrors); + this.sendMcpError(exchange, 400, new McpError(combinedMessage)); + return; + } + + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + sendError(exchange, 404, null); + return; + } + + logger.debug("Handling GET request for session: {}", sessionId); + + McpTransportContext transportContext = this.contextExtractor.extract(exchange, new DefaultMcpTransportContext()); + + try { + exchange.getResponseHeaders().add("Content-Type", TEXT_EVENT_STREAM); + exchange.getResponseHeaders().add("Content-Encoding", UTF_8); + exchange.getResponseHeaders().add("Cache-Control", "no-cache"); + exchange.getResponseHeaders().add("Connection", "keep-alive"); + exchange.getResponseHeaders().add("Access-Control-Allow-Origin", "*"); + exchange.sendResponseHeaders(200, 0); + + var writer = new PrintWriter(exchange.getResponseBody()); + HttpServletStreamableMcpSessionTransport sessionTransport = new HttpServletStreamableMcpSessionTransport( + sessionId, exchange, writer); + + // Check if this is a replay request + if (exchange.getRequestHeaders().getFirst(HttpHeaders.LAST_EVENT_ID) != null) { + String lastId = exchange.getRequestHeaders().getFirst(HttpHeaders.LAST_EVENT_ID); + + try { + session.replay(lastId) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .toIterable() + .forEach(message -> { + try { + sessionTransport.sendMessage(message) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + } + catch (Exception e) { + logger.error("Failed to replay message: {}", e.getMessage()); + exchange.close(); + } + }); + } + catch (Exception e) { + logger.error("Failed to replay messages: {}", e.getMessage()); + exchange.close(); + } + } + } + catch (Exception e) { + logger.error("Failed to handle GET request for session {}: {}", sessionId, e.getMessage()); + sendError(exchange, 500, null); + } + } + + private void sendError(HttpExchange exchange, int code, String message) throws IOException { + var b = message != null ? message.getBytes(StandardCharsets.UTF_8) : new byte[0]; + exchange.getResponseHeaders().add("Content-Encoding", UTF_8); + exchange.sendResponseHeaders(code, b.length != 0 ? b.length : -1); + try (OutputStream os = exchange.getResponseBody()) { + os.write(b); + } + } + + public void doPost(HttpExchange exchange) + throws IOException { + + String requestURI = exchange.getRequestURI().toString(); + if (!requestURI.endsWith(mcpEndpoint)) { + sendError(exchange, 404, null); + return; + } + + if (this.isClosing) { + sendError(exchange, 503, "Server is shutting down"); + return; + } + + List badRequestErrors = new ArrayList<>(); + + String accept = exchange.getRequestHeaders().getFirst(ACCEPT); + if (accept == null || !accept.contains(TEXT_EVENT_STREAM)) { + badRequestErrors.add("text/event-stream required in Accept header"); + } + if (accept == null || !accept.contains(APPLICATION_JSON)) { + badRequestErrors.add("application/json required in Accept header"); + } + + McpTransportContext transportContext = this.contextExtractor.extract(exchange, new DefaultMcpTransportContext()); + + try { + var body = new String(exchange.getRequestBody().readAllBytes(), StandardCharsets.UTF_8); + + McpSchemaFiles.JSONRPCMessage message = McpSchemaFiles.deserializeJsonRpcMessage(objectMapper, body); + + // Handle initialization request + if (message instanceof McpSchemaFiles.JSONRPCRequest jsonrpcRequest + && jsonrpcRequest.method().equals(McpSchemaFiles.METHOD_INITIALIZE)) { + if (!badRequestErrors.isEmpty()) { + String combinedMessage = String.join("; ", badRequestErrors); + this.sendMcpError(exchange, 400, new McpError(combinedMessage)); + return; + } + + McpSchemaFiles.InitializeRequest initializeRequest = objectMapper.convertValue(jsonrpcRequest.params(), + new TypeReference() { + }); + McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory + .startSession(initializeRequest); + this.sessions.put(init.session().getId(), init.session()); + + try { + McpSchemaFiles.InitializeResult initResult = init.initResult().block(); + + String jsonResponse = objectMapper.writeValueAsString(new McpSchemaFiles.JSONRPCResponse( + McpSchemaFiles.JSONRPC_VERSION, jsonrpcRequest.id(), initResult, null)); + var jsonBytes = jsonResponse.getBytes(StandardCharsets.UTF_8); + + exchange.getResponseHeaders().add("Content-Type", APPLICATION_JSON); + exchange.getResponseHeaders().add("Content-Encoding", UTF_8); + exchange.getResponseHeaders().add(HttpHeaders.MCP_SESSION_ID, init.session().getId()); + exchange.sendResponseHeaders(200, jsonBytes.length); + exchange.getResponseBody().write(jsonBytes); + return; + } + catch (Exception e) { + logger.error("Failed to initialize session: {}", e.getMessage()); + this.sendMcpError(exchange, 500, + new McpError("Failed to initialize session: " + e.getMessage())); + return; + } + } + + String sessionId = exchange.getRequestHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + + if (sessionId == null || sessionId.isBlank()) { + badRequestErrors.add("Session ID required in mcp-session-id header"); + } + + if (!badRequestErrors.isEmpty()) { + String combinedMessage = String.join("; ", badRequestErrors); + this.sendMcpError(exchange, 400, new McpError(combinedMessage)); + return; + } + + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + this.sendMcpError(exchange, 404, + new McpError("Session not found: " + sessionId)); + return; + } + + if (message instanceof McpSchemaFiles.JSONRPCResponse jsonrpcResponse) { + session.accept(jsonrpcResponse) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + exchange.sendResponseHeaders(200, -1); + } + else if (message instanceof McpSchemaFiles.JSONRPCNotification jsonrpcNotification) { + session.accept(jsonrpcNotification) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + exchange.sendResponseHeaders(202, -1); + } + else if (message instanceof McpSchemaFiles.JSONRPCRequest jsonrpcRequest) { + // For streaming responses, we need to return SSE + exchange.getResponseHeaders().add("Content-Type", TEXT_EVENT_STREAM); + exchange.getResponseHeaders().add("Content-Encoding", UTF_8); + exchange.getResponseHeaders().add("Cache-Control", "no-cache"); + exchange.getResponseHeaders().add("Connection", "keep-alive"); + exchange.getResponseHeaders().add("Access-Control-Allow-Origin", "*"); + exchange.sendResponseHeaders(200, 0); + + var writer = new PrintWriter(exchange.getResponseBody()); + + HttpServletStreamableMcpSessionTransport sessionTransport = new HttpServletStreamableMcpSessionTransport( + sessionId, exchange, writer); + + try { + session.responseStream(jsonrpcRequest, sessionTransport) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + } + catch (Exception e) { + logger.error("Failed to handle request stream: {}", e.getMessage()); + exchange.close(); + } + } + else { + this.sendMcpError(exchange, 500, new McpError("Unknown message type")); + } + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + this.sendMcpError(exchange, 400, + new McpError("Invalid message format: " + e.getMessage())); + } + catch (Exception e) { + logger.error("Error handling message: {}", e.getMessage()); + try { + this.sendMcpError(exchange, 500, + new McpError("Error processing message: " + e.getMessage())); + } + catch (IOException ex) { + logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage()); + sendError(exchange, 500, "Error processing message"); + } + } + } +// +// /** +// * Handles DELETE requests for session deletion. +// * @param request The HTTP servlet request +// * @param response The HTTP servlet response +// * @throws ServletException If a servlet-specific error occurs +// * @throws IOException If an I/O error occurs +// */ +// @Override +// protected void doDelete(HttpRequest request, HttpServletResponse response) +// throws IOException { +// +// String requestURI = request.getRequestURI(); +// if (!requestURI.endsWith(mcpEndpoint)) { +// response.sendError(404); +// return; +// } +// +// if (this.isClosing) { +// response.sendError(503, "Server is shutting down"); +// return; +// } +// +// if (this.disallowDelete) { +// response.sendError(405); +// return; +// } +// +// McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); +// +// if (request.headers().firstValue(HttpHeaders.MCP_SESSION_ID).orElse(null) == null) { +// this.responseError(response, 400, +// new McpError("Session ID required in mcp-session-id header")); +// return; +// } +// +// String sessionId = request.headers().firstValue(HttpHeaders.MCP_SESSION_ID).orElse(null); +// McpStreamableServerSession session = this.sessions.get(sessionId); +// +// if (session == null) { +// response.sendError(404); +// return; +// } +// +// try { +// session.delete().contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block(); +// this.sessions.remove(sessionId); +// response.setStatus(200); +// } +// catch (Exception e) { +// logger.error("Failed to delete session {}: {}", sessionId, e.getMessage()); +// try { +// this.responseError(response, 500, +// new McpError(e.getMessage())); +// } +// catch (IOException ex) { +// logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage()); +// response.sendError(500, "Error deleting session"); +// } +// } +// } + + public void sendMcpError(HttpExchange exchange, int httpCode, McpError mcpError) throws IOException { + var jsonError = objectMapper.writeValueAsString(mcpError); + var bytes = jsonError.getBytes(StandardCharsets.UTF_8); + exchange.getResponseHeaders().set("Content-Type", APPLICATION_JSON); + exchange.getResponseHeaders().add("Content-Encoding", UTF_8); + exchange.sendResponseHeaders(httpCode, bytes.length); + try (OutputStream os = exchange.getResponseBody()) { + os.write(bytes); + } + } + + /** + * Sends an SSE event to a client with a specific ID. + * @param writer The writer to send the event through + * @param eventType The type of event (message or endpoint) + * @param data The event data + * @param id The event ID + * @throws IOException If an error occurs while writing the event + */ + private void sendEvent(PrintWriter writer, String eventType, String data, String id) throws IOException { + if (id != null) { + writer.write("id: " + id + "\n"); + } + writer.write("event: " + eventType + "\n"); + writer.write("data: " + data + "\n\n"); + writer.flush(); + + if (writer.checkError()) { + throw new IOException("Client disconnected"); + } + } + + /** + * Implementation of McpStreamableServerTransport for HttpServlet SSE sessions. This + * class handles the transport-level communication for a specific client session. + * + *

+ * This class is thread-safe and uses a ReentrantLock to synchronize access to the + * underlying PrintWriter to prevent race conditions when multiple threads attempt to + * send messages concurrently. + */ + + private class HttpServletStreamableMcpSessionTransport implements McpStreamableServerTransport { + + private final String sessionId; + + private final HttpExchange exchange; + + private final PrintWriter writer; + + private volatile boolean closed = false; + + private final ReentrantLock lock = new ReentrantLock(); + + HttpServletStreamableMcpSessionTransport(String sessionId, HttpExchange exchange, PrintWriter writer) { + this.sessionId = sessionId; + this.exchange = exchange; + this.writer = writer; + logger.debug("Streamable session transport {} initialized with SSE writer", sessionId); + } + + /** + * Sends a JSON-RPC message to the client through the SSE connection. + * @param message The JSON-RPC message to send + * @return A Mono that completes when the message has been sent + */ + @Override + public Mono sendMessage(McpSchemaFiles.JSONRPCMessage message) { + return sendMessage(message, null); + } + + /** + * Sends a JSON-RPC message to the client through the SSE connection with a + * specific message ID. + * @param message The JSON-RPC message to send + * @param messageId The message ID for SSE event identification + * @return A Mono that completes when the message has been sent + */ + @Override + public Mono sendMessage(McpSchemaFiles.JSONRPCMessage message, String messageId) { + return Mono.fromRunnable(() -> { + if (this.closed) { + logger.debug("Attempted to send message to closed session: {}", this.sessionId); + return; + } + + lock.lock(); + try { + if (this.closed) { + logger.debug("Session {} was closed during message send attempt", this.sessionId); + return; + } + + String jsonText = objectMapper.writeValueAsString(message); + HttpStreamableServerTransportProvider.this.sendEvent(writer, MESSAGE_EVENT_TYPE, jsonText, + messageId != null ? messageId : this.sessionId); + logger.debug("Message sent to session {} with ID {}", this.sessionId, messageId); + } + catch (Exception e) { + logger.error("Failed to send message to session {}: {}", this.sessionId, e.getMessage()); + HttpStreamableServerTransportProvider.this.sessions.remove(this.sessionId); + exchange.close(); + } + finally { + lock.unlock(); + } + }); + } + + /** + * Converts data from one type to another using the configured ObjectMapper. + * @param data The source data object to convert + * @param typeRef The target type reference + * @return The converted object of type T + * @param The target type + */ + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + /** + * Initiates a graceful shutdown of the transport. + * @return A Mono that completes when the shutdown is complete + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + HttpServletStreamableMcpSessionTransport.this.close(); + }); + } + + /** + * Closes the transport immediately. + */ + @Override + public void close() { + lock.lock(); + try { + if (this.closed) { + logger.debug("Session transport {} already closed", this.sessionId); + return; + } + + this.closed = true; + + // HttpServletStreamableServerTransportProvider.this.sessions.remove(this.sessionId); + exchange.close(); + logger.debug("Successfully completed async context for session {}", sessionId); + } + catch (Exception e) { + logger.warn("Failed to complete async context for session {}: {}", sessionId, e.getMessage()); + } + finally { + lock.unlock(); + } + } + + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of + * {@link HttpStreamableServerTransportProvider}. + */ + public static class Builder { + + private ObjectMapper objectMapper; + + private String mcpEndpoint = "/mcp"; + + private boolean disallowDelete = false; + + private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; + + private Duration keepAliveInterval; + + /** + * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP + * messages. + * @param objectMapper The ObjectMapper instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if objectMapper is null + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param mcpEndpoint The MCP endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if mcpEndpoint is null + */ + public Builder mcpEndpoint(String mcpEndpoint) { + Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); + this.mcpEndpoint = mcpEndpoint; + return this; + } + + /** + * Sets whether to disallow DELETE requests on the endpoint. + * @param disallowDelete true to disallow DELETE requests, false otherwise + * @return this builder instance + */ + public Builder disallowDelete(boolean disallowDelete) { + this.disallowDelete = disallowDelete; + return this; + } + + /** + * Sets the context extractor for extracting transport context from the request. + * @param contextExtractor The context extractor to use. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "Context extractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Sets the keep-alive interval for the transport. If set, a keep-alive scheduler + * will be activated to periodically ping active sessions. + * @param keepAliveInterval The interval for keep-alive pings. If null, no + * keep-alive will be scheduled. + * @return this builder instance + */ + public Builder keepAliveInterval(Duration keepAliveInterval) { + this.keepAliveInterval = keepAliveInterval; + return this; + } + + /** + * Builds a new instance of {@link HttpStreamableServerTransportProvider} + * with the configured settings. + * @return A new HttpServletStreamableServerTransportProvider instance + * @throws IllegalStateException if required parameters are not set + */ + public HttpStreamableServerTransportProvider build() { + Assert.notNull(this.objectMapper, "ObjectMapper must be set"); + Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set"); + + return new HttpStreamableServerTransportProvider(this.objectMapper, this.mcpEndpoint, + this.disallowDelete, this.contextExtractor, this.keepAliveInterval); + } + + } + +} diff --git a/app/src/main/java/io/xpipe/app/mcp/McpSchemaFiles.java b/app/src/main/java/io/xpipe/app/mcp/McpSchemaFiles.java new file mode 100644 index 000000000..c909d7235 --- /dev/null +++ b/app/src/main/java/io/xpipe/app/mcp/McpSchemaFiles.java @@ -0,0 +1,13 @@ +package io.xpipe.app.mcp; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +public class McpSchemaFiles { + + public static String load(String name) throws IOException { + try (var in = McpTools.class.getResourceAsStream("find.json")) { + return new String(in.readAllBytes(), StandardCharsets.UTF_8); + } + } +} diff --git a/app/src/main/java/io/xpipe/app/mcp/McpServer.java b/app/src/main/java/io/xpipe/app/mcp/McpServer.java new file mode 100644 index 000000000..acbdc0c53 --- /dev/null +++ b/app/src/main/java/io/xpipe/app/mcp/McpServer.java @@ -0,0 +1,89 @@ +package io.xpipe.app.mcp; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.sun.net.httpserver.HttpExchange; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.server.transport.StdioServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import io.xpipe.app.core.AppLogs; +import io.xpipe.core.JacksonMapper; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.util.List; + +public class McpServer { + + public static final HttpStreamableServerTransportProvider HANDLER = HttpStreamableServerTransportProvider.builder().mcpEndpoint("/mcp").objectMapper(new ObjectMapper()).build(); + + public static void init() { + var transportProvider = HANDLER; + + McpSyncServer syncServer = io.modelcontextprotocol.server.McpServer.sync(transportProvider) + .serverInfo("my-server", "1.0.0") + .capabilities(McpSchema.ServerCapabilities.builder() + .resources(false, true) // Enable resource support + .tools(true) // Enable tool support + .prompts(true) // Enable prompt support + .logging() // Enable logging support + .completions() // Enable completions support + .build()) + .build(); + + syncServer.loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("custom-logger") + .data("Custom log message") + .build()); + + var schema = """ + { + "type" : "object", + "id" : "urn:jsonschema:Operation", + "properties" : { + "operation" : { + "type" : "string" + }, + "a" : { + "type" : "number" + }, + "b" : { + "type" : "number" + } + } + } + """; + var syncToolSpecification = new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("calculator", "Basic calculator", schema), + (exchange, arguments) -> { + // Tool implementation + return new McpSchema.CallToolResult("test", false); + } + ); + + var syncResourceSpecification = new McpServerFeatures.SyncResourceSpecification( + new McpSchema.Resource("custom://resource", "name", "description", "mime-type", null), + (exchange, request) -> { + // Resource read implementation + return new McpSchema.ReadResourceResult(List.of(new McpSchema.TextResourceContents("custom://resource", "name", "test"))); + } + ); + + // Sync prompt specification + var syncPromptSpecification = new McpServerFeatures.SyncPromptSpecification( + new McpSchema.Prompt("greeting", "description", List.of( + new McpSchema.PromptArgument("name", "description", true) + )), + (exchange, request) -> { + // Prompt implementation + return new McpSchema.GetPromptResult("test", List.of(new McpSchema.PromptMessage(McpSchema.Role.USER, new McpSchema.TextContent("abc")))); + } + ); + + // Register tools, resources, and prompts + syncServer.addTool(syncToolSpecification); + syncServer.addResource(syncResourceSpecification); + syncServer.addPrompt(syncPromptSpecification); + } +} diff --git a/app/src/main/java/io/xpipe/app/mcp/McpTools.java b/app/src/main/java/io/xpipe/app/mcp/McpTools.java new file mode 100644 index 000000000..8230fa231 --- /dev/null +++ b/app/src/main/java/io/xpipe/app/mcp/McpTools.java @@ -0,0 +1,184 @@ +package io.xpipe.app.mcp; + +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.spec.McpSchema; +import io.xpipe.core.FilePath; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; + +public final class McpTools { + + private static final Logger log = LoggerFactory.getLogger(McpTools.class); + + /** + * Create an MCP tool to search for files or directories within the filesystem + * starting from the specified {@code start} path. This method recursively traverses + * the directory beginning at the provided {@code start} path, identifying all entries + * (both files and directories) whose names contain the specified target {@code name}. + * The search is case-sensitive and matches partial names (e.g., "temp" would match + * "template.log"). + * @return A list of absolute path strings for all matching entries found during the + * search, wrapped as a {@link McpServerFeatures.SyncToolSpecification} object. + * @throws IOException If an I/O error occurs during filesystem traversal + */ + public static McpServerFeatures.SyncToolSpecification find() throws IOException { + // Step 1: Load the JSON schema for the tool input arguments. + final String schema = McpSchemaFiles.load("find.json"); + + // Step 2: Create a tool with name, description, and JSON schema. + McpSchema.Tool tool = McpSchema.Tool.builder().name("find").description( + "Start from the specified starting path and recursively search for sub-files or sub-directories.") + .inputSchema(schema).build(); + + return McpServerFeatures.SyncToolSpecification.builder().tool(tool).callHandler((exchange, arguments) -> { + // Step 4: List files and return the result. + final String start = arguments.arguments().getOrDefault("start", "").toString(); + final String name = arguments.arguments().getOrDefault("name", "").toString(); + boolean error = false; + String result; + + if (start.isBlank()) { + result = "Please provide a valid start path to find."; + } + else if (Files.notExists(Path.of(start))) { + result = "Start path does not exist: " + start + ", stopped finding."; + } + else if (name.isBlank()) { + result = "Please provide a valid file/directory name to find."; + } + else { + try { + List paths = FileHelper.fuzzySearch(start, name); + if (paths.isEmpty()) { + result = String.format("No file (or directory) found with name '%s'", name); + } + else { + result = String.format("The following are the search results of name '%s': %s", name, paths); + } + } + catch (IOException e) { + error = true; + result = String.format("Error searching file: %s, %s: %s", name, e, e.getMessage()); + log.error(result, e); + } + } + + McpSchema.Content content = new McpSchema.TextContent(result); + return new McpSchema.CallToolResult(List.of(content), error); + }).build(); + } + + /** + * Create an MCP tool to read and return the content of a file or the list of + * immediate subdirectories and files within a directory from the filesystem. This + * method checks the type of the specified path: If the path points to a file, it + * reads the entire content of the file and returns it as a string. If the path points + * to a directory, it returns a list of strings representing the direct children + * (immediate subdirectories and files) directly under the specified directory + * (non-recursive). + * @return If the path points to a file, it returns a string containing the file's + * content. If the path points to a directory, it returns a list of strings + * representing the direct children (immediate subdirectories and files) directly + * under the specified directory (non-recursive), wrapped as a + * {@link McpServerFeatures.SyncToolSpecification} object. + * @throws IOException If an I/O error occurs during reading. + */ + public static McpServerFeatures.SyncToolSpecification read() throws IOException { + // Step 1: Load the JSON schema for the tool input arguments. + final String schema = McpSchemaFiles.load("read.json"); + + // Step 2: Create a tool with name, description, and JSON schema. + McpSchema.Tool tool = McpSchema.Tool.builder().name("read").description( + "Read the contents of a file or non-recursively read the sub-files and sub-directories under a directory.") + .inputSchema(schema).build(); + + return McpServerFeatures.SyncToolSpecification.builder().tool(tool).callHandler((exchange, arguments) -> { + // Step 4: Read the path and return the result. + var path = arguments.arguments().get("path"); + boolean error = false; + String result; + + if (!(path instanceof String s) || s.isBlank()) { + return new McpSchema.CallToolResult("Please provide a valid path to read.", true); + } + + var file = FilePath.of(s); + else { + Path filepath = Path.of(path); + if (Files.notExists(filepath)) { + result = "The path does not exist: " + filepath + ", stopped reading."; + } + else if (Files.isDirectory(filepath)) { + try { + List paths = FileHelper.listDirectory(path); + result = String.format("The directory '%s' contains: %s", path, paths); + } + catch (IOException e) { + error = true; + result = String.format("Error reading directory: %s, %s: %s", path, e, e.getMessage()); + log.error(result, e); + } + } + else { + try { + result = FileHelper.readAsString(filepath); + } + catch (IOException e) { + error = true; + result = String.format("Error reading file: %s, %s: %s", path, e, e.getMessage()); + log.error(result, e); + } + } + } + + McpSchema.Content content = new McpSchema.TextContent(result); + return new McpSchema.CallToolResult(List.of(content), error); + }); + } + + /** + * Create an MCP tool to delete a file or directory from the filesystem. + * @return The operation result, wrapped as a + * {@link McpServerFeatures.SyncToolSpecification} object. + * @throws IOException If an I/O error occurs during deletion. + */ + public static McpServerFeatures.SyncToolSpecification delete() throws IOException { + // Step 1: Load the JSON schema for the tool input arguments. + final String schema = FileHelper.readResourceAsString("schema/delete.json"); + + // Step 2: Create a tool with name, description, and JSON schema. + McpSchema.Tool tool = new McpSchema.Tool("delete", "Delete a file or directory from the filesystem.", schema); + + // Step 3: Create a tool specification with the tool and the call function. + return new McpServerFeatures.SyncToolSpecification(tool, (exchange, arguments) -> { + // Step 4: Delete the path and return the result. + final String path = arguments.getOrDefault("path", StringHelper.EMPTY).toString(); + boolean error = false; + String result; + + if (path.isBlank()) { + result = "Please provide a valid path to delete."; + } + else { + try { + final boolean deleted = Files.deleteIfExists(Path.of(path)); + result = (deleted ? "Successfully deleted path: " : "Failed to delete path: ") + path; + } + catch (IOException e) { + error = true; + result = String.format("Error deleting path: %s, %s: %s", path, e, e.getMessage()); + log.error(result, e); + } + } + + McpSchema.Content content = new McpSchema.TextContent(result); + return new McpSchema.CallToolResult(List.of(content), error); + }); + } + +} diff --git a/app/src/main/java/io/xpipe/app/mcp/find.json b/app/src/main/java/io/xpipe/app/mcp/find.json new file mode 100644 index 000000000..d4d586113 --- /dev/null +++ b/app/src/main/java/io/xpipe/app/mcp/find.json @@ -0,0 +1,22 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "start": { + "type": "string", + "minLength": 1, + "maxLength": 256, + "description": "The starting path to search, required." + }, + "name": { + "type": "string", + "minLength": 1, + "maxLength": 256, + "description": "The name of the target file or directory to search, supports fuzzy matching, required." + } + }, + "required": [ + "start", + "name" + ] +} diff --git a/app/src/main/java/io/xpipe/app/mcp/read.json b/app/src/main/java/io/xpipe/app/mcp/read.json new file mode 100644 index 000000000..12eb52732 --- /dev/null +++ b/app/src/main/java/io/xpipe/app/mcp/read.json @@ -0,0 +1,15 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "path": { + "type": "string", + "minLength": 1, + "maxLength": 256, + "description": "The path to read, can be a file or directory, required." + } + }, + "required": [ + "path" + ] +} diff --git a/app/src/main/java/module-info.java b/app/src/main/java/module-info.java index ef0e7f52e..60f21cece 100644 --- a/app/src/main/java/module-info.java +++ b/app/src/main/java/module-info.java @@ -95,6 +95,11 @@ open module io.xpipe.app { requires java.net.http; requires org.bouncycastle.provider; requires org.jetbrains.annotations; + requires io.modelcontextprotocol.sdk.mcp; + requires reactor.core; + requires reactor.blockhound; + requires org.reactivestreams; + requires context.propagation; // Required runtime modules requires jdk.charsets; diff --git a/gradle/gradle_scripts/modules.gradle b/gradle/gradle_scripts/modules.gradle index 05c1ea95d..ab444fffe 100644 --- a/gradle/gradle_scripts/modules.gradle +++ b/gradle/gradle_scripts/modules.gradle @@ -64,3 +64,46 @@ extraJavaModuleInfo { exportAllPackages() } } + +extraJavaModuleInfo { + module("io.modelcontextprotocol.sdk:mcp", "io.modelcontextprotocol.sdk.mcp") { + exportAllPackages() + requires("com.fasterxml.jackson.core") + requires("com.fasterxml.jackson.databind") + requires("com.fasterxml.jackson.annotation") + requires("org.slf4j") + requires("reactor.core") + requires("org.reactivestreams") + requires("com.networknt.schema") + } +} + +extraJavaModuleInfo { + module("io.projectreactor:reactor-core", "reactor.core") { + requires("context.propagation") + requires("reactor.blockhound") + requires("org.reactivestreams") + exportAllPackages() + } +} + +extraJavaModuleInfo { + module("io.micrometer:context-propagation", "context.propagation") { + exportAllPackages() + uses("io.micrometer.context.ContextAccessor") + uses("io.micrometer.context.ThreadLocalAccessor") + } +} + +extraJavaModuleInfo { + module("io.projectreactor.tools:blockhound", "reactor.blockhound") { + exportAllPackages() + uses("reactor.blockhound.integration.BlockHoundIntegration") + } +} + +extraJavaModuleInfo { + module("org.reactivestreams:reactive-streams", "org.reactivestreams") { + exportAllPackages() + } +}