Path: blob/main/src/vs/workbench/api/common/extHostLanguageModels.ts
3296 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/45import type * as vscode from 'vscode';6import { AsyncIterableObject, AsyncIterableSource, RunOnceScheduler } from '../../../base/common/async.js';7import { VSBuffer } from '../../../base/common/buffer.js';8import { CancellationToken } from '../../../base/common/cancellation.js';9import { SerializedError, transformErrorForSerialization, transformErrorFromSerialization } from '../../../base/common/errors.js';10import { Emitter, Event } from '../../../base/common/event.js';11import { Iterable } from '../../../base/common/iterator.js';12import { IDisposable, toDisposable } from '../../../base/common/lifecycle.js';13import { URI, UriComponents } from '../../../base/common/uri.js';14import { localize } from '../../../nls.js';15import { ExtensionIdentifier, ExtensionIdentifierMap, ExtensionIdentifierSet, IExtensionDescription } from '../../../platform/extensions/common/extensions.js';16import { createDecorator } from '../../../platform/instantiation/common/instantiation.js';17import { ILogService } from '../../../platform/log/common/log.js';18import { Progress } from '../../../platform/progress/common/progress.js';19import { IChatMessage, IChatResponsePart, ILanguageModelChatMetadata, ILanguageModelChatMetadataAndIdentifier } from '../../contrib/chat/common/languageModels.js';20import { DEFAULT_MODEL_PICKER_CATEGORY } from '../../contrib/chat/common/modelPicker/modelPickerWidget.js';21import { INTERNAL_AUTH_PROVIDER_PREFIX } from '../../services/authentication/common/authentication.js';22import { checkProposedApiEnabled, isProposedApiEnabled } from '../../services/extensions/common/extensions.js';23import { SerializableObjectWithBuffers } from '../../services/extensions/common/proxyIdentifier.js';24import { ExtHostLanguageModelsShape, MainContext, MainThreadLanguageModelsShape } from './extHost.protocol.js';25import { IExtHostAuthentication } from './extHostAuthentication.js';26import { IExtHostRpcService } from './extHostRpcService.js';27import * as typeConvert from './extHostTypeConverters.js';28import * as extHostTypes from './extHostTypes.js';2930export interface IExtHostLanguageModels extends ExtHostLanguageModels { }3132export const IExtHostLanguageModels = createDecorator<IExtHostLanguageModels>('IExtHostLanguageModels');3334type LanguageModelProviderData = {35readonly extension: IExtensionDescription;36readonly provider: vscode.LanguageModelChatProvider;37};3839type LMResponsePart = vscode.LanguageModelTextPart | vscode.LanguageModelToolCallPart | vscode.LanguageModelDataPart | vscode.LanguageModelThinkingPart;404142class LanguageModelResponse {4344readonly apiObject: vscode.LanguageModelChatResponse;4546private readonly _defaultStream = new AsyncIterableSource<LMResponsePart>();47private _isDone: boolean = false;4849constructor() {5051const that = this;52this.apiObject = {53// result: promise,54get stream() {55return that._defaultStream.asyncIterable;56},57get text() {58return AsyncIterableObject.map(that._defaultStream.asyncIterable, part => {59if (part instanceof extHostTypes.LanguageModelTextPart) {60return part.value;61} else {62return undefined;63}64}).coalesce();65},66};67}6869handleResponsePart(parts: IChatResponsePart | IChatResponsePart[]): void {70if (this._isDone) {71return;72}7374const lmResponseParts: LMResponsePart[] = [];7576for (const part of Iterable.wrap(parts)) {7778let out: LMResponsePart;79if (part.type === 'text') {80out = new extHostTypes.LanguageModelTextPart(part.value, part.audience);81} else if (part.type === 'thinking') {82out = new extHostTypes.LanguageModelThinkingPart(part.value, part.id, part.metadata);8384} else if (part.type === 'data') {85out = new extHostTypes.LanguageModelDataPart(part.data.buffer, part.mimeType, part.audience);86} else {87out = new extHostTypes.LanguageModelToolCallPart(part.toolCallId, part.name, part.parameters);88}89lmResponseParts.push(out);90}9192this._defaultStream.emitMany(lmResponseParts);93}9495reject(err: Error): void {96this._isDone = true;97this._defaultStream.reject(err);98}99100resolve(): void {101this._isDone = true;102this._defaultStream.resolve();103}104}105106export class ExtHostLanguageModels implements ExtHostLanguageModelsShape {107108declare _serviceBrand: undefined;109110private static _idPool = 1;111112private readonly _proxy: MainThreadLanguageModelsShape;113private readonly _onDidChangeModelAccess = new Emitter<{ from: ExtensionIdentifier; to: ExtensionIdentifier }>();114private readonly _onDidChangeProviders = new Emitter<void>();115readonly onDidChangeProviders = this._onDidChangeProviders.event;116117private readonly _languageModelProviders = new Map<string, LanguageModelProviderData>();118// TODO @lramos15 - Remove the need for both info and metadata as it's a lot of redundancy. Should just need one119private readonly _localModels = new Map<string, { metadata: ILanguageModelChatMetadata; info: vscode.LanguageModelChatInformation }>();120private readonly _modelAccessList = new ExtensionIdentifierMap<ExtensionIdentifierSet>();121private readonly _pendingRequest = new Map<number, { languageModelId: string; res: LanguageModelResponse }>();122private readonly _ignoredFileProviders = new Map<number, vscode.LanguageModelIgnoredFileProvider>();123124constructor(125@IExtHostRpcService extHostRpc: IExtHostRpcService,126@ILogService private readonly _logService: ILogService,127@IExtHostAuthentication private readonly _extHostAuthentication: IExtHostAuthentication,128) {129this._proxy = extHostRpc.getProxy(MainContext.MainThreadLanguageModels);130}131132dispose(): void {133this._onDidChangeModelAccess.dispose();134this._onDidChangeProviders.dispose();135}136137registerLanguageModelChatProvider(extension: IExtensionDescription, vendor: string, provider: vscode.LanguageModelChatProvider): IDisposable {138139this._languageModelProviders.set(vendor, { extension: extension, provider });140this._proxy.$registerLanguageModelProvider(vendor);141142let providerChangeEventDisposable: IDisposable | undefined;143if (provider.onDidChangeLanguageModelChatInformation) {144providerChangeEventDisposable = provider.onDidChangeLanguageModelChatInformation(() => {145this._proxy.$onLMProviderChange(vendor);146});147}148149return toDisposable(() => {150this._languageModelProviders.delete(vendor);151this._clearModelCache(vendor);152providerChangeEventDisposable?.dispose();153this._proxy.$unregisterProvider(vendor);154});155}156157// Helper function to clear the local cache for a specific vendor. There's no lookup, so this involves iterating over all models.158private _clearModelCache(vendor: string): void {159this._localModels.forEach((value, key) => {160if (value.metadata.vendor === vendor) {161this._localModels.delete(key);162}163});164}165166async $provideLanguageModelChatInfo(vendor: string, options: { silent: boolean }, token: CancellationToken): Promise<ILanguageModelChatMetadataAndIdentifier[]> {167const data = this._languageModelProviders.get(vendor);168if (!data) {169return [];170}171this._clearModelCache(vendor);172// TODO @lramos15 - Remove this old prepare method support in debt week173const modelInformation: vscode.LanguageModelChatInformation[] = (data.provider.provideLanguageModelChatInformation ? await data.provider.provideLanguageModelChatInformation(options, token) : await (data.provider as any).prepareLanguageModelChatInformation(options, token)) ?? [];174const modelMetadataAndIdentifier: ILanguageModelChatMetadataAndIdentifier[] = modelInformation.map(m => {175let auth;176if (m.requiresAuthorization && isProposedApiEnabled(data.extension, 'chatProvider')) {177auth = {178providerLabel: data.extension.displayName || data.extension.name,179accountLabel: typeof m.requiresAuthorization === 'object' ? m.requiresAuthorization.label : undefined180};181}182return {183metadata: {184extension: data.extension.identifier,185id: m.id,186vendor,187name: m.name ?? '',188family: m.family ?? '',189detail: m.detail,190tooltip: m.tooltip,191version: m.version,192maxInputTokens: m.maxInputTokens,193maxOutputTokens: m.maxOutputTokens,194auth,195isDefault: m.isDefault,196isUserSelectable: m.isUserSelectable,197statusIcon: m.statusIcon,198modelPickerCategory: m.category ?? DEFAULT_MODEL_PICKER_CATEGORY,199capabilities: m.capabilities ? {200vision: m.capabilities.imageInput,201toolCalling: !!m.capabilities.toolCalling,202agentMode: !!m.capabilities.toolCalling203} : undefined,204},205identifier: `${vendor}/${m.id}`,206};207});208209for (let i = 0; i < modelMetadataAndIdentifier.length; i++) {210211this._localModels.set(modelMetadataAndIdentifier[i].identifier, {212metadata: modelMetadataAndIdentifier[i].metadata,213info: modelInformation[i]214});215}216217return modelMetadataAndIdentifier;218}219220async $startChatRequest(modelId: string, requestId: number, from: ExtensionIdentifier, messages: SerializableObjectWithBuffers<IChatMessage[]>, options: vscode.LanguageModelChatRequestOptions, token: CancellationToken): Promise<void> {221const knownModel = this._localModels.get(modelId);222if (!knownModel) {223throw new Error('Model not found');224}225226const data = this._languageModelProviders.get(knownModel.metadata.vendor);227if (!data) {228throw new Error(`Language model provider for '${knownModel.metadata.id}' not found.`);229}230231const queue: IChatResponsePart[] = [];232const sendNow = () => {233if (queue.length > 0) {234this._proxy.$reportResponsePart(requestId, new SerializableObjectWithBuffers(queue));235queue.length = 0;236}237};238const queueScheduler = new RunOnceScheduler(sendNow, 30);239const sendSoon = (part: IChatResponsePart) => {240const newLen = queue.push(part);241// flush/send if things pile up more than expected242if (newLen > 30) {243sendNow();244queueScheduler.cancel();245} else {246queueScheduler.schedule();247}248};249250const progress = new Progress<vscode.LanguageModelTextPart | vscode.LanguageModelToolCallPart | vscode.LanguageModelDataPart | vscode.LanguageModelThinkingPart>(async fragment => {251if (token.isCancellationRequested) {252this._logService.warn(`[CHAT](${data.extension.identifier.value}) CANNOT send progress because the REQUEST IS CANCELLED`);253return;254}255256let part: IChatResponsePart | undefined;257if (fragment instanceof extHostTypes.LanguageModelToolCallPart) {258part = { type: 'tool_use', name: fragment.name, parameters: fragment.input, toolCallId: fragment.callId };259} else if (fragment instanceof extHostTypes.LanguageModelTextPart) {260part = { type: 'text', value: fragment.value, audience: fragment.audience };261} else if (fragment instanceof extHostTypes.LanguageModelDataPart) {262part = { type: 'data', mimeType: fragment.mimeType, data: VSBuffer.wrap(fragment.data), audience: fragment.audience };263} else if (fragment instanceof extHostTypes.LanguageModelThinkingPart) {264part = { type: 'thinking', value: fragment.value, id: fragment.id, metadata: fragment.metadata };265}266267if (!part) {268this._logService.warn(`[CHAT](${data.extension.identifier.value}) UNKNOWN part ${JSON.stringify(fragment)}`);269return;270}271272sendSoon(part);273});274275let value: unknown;276277try {278value = data.provider.provideLanguageModelChatResponse(279knownModel.info,280messages.value.map(typeConvert.LanguageModelChatMessage2.to),281{ ...options, modelOptions: options.modelOptions ?? {}, requestInitiator: ExtensionIdentifier.toKey(from), toolMode: options.toolMode ?? extHostTypes.LanguageModelChatToolMode.Auto },282progress,283token284);285286} catch (err) {287// synchronously failed288throw err;289}290291Promise.resolve(value).then(() => {292sendNow();293this._proxy.$reportResponseDone(requestId, undefined);294}, err => {295sendNow();296this._proxy.$reportResponseDone(requestId, transformErrorForSerialization(err));297});298}299300//#region --- token counting301302$provideTokenLength(modelId: string, value: string, token: CancellationToken): Promise<number> {303const knownModel = this._localModels.get(modelId);304if (!knownModel) {305return Promise.resolve(0);306}307const data = this._languageModelProviders.get(knownModel.metadata.vendor);308if (!data) {309return Promise.resolve(0);310}311return Promise.resolve(data.provider.provideTokenCount(knownModel.info, value, token));312}313314315//#region --- making request316317async getDefaultLanguageModel(extension: IExtensionDescription, forceResolveModels?: boolean): Promise<vscode.LanguageModelChat | undefined> {318let defaultModelId: string | undefined;319320if (forceResolveModels) {321await this.selectLanguageModels(extension, {});322}323324for (const [modelIdentifier, modelData] of this._localModels) {325if (modelData.metadata.isDefault) {326defaultModelId = modelIdentifier;327break;328}329}330if (!defaultModelId) {331// Maybe the default wasn't cached so we will try again with resolving the models too332return this.getDefaultLanguageModel(extension, true);333}334return this.getLanguageModelByIdentifier(extension, defaultModelId);335}336337async getLanguageModelByIdentifier(extension: IExtensionDescription, modelId: string): Promise<vscode.LanguageModelChat | undefined> {338339const model = this._localModels.get(modelId);340if (!model) {341// model gone? is this an error on us?342return;343}344345// make sure auth information is correct346if (this._isUsingAuth(extension.identifier, model.metadata)) {347await this._fakeAuthPopulate(model.metadata);348}349350let apiObject: vscode.LanguageModelChat | undefined;351if (!apiObject) {352const that = this;353apiObject = {354id: model.info.id,355vendor: model.metadata.vendor,356family: model.info.family,357version: model.info.version,358name: model.info.name,359capabilities: {360supportsImageToText: model.metadata.capabilities?.vision ?? false,361supportsToolCalling: !!model.metadata.capabilities?.toolCalling,362},363maxInputTokens: model.metadata.maxInputTokens,364countTokens(text, token) {365if (!that._localModels.has(modelId)) {366throw extHostTypes.LanguageModelError.NotFound(modelId);367}368return that._computeTokenLength(modelId, text, token ?? CancellationToken.None);369},370sendRequest(messages, options, token) {371if (!that._localModels.has(modelId)) {372throw extHostTypes.LanguageModelError.NotFound(modelId);373}374return that._sendChatRequest(extension, modelId, messages, options ?? {}, token ?? CancellationToken.None);375}376};377378Object.freeze(apiObject);379}380381return apiObject;382}383384async selectLanguageModels(extension: IExtensionDescription, selector: vscode.LanguageModelChatSelector) {385386// this triggers extension activation387const models = await this._proxy.$selectChatModels({ ...selector, extension: extension.identifier });388389const result: vscode.LanguageModelChat[] = [];390391const modelPromises = models.map(identifier => this.getLanguageModelByIdentifier(extension, identifier));392const modelResults = await Promise.all(modelPromises);393for (const model of modelResults) {394if (model) {395result.push(model);396}397}398399return result;400}401402private async _sendChatRequest(extension: IExtensionDescription, languageModelId: string, messages: vscode.LanguageModelChatMessage2[], options: vscode.LanguageModelChatRequestOptions, token: CancellationToken) {403404const internalMessages: IChatMessage[] = this._convertMessages(extension, messages);405406const from = extension.identifier;407const metadata = this._localModels.get(languageModelId)?.metadata;408409if (!metadata || !this._localModels.has(languageModelId)) {410throw extHostTypes.LanguageModelError.NotFound(`Language model '${languageModelId}' is unknown.`);411}412413if (this._isUsingAuth(from, metadata)) {414const success = await this._getAuthAccess(extension, { identifier: metadata.extension, displayName: metadata.auth.providerLabel }, options.justification, false);415416if (!success || !this._modelAccessList.get(from)?.has(metadata.extension)) {417throw extHostTypes.LanguageModelError.NoPermissions(`Language model '${languageModelId}' cannot be used by '${from.value}'.`);418}419}420421const requestId = (Math.random() * 1e6) | 0;422const res = new LanguageModelResponse();423this._pendingRequest.set(requestId, { languageModelId, res });424425try {426await this._proxy.$tryStartChatRequest(from, languageModelId, requestId, new SerializableObjectWithBuffers(internalMessages), options, token);427428} catch (error) {429// error'ing here means that the request could NOT be started/made, e.g. wrong model, no access, etc, but430// later the response can fail as well. Those failures are communicated via the stream-object431this._pendingRequest.delete(requestId);432throw extHostTypes.LanguageModelError.tryDeserialize(error) ?? error;433}434435return res.apiObject;436}437438private _convertMessages(extension: IExtensionDescription, messages: vscode.LanguageModelChatMessage2[]) {439const internalMessages: IChatMessage[] = [];440for (const message of messages) {441if (message.role as number === extHostTypes.LanguageModelChatMessageRole.System) {442checkProposedApiEnabled(extension, 'languageModelSystem');443}444internalMessages.push(typeConvert.LanguageModelChatMessage2.from(message));445}446return internalMessages;447}448449async $acceptResponsePart(requestId: number, chunk: SerializableObjectWithBuffers<IChatResponsePart | IChatResponsePart[]>): Promise<void> {450const data = this._pendingRequest.get(requestId);451if (data) {452data.res.handleResponsePart(chunk.value);453}454}455456async $acceptResponseDone(requestId: number, error: SerializedError | undefined): Promise<void> {457const data = this._pendingRequest.get(requestId);458if (!data) {459return;460}461this._pendingRequest.delete(requestId);462if (error) {463// we error the stream because that's the only way to signal464// that the request has failed465data.res.reject(extHostTypes.LanguageModelError.tryDeserialize(error) ?? transformErrorFromSerialization(error));466} else {467data.res.resolve();468}469}470471// BIG HACK: Using AuthenticationProviders to check access to Language Models472private async _getAuthAccess(from: IExtensionDescription, to: { identifier: ExtensionIdentifier; displayName: string }, justification: string | undefined, silent: boolean | undefined): Promise<boolean> {473// This needs to be done in both MainThread & ExtHost ChatProvider474const providerId = INTERNAL_AUTH_PROVIDER_PREFIX + to.identifier.value;475const session = await this._extHostAuthentication.getSession(from, providerId, [], { silent: true });476477if (session) {478this.$updateModelAccesslist([{ from: from.identifier, to: to.identifier, enabled: true }]);479return true;480}481482if (silent) {483return false;484}485486try {487const detail = justification488? localize('chatAccessWithJustification', "Justification: {1}", to.displayName, justification)489: undefined;490await this._extHostAuthentication.getSession(from, providerId, [], { forceNewSession: { detail } });491this.$updateModelAccesslist([{ from: from.identifier, to: to.identifier, enabled: true }]);492return true;493494} catch (err) {495// ignore496return false;497}498}499500private _isUsingAuth(from: ExtensionIdentifier, toMetadata: ILanguageModelChatMetadata): toMetadata is ILanguageModelChatMetadata & { auth: NonNullable<ILanguageModelChatMetadata['auth']> } {501// If the 'to' extension uses an auth check502return !!toMetadata.auth503// And we're asking from a different extension504&& !ExtensionIdentifier.equals(toMetadata.extension, from);505}506507private async _fakeAuthPopulate(metadata: ILanguageModelChatMetadata): Promise<void> {508509if (!metadata.auth) {510return;511}512513for (const from of this._languageAccessInformationExtensions) {514try {515await this._getAuthAccess(from, { identifier: metadata.extension, displayName: '' }, undefined, true);516} catch (err) {517this._logService.error('Fake Auth request failed');518this._logService.error(err);519}520}521}522523private async _computeTokenLength(modelId: string, value: string | vscode.LanguageModelChatMessage2, token: vscode.CancellationToken): Promise<number> {524525const data = this._localModels.get(modelId);526if (!data) {527throw extHostTypes.LanguageModelError.NotFound(`Language model '${modelId}' is unknown.`);528}529return this._languageModelProviders.get(data.metadata.vendor)?.provider.provideTokenCount(data.info, value, token) ?? 0;530// return this._proxy.$countTokens(languageModelId, (typeof value === 'string' ? value : typeConvert.LanguageModelChatMessage2.from(value)), token);531}532533$updateModelAccesslist(data: { from: ExtensionIdentifier; to: ExtensionIdentifier; enabled: boolean }[]): void {534const updated = new Array<{ from: ExtensionIdentifier; to: ExtensionIdentifier }>();535for (const { from, to, enabled } of data) {536const set = this._modelAccessList.get(from) ?? new ExtensionIdentifierSet();537const oldValue = set.has(to);538if (oldValue !== enabled) {539if (enabled) {540set.add(to);541} else {542set.delete(to);543}544this._modelAccessList.set(from, set);545const newItem = { from, to };546updated.push(newItem);547this._onDidChangeModelAccess.fire(newItem);548}549}550}551552private readonly _languageAccessInformationExtensions = new Set<Readonly<IExtensionDescription>>();553554createLanguageModelAccessInformation(from: Readonly<IExtensionDescription>): vscode.LanguageModelAccessInformation {555556this._languageAccessInformationExtensions.add(from);557558// const that = this;559const _onDidChangeAccess = Event.signal(Event.filter(this._onDidChangeModelAccess.event, e => ExtensionIdentifier.equals(e.from, from.identifier)));560const _onDidAddRemove = Event.signal(this._onDidChangeProviders.event);561562return {563get onDidChange() {564return Event.any(_onDidChangeAccess, _onDidAddRemove);565},566canSendRequest(chat: vscode.LanguageModelChat): boolean | undefined {567return true;568// TODO @lramos15 - Fix569570// let metadata: ILanguageModelChatMetadata | undefined;571572// out: for (const [_, value] of that._allLanguageModelData) {573// for (const candidate of value.apiObjects.values()) {574// if (candidate === chat) {575// metadata = value.metadata;576// break out;577// }578// }579// }580// if (!metadata) {581// return undefined;582// }583// if (!that._isUsingAuth(from.identifier, metadata)) {584// return true;585// }586587// const list = that._modelAccessList.get(from.identifier);588// if (!list) {589// return undefined;590// }591// return list.has(metadata.extension);592}593};594}595596fileIsIgnored(extension: IExtensionDescription, uri: vscode.Uri, token: vscode.CancellationToken = CancellationToken.None): Promise<boolean> {597checkProposedApiEnabled(extension, 'chatParticipantAdditions');598599return this._proxy.$fileIsIgnored(uri, token);600}601602async $isFileIgnored(handle: number, uri: UriComponents, token: CancellationToken): Promise<boolean> {603const provider = this._ignoredFileProviders.get(handle);604if (!provider) {605throw new Error('Unknown LanguageModelIgnoredFileProvider');606}607608return (await provider.provideFileIgnored(URI.revive(uri), token)) ?? false;609}610611registerIgnoredFileProvider(extension: IExtensionDescription, provider: vscode.LanguageModelIgnoredFileProvider): vscode.Disposable {612checkProposedApiEnabled(extension, 'chatParticipantPrivate');613614const handle = ExtHostLanguageModels._idPool++;615this._proxy.$registerFileIgnoreProvider(handle);616this._ignoredFileProviders.set(handle, provider);617return toDisposable(() => {618this._proxy.$unregisterFileIgnoreProvider(handle);619this._ignoredFileProviders.delete(handle);620});621}622}623624625