Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/vscode
Path: blob/main/src/vs/workbench/api/browser/mainThreadLanguageModels.ts
3296 views
1
/*---------------------------------------------------------------------------------------------
2
* Copyright (c) Microsoft Corporation. All rights reserved.
3
* Licensed under the MIT License. See License.txt in the project root for license information.
4
*--------------------------------------------------------------------------------------------*/
5
6
import { AsyncIterableSource, DeferredPromise } from '../../../base/common/async.js';
7
import { VSBuffer } from '../../../base/common/buffer.js';
8
import { CancellationToken } from '../../../base/common/cancellation.js';
9
import { toErrorMessage } from '../../../base/common/errorMessage.js';
10
import { SerializedError, transformErrorForSerialization, transformErrorFromSerialization } from '../../../base/common/errors.js';
11
import { Emitter, Event } from '../../../base/common/event.js';
12
import { Disposable, DisposableMap, DisposableStore, IDisposable, toDisposable } from '../../../base/common/lifecycle.js';
13
import { URI, UriComponents } from '../../../base/common/uri.js';
14
import { localize } from '../../../nls.js';
15
import { ExtensionIdentifier } from '../../../platform/extensions/common/extensions.js';
16
import { ILogService } from '../../../platform/log/common/log.js';
17
import { resizeImage } from '../../contrib/chat/browser/imageUtils.js';
18
import { ILanguageModelIgnoredFilesService } from '../../contrib/chat/common/ignoredFiles.js';
19
import { IChatMessage, IChatResponsePart, ILanguageModelChatResponse, ILanguageModelChatSelector, ILanguageModelsService } from '../../contrib/chat/common/languageModels.js';
20
import { IAuthenticationAccessService } from '../../services/authentication/browser/authenticationAccessService.js';
21
import { AuthenticationSession, AuthenticationSessionsChangeEvent, IAuthenticationProvider, IAuthenticationService, INTERNAL_AUTH_PROVIDER_PREFIX } from '../../services/authentication/common/authentication.js';
22
import { IExtHostContext, extHostNamedCustomer } from '../../services/extensions/common/extHostCustomers.js';
23
import { IExtensionService } from '../../services/extensions/common/extensions.js';
24
import { SerializableObjectWithBuffers } from '../../services/extensions/common/proxyIdentifier.js';
25
import { ExtHostContext, ExtHostLanguageModelsShape, MainContext, MainThreadLanguageModelsShape } from '../common/extHost.protocol.js';
26
import { LanguageModelError } from '../common/extHostTypes.js';
27
28
@extHostNamedCustomer(MainContext.MainThreadLanguageModels)
29
export class MainThreadLanguageModels implements MainThreadLanguageModelsShape {
30
31
private readonly _proxy: ExtHostLanguageModelsShape;
32
private readonly _store = new DisposableStore();
33
private readonly _providerRegistrations = new DisposableMap<string>();
34
private readonly _lmProviderChange = new Emitter<{ vendor: string }>();
35
private readonly _pendingProgress = new Map<number, { defer: DeferredPromise<any>; stream: AsyncIterableSource<IChatResponsePart | IChatResponsePart[]> }>();
36
private readonly _ignoredFileProviderRegistrations = new DisposableMap<number>();
37
38
constructor(
39
extHostContext: IExtHostContext,
40
@ILanguageModelsService private readonly _chatProviderService: ILanguageModelsService,
41
@ILogService private readonly _logService: ILogService,
42
@IAuthenticationService private readonly _authenticationService: IAuthenticationService,
43
@IAuthenticationAccessService private readonly _authenticationAccessService: IAuthenticationAccessService,
44
@IExtensionService private readonly _extensionService: IExtensionService,
45
@ILanguageModelIgnoredFilesService private readonly _ignoredFilesService: ILanguageModelIgnoredFilesService,
46
) {
47
this._proxy = extHostContext.getProxy(ExtHostContext.ExtHostChatProvider);
48
}
49
50
dispose(): void {
51
this._lmProviderChange.dispose();
52
this._providerRegistrations.dispose();
53
this._ignoredFileProviderRegistrations.dispose();
54
this._store.dispose();
55
}
56
57
$registerLanguageModelProvider(vendor: string): void {
58
const dipsosables = new DisposableStore();
59
dipsosables.add(this._chatProviderService.registerLanguageModelProvider(vendor, {
60
onDidChange: Event.filter(this._lmProviderChange.event, e => e.vendor === vendor, dipsosables) as unknown as Event<void>,
61
provideLanguageModelChatInfo: async (options, token) => {
62
const modelsAndIdentifiers = await this._proxy.$provideLanguageModelChatInfo(vendor, options, token);
63
modelsAndIdentifiers.forEach(m => {
64
if (m.metadata.auth) {
65
dipsosables.add(this._registerAuthenticationProvider(m.metadata.extension, m.metadata.auth));
66
}
67
});
68
return modelsAndIdentifiers;
69
},
70
sendChatRequest: async (modelId, messages, from, options, token) => {
71
const requestId = (Math.random() * 1e6) | 0;
72
const defer = new DeferredPromise<any>();
73
const stream = new AsyncIterableSource<IChatResponsePart | IChatResponsePart[]>();
74
75
try {
76
this._pendingProgress.set(requestId, { defer, stream });
77
await Promise.all(
78
messages.flatMap(msg => msg.content)
79
.filter(part => part.type === 'image_url')
80
.map(async part => {
81
part.value.data = VSBuffer.wrap(await resizeImage(part.value.data.buffer));
82
})
83
);
84
await this._proxy.$startChatRequest(modelId, requestId, from, new SerializableObjectWithBuffers(messages), options, token);
85
} catch (err) {
86
this._pendingProgress.delete(requestId);
87
throw err;
88
}
89
90
return {
91
result: defer.p,
92
stream: stream.asyncIterable
93
} satisfies ILanguageModelChatResponse;
94
},
95
provideTokenCount: (modelId, str, token) => {
96
return this._proxy.$provideTokenLength(modelId, str, token);
97
},
98
}));
99
this._providerRegistrations.set(vendor, dipsosables);
100
}
101
102
$onLMProviderChange(vendor: string): void {
103
this._lmProviderChange.fire({ vendor });
104
}
105
106
async $reportResponsePart(requestId: number, chunk: SerializableObjectWithBuffers<IChatResponsePart | IChatResponsePart[]>): Promise<void> {
107
const data = this._pendingProgress.get(requestId);
108
this._logService.trace('[LM] report response PART', Boolean(data), requestId, chunk);
109
if (data) {
110
data.stream.emitOne(chunk.value);
111
}
112
}
113
114
async $reportResponseDone(requestId: number, err: SerializedError | undefined): Promise<void> {
115
const data = this._pendingProgress.get(requestId);
116
this._logService.trace('[LM] report response DONE', Boolean(data), requestId, err);
117
if (data) {
118
this._pendingProgress.delete(requestId);
119
if (err) {
120
const error = LanguageModelError.tryDeserialize(err) ?? transformErrorFromSerialization(err);
121
data.stream.reject(error);
122
data.defer.error(error);
123
} else {
124
data.stream.resolve();
125
data.defer.complete(undefined);
126
}
127
}
128
}
129
130
$unregisterProvider(vendor: string): void {
131
this._providerRegistrations.deleteAndDispose(vendor);
132
}
133
134
$selectChatModels(selector: ILanguageModelChatSelector): Promise<string[]> {
135
return this._chatProviderService.selectLanguageModels(selector);
136
}
137
138
async $tryStartChatRequest(extension: ExtensionIdentifier, modelIdentifier: string, requestId: number, messages: SerializableObjectWithBuffers<IChatMessage[]>, options: {}, token: CancellationToken): Promise<any> {
139
this._logService.trace('[CHAT] request STARTED', extension.value, requestId);
140
141
let response: ILanguageModelChatResponse;
142
try {
143
response = await this._chatProviderService.sendChatRequest(modelIdentifier, extension, messages.value, options, token);
144
} catch (err) {
145
this._logService.error('[CHAT] request FAILED', extension.value, requestId, err);
146
throw err;
147
}
148
149
// !!! IMPORTANT !!!
150
// This method must return before the response is done (has streamed all parts)
151
// and because of that we consume the stream without awaiting
152
// !!! IMPORTANT !!!
153
const streaming = (async () => {
154
try {
155
for await (const part of response.stream) {
156
this._logService.trace('[CHAT] request PART', extension.value, requestId, part);
157
await this._proxy.$acceptResponsePart(requestId, new SerializableObjectWithBuffers(part));
158
}
159
this._logService.trace('[CHAT] request DONE', extension.value, requestId);
160
} catch (err) {
161
this._logService.error('[CHAT] extension request ERRORED in STREAM', toErrorMessage(err, true), extension.value, requestId);
162
this._proxy.$acceptResponseDone(requestId, transformErrorForSerialization(err));
163
}
164
})();
165
166
// When the response is done (signaled via its result) we tell the EH
167
Promise.allSettled([response.result, streaming]).then(() => {
168
this._logService.debug('[CHAT] extension request DONE', extension.value, requestId);
169
this._proxy.$acceptResponseDone(requestId, undefined);
170
}, err => {
171
this._logService.error('[CHAT] extension request ERRORED', toErrorMessage(err, true), extension.value, requestId);
172
this._proxy.$acceptResponseDone(requestId, transformErrorForSerialization(err));
173
});
174
}
175
176
177
$countTokens(modelId: string, value: string | IChatMessage, token: CancellationToken): Promise<number> {
178
return this._chatProviderService.computeTokenLength(modelId, value, token);
179
}
180
181
private _registerAuthenticationProvider(extension: ExtensionIdentifier, auth: { providerLabel: string; accountLabel?: string | undefined }): IDisposable {
182
// This needs to be done in both MainThread & ExtHost ChatProvider
183
const authProviderId = INTERNAL_AUTH_PROVIDER_PREFIX + extension.value;
184
185
// Only register one auth provider per extension
186
if (this._authenticationService.getProviderIds().includes(authProviderId)) {
187
return Disposable.None;
188
}
189
190
const accountLabel = auth.accountLabel ?? localize('languageModelsAccountId', 'Language Models');
191
const disposables = new DisposableStore();
192
this._authenticationService.registerAuthenticationProvider(authProviderId, new LanguageModelAccessAuthProvider(authProviderId, auth.providerLabel, accountLabel));
193
disposables.add(toDisposable(() => {
194
this._authenticationService.unregisterAuthenticationProvider(authProviderId);
195
}));
196
disposables.add(this._authenticationAccessService.onDidChangeExtensionSessionAccess(async (e) => {
197
const allowedExtensions = this._authenticationAccessService.readAllowedExtensions(authProviderId, accountLabel);
198
const accessList = [];
199
for (const allowedExtension of allowedExtensions) {
200
const from = await this._extensionService.getExtension(allowedExtension.id);
201
if (from) {
202
accessList.push({
203
from: from.identifier,
204
to: extension,
205
enabled: allowedExtension.allowed ?? true
206
});
207
}
208
}
209
this._proxy.$updateModelAccesslist(accessList);
210
}));
211
return disposables;
212
}
213
214
$fileIsIgnored(uri: UriComponents, token: CancellationToken): Promise<boolean> {
215
return this._ignoredFilesService.fileIsIgnored(URI.revive(uri), token);
216
}
217
218
$registerFileIgnoreProvider(handle: number): void {
219
this._ignoredFileProviderRegistrations.set(handle, this._ignoredFilesService.registerIgnoredFileProvider({
220
isFileIgnored: async (uri: URI, token: CancellationToken) => this._proxy.$isFileIgnored(handle, uri, token)
221
}));
222
}
223
224
$unregisterFileIgnoreProvider(handle: number): void {
225
this._ignoredFileProviderRegistrations.deleteAndDispose(handle);
226
}
227
}
228
229
// The fake AuthenticationProvider that will be used to gate access to the Language Model. There will be one per provider.
230
class LanguageModelAccessAuthProvider implements IAuthenticationProvider {
231
supportsMultipleAccounts = false;
232
233
// Important for updating the UI
234
private _onDidChangeSessions: Emitter<AuthenticationSessionsChangeEvent> = new Emitter<AuthenticationSessionsChangeEvent>();
235
onDidChangeSessions: Event<AuthenticationSessionsChangeEvent> = this._onDidChangeSessions.event;
236
237
private _session: AuthenticationSession | undefined;
238
239
constructor(readonly id: string, readonly label: string, private readonly _accountLabel: string) { }
240
241
async getSessions(scopes?: string[] | undefined): Promise<readonly AuthenticationSession[]> {
242
// If there are no scopes and no session that means no extension has requested a session yet
243
// and the user is simply opening the Account menu. In that case, we should not return any "sessions".
244
if (scopes === undefined && !this._session) {
245
return [];
246
}
247
if (this._session) {
248
return [this._session];
249
}
250
return [await this.createSession(scopes || [])];
251
}
252
async createSession(scopes: string[]): Promise<AuthenticationSession> {
253
this._session = this._createFakeSession(scopes);
254
this._onDidChangeSessions.fire({ added: [this._session], changed: [], removed: [] });
255
return this._session;
256
}
257
removeSession(sessionId: string): Promise<void> {
258
if (this._session) {
259
this._onDidChangeSessions.fire({ added: [], changed: [], removed: [this._session!] });
260
this._session = undefined;
261
}
262
return Promise.resolve();
263
}
264
265
confirmation(extensionName: string, _recreatingSession: boolean): string {
266
return localize('confirmLanguageModelAccess', "The extension '{0}' wants to access the language models provided by {1}.", extensionName, this.label);
267
}
268
269
private _createFakeSession(scopes: string[]): AuthenticationSession {
270
return {
271
id: 'fake-session',
272
account: {
273
id: this.id,
274
label: this._accountLabel,
275
},
276
accessToken: 'fake-access-token',
277
scopes,
278
};
279
}
280
}
281
282