Path: blob/main/src/vs/workbench/contrib/chat/test/common/languageModels.test.ts
3296 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/45import assert from 'assert';6import { AsyncIterableSource, DeferredPromise, timeout } from '../../../../../base/common/async.js';7import { CancellationToken, CancellationTokenSource } from '../../../../../base/common/cancellation.js';8import { DisposableStore } from '../../../../../base/common/lifecycle.js';9import { mock } from '../../../../../base/test/common/mock.js';10import { ensureNoDisposablesAreLeakedInTestSuite } from '../../../../../base/test/common/utils.js';11import { NullLogService } from '../../../../../platform/log/common/log.js';12import { ChatMessageRole, languageModelChatProviderExtensionPoint, LanguageModelsService, IChatMessage, IChatResponsePart } from '../../common/languageModels.js';13import { IExtensionService, nullExtensionDescription } from '../../../../services/extensions/common/extensions.js';14import { ExtensionsRegistry } from '../../../../services/extensions/common/extensionsRegistry.js';15import { DEFAULT_MODEL_PICKER_CATEGORY } from '../../common/modelPicker/modelPickerWidget.js';16import { ExtensionIdentifier } from '../../../../../platform/extensions/common/extensions.js';17import { TestStorageService } from '../../../../test/common/workbenchTestServices.js';18import { Event } from '../../../../../base/common/event.js';19import { MockContextKeyService } from '../../../../../platform/keybinding/test/common/mockKeybindingService.js';2021suite('LanguageModels', function () {2223let languageModels: LanguageModelsService;2425const store = new DisposableStore();26const activationEvents = new Set<string>();2728setup(function () {2930languageModels = new LanguageModelsService(31new class extends mock<IExtensionService>() {32override activateByEvent(name: string) {33activationEvents.add(name);34return Promise.resolve();35}36},37new NullLogService(),38new TestStorageService(),39new MockContextKeyService()40);4142const ext = ExtensionsRegistry.getExtensionPoints().find(e => e.name === languageModelChatProviderExtensionPoint.name)!;4344ext.acceptUsers([{45description: { ...nullExtensionDescription },46value: { vendor: 'test-vendor' },47collector: null!48}, {49description: { ...nullExtensionDescription },50value: { vendor: 'actual-vendor' },51collector: null!52}]);5354store.add(languageModels.registerLanguageModelProvider('test-vendor', {55onDidChange: Event.None,56provideLanguageModelChatInfo: async () => {57const modelMetadata = [58{59extension: nullExtensionDescription.identifier,60name: 'Pretty Name',61vendor: 'test-vendor',62family: 'test-family',63version: 'test-version',64modelPickerCategory: undefined,65id: 'test-id-1',66maxInputTokens: 100,67maxOutputTokens: 100,68},69{70extension: nullExtensionDescription.identifier,71name: 'Pretty Name',72vendor: 'test-vendor',73family: 'test2-family',74version: 'test2-version',75modelPickerCategory: undefined,76id: 'test-id-12',77maxInputTokens: 100,78maxOutputTokens: 100,79}80];81const modelMetadataAndIdentifier = modelMetadata.map(m => ({82metadata: m,83identifier: m.id,84}));85return modelMetadataAndIdentifier;86},87sendChatRequest: async () => {88throw new Error();89},90provideTokenCount: async () => {91throw new Error();92}93}));94});9596teardown(function () {97languageModels.dispose();98activationEvents.clear();99store.clear();100});101102ensureNoDisposablesAreLeakedInTestSuite();103104test('empty selector returns all', async function () {105106const result1 = await languageModels.selectLanguageModels({});107assert.deepStrictEqual(result1.length, 2);108assert.deepStrictEqual(result1[0], 'test-id-1');109assert.deepStrictEqual(result1[1], 'test-id-12');110});111112test('selector with id works properly', async function () {113const result1 = await languageModels.selectLanguageModels({ id: 'test-id-1' });114assert.deepStrictEqual(result1.length, 1);115assert.deepStrictEqual(result1[0], 'test-id-1');116});117118test('no warning that a matching model was not found #213716', async function () {119const result1 = await languageModels.selectLanguageModels({ vendor: 'test-vendor' });120assert.deepStrictEqual(result1.length, 2);121122const result2 = await languageModels.selectLanguageModels({ vendor: 'test-vendor', family: 'FAKE' });123assert.deepStrictEqual(result2.length, 0);124});125126test('sendChatRequest returns a response-stream', async function () {127128store.add(languageModels.registerLanguageModelProvider('actual-vendor', {129onDidChange: Event.None,130provideLanguageModelChatInfo: async () => {131const modelMetadata = [132{133extension: nullExtensionDescription.identifier,134name: 'Pretty Name',135vendor: 'actual-vendor',136family: 'actual-family',137version: 'actual-version',138id: 'actual-lm',139maxInputTokens: 100,140maxOutputTokens: 100,141modelPickerCategory: DEFAULT_MODEL_PICKER_CATEGORY,142}143];144const modelMetadataAndIdentifier = modelMetadata.map(m => ({145metadata: m,146identifier: m.id,147}));148return modelMetadataAndIdentifier;149},150sendChatRequest: async (modelId: string, messages: IChatMessage[], _from: ExtensionIdentifier, _options: { [name: string]: any }, token: CancellationToken) => {151// const message = messages.at(-1);152153const defer = new DeferredPromise();154const stream = new AsyncIterableSource<IChatResponsePart>();155156(async () => {157while (!token.isCancellationRequested) {158stream.emitOne({ type: 'text', value: Date.now().toString() });159await timeout(10);160}161defer.complete(undefined);162})();163164return {165stream: stream.asyncIterable,166result: defer.p167};168},169provideTokenCount: async () => {170throw new Error();171}172}));173174// Register the extension point for the actual vendor175const ext = ExtensionsRegistry.getExtensionPoints().find(e => e.name === languageModelChatProviderExtensionPoint.name)!;176ext.acceptUsers([{177description: { ...nullExtensionDescription },178value: { vendor: 'actual-vendor' },179collector: null!180}]);181182const models = await languageModels.selectLanguageModels({ id: 'actual-lm' });183assert.ok(models.length === 1);184185const first = models[0];186187const cts = new CancellationTokenSource();188189const request = await languageModels.sendChatRequest(first, nullExtensionDescription.identifier, [{ role: ChatMessageRole.User, content: [{ type: 'text', value: 'hello' }] }], {}, cts.token);190191assert.ok(request);192193cts.dispose(true);194195await request.result;196});197});198199200