Path: blob/main/src/vs/workbench/contrib/mcp/common/mcpSamplingService.ts
5250 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/45import { asArray } from '../../../../base/common/arrays.js';6import { mapFindFirst } from '../../../../base/common/arraysFind.js';7import { Sequencer } from '../../../../base/common/async.js';8import { decodeBase64 } from '../../../../base/common/buffer.js';9import { CancellationToken } from '../../../../base/common/cancellation.js';10import { Event } from '../../../../base/common/event.js';11import { Disposable } from '../../../../base/common/lifecycle.js';12import { isDefined } from '../../../../base/common/types.js';13import { localize } from '../../../../nls.js';14import { ICommandService } from '../../../../platform/commands/common/commands.js';15import { ConfigurationTarget, getConfigValueInTarget, IConfigurationService } from '../../../../platform/configuration/common/configuration.js';16import { IDialogService } from '../../../../platform/dialogs/common/dialogs.js';17import { ExtensionIdentifier } from '../../../../platform/extensions/common/extensions.js';18import { IInstantiationService } from '../../../../platform/instantiation/common/instantiation.js';19import { INotificationService, Severity } from '../../../../platform/notification/common/notification.js';20import { ChatAgentLocation, ChatConfiguration } from '../../chat/common/constants.js';21import { ChatImageMimeType, ChatMessageRole, IChatMessage, IChatMessagePart, ILanguageModelsService } from '../../chat/common/languageModels.js';22import { McpCommandIds } from './mcpCommandIds.js';23import { IMcpServerSamplingConfiguration, mcpServerSamplingSection } from './mcpConfiguration.js';24import { McpSamplingLog } from './mcpSamplingLog.js';25import { IMcpSamplingService, IMcpServer, ISamplingOptions, ISamplingResult, McpError } from './mcpTypes.js';26import { MCP } from './modelContextProtocol.js';2728const enum ModelMatch {29UnsureAllowedDuringChat,30UnsureAllowedOutsideChat,31NotAllowed,32NoMatchingModel,33}3435export class McpSamplingService extends Disposable implements IMcpSamplingService {36declare readonly _serviceBrand: undefined;3738private readonly _sessionSets = {39allowedDuringChat: new Map<string, boolean>(),40allowedOutsideChat: new Map<string, boolean>(),41};4243private readonly _logs: McpSamplingLog;4445private readonly _modelSequencer = new Sequencer();4647constructor(48@ILanguageModelsService private readonly _languageModelsService: ILanguageModelsService,49@IConfigurationService private readonly _configurationService: IConfigurationService,50@IDialogService private readonly _dialogService: IDialogService,51@INotificationService private readonly _notificationService: INotificationService,52@ICommandService private readonly _commandService: ICommandService,53@IInstantiationService instaService: IInstantiationService,54) {55super();56this._logs = this._register(instaService.createInstance(McpSamplingLog));57}5859async sample(opts: ISamplingOptions, token = CancellationToken.None): Promise<ISamplingResult> {60const messages = opts.params.messages.map((message): IChatMessage | undefined => {61const content: IChatMessagePart[] = asArray(message.content).map((part): IChatMessagePart | undefined => part.type === 'text'62? { type: 'text', value: part.text }63: part.type === 'image' || part.type === 'audio'64? { type: 'image_url', value: { mimeType: part.mimeType as ChatImageMimeType, data: decodeBase64(part.data) } }65: undefined66).filter(isDefined);6768if (!content.length) {69return undefined;70}71return {72role: message.role === 'assistant' ? ChatMessageRole.Assistant : ChatMessageRole.User,73content,74};75}).filter(isDefined);7677if (opts.params.systemPrompt) {78messages.unshift({ role: ChatMessageRole.System, content: [{ type: 'text', value: opts.params.systemPrompt }] });79}8081const model = await this._modelSequencer.queue(() => this._getMatchingModel(opts));82// todo@connor4312: nullExtensionDescription.identifier -> undefined with API update83const response = await this._languageModelsService.sendChatRequest(model, new ExtensionIdentifier('core'), messages, {}, token);8485let responseText = '';8687// MCP doesn't have a notion of a multi-part sampling response, so we only preserve text88// Ref https://github.com/modelcontextprotocol/modelcontextprotocol/issues/9189const streaming = (async () => {90for await (const part of response.stream) {91if (Array.isArray(part)) {92for (const p of part) {93if (p.type === 'text') {94responseText += p.value;95}96}97} else if (part.type === 'text') {98responseText += part.value;99}100}101})();102103try {104await Promise.all([response.result, streaming]);105this._logs.add(opts.server, opts.params.messages, responseText, model);106return {107sample: {108model,109content: { type: 'text', text: responseText },110role: 'assistant', // it came from the model!111},112};113} catch (err) {114throw McpError.unknown(err);115}116}117118hasLogs(server: IMcpServer): boolean {119return this._logs.has(server);120}121122getLogText(server: IMcpServer): string {123return this._logs.getAsText(server);124}125126private async _getMatchingModel(opts: ISamplingOptions): Promise<string> {127const model = await this._getMatchingModelInner(opts.server, opts.isDuringToolCall, opts.params.modelPreferences);128const globalAutoApprove = this._configurationService.getValue<boolean>(ChatConfiguration.GlobalAutoApprove);129130if (model === ModelMatch.UnsureAllowedDuringChat) {131// In YOLO mode, auto-approve MCP sampling requests without prompting132if (globalAutoApprove) {133this._sessionSets.allowedDuringChat.set(opts.server.definition.id, true);134return this._getMatchingModel(opts);135}136const retry = await this._showContextual(137opts.isDuringToolCall,138localize('mcp.sampling.allowDuringChat.title', 'Allow MCP tools from "{0}" to make LLM requests?', opts.server.definition.label),139localize('mcp.sampling.allowDuringChat.desc', 'The MCP server "{0}" has issued a request to make a language model call. Do you want to allow it to make requests during chat?', opts.server.definition.label),140this.allowButtons(opts.server, 'allowedDuringChat')141);142if (retry) {143return this._getMatchingModel(opts);144}145throw McpError.notAllowed();146} else if (model === ModelMatch.UnsureAllowedOutsideChat) {147// In YOLO mode, auto-approve MCP sampling requests without prompting148if (globalAutoApprove) {149this._sessionSets.allowedOutsideChat.set(opts.server.definition.id, true);150return this._getMatchingModel(opts);151}152const retry = await this._showContextual(153opts.isDuringToolCall,154localize('mcp.sampling.allowOutsideChat.title', 'Allow MCP server "{0}" to make LLM requests?', opts.server.definition.label),155localize('mcp.sampling.allowOutsideChat.desc', 'The MCP server "{0}" has issued a request to make a language model call. Do you want to allow it to make requests, outside of tool calls during chat?', opts.server.definition.label),156this.allowButtons(opts.server, 'allowedOutsideChat')157);158if (retry) {159return this._getMatchingModel(opts);160}161throw McpError.notAllowed();162} else if (model === ModelMatch.NotAllowed) {163throw McpError.notAllowed();164} else if (model === ModelMatch.NoMatchingModel) {165const newlyPickedModels = opts.isDuringToolCall166? await this._commandService.executeCommand<number>(McpCommandIds.ConfigureSamplingModels, opts.server)167: await this._notify(168localize('mcp.sampling.needsModels', 'MCP server "{0}" triggered a language model request, but it has no allowlisted models.', opts.server.definition.label),169{170[localize('configure', 'Configure')]: () => this._commandService.executeCommand<number>(McpCommandIds.ConfigureSamplingModels, opts.server),171[localize('cancel', 'Cancel')]: () => Promise.resolve(undefined),172}173);174if (newlyPickedModels) {175return this._getMatchingModel(opts);176}177throw McpError.notAllowed();178}179180return model;181}182183private allowButtons(server: IMcpServer, key: 'allowedDuringChat' | 'allowedOutsideChat') {184return {185[localize('mcp.sampling.allow.inSession', 'Allow in this Session')]: async () => {186this._sessionSets[key].set(server.definition.id, true);187return true;188},189[localize('mcp.sampling.allow.always', 'Always')]: async () => {190await this.updateConfig(server, c => c[key] = true);191return true;192},193[localize('mcp.sampling.allow.notNow', 'Not Now')]: async () => {194this._sessionSets[key].set(server.definition.id, false);195return false;196},197[localize('mcp.sampling.allow.never', 'Never')]: async () => {198await this.updateConfig(server, c => c[key] = false);199return false;200},201};202}203204private async _showContextual<T>(isDuringToolCall: boolean, title: string, message: string, buttons: Record<string, () => T>): Promise<Awaited<T> | undefined> {205if (isDuringToolCall) {206const result = await this._dialogService.prompt({207type: 'question',208title: title,209message,210buttons: Object.entries(buttons).map(([label, run]) => ({ label, run })),211});212return await result.result;213} else {214return await this._notify(message, buttons);215}216}217218private async _notify<T>(message: string, buttons: Record<string, () => T>): Promise<Awaited<T> | undefined> {219return await new Promise<T | undefined>(resolve => {220const handle = this._notificationService.prompt(221Severity.Info,222message,223Object.entries(buttons).map(([label, action]) => ({224label,225run: () => resolve(action()),226}))227);228Event.once(handle.onDidClose)(() => resolve(undefined));229});230}231232/**233* Gets the matching model for the MCP server in this context, or234* a reason why no model could be selected.235*/236private async _getMatchingModelInner(server: IMcpServer, isDuringToolCall: boolean, preferences: MCP.ModelPreferences | undefined): Promise<ModelMatch | string> {237const config = this.getConfig(server);238// 1. Ensure the server is allowed to sample in this context239if (isDuringToolCall && !config.allowedDuringChat && !this._sessionSets.allowedDuringChat.has(server.definition.id)) {240return config.allowedDuringChat === undefined ? ModelMatch.UnsureAllowedDuringChat : ModelMatch.NotAllowed;241} else if (!isDuringToolCall && !config.allowedOutsideChat && !this._sessionSets.allowedOutsideChat.has(server.definition.id)) {242return config.allowedOutsideChat === undefined ? ModelMatch.UnsureAllowedOutsideChat : ModelMatch.NotAllowed;243}244245// 2. Get the configured models, or the default model(s)246const foundModelIdsDeep = config.allowedModels?.filter(m => !!this._languageModelsService.lookupLanguageModel(m)) || this._languageModelsService.getLanguageModelIds().filter(m => this._languageModelsService.lookupLanguageModel(m)?.isDefaultForLocation[ChatAgentLocation.Chat]);247248const foundModelIds = foundModelIdsDeep.flat().sort((a, b) => b.length - a.length); // Sort by length to prefer most specific249250if (!foundModelIds.length) {251return ModelMatch.NoMatchingModel;252}253254// 3. If preferences are provided, try to match them from the allowed models255if (preferences?.hints) {256const found = mapFindFirst(preferences.hints, hint => foundModelIds.find(model => model.toLowerCase().includes(hint.name!.toLowerCase())));257if (found) {258return found;259}260}261262return foundModelIds[0]; // Return the first matching model263}264265private _configKey(server: IMcpServer) {266return `${server.collection.label}: ${server.definition.label}`;267}268269public getConfig(server: IMcpServer): IMcpServerSamplingConfiguration {270return this._getConfig(server).value || {};271}272273/**274* _getConfig reads the sampling config reads the `{ server: data }` mapping275* from the appropriate config. We read from the most specific possible276* config up to the default configuration location that the MCP server itself277* is defined in. We don't go further because then workspace-specific servers278* would get in the user settings which is not meaningful and could lead279* to confusion.280*281* todo@connor4312: generalize this for other esttings when we have them282*/283private _getConfig(server: IMcpServer) {284const def = server.readDefinitions().get();285const mostSpecificConfig = ConfigurationTarget.MEMORY;286const leastSpecificConfig = def.collection?.configTarget || ConfigurationTarget.USER;287const key = this._configKey(server);288const resource = def.collection?.presentation?.origin;289290const configValue = this._configurationService.inspect<Record<string, IMcpServerSamplingConfiguration>>(mcpServerSamplingSection, { resource });291for (let target = mostSpecificConfig; target >= leastSpecificConfig; target--) {292const mapping = getConfigValueInTarget(configValue, target);293const config = mapping?.[key];294if (config) {295return { value: config, key, mapping, target, resource };296}297}298299return { value: undefined, mapping: getConfigValueInTarget(configValue, leastSpecificConfig), key, target: leastSpecificConfig, resource };300}301302public async updateConfig(server: IMcpServer, mutate: (r: IMcpServerSamplingConfiguration) => unknown) {303const { value, mapping, key, target, resource } = this._getConfig(server);304305const newConfig = { ...value };306mutate(newConfig);307308await this._configurationService.updateValue(309mcpServerSamplingSection,310{ ...mapping, [key]: newConfig },311{ resource },312target,313);314return newConfig;315}316}317318319