Path: blob/main/src/vs/workbench/contrib/chat/browser/chatManagement/chatModelsViewModel.ts
4780 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/45import { distinct } from '../../../../../base/common/arrays.js';6import { IMatch, IFilter, or, matchesCamelCase, matchesWords, matchesBaseContiguousSubString } from '../../../../../base/common/filters.js';7import { Emitter } from '../../../../../base/common/event.js';8import { ILanguageModelsService, IUserFriendlyLanguageModel, ILanguageModelChatMetadataAndIdentifier } from '../../../chat/common/languageModels.js';9import { IChatEntitlementService } from '../../../../services/chat/common/chatEntitlementService.js';10import { localize } from '../../../../../nls.js';11import { Disposable } from '../../../../../base/common/lifecycle.js';12import { ILanguageModelsProviderGroup, ILanguageModelsConfigurationService } from '../../common/languageModelsConfiguration.js';13import { Throttler } from '../../../../../base/common/async.js';14import Severity from '../../../../../base/common/severity.js';1516export const MODEL_ENTRY_TEMPLATE_ID = 'model.entry.template';17export const VENDOR_ENTRY_TEMPLATE_ID = 'vendor.entry.template';18export const GROUP_ENTRY_TEMPLATE_ID = 'group.entry.template';1920const wordFilter = or(matchesBaseContiguousSubString, matchesWords);21const CAPABILITY_REGEX = /@capability:\s*([^\s]+)/gi;22const VISIBLE_REGEX = /@visible:\s*(true|false)/i;23const PROVIDER_REGEX = /@provider:\s*((".+?")|([^\s]+))/gi;2425export const SEARCH_SUGGESTIONS = {26FILTER_TYPES: [27'@provider:',28'@capability:',29'@visible:'30],31CAPABILITIES: [32'@capability:tools',33'@capability:vision',34'@capability:agent'35],36VISIBILITY: [37'@visible:true',38'@visible:false'39]40};4142export interface ILanguageModelProvider {43vendor: IUserFriendlyLanguageModel;44group: ILanguageModelsProviderGroup;45}4647export interface ILanguageModel extends ILanguageModelChatMetadataAndIdentifier {48provider: ILanguageModelProvider;49}5051export interface ILanguageModelEntry {52type: 'model';53id: string;54templateId: string;55model: ILanguageModel;56providerMatches?: IMatch[];57modelNameMatches?: IMatch[];58modelIdMatches?: IMatch[];59capabilityMatches?: string[];60}6162export interface ILanguageModelGroupEntry {63type: 'group';64id: string;65label: string;66collapsed: boolean;67templateId: string;68}6970export interface ILanguageModelProviderEntry {71type: 'vendor';72id: string;73label: string;74templateId: string;75collapsed: boolean;76vendorEntry: ILanguageModelProvider;77}7879export interface IStatusEntry {80type: 'status';81id: string;82message: string;83severity: Severity;84}8586export interface ILanguageModelEntriesGroup {87group: ILanguageModelGroupEntry | ILanguageModelProviderEntry;88models: ILanguageModel[];89status?: IStatusEntry;90}9192export function isLanguageModelProviderEntry(entry: IViewModelEntry): entry is ILanguageModelProviderEntry {93return entry.type === 'vendor';94}9596export function isLanguageModelGroupEntry(entry: IViewModelEntry): entry is ILanguageModelGroupEntry {97return entry.type === 'group';98}99100export function isStatusEntry(entry: IViewModelEntry): entry is IStatusEntry {101return entry.type === 'status';102}103104export type IViewModelEntry = ILanguageModelEntry | ILanguageModelProviderEntry | ILanguageModelGroupEntry | IStatusEntry;105106export interface IViewModelChangeEvent {107at: number;108removed: number;109added: IViewModelEntry[];110}111112export const enum ChatModelGroup {113Vendor = 'vendor',114Visibility = 'visibility'115}116117export class ChatModelsViewModel extends Disposable {118119private readonly _onDidChange = this._register(new Emitter<IViewModelChangeEvent>());120readonly onDidChange = this._onDidChange.event;121122private readonly _onDidChangeGrouping = this._register(new Emitter<ChatModelGroup>());123readonly onDidChangeGrouping = this._onDidChangeGrouping.event;124125private languageModels: ILanguageModel[];126private languageModelGroupStatuses: Array<{ provider: ILanguageModelProvider; status: { severity: Severity; message: string } }> = [];127private languageModelGroups: ILanguageModelEntriesGroup[] = [];128129private readonly collapsedGroups = new Set<string>();130private searchValue: string = '';131private modelsSorted: boolean = false;132133private _groupBy: ChatModelGroup = ChatModelGroup.Vendor;134get groupBy(): ChatModelGroup { return this._groupBy; }135set groupBy(groupBy: ChatModelGroup) {136if (this._groupBy !== groupBy) {137this._groupBy = groupBy;138this.collapsedGroups.clear();139this.languageModelGroups = this.groupModels(this.languageModels);140this.doFilter();141this._onDidChangeGrouping.fire(groupBy);142}143}144145private readonly refreshThrottler = this._register(new Throttler());146147constructor(148@ILanguageModelsService private readonly languageModelsService: ILanguageModelsService,149@ILanguageModelsConfigurationService private readonly languageModelsConfigurationService: ILanguageModelsConfigurationService,150@IChatEntitlementService private readonly chatEntitlementService: IChatEntitlementService151) {152super();153this.languageModels = [];154this._register(this.chatEntitlementService.onDidChangeEntitlement(() => this.refresh()));155this._register(this.languageModelsConfigurationService.onDidChangeLanguageModelGroups(() => this.refresh()));156}157158private readonly _viewModelEntries: IViewModelEntry[] = [];159get viewModelEntries(): readonly IViewModelEntry[] {160return this._viewModelEntries;161}162private splice(at: number, removed: number, added: IViewModelEntry[]): void {163this._viewModelEntries.splice(at, removed, ...added);164if (this.selectedEntry) {165this.selectedEntry = this._viewModelEntries.find(entry => entry.id === this.selectedEntry?.id);166}167this._onDidChange.fire({ at, removed, added });168}169170selectedEntry: IViewModelEntry | undefined;171172public shouldRefilter(): boolean {173return !this.modelsSorted;174}175176filter(searchValue: string): readonly IViewModelEntry[] {177if (searchValue !== this.searchValue) {178this.collapsedGroups.clear();179}180this.searchValue = searchValue;181if (!this.modelsSorted) {182this.languageModelGroups = this.groupModels(this.languageModels);183}184185this.doFilter();186return this.viewModelEntries;187}188189private doFilter(): void {190const viewModelEntries: IViewModelEntry[] = [];191const shouldShowGroupHeaders = this.languageModelGroups.length > 1;192193for (const group of this.languageModelGroups) {194if (this.collapsedGroups.has(group.group.id)) {195group.group.collapsed = true;196if (shouldShowGroupHeaders) {197viewModelEntries.push(group.group);198}199continue;200}201202const groupEntries: IViewModelEntry[] = [];203if (group.status) {204groupEntries.push(group.status);205}206207groupEntries.push(...this.filterModels(group.models, this.searchValue));208209if (groupEntries.length > 0) {210group.group.collapsed = false;211if (shouldShowGroupHeaders) {212viewModelEntries.push(group.group);213}214viewModelEntries.push(...groupEntries);215}216}217this.splice(0, this._viewModelEntries.length, viewModelEntries);218}219220private filterModels(modelEntries: ILanguageModel[], searchValue: string): IViewModelEntry[] {221let visible: boolean | undefined;222223const visibleMatches = VISIBLE_REGEX.exec(searchValue);224if (visibleMatches && visibleMatches[1]) {225visible = visibleMatches[1].toLowerCase() === 'true';226searchValue = searchValue.replace(VISIBLE_REGEX, '');227}228229const providerNames: string[] = [];230let providerMatch: RegExpExecArray | null;231PROVIDER_REGEX.lastIndex = 0;232while ((providerMatch = PROVIDER_REGEX.exec(searchValue)) !== null) {233const providerName = providerMatch[2] ? providerMatch[2].substring(1, providerMatch[2].length - 1) : providerMatch[3];234providerNames.push(providerName);235}236if (providerNames.length > 0) {237searchValue = searchValue.replace(PROVIDER_REGEX, '');238}239240const capabilities: string[] = [];241let capabilityMatch: RegExpExecArray | null;242CAPABILITY_REGEX.lastIndex = 0;243while ((capabilityMatch = CAPABILITY_REGEX.exec(searchValue)) !== null) {244capabilities.push(capabilityMatch[1].toLowerCase());245}246if (capabilities.length > 0) {247searchValue = searchValue.replace(CAPABILITY_REGEX, '');248}249250const quoteAtFirstChar = searchValue.charAt(0) === '"';251const quoteAtLastChar = searchValue.charAt(searchValue.length - 1) === '"';252const completeMatch = quoteAtFirstChar && quoteAtLastChar;253if (quoteAtFirstChar) {254searchValue = searchValue.substring(1);255}256if (quoteAtLastChar) {257searchValue = searchValue.substring(0, searchValue.length - 1);258}259searchValue = searchValue.trim();260261const result: IViewModelEntry[] = [];262const words = searchValue.split(' ');263const lowerProviders = providerNames.map(p => p.toLowerCase().trim());264265for (const modelEntry of modelEntries) {266if (visible !== undefined) {267if ((modelEntry.metadata.isUserSelectable ?? false) !== visible) {268continue;269}270}271272if (lowerProviders.length > 0) {273const matchesProvider = lowerProviders.some(provider =>274modelEntry.provider.vendor.vendor.toLowerCase() === provider ||275modelEntry.provider.vendor.displayName.toLowerCase() === provider276);277if (!matchesProvider) {278continue;279}280}281282// Filter by capabilities283let matchedCapabilities: string[] = [];284if (capabilities.length > 0) {285if (!modelEntry.metadata.capabilities) {286continue;287}288let matchesAll = true;289for (const capability of capabilities) {290const matchedForThisCapability = this.getMatchingCapabilities(modelEntry, capability);291if (matchedForThisCapability.length === 0) {292matchesAll = false;293break;294}295matchedCapabilities.push(...matchedForThisCapability);296}297if (!matchesAll) {298continue;299}300matchedCapabilities = distinct(matchedCapabilities);301}302303// Filter by text304let modelMatches: ModelItemMatches | undefined;305if (searchValue) {306modelMatches = new ModelItemMatches(modelEntry, searchValue, words, completeMatch);307if (!modelMatches.modelNameMatches && !modelMatches.modelIdMatches && !modelMatches.providerMatches && !modelMatches.capabilityMatches) {308continue;309}310}311312const modelId = this.getModelId(modelEntry);313result.push({314type: 'model',315id: modelId,316templateId: MODEL_ENTRY_TEMPLATE_ID,317model: modelEntry,318modelNameMatches: modelMatches?.modelNameMatches || undefined,319modelIdMatches: modelMatches?.modelIdMatches || undefined,320providerMatches: modelMatches?.providerMatches || undefined,321capabilityMatches: matchedCapabilities.length ? matchedCapabilities : undefined,322});323}324return result;325}326327private getMatchingCapabilities(modelEntry: ILanguageModel, capability: string): string[] {328const matchedCapabilities: string[] = [];329if (!modelEntry.metadata.capabilities) {330return matchedCapabilities;331}332333switch (capability) {334case 'tools':335case 'toolcalling':336if (modelEntry.metadata.capabilities.toolCalling === true) {337matchedCapabilities.push('toolCalling');338}339break;340case 'vision':341if (modelEntry.metadata.capabilities.vision === true) {342matchedCapabilities.push('vision');343}344break;345case 'agent':346case 'agentmode':347if (modelEntry.metadata.capabilities.agentMode === true) {348matchedCapabilities.push('agentMode');349}350break;351default:352// Check edit tools353if (modelEntry.metadata.capabilities.editTools) {354for (const tool of modelEntry.metadata.capabilities.editTools) {355if (tool.toLowerCase().includes(capability)) {356matchedCapabilities.push(tool);357}358}359}360break;361}362return matchedCapabilities;363}364365private groupModels(languageModels: ILanguageModel[]): ILanguageModelEntriesGroup[] {366const result: ILanguageModelEntriesGroup[] = [];367if (this.groupBy === ChatModelGroup.Visibility) {368const visible = [], hidden = [];369for (const model of languageModels) {370if (model.metadata.isUserSelectable) {371visible.push(model);372} else {373hidden.push(model);374}375}376result.push({377group: {378type: 'group',379id: 'visible',380label: localize('visible', "Visible"),381templateId: GROUP_ENTRY_TEMPLATE_ID,382collapsed: this.collapsedGroups.has('visible')383},384models: visible385});386result.push({387group: {388type: 'group',389id: 'hidden',390label: localize('hidden', "Hidden"),391templateId: GROUP_ENTRY_TEMPLATE_ID,392collapsed: this.collapsedGroups.has('hidden'),393},394models: hidden395});396}397else if (this.groupBy === ChatModelGroup.Vendor) {398for (const model of languageModels) {399const groupId = this.getProviderGroupId(model.provider.group);400let group = result.find(group => group.group.id === groupId);401if (!group) {402group = {403group: this.createLanguageModelProviderEntry(model.provider),404models: [],405};406result.push(group);407}408group.models.push(model);409}410for (const statusGroup of this.languageModelGroupStatuses) {411const groupId = this.getProviderGroupId(statusGroup.provider.group);412let group = result.find(group => group.group.id === groupId);413if (!group) {414group = {415group: this.createLanguageModelProviderEntry(statusGroup.provider),416models: [],417};418result.push(group);419}420group.status = {421id: `status.${group.group.id}`,422type: 'status',423...statusGroup.status,424};425}426result.sort((a, b) => {427if (a.models[0]?.provider.vendor.vendor === 'copilot') { return -1; }428if (b.models[0]?.provider.vendor.vendor === 'copilot') { return 1; }429return a.group.label.localeCompare(b.group.label);430});431}432for (const group of result) {433group.models.sort((a, b) => {434if (a.provider.vendor.vendor === 'copilot' && b.provider.vendor.vendor === 'copilot') {435return a.metadata.name.localeCompare(b.metadata.name);436}437if (a.provider.vendor.vendor === 'copilot') { return -1; }438if (b.provider.vendor.vendor === 'copilot') { return 1; }439if (a.provider.group.name === b.provider.group.name) {440return a.metadata.name.localeCompare(b.metadata.name);441}442return a.provider.group.name.localeCompare(b.provider.group.name);443});444}445this.modelsSorted = true;446return result;447}448449private createLanguageModelProviderEntry(provider: ILanguageModelProvider): ILanguageModelProviderEntry {450const id = this.getProviderGroupId(provider.group);451return {452type: 'vendor',453id,454label: provider.group.name,455templateId: VENDOR_ENTRY_TEMPLATE_ID,456collapsed: this.collapsedGroups.has(id),457vendorEntry: {458group: provider.group,459vendor: provider.vendor460},461};462}463464getVendors(): IUserFriendlyLanguageModel[] {465return [...this.languageModelsService.getVendors()].sort((a, b) => {466if (a.vendor === 'copilot') { return -1; }467if (b.vendor === 'copilot') { return 1; }468return a.displayName.localeCompare(b.displayName);469});470}471472refresh(): Promise<void> {473return this.refreshThrottler.queue(() => this.doRefresh());474}475476private async doRefresh(): Promise<void> {477this.languageModels = [];478this.languageModelGroupStatuses = [];479for (const vendor of this.getVendors()) {480const models: ILanguageModel[] = [];481const languageModelsGroups = await this.languageModelsService.fetchLanguageModelGroups(vendor.vendor);482for (const group of languageModelsGroups) {483const provider: ILanguageModelProvider = {484group: group.group ?? {485vendor: vendor.vendor,486name: vendor.displayName487},488vendor489};490if (group.status) {491this.languageModelGroupStatuses.push({492provider,493status: {494message: group.status.message,495severity: group.status.severity496}497});498}499for (const model of group.models) {500if (vendor.vendor === 'copilot' && model.metadata.id === 'auto') {501continue;502}503models.push({504identifier: model.identifier,505metadata: model.metadata,506provider,507});508}509}510this.languageModels.push(...models.sort((a, b) => a.metadata.name.localeCompare(b.metadata.name)));511this.languageModelGroups = this.groupModels(this.languageModels);512this.doFilter();513}514}515516toggleVisibility(model: ILanguageModelEntry): void {517const isVisible = model.model.metadata.isUserSelectable ?? false;518const newVisibility = !isVisible;519this.languageModelsService.updateModelPickerPreference(model.model.identifier, newVisibility);520const metadata = this.languageModelsService.lookupLanguageModel(model.model.identifier);521const index = this.viewModelEntries.indexOf(model);522if (metadata && index !== -1) {523model.id = this.getModelId(model.model);524model.model.metadata = metadata;525if (this.groupBy === ChatModelGroup.Visibility) {526this.modelsSorted = false;527}528this.splice(index, 1, [model]);529}530}531532private getModelId(modelEntry: ILanguageModel): string {533return `${modelEntry.provider.group.name}.${modelEntry.identifier}.${modelEntry.metadata.version}-visible:${modelEntry.metadata.isUserSelectable}`;534}535536private getProviderGroupId(group: ILanguageModelsProviderGroup): string {537return `${group.vendor}-${group.name}`;538}539540toggleCollapsed(viewModelEntry: IViewModelEntry): void {541const id = isLanguageModelGroupEntry(viewModelEntry) ? viewModelEntry.id : isLanguageModelProviderEntry(viewModelEntry) ? viewModelEntry.id : undefined;542if (!id) {543return;544}545this.selectedEntry = viewModelEntry;546if (!this.collapsedGroups.delete(id)) {547this.collapsedGroups.add(id);548}549this.doFilter();550}551552collapseAll(): void {553this.collapsedGroups.clear();554for (const entry of this.viewModelEntries) {555if (isLanguageModelProviderEntry(entry) || isLanguageModelGroupEntry(entry)) {556this.collapsedGroups.add(entry.id);557}558}559this.filter(this.searchValue);560}561562getConfiguredVendors(): ILanguageModelProvider[] {563const result: ILanguageModelProvider[] = [];564const seenVendors = new Set<string>();565for (const modelEntry of this.languageModels) {566if (!seenVendors.has(modelEntry.provider.group.name)) {567seenVendors.add(modelEntry.provider.group.name);568result.push(modelEntry.provider);569}570}571return result;572}573}574575class ModelItemMatches {576577readonly modelNameMatches: IMatch[] | null = null;578readonly modelIdMatches: IMatch[] | null = null;579readonly providerMatches: IMatch[] | null = null;580readonly capabilityMatches: IMatch[] | null = null;581582constructor(modelEntry: ILanguageModel, searchValue: string, words: string[], completeMatch: boolean) {583if (!completeMatch) {584// Match against model name585this.modelNameMatches = modelEntry.metadata.name ?586this.matches(searchValue, modelEntry.metadata.name, (word, wordToMatchAgainst) => matchesWords(word, wordToMatchAgainst, true), words) :587null;588589this.modelIdMatches = this.matches(searchValue, modelEntry.metadata.id, or(matchesWords, matchesCamelCase), words);590591// Match against vendor display name592this.providerMatches = this.matches(searchValue, modelEntry.provider.group.name, (word, wordToMatchAgainst) => matchesWords(word, wordToMatchAgainst, true), words);593594// Match against capabilities595if (modelEntry.metadata.capabilities) {596const capabilityStrings: string[] = [];597if (modelEntry.metadata.capabilities.toolCalling) {598capabilityStrings.push('tools', 'toolCalling');599}600if (modelEntry.metadata.capabilities.vision) {601capabilityStrings.push('vision');602}603if (modelEntry.metadata.capabilities.agentMode) {604capabilityStrings.push('agent', 'agentMode');605}606if (modelEntry.metadata.capabilities.editTools) {607capabilityStrings.push(...modelEntry.metadata.capabilities.editTools);608}609610const capabilityString = capabilityStrings.join(' ');611if (capabilityString) {612this.capabilityMatches = this.matches(searchValue, capabilityString, or(matchesWords, matchesCamelCase), words);613}614}615}616}617618private matches(searchValue: string | null, wordToMatchAgainst: string, wordMatchesFilter: IFilter, words: string[]): IMatch[] | null {619let matches = searchValue ? wordFilter(searchValue, wordToMatchAgainst) : null;620if (!matches) {621matches = this.matchesWords(words, wordToMatchAgainst, wordMatchesFilter);622}623if (matches) {624matches = this.filterAndSort(matches);625}626return matches;627}628629private matchesWords(words: string[], wordToMatchAgainst: string, wordMatchesFilter: IFilter): IMatch[] | null {630let matches: IMatch[] | null = [];631for (const word of words) {632const wordMatches = wordMatchesFilter(word, wordToMatchAgainst);633if (wordMatches) {634matches = [...(matches || []), ...wordMatches];635} else {636matches = null;637break;638}639}640return matches;641}642643private filterAndSort(matches: IMatch[]): IMatch[] {644return distinct(matches, (a => a.start + '.' + a.end))645.filter(match => !matches.some(m => !(m.start === match.start && m.end === match.end) && (m.start <= match.start && m.end >= match.end)))646.sort((a, b) => a.start - b.start);647}648}649650651