Path: blob/main/src/vs/workbench/api/browser/mainThreadLanguageModels.ts
5222 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/45import { AsyncIterableSource, DeferredPromise } from '../../../base/common/async.js';6import { VSBuffer } from '../../../base/common/buffer.js';7import { CancellationToken } from '../../../base/common/cancellation.js';8import { toErrorMessage } from '../../../base/common/errorMessage.js';9import { SerializedError, transformErrorForSerialization, transformErrorFromSerialization } from '../../../base/common/errors.js';10import { Emitter, Event } from '../../../base/common/event.js';11import { Disposable, DisposableMap, DisposableStore, IDisposable, toDisposable } from '../../../base/common/lifecycle.js';12import { URI, UriComponents } from '../../../base/common/uri.js';13import { localize } from '../../../nls.js';14import { ExtensionIdentifier } from '../../../platform/extensions/common/extensions.js';15import { ILogService } from '../../../platform/log/common/log.js';16import { resizeImage } from '../../contrib/chat/browser/chatImageUtils.js';17import { ILanguageModelIgnoredFilesService } from '../../contrib/chat/common/ignoredFiles.js';18import { IChatMessage, IChatResponsePart, ILanguageModelChatResponse, ILanguageModelChatSelector, ILanguageModelsService } from '../../contrib/chat/common/languageModels.js';19import { IAuthenticationAccessService } from '../../services/authentication/browser/authenticationAccessService.js';20import { AuthenticationSession, AuthenticationSessionsChangeEvent, IAuthenticationProvider, IAuthenticationService, INTERNAL_AUTH_PROVIDER_PREFIX } from '../../services/authentication/common/authentication.js';21import { IExtHostContext, extHostNamedCustomer } from '../../services/extensions/common/extHostCustomers.js';22import { IExtensionService } from '../../services/extensions/common/extensions.js';23import { SerializableObjectWithBuffers } from '../../services/extensions/common/proxyIdentifier.js';24import { ExtHostContext, ExtHostLanguageModelsShape, MainContext, MainThreadLanguageModelsShape } from '../common/extHost.protocol.js';25import { LanguageModelError } from '../common/extHostTypes.js';2627@extHostNamedCustomer(MainContext.MainThreadLanguageModels)28export class MainThreadLanguageModels implements MainThreadLanguageModelsShape {2930private readonly _proxy: ExtHostLanguageModelsShape;31private readonly _store = new DisposableStore();32private readonly _providerRegistrations = new DisposableMap<string>();33private readonly _lmProviderChange = new Emitter<{ vendor: string }>();34private readonly _pendingProgress = new Map<number, { defer: DeferredPromise<unknown>; stream: AsyncIterableSource<IChatResponsePart | IChatResponsePart[]> }>();35private readonly _ignoredFileProviderRegistrations = new DisposableMap<number>();3637constructor(38extHostContext: IExtHostContext,39@ILanguageModelsService private readonly _chatProviderService: ILanguageModelsService,40@ILogService private readonly _logService: ILogService,41@IAuthenticationService private readonly _authenticationService: IAuthenticationService,42@IAuthenticationAccessService private readonly _authenticationAccessService: IAuthenticationAccessService,43@IExtensionService private readonly _extensionService: IExtensionService,44@ILanguageModelIgnoredFilesService private readonly _ignoredFilesService: ILanguageModelIgnoredFilesService,45) {46this._proxy = extHostContext.getProxy(ExtHostContext.ExtHostChatProvider);47}4849dispose(): void {50this._lmProviderChange.dispose();51this._providerRegistrations.dispose();52this._ignoredFileProviderRegistrations.dispose();53this._store.dispose();54}5556$registerLanguageModelProvider(vendor: string): void {57const disposables = new DisposableStore();58try {59disposables.add(this._chatProviderService.registerLanguageModelProvider(vendor, {60onDidChange: Event.filter(this._lmProviderChange.event, e => e.vendor === vendor, disposables) as unknown as Event<void>,61provideLanguageModelChatInfo: async (options, token) => {62const modelsAndIdentifiers = await this._proxy.$provideLanguageModelChatInfo(vendor, options, token);63modelsAndIdentifiers.forEach(m => {64if (m.metadata.auth) {65disposables.add(this._registerAuthenticationProvider(m.metadata.extension, m.metadata.auth));66}67});68return modelsAndIdentifiers;69},70sendChatRequest: async (modelId, messages, from, options, token) => {71const requestId = (Math.random() * 1e6) | 0;72const defer = new DeferredPromise<unknown>();73const stream = new AsyncIterableSource<IChatResponsePart | IChatResponsePart[]>();7475try {76this._pendingProgress.set(requestId, { defer, stream });77await Promise.all(78messages.flatMap(msg => msg.content)79.filter(part => part.type === 'image_url')80.map(async part => {81part.value.data = VSBuffer.wrap(await resizeImage(part.value.data.buffer));82})83);84await this._proxy.$startChatRequest(modelId, requestId, from, new SerializableObjectWithBuffers(messages), options, token);85} catch (err) {86this._pendingProgress.delete(requestId);87throw err;88}8990return {91result: defer.p,92stream: stream.asyncIterable93} satisfies ILanguageModelChatResponse;94},95provideTokenCount: (modelId, str, token) => {96return this._proxy.$provideTokenLength(modelId, str, token);97},98}));99this._providerRegistrations.set(vendor, disposables);100} catch (err) {101disposables.dispose();102throw err;103}104}105106$onLMProviderChange(vendor: string): void {107this._lmProviderChange.fire({ vendor });108}109110async $reportResponsePart(requestId: number, chunk: SerializableObjectWithBuffers<IChatResponsePart | IChatResponsePart[]>): Promise<void> {111const data = this._pendingProgress.get(requestId);112this._logService.trace('[LM] report response PART', Boolean(data), requestId, chunk);113if (data) {114data.stream.emitOne(chunk.value);115}116}117118async $reportResponseDone(requestId: number, err: SerializedError | undefined): Promise<void> {119const data = this._pendingProgress.get(requestId);120this._logService.trace('[LM] report response DONE', Boolean(data), requestId, err);121if (data) {122this._pendingProgress.delete(requestId);123if (err) {124const error = LanguageModelError.tryDeserialize(err) ?? transformErrorFromSerialization(err);125data.stream.reject(error);126data.defer.error(error);127} else {128data.stream.resolve();129data.defer.complete(undefined);130}131}132}133134$unregisterProvider(vendor: string): void {135this._providerRegistrations.deleteAndDispose(vendor);136}137138$selectChatModels(selector: ILanguageModelChatSelector): Promise<string[]> {139return this._chatProviderService.selectLanguageModels(selector);140}141142async $tryStartChatRequest(extension: ExtensionIdentifier, modelIdentifier: string, requestId: number, messages: SerializableObjectWithBuffers<IChatMessage[]>, options: {}, token: CancellationToken): Promise<void> {143this._logService.trace('[CHAT] request STARTED', extension.value, requestId);144145let response: ILanguageModelChatResponse;146try {147response = await this._chatProviderService.sendChatRequest(modelIdentifier, extension, messages.value, options, token);148} catch (err) {149this._logService.error('[CHAT] request FAILED', extension.value, requestId, err);150throw err;151}152153// !!! IMPORTANT !!!154// This method must return before the response is done (has streamed all parts)155// and because of that we consume the stream without awaiting156// !!! IMPORTANT !!!157const streaming = (async () => {158try {159for await (const part of response.stream) {160this._logService.trace('[CHAT] request PART', extension.value, requestId, part);161await this._proxy.$acceptResponsePart(requestId, new SerializableObjectWithBuffers(part));162}163this._logService.trace('[CHAT] request DONE', extension.value, requestId);164} catch (err) {165this._logService.error('[CHAT] extension request ERRORED in STREAM', toErrorMessage(err, true), extension.value, requestId);166this._proxy.$acceptResponseDone(requestId, transformErrorForSerialization(err));167}168})();169170// When the response is done (signaled via its result) we tell the EH171Promise.allSettled([response.result, streaming]).then(() => {172this._logService.debug('[CHAT] extension request DONE', extension.value, requestId);173this._proxy.$acceptResponseDone(requestId, undefined);174}, err => {175this._logService.error('[CHAT] extension request ERRORED', toErrorMessage(err, true), extension.value, requestId);176this._proxy.$acceptResponseDone(requestId, transformErrorForSerialization(err));177});178}179180181$countTokens(modelId: string, value: string | IChatMessage, token: CancellationToken): Promise<number> {182return this._chatProviderService.computeTokenLength(modelId, value, token);183}184185private _registerAuthenticationProvider(extension: ExtensionIdentifier, auth: { providerLabel: string; accountLabel?: string | undefined }): IDisposable {186// This needs to be done in both MainThread & ExtHost ChatProvider187const authProviderId = INTERNAL_AUTH_PROVIDER_PREFIX + extension.value;188189// Only register one auth provider per extension190if (this._authenticationService.getProviderIds().includes(authProviderId)) {191return Disposable.None;192}193194const accountLabel = auth.accountLabel ?? localize('languageModelsAccountId', 'Language Models');195const disposables = new DisposableStore();196this._authenticationService.registerAuthenticationProvider(authProviderId, new LanguageModelAccessAuthProvider(authProviderId, auth.providerLabel, accountLabel));197disposables.add(toDisposable(() => {198this._authenticationService.unregisterAuthenticationProvider(authProviderId);199}));200disposables.add(this._authenticationAccessService.onDidChangeExtensionSessionAccess(async (e) => {201const allowedExtensions = this._authenticationAccessService.readAllowedExtensions(authProviderId, accountLabel);202const accessList = [];203for (const allowedExtension of allowedExtensions) {204const from = await this._extensionService.getExtension(allowedExtension.id);205if (from) {206accessList.push({207from: from.identifier,208to: extension,209enabled: allowedExtension.allowed ?? true210});211}212}213this._proxy.$updateModelAccesslist(accessList);214}));215return disposables;216}217218$fileIsIgnored(uri: UriComponents, token: CancellationToken): Promise<boolean> {219return this._ignoredFilesService.fileIsIgnored(URI.revive(uri), token);220}221222$registerFileIgnoreProvider(handle: number): void {223this._ignoredFileProviderRegistrations.set(handle, this._ignoredFilesService.registerIgnoredFileProvider({224isFileIgnored: async (uri: URI, token: CancellationToken) => this._proxy.$isFileIgnored(handle, uri, token)225}));226}227228$unregisterFileIgnoreProvider(handle: number): void {229this._ignoredFileProviderRegistrations.deleteAndDispose(handle);230}231}232233// The fake AuthenticationProvider that will be used to gate access to the Language Model. There will be one per provider.234class LanguageModelAccessAuthProvider implements IAuthenticationProvider {235supportsMultipleAccounts = false;236237// Important for updating the UI238private _onDidChangeSessions: Emitter<AuthenticationSessionsChangeEvent> = new Emitter<AuthenticationSessionsChangeEvent>();239readonly onDidChangeSessions: Event<AuthenticationSessionsChangeEvent> = this._onDidChangeSessions.event;240241private _session: AuthenticationSession | undefined;242243constructor(readonly id: string, readonly label: string, private readonly _accountLabel: string) { }244245async getSessions(scopes?: string[] | undefined): Promise<readonly AuthenticationSession[]> {246// If there are no scopes and no session that means no extension has requested a session yet247// and the user is simply opening the Account menu. In that case, we should not return any "sessions".248if (scopes === undefined && !this._session) {249return [];250}251if (this._session) {252return [this._session];253}254return [await this.createSession(scopes || [])];255}256async createSession(scopes: string[]): Promise<AuthenticationSession> {257this._session = this._createFakeSession(scopes);258this._onDidChangeSessions.fire({ added: [this._session], changed: [], removed: [] });259return this._session;260}261removeSession(sessionId: string): Promise<void> {262if (this._session) {263this._onDidChangeSessions.fire({ added: [], changed: [], removed: [this._session] });264this._session = undefined;265}266return Promise.resolve();267}268269confirmation(extensionName: string, _recreatingSession: boolean): string {270return localize('confirmLanguageModelAccess', "The extension '{0}' wants to access the language models provided by {1}.", extensionName, this.label);271}272273private _createFakeSession(scopes: string[]): AuthenticationSession {274return {275id: 'fake-session',276account: {277id: this.id,278label: this._accountLabel,279},280accessToken: 'fake-access-token',281scopes,282};283}284}285286287