Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/vscode
Path: blob/main/src/vs/workbench/contrib/mcp/common/mcpSamplingService.ts
3296 views
1
/*---------------------------------------------------------------------------------------------
2
* Copyright (c) Microsoft Corporation. All rights reserved.
3
* Licensed under the MIT License. See License.txt in the project root for license information.
4
*--------------------------------------------------------------------------------------------*/
5
6
import { mapFindFirst } from '../../../../base/common/arraysFind.js';
7
import { decodeBase64 } from '../../../../base/common/buffer.js';
8
import { CancellationToken } from '../../../../base/common/cancellation.js';
9
import { Event } from '../../../../base/common/event.js';
10
import { Disposable } from '../../../../base/common/lifecycle.js';
11
import { isDefined } from '../../../../base/common/types.js';
12
import { localize } from '../../../../nls.js';
13
import { ICommandService } from '../../../../platform/commands/common/commands.js';
14
import { ConfigurationTarget, getConfigValueInTarget, IConfigurationService } from '../../../../platform/configuration/common/configuration.js';
15
import { IDialogService } from '../../../../platform/dialogs/common/dialogs.js';
16
import { ExtensionIdentifier } from '../../../../platform/extensions/common/extensions.js';
17
import { IInstantiationService } from '../../../../platform/instantiation/common/instantiation.js';
18
import { INotificationService, Severity } from '../../../../platform/notification/common/notification.js';
19
import { ChatImageMimeType, ChatMessageRole, IChatMessage, IChatMessagePart, ILanguageModelsService } from '../../chat/common/languageModels.js';
20
import { McpCommandIds } from './mcpCommandIds.js';
21
import { IMcpServerSamplingConfiguration, mcpServerSamplingSection } from './mcpConfiguration.js';
22
import { McpSamplingLog } from './mcpSamplingLog.js';
23
import { IMcpSamplingService, IMcpServer, ISamplingOptions, ISamplingResult, McpError } from './mcpTypes.js';
24
import { MCP } from './modelContextProtocol.js';
25
26
const enum ModelMatch {
27
UnsureAllowedDuringChat,
28
UnsureAllowedOutsideChat,
29
NotAllowed,
30
NoMatchingModel,
31
}
32
33
export class McpSamplingService extends Disposable implements IMcpSamplingService {
34
declare readonly _serviceBrand: undefined;
35
36
private readonly _sessionSets = {
37
allowedDuringChat: new Map<string, boolean>(),
38
allowedOutsideChat: new Map<string, boolean>(),
39
};
40
41
private readonly _logs: McpSamplingLog;
42
43
constructor(
44
@ILanguageModelsService private readonly _languageModelsService: ILanguageModelsService,
45
@IConfigurationService private readonly _configurationService: IConfigurationService,
46
@IDialogService private readonly _dialogService: IDialogService,
47
@INotificationService private readonly _notificationService: INotificationService,
48
@ICommandService private readonly _commandService: ICommandService,
49
@IInstantiationService instaService: IInstantiationService,
50
) {
51
super();
52
this._logs = this._register(instaService.createInstance(McpSamplingLog));
53
}
54
55
async sample(opts: ISamplingOptions, token = CancellationToken.None): Promise<ISamplingResult> {
56
const messages = opts.params.messages.map((message): IChatMessage | undefined => {
57
const content: IChatMessagePart | undefined = message.content.type === 'text'
58
? { type: 'text', value: message.content.text }
59
: message.content.type === 'image' || message.content.type === 'audio'
60
? { type: 'image_url', value: { mimeType: message.content.mimeType as ChatImageMimeType, data: decodeBase64(message.content.data) } }
61
: undefined;
62
if (!content) {
63
return undefined;
64
}
65
return {
66
role: message.role === 'assistant' ? ChatMessageRole.Assistant : ChatMessageRole.User,
67
content: [content]
68
};
69
}).filter(isDefined);
70
71
if (opts.params.systemPrompt) {
72
messages.unshift({ role: ChatMessageRole.System, content: [{ type: 'text', value: opts.params.systemPrompt }] });
73
}
74
75
const model = await this._getMatchingModel(opts);
76
// todo@connor4312: nullExtensionDescription.identifier -> undefined with API update
77
const response = await this._languageModelsService.sendChatRequest(model, new ExtensionIdentifier('core'), messages, {}, token);
78
79
let responseText = '';
80
81
// MCP doesn't have a notion of a multi-part sampling response, so we only preserve text
82
// Ref https://github.com/modelcontextprotocol/modelcontextprotocol/issues/91
83
const streaming = (async () => {
84
for await (const part of response.stream) {
85
if (Array.isArray(part)) {
86
for (const p of part) {
87
if (p.type === 'text') {
88
responseText += p.value;
89
}
90
}
91
} else if (part.type === 'text') {
92
responseText += part.value;
93
}
94
}
95
})();
96
97
try {
98
await Promise.all([response.result, streaming]);
99
this._logs.add(opts.server, opts.params.messages, responseText, model);
100
return {
101
sample: {
102
model,
103
content: { type: 'text', text: responseText },
104
role: 'assistant', // it came from the model!
105
},
106
};
107
} catch (err) {
108
throw McpError.unknown(err);
109
}
110
}
111
112
hasLogs(server: IMcpServer): boolean {
113
return this._logs.has(server);
114
}
115
116
getLogText(server: IMcpServer): string {
117
return this._logs.getAsText(server);
118
}
119
120
private async _getMatchingModel(opts: ISamplingOptions): Promise<string> {
121
const model = await this._getMatchingModelInner(opts.server, opts.isDuringToolCall, opts.params.modelPreferences);
122
123
if (model === ModelMatch.UnsureAllowedDuringChat) {
124
const retry = await this._showContextual(
125
opts.isDuringToolCall,
126
localize('mcp.sampling.allowDuringChat.title', 'Allow MCP tools from "{0}" to make LLM requests?', opts.server.definition.label),
127
localize('mcp.sampling.allowDuringChat.desc', 'The MCP server "{0}" has issued a request to make a language model call. Do you want to allow it to make requests during chat?', opts.server.definition.label),
128
this.allowButtons(opts.server, 'allowedDuringChat')
129
);
130
if (retry) {
131
return this._getMatchingModel(opts);
132
}
133
throw McpError.notAllowed();
134
} else if (model === ModelMatch.UnsureAllowedOutsideChat) {
135
const retry = await this._showContextual(
136
opts.isDuringToolCall,
137
localize('mcp.sampling.allowOutsideChat.title', 'Allow MCP server "{0}" to make LLM requests?', opts.server.definition.label),
138
localize('mcp.sampling.allowOutsideChat.desc', 'The MCP server "{0}" has issued a request to make a language model call. Do you want to allow it to make requests, outside of tool calls during chat?', opts.server.definition.label),
139
this.allowButtons(opts.server, 'allowedOutsideChat')
140
);
141
if (retry) {
142
return this._getMatchingModel(opts);
143
}
144
throw McpError.notAllowed();
145
} else if (model === ModelMatch.NotAllowed) {
146
throw McpError.notAllowed();
147
} else if (model === ModelMatch.NoMatchingModel) {
148
const newlyPickedModels = opts.isDuringToolCall
149
? await this._commandService.executeCommand<number>(McpCommandIds.ConfigureSamplingModels, opts.server)
150
: await this._notify(
151
localize('mcp.sampling.needsModels', 'MCP server "{0}" triggered a language model request, but it has no allowlisted models.', opts.server.definition.label),
152
{
153
[localize('configure', 'Configure')]: () => this._commandService.executeCommand<number>(McpCommandIds.ConfigureSamplingModels, opts.server),
154
[localize('cancel', 'Cancel')]: () => Promise.resolve(undefined),
155
}
156
);
157
if (newlyPickedModels) {
158
return this._getMatchingModel(opts);
159
}
160
throw McpError.notAllowed();
161
}
162
163
return model;
164
}
165
166
private allowButtons(server: IMcpServer, key: 'allowedDuringChat' | 'allowedOutsideChat') {
167
return {
168
[localize('mcp.sampling.allow.inSession', 'Allow in this Session')]: async () => {
169
this._sessionSets[key].set(server.definition.id, true);
170
return true;
171
},
172
[localize('mcp.sampling.allow.always', 'Always')]: async () => {
173
await this.updateConfig(server, c => c[key] = true);
174
return true;
175
},
176
[localize('mcp.sampling.allow.notNow', 'Not Now')]: async () => {
177
this._sessionSets[key].set(server.definition.id, false);
178
return false;
179
},
180
[localize('mcp.sampling.allow.never', 'Never')]: async () => {
181
await this.updateConfig(server, c => c[key] = false);
182
return false;
183
},
184
};
185
}
186
187
private async _showContextual<T>(isDuringToolCall: boolean, title: string, message: string, buttons: Record<string, () => T>): Promise<Awaited<T> | undefined> {
188
if (isDuringToolCall) {
189
const result = await this._dialogService.prompt({
190
type: 'question',
191
title: title,
192
message,
193
buttons: Object.entries(buttons).map(([label, run]) => ({ label, run })),
194
});
195
return await result.result;
196
} else {
197
return await this._notify(message, buttons);
198
}
199
}
200
201
private async _notify<T>(message: string, buttons: Record<string, () => T>): Promise<Awaited<T> | undefined> {
202
return await new Promise<T | undefined>(resolve => {
203
const handle = this._notificationService.prompt(
204
Severity.Info,
205
message,
206
Object.entries(buttons).map(([label, action]) => ({
207
label,
208
run: () => resolve(action()),
209
}))
210
);
211
Event.once(handle.onDidClose)(() => resolve(undefined));
212
});
213
}
214
215
/**
216
* Gets the matching model for the MCP server in this context, or
217
* a reason why no model could be selected.
218
*/
219
private async _getMatchingModelInner(server: IMcpServer, isDuringToolCall: boolean, preferences: MCP.ModelPreferences | undefined): Promise<ModelMatch | string> {
220
const config = this.getConfig(server);
221
// 1. Ensure the server is allowed to sample in this context
222
if (isDuringToolCall && !config.allowedDuringChat && !this._sessionSets.allowedDuringChat.has(server.definition.id)) {
223
return config.allowedDuringChat === undefined ? ModelMatch.UnsureAllowedDuringChat : ModelMatch.NotAllowed;
224
} else if (!isDuringToolCall && !config.allowedOutsideChat && !this._sessionSets.allowedOutsideChat.has(server.definition.id)) {
225
return config.allowedOutsideChat === undefined ? ModelMatch.UnsureAllowedOutsideChat : ModelMatch.NotAllowed;
226
}
227
228
// 2. Get the configured models, or the default model(s)
229
const foundModelIdsDeep = config.allowedModels?.filter(m => !!this._languageModelsService.lookupLanguageModel(m)) || this._languageModelsService.getLanguageModelIds().filter(m => this._languageModelsService.lookupLanguageModel(m)?.isDefault);
230
231
const foundModelIds = foundModelIdsDeep.flat().sort((a, b) => b.length - a.length); // Sort by length to prefer most specific
232
233
if (!foundModelIds.length) {
234
return ModelMatch.NoMatchingModel;
235
}
236
237
// 3. If preferences are provided, try to match them from the allowed models
238
if (preferences?.hints) {
239
const found = mapFindFirst(preferences.hints, hint => foundModelIds.find(model => model.toLowerCase().includes(hint.name!.toLowerCase())));
240
if (found) {
241
return found;
242
}
243
}
244
245
return foundModelIds[0]; // Return the first matching model
246
}
247
248
private _configKey(server: IMcpServer) {
249
return `${server.collection.label}: ${server.definition.label}`;
250
}
251
252
public getConfig(server: IMcpServer): IMcpServerSamplingConfiguration {
253
return this._getConfig(server).value || {};
254
}
255
256
/**
257
* _getConfig reads the sampling config reads the `{ server: data }` mapping
258
* from the appropriate config. We read from the most specific possible
259
* config up to the default configuration location that the MCP server itself
260
* is defined in. We don't go further because then workspace-specific servers
261
* would get in the user settings which is not meaningful and could lead
262
* to confusion.
263
*
264
* todo@connor4312: generalize this for other esttings when we have them
265
*/
266
private _getConfig(server: IMcpServer) {
267
const def = server.readDefinitions().get();
268
const mostSpecificConfig = ConfigurationTarget.MEMORY;
269
const leastSpecificConfig = def.collection?.configTarget || ConfigurationTarget.USER;
270
const key = this._configKey(server);
271
const resource = def.collection?.presentation?.origin;
272
273
const configValue = this._configurationService.inspect<Record<string, IMcpServerSamplingConfiguration>>(mcpServerSamplingSection, { resource });
274
for (let target = mostSpecificConfig; target >= leastSpecificConfig; target--) {
275
const mapping = getConfigValueInTarget(configValue, target);
276
const config = mapping?.[key];
277
if (config) {
278
return { value: config, key, mapping, target, resource };
279
}
280
}
281
282
return { value: undefined, mapping: undefined, key, target: leastSpecificConfig, resource };
283
}
284
285
public async updateConfig(server: IMcpServer, mutate: (r: IMcpServerSamplingConfiguration) => unknown) {
286
const { value, mapping, key, target, resource } = this._getConfig(server);
287
288
const newConfig = { ...value };
289
mutate(newConfig);
290
291
await this._configurationService.updateValue(
292
mcpServerSamplingSection,
293
{ ...mapping, [key]: newConfig },
294
{ resource },
295
target,
296
);
297
return newConfig;
298
}
299
}
300
301