Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/vscode
Path: blob/main/src/vs/workbench/contrib/mcp/common/mcpSamplingService.ts
5251 views
1
/*---------------------------------------------------------------------------------------------
2
* Copyright (c) Microsoft Corporation. All rights reserved.
3
* Licensed under the MIT License. See License.txt in the project root for license information.
4
*--------------------------------------------------------------------------------------------*/
5
6
import { asArray } from '../../../../base/common/arrays.js';
7
import { mapFindFirst } from '../../../../base/common/arraysFind.js';
8
import { Sequencer } from '../../../../base/common/async.js';
9
import { decodeBase64 } from '../../../../base/common/buffer.js';
10
import { CancellationToken } from '../../../../base/common/cancellation.js';
11
import { Event } from '../../../../base/common/event.js';
12
import { Disposable } from '../../../../base/common/lifecycle.js';
13
import { isDefined } from '../../../../base/common/types.js';
14
import { localize } from '../../../../nls.js';
15
import { ICommandService } from '../../../../platform/commands/common/commands.js';
16
import { ConfigurationTarget, getConfigValueInTarget, IConfigurationService } from '../../../../platform/configuration/common/configuration.js';
17
import { IDialogService } from '../../../../platform/dialogs/common/dialogs.js';
18
import { ExtensionIdentifier } from '../../../../platform/extensions/common/extensions.js';
19
import { IInstantiationService } from '../../../../platform/instantiation/common/instantiation.js';
20
import { INotificationService, Severity } from '../../../../platform/notification/common/notification.js';
21
import { ChatAgentLocation, ChatConfiguration } from '../../chat/common/constants.js';
22
import { ChatImageMimeType, ChatMessageRole, IChatMessage, IChatMessagePart, ILanguageModelsService } from '../../chat/common/languageModels.js';
23
import { McpCommandIds } from './mcpCommandIds.js';
24
import { IMcpServerSamplingConfiguration, mcpServerSamplingSection } from './mcpConfiguration.js';
25
import { McpSamplingLog } from './mcpSamplingLog.js';
26
import { IMcpSamplingService, IMcpServer, ISamplingOptions, ISamplingResult, McpError } from './mcpTypes.js';
27
import { MCP } from './modelContextProtocol.js';
28
29
const enum ModelMatch {
30
UnsureAllowedDuringChat,
31
UnsureAllowedOutsideChat,
32
NotAllowed,
33
NoMatchingModel,
34
}
35
36
export class McpSamplingService extends Disposable implements IMcpSamplingService {
37
declare readonly _serviceBrand: undefined;
38
39
private readonly _sessionSets = {
40
allowedDuringChat: new Map<string, boolean>(),
41
allowedOutsideChat: new Map<string, boolean>(),
42
};
43
44
private readonly _logs: McpSamplingLog;
45
46
private readonly _modelSequencer = new Sequencer();
47
48
constructor(
49
@ILanguageModelsService private readonly _languageModelsService: ILanguageModelsService,
50
@IConfigurationService private readonly _configurationService: IConfigurationService,
51
@IDialogService private readonly _dialogService: IDialogService,
52
@INotificationService private readonly _notificationService: INotificationService,
53
@ICommandService private readonly _commandService: ICommandService,
54
@IInstantiationService instaService: IInstantiationService,
55
) {
56
super();
57
this._logs = this._register(instaService.createInstance(McpSamplingLog));
58
}
59
60
async sample(opts: ISamplingOptions, token = CancellationToken.None): Promise<ISamplingResult> {
61
const messages = opts.params.messages.map((message): IChatMessage | undefined => {
62
const content: IChatMessagePart[] = asArray(message.content).map((part): IChatMessagePart | undefined => part.type === 'text'
63
? { type: 'text', value: part.text }
64
: part.type === 'image' || part.type === 'audio'
65
? { type: 'image_url', value: { mimeType: part.mimeType as ChatImageMimeType, data: decodeBase64(part.data) } }
66
: undefined
67
).filter(isDefined);
68
69
if (!content.length) {
70
return undefined;
71
}
72
return {
73
role: message.role === 'assistant' ? ChatMessageRole.Assistant : ChatMessageRole.User,
74
content,
75
};
76
}).filter(isDefined);
77
78
if (opts.params.systemPrompt) {
79
messages.unshift({ role: ChatMessageRole.System, content: [{ type: 'text', value: opts.params.systemPrompt }] });
80
}
81
82
const model = await this._modelSequencer.queue(() => this._getMatchingModel(opts));
83
// todo@connor4312: nullExtensionDescription.identifier -> undefined with API update
84
const response = await this._languageModelsService.sendChatRequest(model, new ExtensionIdentifier('core'), messages, {}, token);
85
86
let responseText = '';
87
88
// MCP doesn't have a notion of a multi-part sampling response, so we only preserve text
89
// Ref https://github.com/modelcontextprotocol/modelcontextprotocol/issues/91
90
const streaming = (async () => {
91
for await (const part of response.stream) {
92
if (Array.isArray(part)) {
93
for (const p of part) {
94
if (p.type === 'text') {
95
responseText += p.value;
96
}
97
}
98
} else if (part.type === 'text') {
99
responseText += part.value;
100
}
101
}
102
})();
103
104
try {
105
await Promise.all([response.result, streaming]);
106
this._logs.add(opts.server, opts.params.messages, responseText, model);
107
return {
108
sample: {
109
model,
110
content: { type: 'text', text: responseText },
111
role: 'assistant', // it came from the model!
112
},
113
};
114
} catch (err) {
115
throw McpError.unknown(err);
116
}
117
}
118
119
hasLogs(server: IMcpServer): boolean {
120
return this._logs.has(server);
121
}
122
123
getLogText(server: IMcpServer): string {
124
return this._logs.getAsText(server);
125
}
126
127
private async _getMatchingModel(opts: ISamplingOptions): Promise<string> {
128
const model = await this._getMatchingModelInner(opts.server, opts.isDuringToolCall, opts.params.modelPreferences);
129
const globalAutoApprove = this._configurationService.getValue<boolean>(ChatConfiguration.GlobalAutoApprove);
130
131
if (model === ModelMatch.UnsureAllowedDuringChat) {
132
// In YOLO mode, auto-approve MCP sampling requests without prompting
133
if (globalAutoApprove) {
134
this._sessionSets.allowedDuringChat.set(opts.server.definition.id, true);
135
return this._getMatchingModel(opts);
136
}
137
const retry = await this._showContextual(
138
opts.isDuringToolCall,
139
localize('mcp.sampling.allowDuringChat.title', 'Allow MCP tools from "{0}" to make LLM requests?', opts.server.definition.label),
140
localize('mcp.sampling.allowDuringChat.desc', 'The MCP server "{0}" has issued a request to make a language model call. Do you want to allow it to make requests during chat?', opts.server.definition.label),
141
this.allowButtons(opts.server, 'allowedDuringChat')
142
);
143
if (retry) {
144
return this._getMatchingModel(opts);
145
}
146
throw McpError.notAllowed();
147
} else if (model === ModelMatch.UnsureAllowedOutsideChat) {
148
// In YOLO mode, auto-approve MCP sampling requests without prompting
149
if (globalAutoApprove) {
150
this._sessionSets.allowedOutsideChat.set(opts.server.definition.id, true);
151
return this._getMatchingModel(opts);
152
}
153
const retry = await this._showContextual(
154
opts.isDuringToolCall,
155
localize('mcp.sampling.allowOutsideChat.title', 'Allow MCP server "{0}" to make LLM requests?', opts.server.definition.label),
156
localize('mcp.sampling.allowOutsideChat.desc', 'The MCP server "{0}" has issued a request to make a language model call. Do you want to allow it to make requests, outside of tool calls during chat?', opts.server.definition.label),
157
this.allowButtons(opts.server, 'allowedOutsideChat')
158
);
159
if (retry) {
160
return this._getMatchingModel(opts);
161
}
162
throw McpError.notAllowed();
163
} else if (model === ModelMatch.NotAllowed) {
164
throw McpError.notAllowed();
165
} else if (model === ModelMatch.NoMatchingModel) {
166
const newlyPickedModels = opts.isDuringToolCall
167
? await this._commandService.executeCommand<number>(McpCommandIds.ConfigureSamplingModels, opts.server)
168
: await this._notify(
169
localize('mcp.sampling.needsModels', 'MCP server "{0}" triggered a language model request, but it has no allowlisted models.', opts.server.definition.label),
170
{
171
[localize('configure', 'Configure')]: () => this._commandService.executeCommand<number>(McpCommandIds.ConfigureSamplingModels, opts.server),
172
[localize('cancel', 'Cancel')]: () => Promise.resolve(undefined),
173
}
174
);
175
if (newlyPickedModels) {
176
return this._getMatchingModel(opts);
177
}
178
throw McpError.notAllowed();
179
}
180
181
return model;
182
}
183
184
private allowButtons(server: IMcpServer, key: 'allowedDuringChat' | 'allowedOutsideChat') {
185
return {
186
[localize('mcp.sampling.allow.inSession', 'Allow in this Session')]: async () => {
187
this._sessionSets[key].set(server.definition.id, true);
188
return true;
189
},
190
[localize('mcp.sampling.allow.always', 'Always')]: async () => {
191
await this.updateConfig(server, c => c[key] = true);
192
return true;
193
},
194
[localize('mcp.sampling.allow.notNow', 'Not Now')]: async () => {
195
this._sessionSets[key].set(server.definition.id, false);
196
return false;
197
},
198
[localize('mcp.sampling.allow.never', 'Never')]: async () => {
199
await this.updateConfig(server, c => c[key] = false);
200
return false;
201
},
202
};
203
}
204
205
private async _showContextual<T>(isDuringToolCall: boolean, title: string, message: string, buttons: Record<string, () => T>): Promise<Awaited<T> | undefined> {
206
if (isDuringToolCall) {
207
const result = await this._dialogService.prompt({
208
type: 'question',
209
title: title,
210
message,
211
buttons: Object.entries(buttons).map(([label, run]) => ({ label, run })),
212
});
213
return await result.result;
214
} else {
215
return await this._notify(message, buttons);
216
}
217
}
218
219
private async _notify<T>(message: string, buttons: Record<string, () => T>): Promise<Awaited<T> | undefined> {
220
return await new Promise<T | undefined>(resolve => {
221
const handle = this._notificationService.prompt(
222
Severity.Info,
223
message,
224
Object.entries(buttons).map(([label, action]) => ({
225
label,
226
run: () => resolve(action()),
227
}))
228
);
229
Event.once(handle.onDidClose)(() => resolve(undefined));
230
});
231
}
232
233
/**
234
* Gets the matching model for the MCP server in this context, or
235
* a reason why no model could be selected.
236
*/
237
private async _getMatchingModelInner(server: IMcpServer, isDuringToolCall: boolean, preferences: MCP.ModelPreferences | undefined): Promise<ModelMatch | string> {
238
const config = this.getConfig(server);
239
// 1. Ensure the server is allowed to sample in this context
240
if (isDuringToolCall && !config.allowedDuringChat && !this._sessionSets.allowedDuringChat.has(server.definition.id)) {
241
return config.allowedDuringChat === undefined ? ModelMatch.UnsureAllowedDuringChat : ModelMatch.NotAllowed;
242
} else if (!isDuringToolCall && !config.allowedOutsideChat && !this._sessionSets.allowedOutsideChat.has(server.definition.id)) {
243
return config.allowedOutsideChat === undefined ? ModelMatch.UnsureAllowedOutsideChat : ModelMatch.NotAllowed;
244
}
245
246
// 2. Get the configured models, or the default model(s)
247
const foundModelIdsDeep = config.allowedModels?.filter(m => !!this._languageModelsService.lookupLanguageModel(m)) || this._languageModelsService.getLanguageModelIds().filter(m => this._languageModelsService.lookupLanguageModel(m)?.isDefaultForLocation[ChatAgentLocation.Chat]);
248
249
const foundModelIds = foundModelIdsDeep.flat().sort((a, b) => b.length - a.length); // Sort by length to prefer most specific
250
251
if (!foundModelIds.length) {
252
return ModelMatch.NoMatchingModel;
253
}
254
255
// 3. If preferences are provided, try to match them from the allowed models
256
if (preferences?.hints) {
257
const found = mapFindFirst(preferences.hints, hint => foundModelIds.find(model => model.toLowerCase().includes(hint.name!.toLowerCase())));
258
if (found) {
259
return found;
260
}
261
}
262
263
return foundModelIds[0]; // Return the first matching model
264
}
265
266
private _configKey(server: IMcpServer) {
267
return `${server.collection.label}: ${server.definition.label}`;
268
}
269
270
public getConfig(server: IMcpServer): IMcpServerSamplingConfiguration {
271
return this._getConfig(server).value || {};
272
}
273
274
/**
275
* _getConfig reads the sampling config reads the `{ server: data }` mapping
276
* from the appropriate config. We read from the most specific possible
277
* config up to the default configuration location that the MCP server itself
278
* is defined in. We don't go further because then workspace-specific servers
279
* would get in the user settings which is not meaningful and could lead
280
* to confusion.
281
*
282
* todo@connor4312: generalize this for other esttings when we have them
283
*/
284
private _getConfig(server: IMcpServer) {
285
const def = server.readDefinitions().get();
286
const mostSpecificConfig = ConfigurationTarget.MEMORY;
287
const leastSpecificConfig = def.collection?.configTarget || ConfigurationTarget.USER;
288
const key = this._configKey(server);
289
const resource = def.collection?.presentation?.origin;
290
291
const configValue = this._configurationService.inspect<Record<string, IMcpServerSamplingConfiguration>>(mcpServerSamplingSection, { resource });
292
for (let target = mostSpecificConfig; target >= leastSpecificConfig; target--) {
293
const mapping = getConfigValueInTarget(configValue, target);
294
const config = mapping?.[key];
295
if (config) {
296
return { value: config, key, mapping, target, resource };
297
}
298
}
299
300
return { value: undefined, mapping: getConfigValueInTarget(configValue, leastSpecificConfig), key, target: leastSpecificConfig, resource };
301
}
302
303
public async updateConfig(server: IMcpServer, mutate: (r: IMcpServerSamplingConfiguration) => unknown) {
304
const { value, mapping, key, target, resource } = this._getConfig(server);
305
306
const newConfig = { ...value };
307
mutate(newConfig);
308
309
await this._configurationService.updateValue(
310
mcpServerSamplingSection,
311
{ ...mapping, [key]: newConfig },
312
{ resource },
313
target,
314
);
315
return newConfig;
316
}
317
}
318
319