Path: blob/main/src/vs/workbench/api/browser/mainThreadMcp.ts
3296 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/45import { disposableTimeout } from '../../../base/common/async.js';6import { CancellationError } from '../../../base/common/errors.js';7import { Emitter } from '../../../base/common/event.js';8import { Disposable, DisposableMap } from '../../../base/common/lifecycle.js';9import { IAuthorizationProtectedResourceMetadata, IAuthorizationServerMetadata } from '../../../base/common/oauth.js';10import { ISettableObservable, observableValue } from '../../../base/common/observable.js';11import Severity from '../../../base/common/severity.js';12import { URI, UriComponents } from '../../../base/common/uri.js';13import * as nls from '../../../nls.js';14import { IDialogService, IPromptButton } from '../../../platform/dialogs/common/dialogs.js';15import { ExtensionIdentifier } from '../../../platform/extensions/common/extensions.js';16import { LogLevel } from '../../../platform/log/common/log.js';17import { IMcpMessageTransport, IMcpRegistry } from '../../contrib/mcp/common/mcpRegistryTypes.js';18import { McpCollectionDefinition, McpConnectionState, McpServerDefinition, McpServerLaunch, McpServerTransportType, McpServerTrust } from '../../contrib/mcp/common/mcpTypes.js';19import { MCP } from '../../contrib/mcp/common/modelContextProtocol.js';20import { IAuthenticationMcpAccessService } from '../../services/authentication/browser/authenticationMcpAccessService.js';21import { IAuthenticationMcpService } from '../../services/authentication/browser/authenticationMcpService.js';22import { IAuthenticationMcpUsageService } from '../../services/authentication/browser/authenticationMcpUsageService.js';23import { AuthenticationSession, AuthenticationSessionAccount, IAuthenticationService } from '../../services/authentication/common/authentication.js';24import { ExtensionHostKind, extensionHostKindToString } from '../../services/extensions/common/extensionHostKind.js';25import { IExtHostContext, extHostNamedCustomer } from '../../services/extensions/common/extHostCustomers.js';26import { Proxied } from '../../services/extensions/common/proxyIdentifier.js';27import { ExtHostContext, ExtHostMcpShape, MainContext, MainThreadMcpShape } from '../common/extHost.protocol.js';2829@extHostNamedCustomer(MainContext.MainThreadMcp)30export class MainThreadMcp extends Disposable implements MainThreadMcpShape {3132private _serverIdCounter = 0;3334private readonly _servers = new Map<number, ExtHostMcpServerLaunch>();35private readonly _serverDefinitions = new Map<number, McpServerDefinition>();36private readonly _proxy: Proxied<ExtHostMcpShape>;37private readonly _collectionDefinitions = this._register(new DisposableMap<string, {38fromExtHost: McpCollectionDefinition.FromExtHost;39servers: ISettableObservable<readonly McpServerDefinition[]>;40dispose(): void;41}>());4243constructor(44private readonly _extHostContext: IExtHostContext,45@IMcpRegistry private readonly _mcpRegistry: IMcpRegistry,46@IDialogService private readonly dialogService: IDialogService,47@IAuthenticationService private readonly _authenticationService: IAuthenticationService,48@IAuthenticationMcpService private readonly authenticationMcpServersService: IAuthenticationMcpService,49@IAuthenticationMcpAccessService private readonly authenticationMCPServerAccessService: IAuthenticationMcpAccessService,50@IAuthenticationMcpUsageService private readonly authenticationMCPServerUsageService: IAuthenticationMcpUsageService,51) {52super();53const proxy = this._proxy = _extHostContext.getProxy(ExtHostContext.ExtHostMcp);54this._register(this._mcpRegistry.registerDelegate({55// Prefer Node.js extension hosts when they're available. No CORS issues etc.56priority: _extHostContext.extensionHostKind === ExtensionHostKind.LocalWebWorker ? 0 : 1,57waitForInitialProviderPromises() {58return proxy.$waitForInitialCollectionProviders();59},60canStart(collection, serverDefinition) {61if (collection.remoteAuthority !== _extHostContext.remoteAuthority) {62return false;63}64if (serverDefinition.launch.type === McpServerTransportType.Stdio && _extHostContext.extensionHostKind === ExtensionHostKind.LocalWebWorker) {65return false;66}67return true;68},69start: (_collection, serverDefiniton, resolveLaunch) => {70const id = ++this._serverIdCounter;71const launch = new ExtHostMcpServerLaunch(72_extHostContext.extensionHostKind,73() => proxy.$stopMcp(id),74msg => proxy.$sendMessage(id, JSON.stringify(msg)),75);76this._servers.set(id, launch);77this._serverDefinitions.set(id, serverDefiniton);78proxy.$startMcp(id, resolveLaunch);7980return launch;81},82}));83}8485$upsertMcpCollection(collection: McpCollectionDefinition.FromExtHost, serversDto: McpServerDefinition.Serialized[]): void {86const servers = serversDto.map(McpServerDefinition.fromSerialized);87const existing = this._collectionDefinitions.get(collection.id);88if (existing) {89existing.servers.set(servers, undefined);90} else {91const serverDefinitions = observableValue<readonly McpServerDefinition[]>('mcpServers', servers);92const handle = this._mcpRegistry.registerCollection({93...collection,94source: new ExtensionIdentifier(collection.extensionId),95resolveServerLanch: collection.canResolveLaunch ? (async def => {96const r = await this._proxy.$resolveMcpLaunch(collection.id, def.label);97return r ? McpServerLaunch.fromSerialized(r) : undefined;98}) : undefined,99trustBehavior: collection.isTrustedByDefault ? McpServerTrust.Kind.Trusted : McpServerTrust.Kind.TrustedOnNonce,100remoteAuthority: this._extHostContext.remoteAuthority,101serverDefinitions,102});103104this._collectionDefinitions.set(collection.id, {105fromExtHost: collection,106servers: serverDefinitions,107dispose: () => handle.dispose(),108});109}110}111112$deleteMcpCollection(collectionId: string): void {113this._collectionDefinitions.deleteAndDispose(collectionId);114}115116$onDidChangeState(id: number, update: McpConnectionState): void {117const server = this._servers.get(id);118if (!server) {119return;120}121122server.state.set(update, undefined);123if (!McpConnectionState.isRunning(update)) {124server.dispose();125this._servers.delete(id);126this._serverDefinitions.delete(id);127}128}129130$onDidPublishLog(id: number, level: LogLevel, log: string): void {131if (typeof level === 'string') {132level = LogLevel.Info;133log = level as unknown as string;134}135136this._servers.get(id)?.pushLog(level, log);137}138139$onDidReceiveMessage(id: number, message: string): void {140this._servers.get(id)?.pushMessage(message);141}142143async $getTokenFromServerMetadata(id: number, authServerComponents: UriComponents, serverMetadata: IAuthorizationServerMetadata, resourceMetadata: IAuthorizationProtectedResourceMetadata | undefined): Promise<string | undefined> {144const server = this._serverDefinitions.get(id);145if (!server) {146return undefined;147}148149const authorizationServer = URI.revive(authServerComponents);150const scopesSupported = resourceMetadata?.scopes_supported || serverMetadata.scopes_supported || [];151let providerId = await this._authenticationService.getOrActivateProviderIdForServer(authorizationServer);152if (!providerId) {153const provider = await this._authenticationService.createDynamicAuthenticationProvider(authorizationServer, serverMetadata, resourceMetadata);154if (!provider) {155return undefined;156}157providerId = provider.id;158}159const sessions = await this._authenticationService.getSessions(providerId, scopesSupported, { authorizationServer: authorizationServer }, true);160const accountNamePreference = this.authenticationMcpServersService.getAccountPreference(server.id, providerId);161let matchingAccountPreferenceSession: AuthenticationSession | undefined;162if (accountNamePreference) {163matchingAccountPreferenceSession = sessions.find(session => session.account.label === accountNamePreference);164}165const provider = this._authenticationService.getProvider(providerId);166let session: AuthenticationSession;167if (sessions.length) {168// If we have an existing session preference, use that. If not, we'll return any valid session at the end of this function.169if (matchingAccountPreferenceSession && this.authenticationMCPServerAccessService.isAccessAllowed(providerId, matchingAccountPreferenceSession.account.label, server.id)) {170this.authenticationMCPServerUsageService.addAccountUsage(providerId, matchingAccountPreferenceSession.account.label, scopesSupported, server.id, server.label);171return matchingAccountPreferenceSession.accessToken;172}173// If we only have one account for a single auth provider, lets just check if it's allowed and return it if it is.174if (!provider.supportsMultipleAccounts && this.authenticationMCPServerAccessService.isAccessAllowed(providerId, sessions[0].account.label, server.id)) {175this.authenticationMCPServerUsageService.addAccountUsage(providerId, sessions[0].account.label, scopesSupported, server.id, server.label);176return sessions[0].accessToken;177}178}179180const isAllowed = await this.loginPrompt(server.label, provider.label, false);181if (!isAllowed) {182throw new Error('User did not consent to login.');183}184185if (sessions.length) {186session = provider.supportsMultipleAccounts187? await this.authenticationMcpServersService.selectSession(providerId, server.id, server.label, scopesSupported, sessions)188: sessions[0];189}190else {191const accountToCreate: AuthenticationSessionAccount | undefined = matchingAccountPreferenceSession?.account;192do {193session = await this._authenticationService.createSession(194providerId,195scopesSupported,196{197activateImmediate: true,198account: accountToCreate,199authorizationServer200});201} while (202accountToCreate203&& accountToCreate.label !== session.account.label204&& !await this.continueWithIncorrectAccountPrompt(session.account.label, accountToCreate.label)205);206}207208this.authenticationMCPServerAccessService.updateAllowedMcpServers(providerId, session.account.label, [{ id: server.id, name: server.label, allowed: true }]);209this.authenticationMcpServersService.updateAccountPreference(server.id, providerId, session.account);210this.authenticationMCPServerUsageService.addAccountUsage(providerId, session.account.label, scopesSupported, server.id, server.label);211return session.accessToken;212}213214private async continueWithIncorrectAccountPrompt(chosenAccountLabel: string, requestedAccountLabel: string): Promise<boolean> {215const result = await this.dialogService.prompt({216message: nls.localize('incorrectAccount', "Incorrect account detected"),217detail: nls.localize('incorrectAccountDetail', "The chosen account, {0}, does not match the requested account, {1}.", chosenAccountLabel, requestedAccountLabel),218type: Severity.Warning,219cancelButton: true,220buttons: [221{222label: nls.localize('keep', 'Keep {0}', chosenAccountLabel),223run: () => chosenAccountLabel224},225{226label: nls.localize('loginWith', 'Login with {0}', requestedAccountLabel),227run: () => requestedAccountLabel228}229],230});231232if (!result.result) {233throw new CancellationError();234}235236return result.result === chosenAccountLabel;237}238239private async loginPrompt(mcpLabel: string, providerLabel: string, recreatingSession: boolean): Promise<boolean> {240const message = recreatingSession241? nls.localize('confirmRelogin', "The MCP Server Definition '{0}' wants you to authenticate to {1}.", mcpLabel, providerLabel)242: nls.localize('confirmLogin', "The MCP Server Definition '{0}' wants to authenticate to {1}.", mcpLabel, providerLabel);243244const buttons: IPromptButton<boolean | undefined>[] = [245{246label: nls.localize({ key: 'allow', comment: ['&& denotes a mnemonic'] }, "&&Allow"),247run() {248return true;249},250}251];252const { result } = await this.dialogService.prompt({253type: Severity.Info,254message,255buttons,256cancelButton: true,257});258259return result ?? false;260}261262override dispose(): void {263for (const server of this._servers.values()) {264server.extHostDispose();265}266this._servers.clear();267this._serverDefinitions.clear();268super.dispose();269}270}271272273class ExtHostMcpServerLaunch extends Disposable implements IMcpMessageTransport {274public readonly state = observableValue<McpConnectionState>('mcpServerState', { state: McpConnectionState.Kind.Starting });275276private readonly _onDidLog = this._register(new Emitter<{ level: LogLevel; message: string }>());277public readonly onDidLog = this._onDidLog.event;278279private readonly _onDidReceiveMessage = this._register(new Emitter<MCP.JSONRPCMessage>());280public readonly onDidReceiveMessage = this._onDidReceiveMessage.event;281282pushLog(level: LogLevel, message: string): void {283this._onDidLog.fire({ message, level });284}285286pushMessage(message: string): void {287let parsed: MCP.JSONRPCMessage | undefined;288try {289parsed = JSON.parse(message);290} catch (e) {291this.pushLog(LogLevel.Warning, `Failed to parse message: ${JSON.stringify(message)}`);292}293294if (parsed) {295if (Array.isArray(parsed)) { // streamable HTTP supports batching296parsed.forEach(p => this._onDidReceiveMessage.fire(p));297} else {298this._onDidReceiveMessage.fire(parsed);299}300}301}302303constructor(304extHostKind: ExtensionHostKind,305public readonly stop: () => void,306public readonly send: (message: MCP.JSONRPCMessage) => void,307) {308super();309310this._register(disposableTimeout(() => {311this.pushLog(LogLevel.Info, `Starting server from ${extensionHostKindToString(extHostKind)} extension host`);312}));313}314315public extHostDispose() {316if (McpConnectionState.isRunning(this.state.get())) {317this.pushLog(LogLevel.Warning, 'Extension host shut down, server will stop.');318this.state.set({ state: McpConnectionState.Kind.Stopped }, undefined);319}320this.dispose();321}322323public override dispose(): void {324if (McpConnectionState.isRunning(this.state.get())) {325this.stop();326}327328super.dispose();329}330}331332333