Path: blob/main/src/vs/workbench/api/browser/mainThreadMcp.ts
5237 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/45import { mapFindFirst } from '../../../base/common/arraysFind.js';6import { disposableTimeout, RunOnceScheduler } from '../../../base/common/async.js';7import { CancellationError } from '../../../base/common/errors.js';8import { Emitter } from '../../../base/common/event.js';9import { Disposable, DisposableMap, DisposableStore, MutableDisposable } from '../../../base/common/lifecycle.js';10import { autorun, ISettableObservable, observableValue } from '../../../base/common/observable.js';11import Severity from '../../../base/common/severity.js';12import { URI } from '../../../base/common/uri.js';13import { generateUuid } from '../../../base/common/uuid.js';14import * as nls from '../../../nls.js';15import { ContextKeyExpr, IContextKeyService } from '../../../platform/contextkey/common/contextkey.js';16import { IDialogService, IPromptButton } from '../../../platform/dialogs/common/dialogs.js';17import { ExtensionIdentifier } from '../../../platform/extensions/common/extensions.js';18import { LogLevel } from '../../../platform/log/common/log.js';19import { ITelemetryService } from '../../../platform/telemetry/common/telemetry.js';20import { IMcpGatewayResult, IWorkbenchMcpGatewayService } from '../../contrib/mcp/common/mcpGatewayService.js';21import { IMcpMessageTransport, IMcpRegistry } from '../../contrib/mcp/common/mcpRegistryTypes.js';22import { extensionPrefixedIdentifier, McpCollectionDefinition, McpConnectionState, McpServerDefinition, McpServerLaunch, McpServerTransportType, McpServerTrust, UserInteractionRequiredError } from '../../contrib/mcp/common/mcpTypes.js';23import { MCP } from '../../contrib/mcp/common/modelContextProtocol.js';24import { IAuthenticationMcpAccessService } from '../../services/authentication/browser/authenticationMcpAccessService.js';25import { IAuthenticationMcpService } from '../../services/authentication/browser/authenticationMcpService.js';26import { IAuthenticationMcpUsageService } from '../../services/authentication/browser/authenticationMcpUsageService.js';27import { AuthenticationSession, AuthenticationSessionAccount, IAuthenticationService } from '../../services/authentication/common/authentication.js';28import { IDynamicAuthenticationProviderStorageService } from '../../services/authentication/common/dynamicAuthenticationProviderStorage.js';29import { ExtensionHostKind, extensionHostKindToString } from '../../services/extensions/common/extensionHostKind.js';30import { IExtensionService } from '../../services/extensions/common/extensions.js';31import { IExtHostContext, extHostNamedCustomer } from '../../services/extensions/common/extHostCustomers.js';32import { Proxied } from '../../services/extensions/common/proxyIdentifier.js';33import { ExtHostContext, ExtHostMcpShape, IMcpAuthenticationDetails, IMcpAuthenticationOptions, IAuthMetadataSource, MainContext, MainThreadMcpShape } from '../common/extHost.protocol.js';3435@extHostNamedCustomer(MainContext.MainThreadMcp)36export class MainThreadMcp extends Disposable implements MainThreadMcpShape {3738private _serverIdCounter = 0;3940private readonly _servers = new Map<number, ExtHostMcpServerLaunch>();41private readonly _serverDefinitions = new Map<number, McpServerDefinition>();42private readonly _serverAuthTracking = new McpServerAuthTracker();43private readonly _proxy: Proxied<ExtHostMcpShape>;44private readonly _collectionDefinitions = this._register(new DisposableMap<string, {45servers: ISettableObservable<readonly McpServerDefinition[]>;46dispose(): void;47}>());48private readonly _gateways = this._register(new DisposableMap<string, IMcpGatewayResult>());4950constructor(51private readonly _extHostContext: IExtHostContext,52@IMcpRegistry private readonly _mcpRegistry: IMcpRegistry,53@IDialogService private readonly dialogService: IDialogService,54@IAuthenticationService private readonly _authenticationService: IAuthenticationService,55@IAuthenticationMcpService private readonly authenticationMcpServersService: IAuthenticationMcpService,56@IAuthenticationMcpAccessService private readonly authenticationMCPServerAccessService: IAuthenticationMcpAccessService,57@IAuthenticationMcpUsageService private readonly authenticationMCPServerUsageService: IAuthenticationMcpUsageService,58@IDynamicAuthenticationProviderStorageService private readonly _dynamicAuthenticationProviderStorageService: IDynamicAuthenticationProviderStorageService,59@IExtensionService private readonly _extensionService: IExtensionService,60@IContextKeyService private readonly _contextKeyService: IContextKeyService,61@ITelemetryService private readonly _telemetryService: ITelemetryService,62@IWorkbenchMcpGatewayService private readonly _mcpGatewayService: IWorkbenchMcpGatewayService,63) {64super();65this._register(_authenticationService.onDidChangeSessions(e => this._onDidChangeAuthSessions(e.providerId, e.label)));66const proxy = this._proxy = _extHostContext.getProxy(ExtHostContext.ExtHostMcp);67this._register(this._mcpRegistry.registerDelegate({68// Prefer Node.js extension hosts when they're available. No CORS issues etc.69priority: _extHostContext.extensionHostKind === ExtensionHostKind.LocalWebWorker ? 0 : 1,70waitForInitialProviderPromises() {71return proxy.$waitForInitialCollectionProviders();72},73canStart(collection, serverDefinition) {74if (collection.remoteAuthority !== _extHostContext.remoteAuthority) {75return false;76}77if (serverDefinition.launch.type === McpServerTransportType.Stdio && _extHostContext.extensionHostKind === ExtensionHostKind.LocalWebWorker) {78return false;79}80return true;81},82async substituteVariables(serverDefinition, launch) {83const ser = await proxy.$substituteVariables(serverDefinition.variableReplacement?.folder?.uri, McpServerLaunch.toSerialized(launch));84return McpServerLaunch.fromSerialized(ser);85},86start: (_collection, serverDefiniton, resolveLaunch, options) => {87const id = ++this._serverIdCounter;88const launch = new ExtHostMcpServerLaunch(89_extHostContext.extensionHostKind,90() => proxy.$stopMcp(id),91msg => proxy.$sendMessage(id, JSON.stringify(msg)),92);93this._servers.set(id, launch);94this._serverDefinitions.set(id, serverDefiniton);95proxy.$startMcp(id, {96launch: resolveLaunch,97defaultCwd: serverDefiniton.variableReplacement?.folder?.uri,98errorOnUserInteraction: options?.errorOnUserInteraction,99});100101return launch;102},103}));104105// Subscribe to MCP server definition changes and notify ext host106const onDidChangeMcpServerDefinitionsTrigger = this._register(new RunOnceScheduler(() => this._publishServerDefinitions(), 500));107this._register(autorun(reader => {108const collections = this._mcpRegistry.collections.read(reader);109// Read all server definitions to track changes110for (const collection of collections) {111collection.serverDefinitions.read(reader);112}113// Notify ext host that definitions changed (it will re-fetch if needed)114if (!onDidChangeMcpServerDefinitionsTrigger.isScheduled()) {115onDidChangeMcpServerDefinitionsTrigger.schedule();116}117}));118119onDidChangeMcpServerDefinitionsTrigger.schedule();120}121122private _publishServerDefinitions() {123const collections = this._mcpRegistry.collections.get();124const allServers: McpServerDefinition.Serialized[] = [];125126for (const collection of collections) {127const servers = collection.serverDefinitions.get();128for (const server of servers) {129allServers.push(McpServerDefinition.toSerialized(server));130}131}132133this._proxy.$onDidChangeMcpServerDefinitions(allServers);134}135136$upsertMcpCollection(collection: McpCollectionDefinition.FromExtHost, serversDto: McpServerDefinition.Serialized[]): void {137const servers = serversDto.map(McpServerDefinition.fromSerialized);138const existing = this._collectionDefinitions.get(collection.id);139if (existing) {140existing.servers.set(servers, undefined);141} else {142const serverDefinitions = observableValue<readonly McpServerDefinition[]>('mcpServers', servers);143const extensionId = new ExtensionIdentifier(collection.extensionId);144const store = new DisposableStore();145const handle = store.add(new MutableDisposable());146const register = () => {147handle.value ??= this._mcpRegistry.registerCollection({148...collection,149source: extensionId,150resolveServerLanch: collection.canResolveLaunch ? (async def => {151const r = await this._proxy.$resolveMcpLaunch(collection.id, def.label);152return r ? McpServerLaunch.fromSerialized(r) : undefined;153}) : undefined,154trustBehavior: collection.isTrustedByDefault ? McpServerTrust.Kind.Trusted : McpServerTrust.Kind.TrustedOnNonce,155remoteAuthority: this._extHostContext.remoteAuthority,156serverDefinitions,157});158};159160const whenClauseStr = mapFindFirst(this._extensionService.extensions, e =>161ExtensionIdentifier.equals(extensionId, e.identifier)162? e.contributes?.mcpServerDefinitionProviders?.find(p => extensionPrefixedIdentifier(extensionId, p.id) === collection.id)?.when163: undefined);164const whenClause = whenClauseStr && ContextKeyExpr.deserialize(whenClauseStr);165166if (!whenClause) {167register();168} else {169const evaluate = () => {170if (this._contextKeyService.contextMatchesRules(whenClause)) {171register();172} else {173handle.clear();174}175};176177store.add(this._contextKeyService.onDidChangeContext(evaluate));178evaluate();179}180181this._collectionDefinitions.set(collection.id, {182servers: serverDefinitions,183dispose: () => store.dispose(),184});185}186}187188$deleteMcpCollection(collectionId: string): void {189this._collectionDefinitions.deleteAndDispose(collectionId);190}191192$onDidChangeState(id: number, update: McpConnectionState): void {193const server = this._servers.get(id);194if (!server) {195return;196}197198server.state.set(update, undefined);199if (!McpConnectionState.isRunning(update)) {200server.dispose();201this._servers.delete(id);202this._serverDefinitions.delete(id);203this._serverAuthTracking.untrack(id);204}205}206207$onDidPublishLog(id: number, level: LogLevel, log: string): void {208if (typeof level === 'string') {209level = LogLevel.Info;210log = level as unknown as string;211}212213this._servers.get(id)?.pushLog(level, log);214}215216$onDidReceiveMessage(id: number, message: string): void {217this._servers.get(id)?.pushMessage(message);218}219220async $getTokenForProviderId(id: number, providerId: string, scopes: string[], options: IMcpAuthenticationOptions = {}): Promise<string | undefined> {221const server = this._serverDefinitions.get(id);222if (!server) {223return undefined;224}225return this._getSessionForProvider(id, server, providerId, scopes, undefined, options.errorOnUserInteraction);226}227228async $getTokenFromServerMetadata(id: number, authDetails: IMcpAuthenticationDetails, { errorOnUserInteraction, forceNewRegistration }: IMcpAuthenticationOptions = {}): Promise<string | undefined> {229const server = this._serverDefinitions.get(id);230if (!server) {231return undefined;232}233const authorizationServer = URI.revive(authDetails.authorizationServer);234const resourceServer = authDetails.resourceMetadata?.resource ? URI.parse(authDetails.resourceMetadata.resource) : undefined;235const resolvedScopes = authDetails.scopes ?? authDetails.resourceMetadata?.scopes_supported ?? authDetails.authorizationServerMetadata.scopes_supported ?? [];236let providerId = await this._authenticationService.getOrActivateProviderIdForServer(authorizationServer, resourceServer);237if (forceNewRegistration && providerId) {238if (!this._authenticationService.isDynamicAuthenticationProvider(providerId)) {239throw new Error('Cannot force new registration for a non-dynamic authentication provider.');240}241this._authenticationService.unregisterAuthenticationProvider(providerId);242// TODO: Encapsulate this and the unregister in one call in the auth service243await this._dynamicAuthenticationProviderStorageService.removeDynamicProvider(providerId);244providerId = undefined;245}246247if (!providerId) {248const provider = await this._authenticationService.createDynamicAuthenticationProvider(authorizationServer, authDetails.authorizationServerMetadata, authDetails.resourceMetadata);249if (!provider) {250return undefined;251}252providerId = provider.id;253}254255return this._getSessionForProvider(id, server, providerId, resolvedScopes, authorizationServer, errorOnUserInteraction);256}257258private async _getSessionForProvider(259serverId: number,260server: McpServerDefinition,261providerId: string,262scopes: string[],263authorizationServer?: URI,264errorOnUserInteraction: boolean = false265): Promise<string | undefined> {266const sessions = await this._authenticationService.getSessions(providerId, scopes, { authorizationServer }, true);267const accountNamePreference = this.authenticationMcpServersService.getAccountPreference(server.id, providerId);268let matchingAccountPreferenceSession: AuthenticationSession | undefined;269if (accountNamePreference) {270matchingAccountPreferenceSession = sessions.find(session => session.account.label === accountNamePreference);271}272const provider = this._authenticationService.getProvider(providerId);273let session: AuthenticationSession;274if (sessions.length) {275// If we have an existing session preference, use that. If not, we'll return any valid session at the end of this function.276if (matchingAccountPreferenceSession && this.authenticationMCPServerAccessService.isAccessAllowed(providerId, matchingAccountPreferenceSession.account.label, server.id)) {277this.authenticationMCPServerUsageService.addAccountUsage(providerId, matchingAccountPreferenceSession.account.label, scopes, server.id, server.label);278this._serverAuthTracking.track(providerId, serverId, scopes);279return matchingAccountPreferenceSession.accessToken;280}281// If we only have one account for a single auth provider, lets just check if it's allowed and return it if it is.282if (!provider.supportsMultipleAccounts && this.authenticationMCPServerAccessService.isAccessAllowed(providerId, sessions[0].account.label, server.id)) {283this.authenticationMCPServerUsageService.addAccountUsage(providerId, sessions[0].account.label, scopes, server.id, server.label);284this._serverAuthTracking.track(providerId, serverId, scopes);285return sessions[0].accessToken;286}287}288289if (errorOnUserInteraction) {290throw new UserInteractionRequiredError('authentication');291}292293const isAllowed = await this.loginPrompt(server.label, provider.label, false);294if (!isAllowed) {295throw new Error('User did not consent to login.');296}297298if (sessions.length) {299if (provider.supportsMultipleAccounts && errorOnUserInteraction) {300throw new UserInteractionRequiredError('authentication');301}302session = provider.supportsMultipleAccounts303? await this.authenticationMcpServersService.selectSession(providerId, server.id, server.label, scopes, sessions)304: sessions[0];305}306else {307if (errorOnUserInteraction) {308throw new UserInteractionRequiredError('authentication');309}310const accountToCreate: AuthenticationSessionAccount | undefined = matchingAccountPreferenceSession?.account;311do {312session = await this._authenticationService.createSession(313providerId,314scopes,315{316activateImmediate: true,317account: accountToCreate,318authorizationServer319});320} while (321accountToCreate322&& accountToCreate.label !== session.account.label323&& !await this.continueWithIncorrectAccountPrompt(session.account.label, accountToCreate.label)324);325}326327this.authenticationMCPServerAccessService.updateAllowedMcpServers(providerId, session.account.label, [{ id: server.id, name: server.label, allowed: true }]);328this.authenticationMcpServersService.updateAccountPreference(server.id, providerId, session.account);329this.authenticationMCPServerUsageService.addAccountUsage(providerId, session.account.label, scopes, server.id, server.label);330this._serverAuthTracking.track(providerId, serverId, scopes);331return session.accessToken;332}333334private async continueWithIncorrectAccountPrompt(chosenAccountLabel: string, requestedAccountLabel: string): Promise<boolean> {335const result = await this.dialogService.prompt({336message: nls.localize('incorrectAccount', "Incorrect account detected"),337detail: nls.localize('incorrectAccountDetail', "The chosen account, {0}, does not match the requested account, {1}.", chosenAccountLabel, requestedAccountLabel),338type: Severity.Warning,339cancelButton: true,340buttons: [341{342label: nls.localize('keep', 'Keep {0}', chosenAccountLabel),343run: () => chosenAccountLabel344},345{346label: nls.localize('loginWith', 'Login with {0}', requestedAccountLabel),347run: () => requestedAccountLabel348}349],350});351352if (!result.result) {353throw new CancellationError();354}355356return result.result === chosenAccountLabel;357}358359private async _onDidChangeAuthSessions(providerId: string, providerLabel: string): Promise<void> {360const serversUsingProvider = this._serverAuthTracking.get(providerId);361if (!serversUsingProvider) {362return;363}364365for (const { serverId, scopes } of serversUsingProvider) {366const server = this._servers.get(serverId);367const serverDefinition = this._serverDefinitions.get(serverId);368369if (!server || !serverDefinition) {370continue;371}372373// Only validate servers that are running374const state = server.state.get();375if (state.state !== McpConnectionState.Kind.Running) {376continue;377}378379// Validate if the session is still available380try {381await this._getSessionForProvider(serverId, serverDefinition, providerId, scopes, undefined, true);382} catch (e) {383if (UserInteractionRequiredError.is(e)) {384// Session is no longer valid, stop the server385server.pushLog(LogLevel.Warning, nls.localize('mcpAuthSessionRemoved', "Authentication session for {0} removed, stopping server", providerLabel));386server.stop();387}388// Ignore other errors to avoid disrupting other servers389}390}391}392393$logMcpAuthSetup(data: IAuthMetadataSource): void {394type McpAuthSetupClassification = {395owner: 'TylerLeonhardt';396comment: 'Tracks how MCP OAuth authentication setup was discovered and configured';397resourceMetadataSource: { classification: 'SystemMetaData'; purpose: 'FeatureInsight'; comment: 'How resource metadata was discovered (header, wellKnown, or none)' };398serverMetadataSource: { classification: 'SystemMetaData'; purpose: 'FeatureInsight'; comment: 'How authorization server metadata was discovered (resourceMetadata, wellKnown, or default)' };399};400this._telemetryService.publicLog2<IAuthMetadataSource, McpAuthSetupClassification>('mcp/authSetup', data);401}402403async $startMcpGateway(): Promise<{ address: URI; gatewayId: string } | undefined> {404const result = await this._mcpGatewayService.createGateway(this._extHostContext.extensionHostKind === ExtensionHostKind.Remote);405if (!result) {406return undefined;407}408409if (this._store.isDisposed) {410result.dispose();411return undefined;412}413414const gatewayId = generateUuid();415this._gateways.set(gatewayId, result);416417return {418address: result.address,419gatewayId,420};421}422423$disposeMcpGateway(gatewayId: string): void {424this._gateways.deleteAndDispose(gatewayId);425}426427private async loginPrompt(mcpLabel: string, providerLabel: string, recreatingSession: boolean): Promise<boolean> {428const message = recreatingSession429? nls.localize('confirmRelogin', "The MCP Server Definition '{0}' wants you to authenticate to {1}.", mcpLabel, providerLabel)430: nls.localize('confirmLogin', "The MCP Server Definition '{0}' wants to authenticate to {1}.", mcpLabel, providerLabel);431432const buttons: IPromptButton<boolean | undefined>[] = [433{434label: nls.localize({ key: 'allow', comment: ['&& denotes a mnemonic'] }, "&&Allow"),435run() {436return true;437},438}439];440const { result } = await this.dialogService.prompt({441type: Severity.Info,442message,443buttons,444cancelButton: true,445});446447return result ?? false;448}449450override dispose(): void {451for (const server of this._servers.values()) {452server.extHostDispose();453}454this._servers.clear();455this._serverDefinitions.clear();456this._serverAuthTracking.clear();457super.dispose();458}459}460461462class ExtHostMcpServerLaunch extends Disposable implements IMcpMessageTransport {463public readonly state = observableValue<McpConnectionState>('mcpServerState', { state: McpConnectionState.Kind.Starting });464465private readonly _onDidLog = this._register(new Emitter<{ level: LogLevel; message: string }>());466public readonly onDidLog = this._onDidLog.event;467468private readonly _onDidReceiveMessage = this._register(new Emitter<MCP.JSONRPCMessage>());469public readonly onDidReceiveMessage = this._onDidReceiveMessage.event;470471pushLog(level: LogLevel, message: string): void {472this._onDidLog.fire({ message, level });473}474475pushMessage(message: string): void {476let parsed: MCP.JSONRPCMessage | undefined;477try {478parsed = JSON.parse(message);479} catch (e) {480this.pushLog(LogLevel.Warning, `Failed to parse message: ${JSON.stringify(message)}`);481}482483if (parsed) {484if (Array.isArray(parsed)) { // streamable HTTP supports batching485parsed.forEach(p => this._onDidReceiveMessage.fire(p));486} else {487this._onDidReceiveMessage.fire(parsed);488}489}490}491492constructor(493extHostKind: ExtensionHostKind,494public readonly stop: () => void,495public readonly send: (message: MCP.JSONRPCMessage) => void,496) {497super();498499this._register(disposableTimeout(() => {500this.pushLog(LogLevel.Info, `Starting server from ${extensionHostKindToString(extHostKind)} extension host`);501}));502}503504public extHostDispose() {505if (McpConnectionState.isRunning(this.state.get())) {506this.pushLog(LogLevel.Warning, 'Extension host shut down, server will stop.');507this.state.set({ state: McpConnectionState.Kind.Stopped }, undefined);508}509this.dispose();510}511512public override dispose(): void {513if (McpConnectionState.isRunning(this.state.get())) {514this.stop();515}516517super.dispose();518}519}520521/**522* Tracks which MCP servers are using which authentication providers.523* Organized by provider ID for efficient lookup when auth sessions change.524*/525class McpServerAuthTracker {526// Provider ID -> Array of serverId and scopes used527private readonly _tracking = new Map<string, Array<{ serverId: number; scopes: string[] }>>();528529/**530* Track authentication for a server with a specific provider.531* Replaces any existing tracking for this server/provider combination.532*/533track(providerId: string, serverId: number, scopes: string[]): void {534const servers = this._tracking.get(providerId) || [];535const filtered = servers.filter(s => s.serverId !== serverId);536filtered.push({ serverId, scopes });537this._tracking.set(providerId, filtered);538}539540/**541* Remove all authentication tracking for a server across all providers.542*/543untrack(serverId: number): void {544for (const [providerId, servers] of this._tracking.entries()) {545const filtered = servers.filter(s => s.serverId !== serverId);546if (filtered.length === 0) {547this._tracking.delete(providerId);548} else {549this._tracking.set(providerId, filtered);550}551}552}553554/**555* Get all servers using a specific authentication provider.556*/557get(providerId: string): ReadonlyArray<{ serverId: number; scopes: string[] }> | undefined {558return this._tracking.get(providerId);559}560561/**562* Clear all tracking data.563*/564clear(): void {565this._tracking.clear();566}567}568569570