Path: blob/main/src/vs/workbench/contrib/mcp/common/mcpSamplingService.ts
3296 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/45import { mapFindFirst } from '../../../../base/common/arraysFind.js';6import { decodeBase64 } from '../../../../base/common/buffer.js';7import { CancellationToken } from '../../../../base/common/cancellation.js';8import { Event } from '../../../../base/common/event.js';9import { Disposable } from '../../../../base/common/lifecycle.js';10import { isDefined } from '../../../../base/common/types.js';11import { localize } from '../../../../nls.js';12import { ICommandService } from '../../../../platform/commands/common/commands.js';13import { ConfigurationTarget, getConfigValueInTarget, IConfigurationService } from '../../../../platform/configuration/common/configuration.js';14import { IDialogService } from '../../../../platform/dialogs/common/dialogs.js';15import { ExtensionIdentifier } from '../../../../platform/extensions/common/extensions.js';16import { IInstantiationService } from '../../../../platform/instantiation/common/instantiation.js';17import { INotificationService, Severity } from '../../../../platform/notification/common/notification.js';18import { ChatImageMimeType, ChatMessageRole, IChatMessage, IChatMessagePart, ILanguageModelsService } from '../../chat/common/languageModels.js';19import { McpCommandIds } from './mcpCommandIds.js';20import { IMcpServerSamplingConfiguration, mcpServerSamplingSection } from './mcpConfiguration.js';21import { McpSamplingLog } from './mcpSamplingLog.js';22import { IMcpSamplingService, IMcpServer, ISamplingOptions, ISamplingResult, McpError } from './mcpTypes.js';23import { MCP } from './modelContextProtocol.js';2425const enum ModelMatch {26UnsureAllowedDuringChat,27UnsureAllowedOutsideChat,28NotAllowed,29NoMatchingModel,30}3132export class McpSamplingService extends Disposable implements IMcpSamplingService {33declare readonly _serviceBrand: undefined;3435private readonly _sessionSets = {36allowedDuringChat: new Map<string, boolean>(),37allowedOutsideChat: new Map<string, boolean>(),38};3940private readonly _logs: McpSamplingLog;4142constructor(43@ILanguageModelsService private readonly _languageModelsService: ILanguageModelsService,44@IConfigurationService private readonly _configurationService: IConfigurationService,45@IDialogService private readonly _dialogService: IDialogService,46@INotificationService private readonly _notificationService: INotificationService,47@ICommandService private readonly _commandService: ICommandService,48@IInstantiationService instaService: IInstantiationService,49) {50super();51this._logs = this._register(instaService.createInstance(McpSamplingLog));52}5354async sample(opts: ISamplingOptions, token = CancellationToken.None): Promise<ISamplingResult> {55const messages = opts.params.messages.map((message): IChatMessage | undefined => {56const content: IChatMessagePart | undefined = message.content.type === 'text'57? { type: 'text', value: message.content.text }58: message.content.type === 'image' || message.content.type === 'audio'59? { type: 'image_url', value: { mimeType: message.content.mimeType as ChatImageMimeType, data: decodeBase64(message.content.data) } }60: undefined;61if (!content) {62return undefined;63}64return {65role: message.role === 'assistant' ? ChatMessageRole.Assistant : ChatMessageRole.User,66content: [content]67};68}).filter(isDefined);6970if (opts.params.systemPrompt) {71messages.unshift({ role: ChatMessageRole.System, content: [{ type: 'text', value: opts.params.systemPrompt }] });72}7374const model = await this._getMatchingModel(opts);75// todo@connor4312: nullExtensionDescription.identifier -> undefined with API update76const response = await this._languageModelsService.sendChatRequest(model, new ExtensionIdentifier('core'), messages, {}, token);7778let responseText = '';7980// MCP doesn't have a notion of a multi-part sampling response, so we only preserve text81// Ref https://github.com/modelcontextprotocol/modelcontextprotocol/issues/9182const streaming = (async () => {83for await (const part of response.stream) {84if (Array.isArray(part)) {85for (const p of part) {86if (p.type === 'text') {87responseText += p.value;88}89}90} else if (part.type === 'text') {91responseText += part.value;92}93}94})();9596try {97await Promise.all([response.result, streaming]);98this._logs.add(opts.server, opts.params.messages, responseText, model);99return {100sample: {101model,102content: { type: 'text', text: responseText },103role: 'assistant', // it came from the model!104},105};106} catch (err) {107throw McpError.unknown(err);108}109}110111hasLogs(server: IMcpServer): boolean {112return this._logs.has(server);113}114115getLogText(server: IMcpServer): string {116return this._logs.getAsText(server);117}118119private async _getMatchingModel(opts: ISamplingOptions): Promise<string> {120const model = await this._getMatchingModelInner(opts.server, opts.isDuringToolCall, opts.params.modelPreferences);121122if (model === ModelMatch.UnsureAllowedDuringChat) {123const retry = await this._showContextual(124opts.isDuringToolCall,125localize('mcp.sampling.allowDuringChat.title', 'Allow MCP tools from "{0}" to make LLM requests?', opts.server.definition.label),126localize('mcp.sampling.allowDuringChat.desc', 'The MCP server "{0}" has issued a request to make a language model call. Do you want to allow it to make requests during chat?', opts.server.definition.label),127this.allowButtons(opts.server, 'allowedDuringChat')128);129if (retry) {130return this._getMatchingModel(opts);131}132throw McpError.notAllowed();133} else if (model === ModelMatch.UnsureAllowedOutsideChat) {134const retry = await this._showContextual(135opts.isDuringToolCall,136localize('mcp.sampling.allowOutsideChat.title', 'Allow MCP server "{0}" to make LLM requests?', opts.server.definition.label),137localize('mcp.sampling.allowOutsideChat.desc', 'The MCP server "{0}" has issued a request to make a language model call. Do you want to allow it to make requests, outside of tool calls during chat?', opts.server.definition.label),138this.allowButtons(opts.server, 'allowedOutsideChat')139);140if (retry) {141return this._getMatchingModel(opts);142}143throw McpError.notAllowed();144} else if (model === ModelMatch.NotAllowed) {145throw McpError.notAllowed();146} else if (model === ModelMatch.NoMatchingModel) {147const newlyPickedModels = opts.isDuringToolCall148? await this._commandService.executeCommand<number>(McpCommandIds.ConfigureSamplingModels, opts.server)149: await this._notify(150localize('mcp.sampling.needsModels', 'MCP server "{0}" triggered a language model request, but it has no allowlisted models.', opts.server.definition.label),151{152[localize('configure', 'Configure')]: () => this._commandService.executeCommand<number>(McpCommandIds.ConfigureSamplingModels, opts.server),153[localize('cancel', 'Cancel')]: () => Promise.resolve(undefined),154}155);156if (newlyPickedModels) {157return this._getMatchingModel(opts);158}159throw McpError.notAllowed();160}161162return model;163}164165private allowButtons(server: IMcpServer, key: 'allowedDuringChat' | 'allowedOutsideChat') {166return {167[localize('mcp.sampling.allow.inSession', 'Allow in this Session')]: async () => {168this._sessionSets[key].set(server.definition.id, true);169return true;170},171[localize('mcp.sampling.allow.always', 'Always')]: async () => {172await this.updateConfig(server, c => c[key] = true);173return true;174},175[localize('mcp.sampling.allow.notNow', 'Not Now')]: async () => {176this._sessionSets[key].set(server.definition.id, false);177return false;178},179[localize('mcp.sampling.allow.never', 'Never')]: async () => {180await this.updateConfig(server, c => c[key] = false);181return false;182},183};184}185186private async _showContextual<T>(isDuringToolCall: boolean, title: string, message: string, buttons: Record<string, () => T>): Promise<Awaited<T> | undefined> {187if (isDuringToolCall) {188const result = await this._dialogService.prompt({189type: 'question',190title: title,191message,192buttons: Object.entries(buttons).map(([label, run]) => ({ label, run })),193});194return await result.result;195} else {196return await this._notify(message, buttons);197}198}199200private async _notify<T>(message: string, buttons: Record<string, () => T>): Promise<Awaited<T> | undefined> {201return await new Promise<T | undefined>(resolve => {202const handle = this._notificationService.prompt(203Severity.Info,204message,205Object.entries(buttons).map(([label, action]) => ({206label,207run: () => resolve(action()),208}))209);210Event.once(handle.onDidClose)(() => resolve(undefined));211});212}213214/**215* Gets the matching model for the MCP server in this context, or216* a reason why no model could be selected.217*/218private async _getMatchingModelInner(server: IMcpServer, isDuringToolCall: boolean, preferences: MCP.ModelPreferences | undefined): Promise<ModelMatch | string> {219const config = this.getConfig(server);220// 1. Ensure the server is allowed to sample in this context221if (isDuringToolCall && !config.allowedDuringChat && !this._sessionSets.allowedDuringChat.has(server.definition.id)) {222return config.allowedDuringChat === undefined ? ModelMatch.UnsureAllowedDuringChat : ModelMatch.NotAllowed;223} else if (!isDuringToolCall && !config.allowedOutsideChat && !this._sessionSets.allowedOutsideChat.has(server.definition.id)) {224return config.allowedOutsideChat === undefined ? ModelMatch.UnsureAllowedOutsideChat : ModelMatch.NotAllowed;225}226227// 2. Get the configured models, or the default model(s)228const foundModelIdsDeep = config.allowedModels?.filter(m => !!this._languageModelsService.lookupLanguageModel(m)) || this._languageModelsService.getLanguageModelIds().filter(m => this._languageModelsService.lookupLanguageModel(m)?.isDefault);229230const foundModelIds = foundModelIdsDeep.flat().sort((a, b) => b.length - a.length); // Sort by length to prefer most specific231232if (!foundModelIds.length) {233return ModelMatch.NoMatchingModel;234}235236// 3. If preferences are provided, try to match them from the allowed models237if (preferences?.hints) {238const found = mapFindFirst(preferences.hints, hint => foundModelIds.find(model => model.toLowerCase().includes(hint.name!.toLowerCase())));239if (found) {240return found;241}242}243244return foundModelIds[0]; // Return the first matching model245}246247private _configKey(server: IMcpServer) {248return `${server.collection.label}: ${server.definition.label}`;249}250251public getConfig(server: IMcpServer): IMcpServerSamplingConfiguration {252return this._getConfig(server).value || {};253}254255/**256* _getConfig reads the sampling config reads the `{ server: data }` mapping257* from the appropriate config. We read from the most specific possible258* config up to the default configuration location that the MCP server itself259* is defined in. We don't go further because then workspace-specific servers260* would get in the user settings which is not meaningful and could lead261* to confusion.262*263* todo@connor4312: generalize this for other esttings when we have them264*/265private _getConfig(server: IMcpServer) {266const def = server.readDefinitions().get();267const mostSpecificConfig = ConfigurationTarget.MEMORY;268const leastSpecificConfig = def.collection?.configTarget || ConfigurationTarget.USER;269const key = this._configKey(server);270const resource = def.collection?.presentation?.origin;271272const configValue = this._configurationService.inspect<Record<string, IMcpServerSamplingConfiguration>>(mcpServerSamplingSection, { resource });273for (let target = mostSpecificConfig; target >= leastSpecificConfig; target--) {274const mapping = getConfigValueInTarget(configValue, target);275const config = mapping?.[key];276if (config) {277return { value: config, key, mapping, target, resource };278}279}280281return { value: undefined, mapping: undefined, key, target: leastSpecificConfig, resource };282}283284public async updateConfig(server: IMcpServer, mutate: (r: IMcpServerSamplingConfiguration) => unknown) {285const { value, mapping, key, target, resource } = this._getConfig(server);286287const newConfig = { ...value };288mutate(newConfig);289290await this._configurationService.updateValue(291mcpServerSamplingSection,292{ ...mapping, [key]: newConfig },293{ resource },294target,295);296return newConfig;297}298}299300301