Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/vscode
Path: blob/main/src/vs/workbench/api/browser/mainThreadLanguageModels.ts
5222 views
1
/*---------------------------------------------------------------------------------------------
2
* Copyright (c) Microsoft Corporation. All rights reserved.
3
* Licensed under the MIT License. See License.txt in the project root for license information.
4
*--------------------------------------------------------------------------------------------*/
5
6
import { AsyncIterableSource, DeferredPromise } from '../../../base/common/async.js';
7
import { VSBuffer } from '../../../base/common/buffer.js';
8
import { CancellationToken } from '../../../base/common/cancellation.js';
9
import { toErrorMessage } from '../../../base/common/errorMessage.js';
10
import { SerializedError, transformErrorForSerialization, transformErrorFromSerialization } from '../../../base/common/errors.js';
11
import { Emitter, Event } from '../../../base/common/event.js';
12
import { Disposable, DisposableMap, DisposableStore, IDisposable, toDisposable } from '../../../base/common/lifecycle.js';
13
import { URI, UriComponents } from '../../../base/common/uri.js';
14
import { localize } from '../../../nls.js';
15
import { ExtensionIdentifier } from '../../../platform/extensions/common/extensions.js';
16
import { ILogService } from '../../../platform/log/common/log.js';
17
import { resizeImage } from '../../contrib/chat/browser/chatImageUtils.js';
18
import { ILanguageModelIgnoredFilesService } from '../../contrib/chat/common/ignoredFiles.js';
19
import { IChatMessage, IChatResponsePart, ILanguageModelChatResponse, ILanguageModelChatSelector, ILanguageModelsService } from '../../contrib/chat/common/languageModels.js';
20
import { IAuthenticationAccessService } from '../../services/authentication/browser/authenticationAccessService.js';
21
import { AuthenticationSession, AuthenticationSessionsChangeEvent, IAuthenticationProvider, IAuthenticationService, INTERNAL_AUTH_PROVIDER_PREFIX } from '../../services/authentication/common/authentication.js';
22
import { IExtHostContext, extHostNamedCustomer } from '../../services/extensions/common/extHostCustomers.js';
23
import { IExtensionService } from '../../services/extensions/common/extensions.js';
24
import { SerializableObjectWithBuffers } from '../../services/extensions/common/proxyIdentifier.js';
25
import { ExtHostContext, ExtHostLanguageModelsShape, MainContext, MainThreadLanguageModelsShape } from '../common/extHost.protocol.js';
26
import { LanguageModelError } from '../common/extHostTypes.js';
27
28
@extHostNamedCustomer(MainContext.MainThreadLanguageModels)
29
export class MainThreadLanguageModels implements MainThreadLanguageModelsShape {
30
31
private readonly _proxy: ExtHostLanguageModelsShape;
32
private readonly _store = new DisposableStore();
33
private readonly _providerRegistrations = new DisposableMap<string>();
34
private readonly _lmProviderChange = new Emitter<{ vendor: string }>();
35
private readonly _pendingProgress = new Map<number, { defer: DeferredPromise<unknown>; stream: AsyncIterableSource<IChatResponsePart | IChatResponsePart[]> }>();
36
private readonly _ignoredFileProviderRegistrations = new DisposableMap<number>();
37
38
constructor(
39
extHostContext: IExtHostContext,
40
@ILanguageModelsService private readonly _chatProviderService: ILanguageModelsService,
41
@ILogService private readonly _logService: ILogService,
42
@IAuthenticationService private readonly _authenticationService: IAuthenticationService,
43
@IAuthenticationAccessService private readonly _authenticationAccessService: IAuthenticationAccessService,
44
@IExtensionService private readonly _extensionService: IExtensionService,
45
@ILanguageModelIgnoredFilesService private readonly _ignoredFilesService: ILanguageModelIgnoredFilesService,
46
) {
47
this._proxy = extHostContext.getProxy(ExtHostContext.ExtHostChatProvider);
48
}
49
50
dispose(): void {
51
this._lmProviderChange.dispose();
52
this._providerRegistrations.dispose();
53
this._ignoredFileProviderRegistrations.dispose();
54
this._store.dispose();
55
}
56
57
$registerLanguageModelProvider(vendor: string): void {
58
const disposables = new DisposableStore();
59
try {
60
disposables.add(this._chatProviderService.registerLanguageModelProvider(vendor, {
61
onDidChange: Event.filter(this._lmProviderChange.event, e => e.vendor === vendor, disposables) as unknown as Event<void>,
62
provideLanguageModelChatInfo: async (options, token) => {
63
const modelsAndIdentifiers = await this._proxy.$provideLanguageModelChatInfo(vendor, options, token);
64
modelsAndIdentifiers.forEach(m => {
65
if (m.metadata.auth) {
66
disposables.add(this._registerAuthenticationProvider(m.metadata.extension, m.metadata.auth));
67
}
68
});
69
return modelsAndIdentifiers;
70
},
71
sendChatRequest: async (modelId, messages, from, options, token) => {
72
const requestId = (Math.random() * 1e6) | 0;
73
const defer = new DeferredPromise<unknown>();
74
const stream = new AsyncIterableSource<IChatResponsePart | IChatResponsePart[]>();
75
76
try {
77
this._pendingProgress.set(requestId, { defer, stream });
78
await Promise.all(
79
messages.flatMap(msg => msg.content)
80
.filter(part => part.type === 'image_url')
81
.map(async part => {
82
part.value.data = VSBuffer.wrap(await resizeImage(part.value.data.buffer));
83
})
84
);
85
await this._proxy.$startChatRequest(modelId, requestId, from, new SerializableObjectWithBuffers(messages), options, token);
86
} catch (err) {
87
this._pendingProgress.delete(requestId);
88
throw err;
89
}
90
91
return {
92
result: defer.p,
93
stream: stream.asyncIterable
94
} satisfies ILanguageModelChatResponse;
95
},
96
provideTokenCount: (modelId, str, token) => {
97
return this._proxy.$provideTokenLength(modelId, str, token);
98
},
99
}));
100
this._providerRegistrations.set(vendor, disposables);
101
} catch (err) {
102
disposables.dispose();
103
throw err;
104
}
105
}
106
107
$onLMProviderChange(vendor: string): void {
108
this._lmProviderChange.fire({ vendor });
109
}
110
111
async $reportResponsePart(requestId: number, chunk: SerializableObjectWithBuffers<IChatResponsePart | IChatResponsePart[]>): Promise<void> {
112
const data = this._pendingProgress.get(requestId);
113
this._logService.trace('[LM] report response PART', Boolean(data), requestId, chunk);
114
if (data) {
115
data.stream.emitOne(chunk.value);
116
}
117
}
118
119
async $reportResponseDone(requestId: number, err: SerializedError | undefined): Promise<void> {
120
const data = this._pendingProgress.get(requestId);
121
this._logService.trace('[LM] report response DONE', Boolean(data), requestId, err);
122
if (data) {
123
this._pendingProgress.delete(requestId);
124
if (err) {
125
const error = LanguageModelError.tryDeserialize(err) ?? transformErrorFromSerialization(err);
126
data.stream.reject(error);
127
data.defer.error(error);
128
} else {
129
data.stream.resolve();
130
data.defer.complete(undefined);
131
}
132
}
133
}
134
135
$unregisterProvider(vendor: string): void {
136
this._providerRegistrations.deleteAndDispose(vendor);
137
}
138
139
$selectChatModels(selector: ILanguageModelChatSelector): Promise<string[]> {
140
return this._chatProviderService.selectLanguageModels(selector);
141
}
142
143
async $tryStartChatRequest(extension: ExtensionIdentifier, modelIdentifier: string, requestId: number, messages: SerializableObjectWithBuffers<IChatMessage[]>, options: {}, token: CancellationToken): Promise<void> {
144
this._logService.trace('[CHAT] request STARTED', extension.value, requestId);
145
146
let response: ILanguageModelChatResponse;
147
try {
148
response = await this._chatProviderService.sendChatRequest(modelIdentifier, extension, messages.value, options, token);
149
} catch (err) {
150
this._logService.error('[CHAT] request FAILED', extension.value, requestId, err);
151
throw err;
152
}
153
154
// !!! IMPORTANT !!!
155
// This method must return before the response is done (has streamed all parts)
156
// and because of that we consume the stream without awaiting
157
// !!! IMPORTANT !!!
158
const streaming = (async () => {
159
try {
160
for await (const part of response.stream) {
161
this._logService.trace('[CHAT] request PART', extension.value, requestId, part);
162
await this._proxy.$acceptResponsePart(requestId, new SerializableObjectWithBuffers(part));
163
}
164
this._logService.trace('[CHAT] request DONE', extension.value, requestId);
165
} catch (err) {
166
this._logService.error('[CHAT] extension request ERRORED in STREAM', toErrorMessage(err, true), extension.value, requestId);
167
this._proxy.$acceptResponseDone(requestId, transformErrorForSerialization(err));
168
}
169
})();
170
171
// When the response is done (signaled via its result) we tell the EH
172
Promise.allSettled([response.result, streaming]).then(() => {
173
this._logService.debug('[CHAT] extension request DONE', extension.value, requestId);
174
this._proxy.$acceptResponseDone(requestId, undefined);
175
}, err => {
176
this._logService.error('[CHAT] extension request ERRORED', toErrorMessage(err, true), extension.value, requestId);
177
this._proxy.$acceptResponseDone(requestId, transformErrorForSerialization(err));
178
});
179
}
180
181
182
$countTokens(modelId: string, value: string | IChatMessage, token: CancellationToken): Promise<number> {
183
return this._chatProviderService.computeTokenLength(modelId, value, token);
184
}
185
186
private _registerAuthenticationProvider(extension: ExtensionIdentifier, auth: { providerLabel: string; accountLabel?: string | undefined }): IDisposable {
187
// This needs to be done in both MainThread & ExtHost ChatProvider
188
const authProviderId = INTERNAL_AUTH_PROVIDER_PREFIX + extension.value;
189
190
// Only register one auth provider per extension
191
if (this._authenticationService.getProviderIds().includes(authProviderId)) {
192
return Disposable.None;
193
}
194
195
const accountLabel = auth.accountLabel ?? localize('languageModelsAccountId', 'Language Models');
196
const disposables = new DisposableStore();
197
this._authenticationService.registerAuthenticationProvider(authProviderId, new LanguageModelAccessAuthProvider(authProviderId, auth.providerLabel, accountLabel));
198
disposables.add(toDisposable(() => {
199
this._authenticationService.unregisterAuthenticationProvider(authProviderId);
200
}));
201
disposables.add(this._authenticationAccessService.onDidChangeExtensionSessionAccess(async (e) => {
202
const allowedExtensions = this._authenticationAccessService.readAllowedExtensions(authProviderId, accountLabel);
203
const accessList = [];
204
for (const allowedExtension of allowedExtensions) {
205
const from = await this._extensionService.getExtension(allowedExtension.id);
206
if (from) {
207
accessList.push({
208
from: from.identifier,
209
to: extension,
210
enabled: allowedExtension.allowed ?? true
211
});
212
}
213
}
214
this._proxy.$updateModelAccesslist(accessList);
215
}));
216
return disposables;
217
}
218
219
$fileIsIgnored(uri: UriComponents, token: CancellationToken): Promise<boolean> {
220
return this._ignoredFilesService.fileIsIgnored(URI.revive(uri), token);
221
}
222
223
$registerFileIgnoreProvider(handle: number): void {
224
this._ignoredFileProviderRegistrations.set(handle, this._ignoredFilesService.registerIgnoredFileProvider({
225
isFileIgnored: async (uri: URI, token: CancellationToken) => this._proxy.$isFileIgnored(handle, uri, token)
226
}));
227
}
228
229
$unregisterFileIgnoreProvider(handle: number): void {
230
this._ignoredFileProviderRegistrations.deleteAndDispose(handle);
231
}
232
}
233
234
// The fake AuthenticationProvider that will be used to gate access to the Language Model. There will be one per provider.
235
class LanguageModelAccessAuthProvider implements IAuthenticationProvider {
236
supportsMultipleAccounts = false;
237
238
// Important for updating the UI
239
private _onDidChangeSessions: Emitter<AuthenticationSessionsChangeEvent> = new Emitter<AuthenticationSessionsChangeEvent>();
240
readonly onDidChangeSessions: Event<AuthenticationSessionsChangeEvent> = this._onDidChangeSessions.event;
241
242
private _session: AuthenticationSession | undefined;
243
244
constructor(readonly id: string, readonly label: string, private readonly _accountLabel: string) { }
245
246
async getSessions(scopes?: string[] | undefined): Promise<readonly AuthenticationSession[]> {
247
// If there are no scopes and no session that means no extension has requested a session yet
248
// and the user is simply opening the Account menu. In that case, we should not return any "sessions".
249
if (scopes === undefined && !this._session) {
250
return [];
251
}
252
if (this._session) {
253
return [this._session];
254
}
255
return [await this.createSession(scopes || [])];
256
}
257
async createSession(scopes: string[]): Promise<AuthenticationSession> {
258
this._session = this._createFakeSession(scopes);
259
this._onDidChangeSessions.fire({ added: [this._session], changed: [], removed: [] });
260
return this._session;
261
}
262
removeSession(sessionId: string): Promise<void> {
263
if (this._session) {
264
this._onDidChangeSessions.fire({ added: [], changed: [], removed: [this._session] });
265
this._session = undefined;
266
}
267
return Promise.resolve();
268
}
269
270
confirmation(extensionName: string, _recreatingSession: boolean): string {
271
return localize('confirmLanguageModelAccess', "The extension '{0}' wants to access the language models provided by {1}.", extensionName, this.label);
272
}
273
274
private _createFakeSession(scopes: string[]): AuthenticationSession {
275
return {
276
id: 'fake-session',
277
account: {
278
id: this.id,
279
label: this._accountLabel,
280
},
281
accessToken: 'fake-access-token',
282
scopes,
283
};
284
}
285
}
286
287