Path: blob/main/src/vs/workbench/contrib/mcp/common/mcpServerRequestHandler.ts
3296 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/45import { equals } from '../../../../base/common/arrays.js';6import { assertNever } from '../../../../base/common/assert.js';7import { DeferredPromise, IntervalTimer } from '../../../../base/common/async.js';8import { CancellationToken } from '../../../../base/common/cancellation.js';9import { CancellationError } from '../../../../base/common/errors.js';10import { Emitter } from '../../../../base/common/event.js';11import { Iterable } from '../../../../base/common/iterator.js';12import { Disposable, DisposableStore } from '../../../../base/common/lifecycle.js';13import { autorun } from '../../../../base/common/observable.js';14import { IInstantiationService } from '../../../../platform/instantiation/common/instantiation.js';15import { canLog, ILogger, log, LogLevel } from '../../../../platform/log/common/log.js';16import { IProductService } from '../../../../platform/product/common/productService.js';17import { IMcpMessageTransport } from './mcpRegistryTypes.js';18import { IMcpClientMethods, McpConnectionState, McpError, MpcResponseError } from './mcpTypes.js';19import { MCP } from './modelContextProtocol.js';2021/**22* Maps request IDs to handlers.23*/24interface PendingRequest {25promise: DeferredPromise<MCP.Result>;26}2728export interface McpRoot {29uri: string;30name?: string;31}3233export interface IMcpServerRequestHandlerOptions extends IMcpClientMethods {34/** MCP message transport */35launch: IMcpMessageTransport;36/** Logger instance. */37logger: ILogger;38/** Log level MCP messages is logged at */39requestLogLevel?: LogLevel;40}4142/**43* Request handler for communicating with an MCP server.44*45* Handles sending requests and receiving responses, with automatic46* handling of ping requests and typed client request methods.47*/48export class McpServerRequestHandler extends Disposable {49private _nextRequestId = 1;50private readonly _pendingRequests = new Map<MCP.RequestId, PendingRequest>();5152private _hasAnnouncedRoots = false;53private _roots: MCP.Root[] = [];5455public set roots(roots: MCP.Root[]) {56if (!equals(this._roots, roots)) {57this._roots = roots;58if (this._hasAnnouncedRoots) {59this.sendNotification({ method: 'notifications/roots/list_changed' });60this._hasAnnouncedRoots = false;61}62}63}6465private _serverInit!: MCP.InitializeResult;66public get capabilities(): MCP.ServerCapabilities {67return this._serverInit.capabilities;68}6970public get serverInfo(): MCP.Implementation {71return this._serverInit.serverInfo;72}7374public get serverInstructions(): string | undefined {75return this._serverInit.instructions;76}7778// Event emitters for server notifications79private readonly _onDidReceiveCancelledNotification = this._register(new Emitter<MCP.CancelledNotification>());80readonly onDidReceiveCancelledNotification = this._onDidReceiveCancelledNotification.event;8182private readonly _onDidReceiveProgressNotification = this._register(new Emitter<MCP.ProgressNotification>());83readonly onDidReceiveProgressNotification = this._onDidReceiveProgressNotification.event;8485private readonly _onDidChangeResourceList = this._register(new Emitter<void>());86readonly onDidChangeResourceList = this._onDidChangeResourceList.event;8788private readonly _onDidUpdateResource = this._register(new Emitter<MCP.ResourceUpdatedNotification>());89readonly onDidUpdateResource = this._onDidUpdateResource.event;9091private readonly _onDidChangeToolList = this._register(new Emitter<void>());92readonly onDidChangeToolList = this._onDidChangeToolList.event;9394private readonly _onDidChangePromptList = this._register(new Emitter<void>());95readonly onDidChangePromptList = this._onDidChangePromptList.event;9697/**98* Connects to the MCP server and does the initialization handshake.99* @throws MpcResponseError if the server fails to initialize.100*/101public static async create(instaService: IInstantiationService, opts: IMcpServerRequestHandlerOptions, token?: CancellationToken) {102const mcp = new McpServerRequestHandler(opts);103const store = new DisposableStore();104try {105const timer = store.add(new IntervalTimer());106timer.cancelAndSet(() => {107opts.logger.info('Waiting for server to respond to `initialize` request...');108}, 5000);109110await instaService.invokeFunction(async accessor => {111const productService = accessor.get(IProductService);112const initialized = await mcp.sendRequest<MCP.InitializeRequest, MCP.InitializeResult>({113method: 'initialize',114params: {115protocolVersion: MCP.LATEST_PROTOCOL_VERSION,116capabilities: {117roots: { listChanged: true },118sampling: opts.createMessageRequestHandler ? {} : undefined,119elicitation: opts.elicitationRequestHandler ? {} : undefined,120},121clientInfo: {122name: productService.nameLong,123version: productService.version,124}125}126}, token);127128mcp._serverInit = initialized;129mcp._sendLogLevelToServer(opts.logger.getLevel());130131mcp.sendNotification<MCP.InitializedNotification>({132method: 'notifications/initialized'133});134});135136return mcp;137} catch (e) {138mcp.dispose();139throw e;140} finally {141store.dispose();142}143}144145public readonly logger: ILogger;146private readonly _launch: IMcpMessageTransport;147private readonly _requestLogLevel: LogLevel;148private readonly _createMessageRequestHandler: IMcpServerRequestHandlerOptions['createMessageRequestHandler'];149private readonly _elicitationRequestHandler: IMcpServerRequestHandlerOptions['elicitationRequestHandler'];150151protected constructor({152launch,153logger,154createMessageRequestHandler,155elicitationRequestHandler,156requestLogLevel = LogLevel.Debug,157}: IMcpServerRequestHandlerOptions) {158super();159this._launch = launch;160this.logger = logger;161this._requestLogLevel = requestLogLevel;162this._createMessageRequestHandler = createMessageRequestHandler;163this._elicitationRequestHandler = elicitationRequestHandler;164165this._register(launch.onDidReceiveMessage(message => this.handleMessage(message)));166this._register(autorun(reader => {167const state = launch.state.read(reader).state;168// the handler will get disposed when the launch stops, but if we're still169// create()'ing we need to make sure to cancel the initialize request.170if (state === McpConnectionState.Kind.Error || state === McpConnectionState.Kind.Stopped) {171this.cancelAllRequests();172}173}));174175// Listen for log level changes and forward them to the MCP server176this._register(logger.onDidChangeLogLevel((logLevel) => {177this._sendLogLevelToServer(logLevel);178}));179}180181/**182* Send a client request to the server and return the response.183*184* @param request The request to send185* @param token Cancellation token186* @param timeoutMs Optional timeout in milliseconds187* @returns A promise that resolves with the response188*/189private async sendRequest<T extends MCP.ClientRequest, R extends MCP.ServerResult>(190request: Pick<T, 'params' | 'method'>,191token: CancellationToken = CancellationToken.None192): Promise<R> {193if (this._store.isDisposed) {194return Promise.reject(new CancellationError());195}196197const id = this._nextRequestId++;198199// Create the full JSON-RPC request200const jsonRpcRequest: MCP.JSONRPCRequest = {201jsonrpc: MCP.JSONRPC_VERSION,202id,203...request204};205206const promise = new DeferredPromise<MCP.ServerResult>();207// Store the pending request208this._pendingRequests.set(id, { promise });209// Set up cancellation210const cancelListener = token.onCancellationRequested(() => {211if (!promise.isSettled) {212this._pendingRequests.delete(id);213this.sendNotification({ method: 'notifications/cancelled', params: { requestId: id } });214promise.cancel();215}216cancelListener.dispose();217});218219// Send the request220this.send(jsonRpcRequest);221const ret = promise.p.finally(() => {222cancelListener.dispose();223this._pendingRequests.delete(id);224});225226return ret as Promise<R>;227}228229private send(mcp: MCP.JSONRPCMessage) {230if (canLog(this.logger.getLevel(), this._requestLogLevel)) { // avoid building the string if we don't need to231log(this.logger, this._requestLogLevel, `[editor -> server] ${JSON.stringify(mcp)}`);232}233234this._launch.send(mcp);235}236237/**238* Handles paginated requests by making multiple requests until all items are retrieved.239*240* @param method The method name to call241* @param getItems Function to extract the array of items from a result242* @param initialParams Initial parameters243* @param token Cancellation token244* @returns Promise with all items combined245*/246private async *sendRequestPaginated<T extends MCP.PaginatedRequest & MCP.ClientRequest, R extends MCP.PaginatedResult, I>(method: T['method'], getItems: (result: R) => I[], initialParams?: Omit<T['params'], 'jsonrpc' | 'id'>, token: CancellationToken = CancellationToken.None): AsyncIterable<I[]> {247let nextCursor: MCP.Cursor | undefined = undefined;248249do {250const params: T['params'] = {251...initialParams,252cursor: nextCursor253};254255const result: R = await this.sendRequest<T, R>({ method, params }, token);256yield getItems(result);257nextCursor = result.nextCursor;258} while (nextCursor !== undefined && !token.isCancellationRequested);259}260261private sendNotification<N extends MCP.ClientNotification>(notification: N): void {262this.send({ ...notification, jsonrpc: MCP.JSONRPC_VERSION });263}264265/**266* Handle incoming messages from the server267*/268private handleMessage(message: MCP.JSONRPCMessage): void {269if (canLog(this.logger.getLevel(), this._requestLogLevel)) { // avoid building the string if we don't need to270log(this.logger, this._requestLogLevel, `[server -> editor] ${JSON.stringify(message)}`);271}272273// Handle responses to our requests274if ('id' in message) {275if ('result' in message) {276this.handleResult(message);277} else if ('error' in message) {278this.handleError(message);279}280}281282// Handle requests from the server283if ('method' in message) {284if ('id' in message) {285this.handleServerRequest(message as MCP.JSONRPCRequest & MCP.ServerRequest);286} else {287this.handleServerNotification(message as MCP.JSONRPCNotification & MCP.ServerNotification);288}289}290}291292/**293* Handle successful responses294*/295private handleResult(response: MCP.JSONRPCResponse): void {296const request = this._pendingRequests.get(response.id);297if (request) {298this._pendingRequests.delete(response.id);299request.promise.complete(response.result);300}301}302303/**304* Handle error responses305*/306private handleError(response: MCP.JSONRPCError): void {307const request = this._pendingRequests.get(response.id);308if (request) {309this._pendingRequests.delete(response.id);310request.promise.error(new MpcResponseError(response.error.message, response.error.code, response.error.data));311}312}313314/**315* Handle incoming server requests316*/317private async handleServerRequest(request: MCP.JSONRPCRequest & MCP.ServerRequest): Promise<void> {318try {319let response: MCP.Result | undefined;320if (request.method === 'ping') {321response = this.handlePing(request);322} else if (request.method === 'roots/list') {323response = this.handleRootsList(request);324} else if (request.method === 'sampling/createMessage' && this._createMessageRequestHandler) {325response = await this._createMessageRequestHandler(request.params as MCP.CreateMessageRequest['params']);326} else if (request.method === 'elicitation/create' && this._elicitationRequestHandler) {327response = await this._elicitationRequestHandler(request.params as MCP.ElicitRequest['params']);328} else {329throw McpError.methodNotFound(request.method);330}331this.respondToRequest(request, response);332} catch (e) {333if (!(e instanceof McpError)) {334this.logger.error(`Error handling request ${request.method}:`, e);335e = McpError.unknown(e);336}337338const errorResponse: MCP.JSONRPCError = {339jsonrpc: MCP.JSONRPC_VERSION,340id: request.id,341error: {342code: e.code,343message: e.message,344data: e.data,345}346};347348this.send(errorResponse);349}350}351/**352* Handle incoming server notifications353*/354private handleServerNotification(request: MCP.JSONRPCNotification & MCP.ServerNotification): void {355switch (request.method) {356case 'notifications/message':357return this.handleLoggingNotification(request);358case 'notifications/cancelled':359this._onDidReceiveCancelledNotification.fire(request);360return this.handleCancelledNotification(request);361case 'notifications/progress':362this._onDidReceiveProgressNotification.fire(request);363return;364case 'notifications/resources/list_changed':365this._onDidChangeResourceList.fire();366return;367case 'notifications/resources/updated':368this._onDidUpdateResource.fire(request);369return;370case 'notifications/tools/list_changed':371this._onDidChangeToolList.fire();372return;373case 'notifications/prompts/list_changed':374this._onDidChangePromptList.fire();375return;376}377}378379private handleCancelledNotification(request: MCP.CancelledNotification): void {380const pendingRequest = this._pendingRequests.get(request.params.requestId);381if (pendingRequest) {382this._pendingRequests.delete(request.params.requestId);383pendingRequest.promise.cancel();384}385}386387private handleLoggingNotification(request: MCP.LoggingMessageNotification): void {388let contents = typeof request.params.data === 'string' ? request.params.data : JSON.stringify(request.params.data);389if (request.params.logger) {390contents = `${request.params.logger}: ${contents}`;391}392393switch (request.params?.level) {394case 'debug':395this.logger.debug(contents);396break;397case 'info':398case 'notice':399this.logger.info(contents);400break;401case 'warning':402this.logger.warn(contents);403break;404case 'error':405case 'critical':406case 'alert':407case 'emergency':408this.logger.error(contents);409break;410default:411this.logger.info(contents);412break;413}414}415416/**417* Send a generic response to a request418*/419private respondToRequest(request: MCP.JSONRPCRequest, result: MCP.Result): void {420const response: MCP.JSONRPCResponse = {421jsonrpc: MCP.JSONRPC_VERSION,422id: request.id,423result424};425this.send(response);426}427428/**429* Send a response to a ping request430*/431private handlePing(_request: MCP.PingRequest): {} {432return {};433}434435/**436* Send a response to a roots/list request437*/438private handleRootsList(_request: MCP.ListRootsRequest): MCP.ListRootsResult {439this._hasAnnouncedRoots = true;440return { roots: this._roots };441}442443private cancelAllRequests() {444this._pendingRequests.forEach(pending => pending.promise.cancel());445this._pendingRequests.clear();446}447448public override dispose(): void {449this.cancelAllRequests();450super.dispose();451}452453/**454* Forwards log level changes to the MCP server if it supports logging455*/456private async _sendLogLevelToServer(logLevel: LogLevel): Promise<void> {457try {458// Only send if the server supports logging capabilities459if (!this.capabilities.logging) {460return;461}462463await this.setLevel({ level: mapLogLevelToMcp(logLevel) });464} catch (error) {465this.logger.error(`Failed to set MCP server log level: ${error}`);466}467}468469/**470* Send an initialize request471*/472initialize(params: MCP.InitializeRequest['params'], token?: CancellationToken): Promise<MCP.InitializeResult> {473return this.sendRequest<MCP.InitializeRequest, MCP.InitializeResult>({ method: 'initialize', params }, token);474}475476/**477* List available resources478*/479listResources(params?: MCP.ListResourcesRequest['params'], token?: CancellationToken): Promise<MCP.Resource[]> {480return Iterable.asyncToArrayFlat(this.listResourcesIterable(params, token));481}482483/**484* List available resources (iterable)485*/486listResourcesIterable(params?: MCP.ListResourcesRequest['params'], token?: CancellationToken): AsyncIterable<MCP.Resource[]> {487return this.sendRequestPaginated<MCP.ListResourcesRequest, MCP.ListResourcesResult, MCP.Resource>('resources/list', result => result.resources, params, token);488}489490/**491* Read a specific resource492*/493readResource(params: MCP.ReadResourceRequest['params'], token?: CancellationToken): Promise<MCP.ReadResourceResult> {494return this.sendRequest<MCP.ReadResourceRequest, MCP.ReadResourceResult>({ method: 'resources/read', params }, token);495}496497/**498* List available resource templates499*/500listResourceTemplates(params?: MCP.ListResourceTemplatesRequest['params'], token?: CancellationToken): Promise<MCP.ResourceTemplate[]> {501return Iterable.asyncToArrayFlat(this.sendRequestPaginated<MCP.ListResourceTemplatesRequest, MCP.ListResourceTemplatesResult, MCP.ResourceTemplate>('resources/templates/list', result => result.resourceTemplates, params, token));502}503504/**505* Subscribe to resource updates506*/507subscribe(params: MCP.SubscribeRequest['params'], token?: CancellationToken): Promise<MCP.EmptyResult> {508return this.sendRequest<MCP.SubscribeRequest, MCP.EmptyResult>({ method: 'resources/subscribe', params }, token);509}510511/**512* Unsubscribe from resource updates513*/514unsubscribe(params: MCP.UnsubscribeRequest['params'], token?: CancellationToken): Promise<MCP.EmptyResult> {515return this.sendRequest<MCP.UnsubscribeRequest, MCP.EmptyResult>({ method: 'resources/unsubscribe', params }, token);516}517518/**519* List available prompts520*/521listPrompts(params?: MCP.ListPromptsRequest['params'], token?: CancellationToken): Promise<MCP.Prompt[]> {522return Iterable.asyncToArrayFlat(this.sendRequestPaginated<MCP.ListPromptsRequest, MCP.ListPromptsResult, MCP.Prompt>('prompts/list', result => result.prompts, params, token));523}524525/**526* Get a specific prompt527*/528getPrompt(params: MCP.GetPromptRequest['params'], token?: CancellationToken): Promise<MCP.GetPromptResult> {529return this.sendRequest<MCP.GetPromptRequest, MCP.GetPromptResult>({ method: 'prompts/get', params }, token);530}531532/**533* List available tools534*/535listTools(params?: MCP.ListToolsRequest['params'], token?: CancellationToken): Promise<MCP.Tool[]> {536return Iterable.asyncToArrayFlat(this.sendRequestPaginated<MCP.ListToolsRequest, MCP.ListToolsResult, MCP.Tool>('tools/list', result => result.tools, params, token));537}538539/**540* Call a specific tool541*/542callTool(params: MCP.CallToolRequest['params'] & MCP.Request['params'], token?: CancellationToken): Promise<MCP.CallToolResult> {543return this.sendRequest<MCP.CallToolRequest, MCP.CallToolResult>({ method: 'tools/call', params }, token);544}545546/**547* Set the logging level548*/549setLevel(params: MCP.SetLevelRequest['params'], token?: CancellationToken): Promise<MCP.EmptyResult> {550return this.sendRequest<MCP.SetLevelRequest, MCP.EmptyResult>({ method: 'logging/setLevel', params }, token);551}552553/**554* Find completions for an argument555*/556complete(params: MCP.CompleteRequest['params'], token?: CancellationToken): Promise<MCP.CompleteResult> {557return this.sendRequest<MCP.CompleteRequest, MCP.CompleteResult>({ method: 'completion/complete', params }, token);558}559}560561562/**563* Maps VSCode LogLevel to MCP LoggingLevel564*/565function mapLogLevelToMcp(logLevel: LogLevel): MCP.LoggingLevel {566switch (logLevel) {567case LogLevel.Trace:568return 'debug'; // MCP doesn't have trace, use debug569case LogLevel.Debug:570return 'debug';571case LogLevel.Info:572return 'info';573case LogLevel.Warning:574return 'warning';575case LogLevel.Error:576return 'error';577case LogLevel.Off:578return 'emergency'; // MCP doesn't have off, use emergency579default:580return assertNever(logLevel); // Off and other levels are not supported581}582}583584585