Path: blob/main/src/vs/workbench/contrib/mcp/test/common/mcpServerRequestHandler.test.ts
5260 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/45import * as assert from 'assert';6import * as sinon from 'sinon';7import { upcast } from '../../../../../base/common/types.js';8import { ensureNoDisposablesAreLeakedInTestSuite } from '../../../../../base/test/common/utils.js';9import { ServiceCollection } from '../../../../../platform/instantiation/common/serviceCollection.js';10import { TestInstantiationService } from '../../../../../platform/instantiation/test/common/instantiationServiceMock.js';11import { ILoggerService } from '../../../../../platform/log/common/log.js';12import { IProductService } from '../../../../../platform/product/common/productService.js';13import { IStorageService } from '../../../../../platform/storage/common/storage.js';14import { TestLoggerService, TestProductService, TestStorageService } from '../../../../test/common/workbenchTestServices.js';15import { IMcpHostDelegate } from '../../common/mcpRegistryTypes.js';16import { McpServerRequestHandler, McpTask } from '../../common/mcpServerRequestHandler.js';17import { McpConnectionState, McpServerDefinition, McpServerLaunch } from '../../common/mcpTypes.js';18import { MCP } from '../../common/modelContextProtocol.js';19import { TestMcpMessageTransport } from './mcpRegistryTypes.js';20import { IOutputService } from '../../../../services/output/common/output.js';21import { Disposable } from '../../../../../base/common/lifecycle.js';22import { CancellationTokenSource } from '../../../../../base/common/cancellation.js';23import { McpTaskManager } from '../../common/mcpTaskManager.js';24import { upcastPartial } from '../../../../../base/test/common/mock.js';2526class TestMcpHostDelegate extends Disposable implements IMcpHostDelegate {27private readonly _transport: TestMcpMessageTransport;2829priority = 0;3031constructor() {32super();33this._transport = this._register(new TestMcpMessageTransport());34}353637substituteVariables(serverDefinition: McpServerDefinition, launch: McpServerLaunch): Promise<McpServerLaunch> {38return Promise.resolve(launch);39}4041canStart(): boolean {42return true;43}4445start(): TestMcpMessageTransport {46return this._transport;47}4849getTransport(): TestMcpMessageTransport {50return this._transport;51}5253waitForInitialProviderPromises(): Promise<void> {54return Promise.resolve();55}56}5758suite('Workbench - MCP - ServerRequestHandler', () => {59const store = ensureNoDisposablesAreLeakedInTestSuite();6061let instantiationService: TestInstantiationService;62let delegate: TestMcpHostDelegate;63let transport: TestMcpMessageTransport;64let handler: McpServerRequestHandler;65let cts: CancellationTokenSource;6667setup(async () => {68delegate = store.add(new TestMcpHostDelegate());69transport = delegate.getTransport();70cts = store.add(new CancellationTokenSource());7172// Setup test services73const services = new ServiceCollection(74[ILoggerService, store.add(new TestLoggerService())],75[IOutputService, upcast({ showChannel: () => { } })],76[IStorageService, store.add(new TestStorageService())],77[IProductService, TestProductService],78);7980instantiationService = store.add(new TestInstantiationService(services));8182transport.setConnectionState({ state: McpConnectionState.Kind.Running });8384// Manually create the handler since we need the transport already set up85const logger = store.add((instantiationService.get(ILoggerService) as TestLoggerService)86.createLogger('mcpServerTest', { hidden: true, name: 'MCP Test' }));8788// Start the handler creation89const handlerPromise = McpServerRequestHandler.create(instantiationService, { logger, launch: transport, taskManager: store.add(new McpTaskManager()) }, cts.token);9091handler = await handlerPromise;92store.add(handler);93});9495test('should send and receive JSON-RPC requests', async () => {96// Setup request97const requestPromise = handler.listResources();9899// Get the sent message and verify it100const sentMessages = transport.getSentMessages();101assert.strictEqual(sentMessages.length, 3); // initialize + listResources102103// Verify listResources request format104const listResourcesRequest = sentMessages[2] as MCP.JSONRPCRequest;105assert.strictEqual(listResourcesRequest.method, 'resources/list');106assert.strictEqual(listResourcesRequest.jsonrpc, MCP.JSONRPC_VERSION);107assert.ok(typeof listResourcesRequest.id === 'number');108109// Simulate server response with mock resources that match the expected Resource interface110transport.simulateReceiveMessage({111jsonrpc: MCP.JSONRPC_VERSION,112id: listResourcesRequest.id,113result: {114resources: [115{ uri: 'resource1', type: 'text/plain', name: 'Test Resource 1' },116{ uri: 'resource2', type: 'text/plain', name: 'Test Resource 2' }117]118}119});120121// Verify the result122const resources = await requestPromise;123assert.strictEqual(resources.length, 2);124assert.strictEqual(resources[0].uri, 'resource1');125assert.strictEqual(resources[1].name, 'Test Resource 2');126});127128test('should handle paginated requests', async () => {129// Setup request130const requestPromise = handler.listResources();131132// Get the first request and respond with pagination133const sentMessages = transport.getSentMessages();134const listResourcesRequest = sentMessages[2] as MCP.JSONRPCRequest;135136// Send first page with nextCursor137transport.simulateReceiveMessage({138jsonrpc: MCP.JSONRPC_VERSION,139id: listResourcesRequest.id,140result: {141resources: [142{ uri: 'resource1', type: 'text/plain', name: 'Test Resource 1' }143],144nextCursor: 'page2'145}146});147148// Clear the sent messages to only capture the next page request149transport.clearSentMessages();150151// Wait a bit to allow the handler to process and send the next request152await new Promise(resolve => setTimeout(resolve, 0));153154// Get the second request and verify cursor is included155const sentMessages2 = transport.getSentMessages();156assert.strictEqual(sentMessages2.length, 1);157158const listResourcesRequest2 = sentMessages2[0] as MCP.JSONRPCRequest;159assert.strictEqual(listResourcesRequest2.method, 'resources/list');160assert.deepStrictEqual(listResourcesRequest2.params, { cursor: 'page2' });161162// Send final page with no nextCursor163transport.simulateReceiveMessage({164jsonrpc: MCP.JSONRPC_VERSION,165id: listResourcesRequest2.id,166result: {167resources: [168{ uri: 'resource2', type: 'text/plain', name: 'Test Resource 2' }169]170}171});172173// Verify the combined result174const resources = await requestPromise;175assert.strictEqual(resources.length, 2);176assert.strictEqual(resources[0].uri, 'resource1');177assert.strictEqual(resources[1].uri, 'resource2');178});179180test('should handle error responses', async () => {181// Setup request182const requestPromise = handler.readResource({ uri: 'non-existent' });183184// Get the sent message185const sentMessages = transport.getSentMessages();186const readResourceRequest = sentMessages[2] as MCP.JSONRPCRequest; // [0] is initialize187188// Simulate error response189transport.simulateReceiveMessage({190jsonrpc: MCP.JSONRPC_VERSION,191id: readResourceRequest.id,192error: {193code: MCP.METHOD_NOT_FOUND,194message: 'Resource not found'195}196});197198// Verify the error is thrown correctly199try {200await requestPromise;201assert.fail('Expected error was not thrown');202} catch (e: unknown) {203assert.strictEqual((e as Error).message, 'MPC -32601: Resource not found');204assert.strictEqual((e as { code: number }).code, MCP.METHOD_NOT_FOUND);205}206});207208test('should handle server requests', async () => {209// Simulate ping request from server210const pingRequest: MCP.JSONRPCRequest & MCP.PingRequest = {211jsonrpc: MCP.JSONRPC_VERSION,212id: 100,213method: 'ping'214};215216transport.simulateReceiveMessage(pingRequest);217218// The handler should have sent a response219const sentMessages = transport.getSentMessages();220const pingResponse = sentMessages.find(m =>221'id' in m && m.id === pingRequest.id && 'result' in m222) as MCP.JSONRPCResultResponse;223224assert.ok(pingResponse, 'No ping response was sent');225assert.deepStrictEqual(pingResponse.result, {});226});227228test('should handle roots list requests', async () => {229// Set roots230handler.roots = [231{ uri: 'file:///test/root1', name: 'Root 1' },232{ uri: 'file:///test/root2', name: 'Root 2' }233];234235// Simulate roots/list request from server236const rootsRequest: MCP.JSONRPCRequest & MCP.ListRootsRequest = {237jsonrpc: MCP.JSONRPC_VERSION,238id: 101,239method: 'roots/list'240};241242transport.simulateReceiveMessage(rootsRequest);243244// The handler should have sent a response245const sentMessages = transport.getSentMessages();246const rootsResponse = sentMessages.find(m =>247'id' in m && m.id === rootsRequest.id && 'result' in m248) as MCP.JSONRPCResultResponse;249250assert.ok(rootsResponse, 'No roots/list response was sent');251assert.strictEqual((rootsResponse.result as MCP.ListRootsResult).roots.length, 2);252assert.strictEqual((rootsResponse.result as MCP.ListRootsResult).roots[0].uri, 'file:///test/root1');253});254255test('should handle server notifications', async () => {256let progressNotificationReceived = false;257store.add(handler.onDidReceiveProgressNotification(notification => {258progressNotificationReceived = true;259assert.strictEqual(notification.method, 'notifications/progress');260assert.strictEqual(notification.params.progressToken, 'token1');261assert.strictEqual(notification.params.progress, 50);262}));263264// Simulate progress notification with correct format265const progressNotification: MCP.JSONRPCNotification & MCP.ProgressNotification = {266jsonrpc: MCP.JSONRPC_VERSION,267method: 'notifications/progress',268params: {269progressToken: 'token1',270progress: 50,271total: 100272}273};274275transport.simulateReceiveMessage(progressNotification);276assert.strictEqual(progressNotificationReceived, true);277});278279test('should handle cancellation', async () => {280// Setup a new cancellation token source for this specific test281const testCts = store.add(new CancellationTokenSource());282const requestPromise = handler.listResources(undefined, testCts.token);283284// Get the request ID285const sentMessages = transport.getSentMessages();286const listResourcesRequest = sentMessages[2] as MCP.JSONRPCRequest;287const requestId = listResourcesRequest.id;288289// Cancel the request290testCts.cancel();291292// Check that a cancellation notification was sent293const cancelNotification = transport.getSentMessages().find(m =>294!('id' in m) &&295'method' in m &&296m.method === 'notifications/cancelled' &&297'params' in m &&298m.params && m.params.requestId === requestId299);300301assert.ok(cancelNotification, 'No cancellation notification was sent');302303// Verify the promise was cancelled304try {305await requestPromise;306assert.fail('Promise should have been cancelled');307} catch (e) {308assert.strictEqual(e.name, 'Canceled');309}310});311312test('should handle cancelled notification from server', async () => {313// Setup request314const requestPromise = handler.listResources();315316// Get the request ID317const sentMessages = transport.getSentMessages();318const listResourcesRequest = sentMessages[2] as MCP.JSONRPCRequest;319const requestId = listResourcesRequest.id;320321// Simulate cancelled notification from server322const cancelledNotification: MCP.JSONRPCNotification & MCP.CancelledNotification = {323jsonrpc: MCP.JSONRPC_VERSION,324method: 'notifications/cancelled',325params: {326requestId327}328};329330transport.simulateReceiveMessage(cancelledNotification);331332// Verify the promise was cancelled333try {334await requestPromise;335assert.fail('Promise should have been cancelled');336} catch (e) {337assert.strictEqual(e.name, 'Canceled');338}339});340341test('should dispose properly and cancel pending requests', async () => {342// Setup multiple requests343const request1 = handler.listResources();344const request2 = handler.listTools();345346// Dispose the handler347handler.dispose();348349// Verify all promises were cancelled350try {351await request1;352assert.fail('Promise 1 should have been cancelled');353} catch (e) {354assert.strictEqual(e.name, 'Canceled');355}356357try {358await request2;359assert.fail('Promise 2 should have been cancelled');360} catch (e) {361assert.strictEqual(e.name, 'Canceled');362}363});364365test('should handle connection error by cancelling requests', async () => {366// Setup request367const requestPromise = handler.listResources();368369// Simulate connection error370transport.setConnectionState({371state: McpConnectionState.Kind.Error,372message: 'Connection lost'373});374375// Verify the promise was cancelled376try {377await requestPromise;378assert.fail('Promise should have been cancelled');379} catch (e) {380assert.strictEqual(e.name, 'Canceled');381}382});383});384385suite.skip('Workbench - MCP - McpTask', () => { // TODO@connor4312 https://github.com/microsoft/vscode/issues/280126386const store = ensureNoDisposablesAreLeakedInTestSuite();387let clock: sinon.SinonFakeTimers;388389setup(() => {390clock = sinon.useFakeTimers();391});392393teardown(() => {394clock.restore();395});396397function createTask(overrides: Partial<MCP.Task> = {}): MCP.Task {398return {399taskId: 'task1',400status: 'working',401createdAt: new Date().toISOString(),402lastUpdatedAt: new Date().toISOString(),403ttl: null,404...overrides405};406}407408test('should resolve when task completes', async () => {409const mockHandler = upcastPartial<McpServerRequestHandler>({410getTask: sinon.stub().resolves(createTask({ status: 'completed' })),411getTaskResult: sinon.stub().resolves({ content: [{ type: 'text', text: 'result' }] })412});413414const task = store.add(new McpTask(createTask()));415task.setHandler(mockHandler);416417// Advance time to trigger polling418await clock.tickAsync(2000);419420// Update to completed state421task.onDidUpdateState(createTask({ status: 'completed' }));422423const result = await task.result;424assert.deepStrictEqual(result, { content: [{ type: 'text', text: 'result' }] });425assert.ok((mockHandler.getTaskResult as sinon.SinonStub).calledWith({ taskId: 'task1' }));426});427428test('should poll for task updates', async () => {429const getTaskStub = sinon.stub();430getTaskStub.onCall(0).resolves(createTask({ status: 'working' }));431getTaskStub.onCall(1).resolves(createTask({ status: 'working' }));432getTaskStub.onCall(2).resolves(createTask({ status: 'completed' }));433434const mockHandler = upcastPartial<McpServerRequestHandler>({435getTask: getTaskStub,436getTaskResult: sinon.stub().resolves({ content: [{ type: 'text', text: 'result' }] })437});438439const task = store.add(new McpTask(createTask({ pollInterval: 1000 })));440task.setHandler(mockHandler);441442// First poll443await clock.tickAsync(1000);444assert.strictEqual(getTaskStub.callCount, 1);445446// Second poll447await clock.tickAsync(1000);448assert.strictEqual(getTaskStub.callCount, 2);449450// Third poll - completes451await clock.tickAsync(1000);452assert.strictEqual(getTaskStub.callCount, 3);453454const result = await task.result;455assert.deepStrictEqual(result, { content: [{ type: 'text', text: 'result' }] });456});457458test('should use default poll interval if not specified', async () => {459const getTaskStub = sinon.stub();460getTaskStub.resolves(createTask({ status: 'working' }));461462const mockHandler = upcastPartial<McpServerRequestHandler>({463getTask: getTaskStub,464});465466const task = store.add(new McpTask(createTask()));467task.setHandler(mockHandler);468469// Default poll interval is 2000ms470await clock.tickAsync(2000);471assert.strictEqual(getTaskStub.callCount, 1);472473await clock.tickAsync(2000);474assert.strictEqual(getTaskStub.callCount, 2);475476task.dispose();477});478479test('should reject when task fails', async () => {480const mockHandler = upcastPartial<McpServerRequestHandler>({481getTask: sinon.stub().resolves(createTask({482status: 'failed',483statusMessage: 'Something went wrong'484}))485});486487const task = store.add(new McpTask(createTask()));488task.setHandler(mockHandler);489490// Update to failed state491task.onDidUpdateState(createTask({492status: 'failed',493statusMessage: 'Something went wrong'494}));495496await assert.rejects(497task.result,498(error: Error) => {499assert.ok(error.message.includes('Task task1 failed'));500assert.ok(error.message.includes('Something went wrong'));501return true;502}503);504});505506test('should cancel when task is cancelled', async () => {507const task = store.add(new McpTask(createTask()));508509// Update to cancelled state510task.onDidUpdateState(createTask({ status: 'cancelled' }));511512await assert.rejects(513task.result,514(error: Error) => {515assert.strictEqual(error.name, 'Canceled');516return true;517}518);519});520521test('should cancel when cancellation token is triggered', async () => {522const cts = store.add(new CancellationTokenSource());523const task = store.add(new McpTask(createTask(), cts.token));524525// Cancel the token526cts.cancel();527528await assert.rejects(529task.result,530(error: Error) => {531assert.strictEqual(error.name, 'Canceled');532return true;533}534);535});536537test('should handle TTL expiration', async () => {538const now = Date.now();539clock.setSystemTime(now);540541const task = store.add(new McpTask(createTask({ ttl: 5000 })));542543// Advance time past TTL544await clock.tickAsync(6000);545546await assert.rejects(547task.result,548(error: Error) => {549assert.strictEqual(error.name, 'Canceled');550return true;551}552);553});554555test('should stop polling when in terminal state', async () => {556const getTaskStub = sinon.stub();557getTaskStub.resolves(createTask({ status: 'completed' }));558559const mockHandler = upcastPartial<McpServerRequestHandler>({560getTask: getTaskStub,561getTaskResult: sinon.stub().resolves({ content: [{ type: 'text', text: 'result' }] })562});563564const task = store.add(new McpTask(createTask({ pollInterval: 1000 })));565task.setHandler(mockHandler);566567// Update to completed state immediately568task.onDidUpdateState(createTask({ status: 'completed' }));569570await task.result;571572// Advance time - should not poll anymore573const initialCallCount = getTaskStub.callCount;574await clock.tickAsync(5000);575assert.strictEqual(getTaskStub.callCount, initialCallCount);576});577578test('should handle handler reconnection', async () => {579const getTaskStub1 = sinon.stub();580getTaskStub1.resolves(createTask({ status: 'working' }));581582const mockHandler1 = upcastPartial<McpServerRequestHandler>({583getTask: getTaskStub1,584});585586const task = store.add(new McpTask(createTask({ pollInterval: 1000 })));587task.setHandler(mockHandler1);588589// First poll with handler1590await clock.tickAsync(1000);591assert.strictEqual(getTaskStub1.callCount, 1);592593// Switch to a new handler594const getTaskStub2 = sinon.stub();595getTaskStub2.resolves(createTask({ status: 'completed' }));596597const mockHandler2 = upcastPartial<McpServerRequestHandler>({598getTask: getTaskStub2,599getTaskResult: sinon.stub().resolves({ content: [{ type: 'text', text: 'result' }] })600});601602task.setHandler(mockHandler2);603604// Second poll with handler2605await clock.tickAsync(1000);606assert.strictEqual(getTaskStub1.callCount, 1); // No more calls to old handler607assert.strictEqual(getTaskStub2.callCount, 1); // New handler is called608609const result = await task.result;610assert.deepStrictEqual(result, { content: [{ type: 'text', text: 'result' }] });611});612613test('should not poll when handler is undefined', async () => {614const task = store.add(new McpTask(createTask({ pollInterval: 1000 })));615616// Advance time - should not crash617await clock.tickAsync(5000);618619// Now set a handler and it should start polling620const getTaskStub = sinon.stub();621getTaskStub.resolves(createTask({ status: 'completed' }));622623const mockHandler = upcastPartial<McpServerRequestHandler>({624getTask: getTaskStub,625getTaskResult: sinon.stub().resolves({ content: [{ type: 'text', text: 'result' }] })626});627628task.setHandler(mockHandler);629await clock.tickAsync(1000);630assert.strictEqual(getTaskStub.callCount, 1);631632task.dispose();633});634635test('should handle input_required state', async () => {636const getTaskStub = sinon.stub();637// getTask call returns completed (triggered by input_required handling)638getTaskStub.resolves(createTask({ status: 'completed' }));639640const mockHandler = upcastPartial<McpServerRequestHandler>({641getTask: getTaskStub,642getTaskResult: sinon.stub().resolves({ content: [{ type: 'text', text: 'result' }] })643});644645const task = store.add(new McpTask(createTask({ pollInterval: 1000 })));646task.setHandler(mockHandler);647648// Update to input_required - this triggers a getTask call649task.onDidUpdateState(createTask({ status: 'input_required' }));650651// Allow the promise to settle652await clock.tickAsync(0);653654// Verify getTask was called655assert.strictEqual(getTaskStub.callCount, 1);656657// Once getTask resolves with completed, should fetch result658const result = await task.result;659assert.deepStrictEqual(result, { content: [{ type: 'text', text: 'result' }] });660});661662test('should handle getTask returning cancelled during polling', async () => {663const getTaskStub = sinon.stub();664getTaskStub.resolves(createTask({ status: 'cancelled' }));665666const mockHandler = upcastPartial<McpServerRequestHandler>({667getTask: getTaskStub,668});669670const task = store.add(new McpTask(createTask({ pollInterval: 1000 })));671task.setHandler(mockHandler);672673// Advance time to trigger polling674await clock.tickAsync(1000);675676await assert.rejects(677task.result,678(error: Error) => {679assert.strictEqual(error.name, 'Canceled');680return true;681}682);683});684685test('should return correct task id', () => {686const task = store.add(new McpTask(createTask({ taskId: 'my-task-id' })));687assert.strictEqual(task.id, 'my-task-id');688});689690test('should dispose cleanly', async () => {691const getTaskStub = sinon.stub();692getTaskStub.resolves(createTask({ status: 'working' }));693694const mockHandler = upcastPartial<McpServerRequestHandler>({695getTask: getTaskStub,696});697698const task = store.add(new McpTask(createTask({ pollInterval: 1000 })));699task.setHandler(mockHandler);700701// Poll once702await clock.tickAsync(1000);703const callCountBeforeDispose = getTaskStub.callCount;704705// Dispose706task.dispose();707708// Advance time - should not poll anymore709await clock.tickAsync(5000);710assert.strictEqual(getTaskStub.callCount, callCountBeforeDispose);711});712});713714715