Path: blob/main/src/vs/workbench/api/browser/mainThreadLanguageModels.ts
3296 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/45import { AsyncIterableSource, DeferredPromise } from '../../../base/common/async.js';6import { VSBuffer } from '../../../base/common/buffer.js';7import { CancellationToken } from '../../../base/common/cancellation.js';8import { toErrorMessage } from '../../../base/common/errorMessage.js';9import { SerializedError, transformErrorForSerialization, transformErrorFromSerialization } from '../../../base/common/errors.js';10import { Emitter, Event } from '../../../base/common/event.js';11import { Disposable, DisposableMap, DisposableStore, IDisposable, toDisposable } from '../../../base/common/lifecycle.js';12import { URI, UriComponents } from '../../../base/common/uri.js';13import { localize } from '../../../nls.js';14import { ExtensionIdentifier } from '../../../platform/extensions/common/extensions.js';15import { ILogService } from '../../../platform/log/common/log.js';16import { resizeImage } from '../../contrib/chat/browser/imageUtils.js';17import { ILanguageModelIgnoredFilesService } from '../../contrib/chat/common/ignoredFiles.js';18import { IChatMessage, IChatResponsePart, ILanguageModelChatResponse, ILanguageModelChatSelector, ILanguageModelsService } from '../../contrib/chat/common/languageModels.js';19import { IAuthenticationAccessService } from '../../services/authentication/browser/authenticationAccessService.js';20import { AuthenticationSession, AuthenticationSessionsChangeEvent, IAuthenticationProvider, IAuthenticationService, INTERNAL_AUTH_PROVIDER_PREFIX } from '../../services/authentication/common/authentication.js';21import { IExtHostContext, extHostNamedCustomer } from '../../services/extensions/common/extHostCustomers.js';22import { IExtensionService } from '../../services/extensions/common/extensions.js';23import { SerializableObjectWithBuffers } from '../../services/extensions/common/proxyIdentifier.js';24import { ExtHostContext, ExtHostLanguageModelsShape, MainContext, MainThreadLanguageModelsShape } from '../common/extHost.protocol.js';25import { LanguageModelError } from '../common/extHostTypes.js';2627@extHostNamedCustomer(MainContext.MainThreadLanguageModels)28export class MainThreadLanguageModels implements MainThreadLanguageModelsShape {2930private readonly _proxy: ExtHostLanguageModelsShape;31private readonly _store = new DisposableStore();32private readonly _providerRegistrations = new DisposableMap<string>();33private readonly _lmProviderChange = new Emitter<{ vendor: string }>();34private readonly _pendingProgress = new Map<number, { defer: DeferredPromise<any>; stream: AsyncIterableSource<IChatResponsePart | IChatResponsePart[]> }>();35private readonly _ignoredFileProviderRegistrations = new DisposableMap<number>();3637constructor(38extHostContext: IExtHostContext,39@ILanguageModelsService private readonly _chatProviderService: ILanguageModelsService,40@ILogService private readonly _logService: ILogService,41@IAuthenticationService private readonly _authenticationService: IAuthenticationService,42@IAuthenticationAccessService private readonly _authenticationAccessService: IAuthenticationAccessService,43@IExtensionService private readonly _extensionService: IExtensionService,44@ILanguageModelIgnoredFilesService private readonly _ignoredFilesService: ILanguageModelIgnoredFilesService,45) {46this._proxy = extHostContext.getProxy(ExtHostContext.ExtHostChatProvider);47}4849dispose(): void {50this._lmProviderChange.dispose();51this._providerRegistrations.dispose();52this._ignoredFileProviderRegistrations.dispose();53this._store.dispose();54}5556$registerLanguageModelProvider(vendor: string): void {57const dipsosables = new DisposableStore();58dipsosables.add(this._chatProviderService.registerLanguageModelProvider(vendor, {59onDidChange: Event.filter(this._lmProviderChange.event, e => e.vendor === vendor, dipsosables) as unknown as Event<void>,60provideLanguageModelChatInfo: async (options, token) => {61const modelsAndIdentifiers = await this._proxy.$provideLanguageModelChatInfo(vendor, options, token);62modelsAndIdentifiers.forEach(m => {63if (m.metadata.auth) {64dipsosables.add(this._registerAuthenticationProvider(m.metadata.extension, m.metadata.auth));65}66});67return modelsAndIdentifiers;68},69sendChatRequest: async (modelId, messages, from, options, token) => {70const requestId = (Math.random() * 1e6) | 0;71const defer = new DeferredPromise<any>();72const stream = new AsyncIterableSource<IChatResponsePart | IChatResponsePart[]>();7374try {75this._pendingProgress.set(requestId, { defer, stream });76await Promise.all(77messages.flatMap(msg => msg.content)78.filter(part => part.type === 'image_url')79.map(async part => {80part.value.data = VSBuffer.wrap(await resizeImage(part.value.data.buffer));81})82);83await this._proxy.$startChatRequest(modelId, requestId, from, new SerializableObjectWithBuffers(messages), options, token);84} catch (err) {85this._pendingProgress.delete(requestId);86throw err;87}8889return {90result: defer.p,91stream: stream.asyncIterable92} satisfies ILanguageModelChatResponse;93},94provideTokenCount: (modelId, str, token) => {95return this._proxy.$provideTokenLength(modelId, str, token);96},97}));98this._providerRegistrations.set(vendor, dipsosables);99}100101$onLMProviderChange(vendor: string): void {102this._lmProviderChange.fire({ vendor });103}104105async $reportResponsePart(requestId: number, chunk: SerializableObjectWithBuffers<IChatResponsePart | IChatResponsePart[]>): Promise<void> {106const data = this._pendingProgress.get(requestId);107this._logService.trace('[LM] report response PART', Boolean(data), requestId, chunk);108if (data) {109data.stream.emitOne(chunk.value);110}111}112113async $reportResponseDone(requestId: number, err: SerializedError | undefined): Promise<void> {114const data = this._pendingProgress.get(requestId);115this._logService.trace('[LM] report response DONE', Boolean(data), requestId, err);116if (data) {117this._pendingProgress.delete(requestId);118if (err) {119const error = LanguageModelError.tryDeserialize(err) ?? transformErrorFromSerialization(err);120data.stream.reject(error);121data.defer.error(error);122} else {123data.stream.resolve();124data.defer.complete(undefined);125}126}127}128129$unregisterProvider(vendor: string): void {130this._providerRegistrations.deleteAndDispose(vendor);131}132133$selectChatModels(selector: ILanguageModelChatSelector): Promise<string[]> {134return this._chatProviderService.selectLanguageModels(selector);135}136137async $tryStartChatRequest(extension: ExtensionIdentifier, modelIdentifier: string, requestId: number, messages: SerializableObjectWithBuffers<IChatMessage[]>, options: {}, token: CancellationToken): Promise<any> {138this._logService.trace('[CHAT] request STARTED', extension.value, requestId);139140let response: ILanguageModelChatResponse;141try {142response = await this._chatProviderService.sendChatRequest(modelIdentifier, extension, messages.value, options, token);143} catch (err) {144this._logService.error('[CHAT] request FAILED', extension.value, requestId, err);145throw err;146}147148// !!! IMPORTANT !!!149// This method must return before the response is done (has streamed all parts)150// and because of that we consume the stream without awaiting151// !!! IMPORTANT !!!152const streaming = (async () => {153try {154for await (const part of response.stream) {155this._logService.trace('[CHAT] request PART', extension.value, requestId, part);156await this._proxy.$acceptResponsePart(requestId, new SerializableObjectWithBuffers(part));157}158this._logService.trace('[CHAT] request DONE', extension.value, requestId);159} catch (err) {160this._logService.error('[CHAT] extension request ERRORED in STREAM', toErrorMessage(err, true), extension.value, requestId);161this._proxy.$acceptResponseDone(requestId, transformErrorForSerialization(err));162}163})();164165// When the response is done (signaled via its result) we tell the EH166Promise.allSettled([response.result, streaming]).then(() => {167this._logService.debug('[CHAT] extension request DONE', extension.value, requestId);168this._proxy.$acceptResponseDone(requestId, undefined);169}, err => {170this._logService.error('[CHAT] extension request ERRORED', toErrorMessage(err, true), extension.value, requestId);171this._proxy.$acceptResponseDone(requestId, transformErrorForSerialization(err));172});173}174175176$countTokens(modelId: string, value: string | IChatMessage, token: CancellationToken): Promise<number> {177return this._chatProviderService.computeTokenLength(modelId, value, token);178}179180private _registerAuthenticationProvider(extension: ExtensionIdentifier, auth: { providerLabel: string; accountLabel?: string | undefined }): IDisposable {181// This needs to be done in both MainThread & ExtHost ChatProvider182const authProviderId = INTERNAL_AUTH_PROVIDER_PREFIX + extension.value;183184// Only register one auth provider per extension185if (this._authenticationService.getProviderIds().includes(authProviderId)) {186return Disposable.None;187}188189const accountLabel = auth.accountLabel ?? localize('languageModelsAccountId', 'Language Models');190const disposables = new DisposableStore();191this._authenticationService.registerAuthenticationProvider(authProviderId, new LanguageModelAccessAuthProvider(authProviderId, auth.providerLabel, accountLabel));192disposables.add(toDisposable(() => {193this._authenticationService.unregisterAuthenticationProvider(authProviderId);194}));195disposables.add(this._authenticationAccessService.onDidChangeExtensionSessionAccess(async (e) => {196const allowedExtensions = this._authenticationAccessService.readAllowedExtensions(authProviderId, accountLabel);197const accessList = [];198for (const allowedExtension of allowedExtensions) {199const from = await this._extensionService.getExtension(allowedExtension.id);200if (from) {201accessList.push({202from: from.identifier,203to: extension,204enabled: allowedExtension.allowed ?? true205});206}207}208this._proxy.$updateModelAccesslist(accessList);209}));210return disposables;211}212213$fileIsIgnored(uri: UriComponents, token: CancellationToken): Promise<boolean> {214return this._ignoredFilesService.fileIsIgnored(URI.revive(uri), token);215}216217$registerFileIgnoreProvider(handle: number): void {218this._ignoredFileProviderRegistrations.set(handle, this._ignoredFilesService.registerIgnoredFileProvider({219isFileIgnored: async (uri: URI, token: CancellationToken) => this._proxy.$isFileIgnored(handle, uri, token)220}));221}222223$unregisterFileIgnoreProvider(handle: number): void {224this._ignoredFileProviderRegistrations.deleteAndDispose(handle);225}226}227228// The fake AuthenticationProvider that will be used to gate access to the Language Model. There will be one per provider.229class LanguageModelAccessAuthProvider implements IAuthenticationProvider {230supportsMultipleAccounts = false;231232// Important for updating the UI233private _onDidChangeSessions: Emitter<AuthenticationSessionsChangeEvent> = new Emitter<AuthenticationSessionsChangeEvent>();234onDidChangeSessions: Event<AuthenticationSessionsChangeEvent> = this._onDidChangeSessions.event;235236private _session: AuthenticationSession | undefined;237238constructor(readonly id: string, readonly label: string, private readonly _accountLabel: string) { }239240async getSessions(scopes?: string[] | undefined): Promise<readonly AuthenticationSession[]> {241// If there are no scopes and no session that means no extension has requested a session yet242// and the user is simply opening the Account menu. In that case, we should not return any "sessions".243if (scopes === undefined && !this._session) {244return [];245}246if (this._session) {247return [this._session];248}249return [await this.createSession(scopes || [])];250}251async createSession(scopes: string[]): Promise<AuthenticationSession> {252this._session = this._createFakeSession(scopes);253this._onDidChangeSessions.fire({ added: [this._session], changed: [], removed: [] });254return this._session;255}256removeSession(sessionId: string): Promise<void> {257if (this._session) {258this._onDidChangeSessions.fire({ added: [], changed: [], removed: [this._session!] });259this._session = undefined;260}261return Promise.resolve();262}263264confirmation(extensionName: string, _recreatingSession: boolean): string {265return localize('confirmLanguageModelAccess', "The extension '{0}' wants to access the language models provided by {1}.", extensionName, this.label);266}267268private _createFakeSession(scopes: string[]): AuthenticationSession {269return {270id: 'fake-session',271account: {272id: this.id,273label: this._accountLabel,274},275accessToken: 'fake-access-token',276scopes,277};278}279}280281282