Path: blob/main/extensions/copilot/src/platform/endpoint/test/node/testEndpointProvider.ts
13405 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/45import type { ChatRequest, LanguageModelChat } from 'vscode';6import { CacheableRequest, SQLiteCache } from '../../../../../test/base/cache';7import { TestingCacheSalts } from '../../../../../test/base/salts';8import { CurrentTestRunInfo } from '../../../../../test/base/simulationContext';9import { TokenizerType } from '../../../../util/common/tokenizer';10import { SequencerByKey } from '../../../../util/vs/base/common/async';11import { Event } from '../../../../util/vs/base/common/event';12import { IInstantiationService } from '../../../../util/vs/platform/instantiation/common/instantiation';13import { IAuthenticationService } from '../../../authentication/common/authentication';14import { CHAT_MODEL, IConfigurationService } from '../../../configuration/common/configurationService';15import { LEGACY_EMBEDDING_MODEL_ID } from '../../../embeddings/common/embeddingsComputer';16import { IEnvService } from '../../../env/common/envService';17import { IOctoKitService } from '../../../github/common/githubService';18import { ILogService } from '../../../log/common/logService';19import { IChatEndpoint, IEmbeddingsEndpoint } from '../../../networking/common/networking';20import { IRequestLogger } from '../../../requestLogger/common/requestLogger';21import { IExperimentationService } from '../../../telemetry/common/nullExperimentationService';22import { ChatEndpointFamily, EmbeddingsEndpointFamily, IChatModelInformation, ICompletionModelInformation, IEmbeddingModelInformation, IEndpointProvider } from '../../common/endpointProvider';23import { EmbeddingEndpoint } from '../../node/embeddingsEndpoint';24import { ModelMetadataFetcher } from '../../node/modelMetadataFetcher';25import { AzureTestEndpoint } from './azureEndpoint';26import { CAPITestEndpoint } from './capiEndpoint';27import { CustomNesEndpoint } from './customNesEndpoint';28import { IModelConfig, OpenAICompatibleTestEndpoint } from './openaiCompatibleEndpoint';293031async function getModelMetadataMap(modelMetadataFetcher: TestModelMetadataFetcher): Promise<Map<string, IChatModelInformation>> {32let metadataArray: IChatModelInformation[] = [];33try {34metadataArray = await modelMetadataFetcher.getAllChatModels();35} catch (e) {36metadataArray = [];37// We only want to catch errors for the model lab models, otherwise we have no models to test and should just throw the error38if (!modelMetadataFetcher.isModelLab) {39throw e;40}41}42const metadataMap = new Map<string, IChatModelInformation>();43metadataArray.forEach(metadata => {44metadataMap.set(metadata.id, metadata);45});46return metadataMap;47}4849type ModelMetadataType = 'prod' | 'modelLab';5051class ModelMetadataRequest implements CacheableRequest {52constructor(readonly hash: string) { }53}5455export class TestModelMetadataFetcher extends ModelMetadataFetcher {5657private static Queues = new SequencerByKey<ModelMetadataType>();5859get isModelLab(): boolean { return this._isModelLab; }6061private readonly cache: SQLiteCache<ModelMetadataRequest, IChatModelInformation[]>;6263constructor(64_isModelLab: boolean,65info: CurrentTestRunInfo | undefined,66private readonly _skipModelMetadataCache: boolean = false,67@IOctoKitService _octoKitService: IOctoKitService,68@IConfigurationService _configService: IConfigurationService,69@IExperimentationService _expService: IExperimentationService,70@IEnvService _envService: IEnvService,71@IAuthenticationService _authService: IAuthenticationService,72@ILogService _logService: ILogService,73@IRequestLogger _requestLogger: IRequestLogger,74@IInstantiationService _instantiationService: IInstantiationService,75) {76super(77_isModelLab,78_octoKitService,79_requestLogger,80_configService,81_expService,82_envService,83_authService,84_logService,85_instantiationService,86);8788this.cache = new SQLiteCache<ModelMetadataRequest, IChatModelInformation[]>('modelMetadata', TestingCacheSalts.modelMetadata, info);89}9091override async getAllChatModels(): Promise<IChatModelInformation[]> {92const type = this._isModelLab ? 'modelLab' : 'prod';93const req = new ModelMetadataRequest(type);9495return await TestModelMetadataFetcher.Queues.queue(type, async () => {96if (this._skipModelMetadataCache) {97return super.getAllChatModels();98}99const result = await this.cache.get(req);100if (result) {101return result;102}103104// If the cache doesn't have the result, we need to fetch it105const modelInfo = await super.getAllChatModels();106await this.cache.set(req, modelInfo);107return modelInfo;108});109}110}111112export class TestEndpointProvider implements IEndpointProvider {113114declare readonly _serviceBrand: undefined;115116readonly onDidModelsRefresh = Event.None;117118private _testEmbeddingEndpoint: IEmbeddingsEndpoint | undefined;119private _chatEndpoints: Map<string, IChatEndpoint> = new Map();120private _prodChatModelMetadata: Promise<Map<string, IChatModelInformation>>;121private _modelLabChatModelMetadata: Promise<Map<string, IChatModelInformation>>;122123constructor(124private readonly gpt4ModelToRunAgainst: string | undefined,125private readonly gpt4oMiniModelToRunAgainst: string | undefined,126_fastRewriteModelToRunAgainst: string | undefined,127info: CurrentTestRunInfo | undefined,128skipModelMetadataCache: boolean,129private readonly customModelConfigs: Map<string, IModelConfig> = new Map(),130@IInstantiationService private readonly _instantiationService: IInstantiationService131) {132const prodModelMetadata = this._instantiationService.createInstance(TestModelMetadataFetcher, false, info, skipModelMetadataCache);133const modelLabModelMetadata = this._instantiationService.createInstance(TestModelMetadataFetcher, true, info, skipModelMetadataCache);134this._prodChatModelMetadata = getModelMetadataMap(prodModelMetadata);135this._modelLabChatModelMetadata = getModelMetadataMap(modelLabModelMetadata);136}137138private async getChatEndpointInfo(model: string, modelLabMetadata: Map<string, IChatModelInformation>, prodMetadata: Map<string, IChatModelInformation>): Promise<IChatEndpoint> {139let chatEndpoint = this._chatEndpoints.get(model);140if (!chatEndpoint) {141const customModel = this.customModelConfigs.get(model);142if (customModel !== undefined) {143chatEndpoint = this._instantiationService.createInstance(OpenAICompatibleTestEndpoint, customModel);144} else if (model === CHAT_MODEL.CUSTOM_NES) {145chatEndpoint = this._instantiationService.createInstance(CustomNesEndpoint);146} else if (model === CHAT_MODEL.EXPERIMENTAL) {147chatEndpoint = this._instantiationService.createInstance(AzureTestEndpoint, model);148} else {149const isProdModel = prodMetadata.has(model);150const modelMetadata: IChatModelInformation | undefined = isProdModel ? prodMetadata.get(model) : modelLabMetadata.get(model);151if (!modelMetadata) {152throw new Error(`Model ${model} not found`);153}154chatEndpoint = this._instantiationService.createInstance(CAPITestEndpoint, modelMetadata, !isProdModel);155}156this._chatEndpoints.set(model, chatEndpoint);157}158return chatEndpoint;159}160161async getAllCompletionModels(forceRefresh?: boolean): Promise<ICompletionModelInformation[]> {162throw new Error('getAllCompletionModels is not implemented in TestEndpointProvider');163}164165async getAllChatEndpoints(): Promise<IChatEndpoint[]> {166const modelIDs: Set<string> = new Set([167CHAT_MODEL.CUSTOM_NES168]);169170if (this.customModelConfigs.size > 0) {171this.customModelConfigs.forEach(config => {172modelIDs.add(config.name);173});174}175176const modelLabMetadata: Map<string, IChatModelInformation> = await this._modelLabChatModelMetadata;177const prodMetadata: Map<string, IChatModelInformation> = await this._prodChatModelMetadata;178modelLabMetadata.forEach((modelMetadata) => {179modelIDs.add(modelMetadata.id);180});181prodMetadata.forEach((modelMetadata) => {182modelIDs.add(modelMetadata.id);183});184for (const model of modelIDs) {185this._chatEndpoints.set(model, await this.getChatEndpointInfo(model, modelLabMetadata, prodMetadata));186}187return Array.from(this._chatEndpoints.values());188}189async getChatEndpoint(requestOrFamilyOrModel: LanguageModelChat | ChatRequest | ChatEndpointFamily): Promise<IChatEndpoint> {190if (typeof requestOrFamilyOrModel !== 'string') {191requestOrFamilyOrModel = 'copilot-base';192}193if (requestOrFamilyOrModel === 'copilot-base') {194return await this.getChatEndpointInfo(this.gpt4ModelToRunAgainst ?? CHAT_MODEL.GPT41, await this._modelLabChatModelMetadata, await this._prodChatModelMetadata);195} else {196return await this.getChatEndpointInfo(this.gpt4oMiniModelToRunAgainst ?? CHAT_MODEL.GPT4OMINI, await this._modelLabChatModelMetadata, await this._prodChatModelMetadata);197}198}199async getEmbeddingsEndpoint(family?: EmbeddingsEndpointFamily): Promise<IEmbeddingsEndpoint> {200const id = LEGACY_EMBEDDING_MODEL_ID.TEXT3SMALL;201const modelInformation: IEmbeddingModelInformation = {202id: id,203vendor: 'Test Provider',204name: id,205version: '1.0',206model_picker_enabled: false,207is_chat_default: false,208billing: { is_premium: false, multiplier: 0 },209is_chat_fallback: false,210capabilities: {211type: 'embeddings',212tokenizer: TokenizerType.O200K,213family: 'test'214}215};216this._testEmbeddingEndpoint ??= this._instantiationService.createInstance(EmbeddingEndpoint, modelInformation);217return this._testEmbeddingEndpoint;218}219}220221222