Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/vscode
Path: blob/main/extensions/copilot/src/platform/endpoint/test/node/testEndpointProvider.ts
13405 views
1
/*---------------------------------------------------------------------------------------------
2
* Copyright (c) Microsoft Corporation. All rights reserved.
3
* Licensed under the MIT License. See License.txt in the project root for license information.
4
*--------------------------------------------------------------------------------------------*/
5
6
import type { ChatRequest, LanguageModelChat } from 'vscode';
7
import { CacheableRequest, SQLiteCache } from '../../../../../test/base/cache';
8
import { TestingCacheSalts } from '../../../../../test/base/salts';
9
import { CurrentTestRunInfo } from '../../../../../test/base/simulationContext';
10
import { TokenizerType } from '../../../../util/common/tokenizer';
11
import { SequencerByKey } from '../../../../util/vs/base/common/async';
12
import { Event } from '../../../../util/vs/base/common/event';
13
import { IInstantiationService } from '../../../../util/vs/platform/instantiation/common/instantiation';
14
import { IAuthenticationService } from '../../../authentication/common/authentication';
15
import { CHAT_MODEL, IConfigurationService } from '../../../configuration/common/configurationService';
16
import { LEGACY_EMBEDDING_MODEL_ID } from '../../../embeddings/common/embeddingsComputer';
17
import { IEnvService } from '../../../env/common/envService';
18
import { IOctoKitService } from '../../../github/common/githubService';
19
import { ILogService } from '../../../log/common/logService';
20
import { IChatEndpoint, IEmbeddingsEndpoint } from '../../../networking/common/networking';
21
import { IRequestLogger } from '../../../requestLogger/common/requestLogger';
22
import { IExperimentationService } from '../../../telemetry/common/nullExperimentationService';
23
import { ChatEndpointFamily, EmbeddingsEndpointFamily, IChatModelInformation, ICompletionModelInformation, IEmbeddingModelInformation, IEndpointProvider } from '../../common/endpointProvider';
24
import { EmbeddingEndpoint } from '../../node/embeddingsEndpoint';
25
import { ModelMetadataFetcher } from '../../node/modelMetadataFetcher';
26
import { AzureTestEndpoint } from './azureEndpoint';
27
import { CAPITestEndpoint } from './capiEndpoint';
28
import { CustomNesEndpoint } from './customNesEndpoint';
29
import { IModelConfig, OpenAICompatibleTestEndpoint } from './openaiCompatibleEndpoint';
30
31
32
async function getModelMetadataMap(modelMetadataFetcher: TestModelMetadataFetcher): Promise<Map<string, IChatModelInformation>> {
33
let metadataArray: IChatModelInformation[] = [];
34
try {
35
metadataArray = await modelMetadataFetcher.getAllChatModels();
36
} catch (e) {
37
metadataArray = [];
38
// We only want to catch errors for the model lab models, otherwise we have no models to test and should just throw the error
39
if (!modelMetadataFetcher.isModelLab) {
40
throw e;
41
}
42
}
43
const metadataMap = new Map<string, IChatModelInformation>();
44
metadataArray.forEach(metadata => {
45
metadataMap.set(metadata.id, metadata);
46
});
47
return metadataMap;
48
}
49
50
type ModelMetadataType = 'prod' | 'modelLab';
51
52
class ModelMetadataRequest implements CacheableRequest {
53
constructor(readonly hash: string) { }
54
}
55
56
export class TestModelMetadataFetcher extends ModelMetadataFetcher {
57
58
private static Queues = new SequencerByKey<ModelMetadataType>();
59
60
get isModelLab(): boolean { return this._isModelLab; }
61
62
private readonly cache: SQLiteCache<ModelMetadataRequest, IChatModelInformation[]>;
63
64
constructor(
65
_isModelLab: boolean,
66
info: CurrentTestRunInfo | undefined,
67
private readonly _skipModelMetadataCache: boolean = false,
68
@IOctoKitService _octoKitService: IOctoKitService,
69
@IConfigurationService _configService: IConfigurationService,
70
@IExperimentationService _expService: IExperimentationService,
71
@IEnvService _envService: IEnvService,
72
@IAuthenticationService _authService: IAuthenticationService,
73
@ILogService _logService: ILogService,
74
@IRequestLogger _requestLogger: IRequestLogger,
75
@IInstantiationService _instantiationService: IInstantiationService,
76
) {
77
super(
78
_isModelLab,
79
_octoKitService,
80
_requestLogger,
81
_configService,
82
_expService,
83
_envService,
84
_authService,
85
_logService,
86
_instantiationService,
87
);
88
89
this.cache = new SQLiteCache<ModelMetadataRequest, IChatModelInformation[]>('modelMetadata', TestingCacheSalts.modelMetadata, info);
90
}
91
92
override async getAllChatModels(): Promise<IChatModelInformation[]> {
93
const type = this._isModelLab ? 'modelLab' : 'prod';
94
const req = new ModelMetadataRequest(type);
95
96
return await TestModelMetadataFetcher.Queues.queue(type, async () => {
97
if (this._skipModelMetadataCache) {
98
return super.getAllChatModels();
99
}
100
const result = await this.cache.get(req);
101
if (result) {
102
return result;
103
}
104
105
// If the cache doesn't have the result, we need to fetch it
106
const modelInfo = await super.getAllChatModels();
107
await this.cache.set(req, modelInfo);
108
return modelInfo;
109
});
110
}
111
}
112
113
export class TestEndpointProvider implements IEndpointProvider {
114
115
declare readonly _serviceBrand: undefined;
116
117
readonly onDidModelsRefresh = Event.None;
118
119
private _testEmbeddingEndpoint: IEmbeddingsEndpoint | undefined;
120
private _chatEndpoints: Map<string, IChatEndpoint> = new Map();
121
private _prodChatModelMetadata: Promise<Map<string, IChatModelInformation>>;
122
private _modelLabChatModelMetadata: Promise<Map<string, IChatModelInformation>>;
123
124
constructor(
125
private readonly gpt4ModelToRunAgainst: string | undefined,
126
private readonly gpt4oMiniModelToRunAgainst: string | undefined,
127
_fastRewriteModelToRunAgainst: string | undefined,
128
info: CurrentTestRunInfo | undefined,
129
skipModelMetadataCache: boolean,
130
private readonly customModelConfigs: Map<string, IModelConfig> = new Map(),
131
@IInstantiationService private readonly _instantiationService: IInstantiationService
132
) {
133
const prodModelMetadata = this._instantiationService.createInstance(TestModelMetadataFetcher, false, info, skipModelMetadataCache);
134
const modelLabModelMetadata = this._instantiationService.createInstance(TestModelMetadataFetcher, true, info, skipModelMetadataCache);
135
this._prodChatModelMetadata = getModelMetadataMap(prodModelMetadata);
136
this._modelLabChatModelMetadata = getModelMetadataMap(modelLabModelMetadata);
137
}
138
139
private async getChatEndpointInfo(model: string, modelLabMetadata: Map<string, IChatModelInformation>, prodMetadata: Map<string, IChatModelInformation>): Promise<IChatEndpoint> {
140
let chatEndpoint = this._chatEndpoints.get(model);
141
if (!chatEndpoint) {
142
const customModel = this.customModelConfigs.get(model);
143
if (customModel !== undefined) {
144
chatEndpoint = this._instantiationService.createInstance(OpenAICompatibleTestEndpoint, customModel);
145
} else if (model === CHAT_MODEL.CUSTOM_NES) {
146
chatEndpoint = this._instantiationService.createInstance(CustomNesEndpoint);
147
} else if (model === CHAT_MODEL.EXPERIMENTAL) {
148
chatEndpoint = this._instantiationService.createInstance(AzureTestEndpoint, model);
149
} else {
150
const isProdModel = prodMetadata.has(model);
151
const modelMetadata: IChatModelInformation | undefined = isProdModel ? prodMetadata.get(model) : modelLabMetadata.get(model);
152
if (!modelMetadata) {
153
throw new Error(`Model ${model} not found`);
154
}
155
chatEndpoint = this._instantiationService.createInstance(CAPITestEndpoint, modelMetadata, !isProdModel);
156
}
157
this._chatEndpoints.set(model, chatEndpoint);
158
}
159
return chatEndpoint;
160
}
161
162
async getAllCompletionModels(forceRefresh?: boolean): Promise<ICompletionModelInformation[]> {
163
throw new Error('getAllCompletionModels is not implemented in TestEndpointProvider');
164
}
165
166
async getAllChatEndpoints(): Promise<IChatEndpoint[]> {
167
const modelIDs: Set<string> = new Set([
168
CHAT_MODEL.CUSTOM_NES
169
]);
170
171
if (this.customModelConfigs.size > 0) {
172
this.customModelConfigs.forEach(config => {
173
modelIDs.add(config.name);
174
});
175
}
176
177
const modelLabMetadata: Map<string, IChatModelInformation> = await this._modelLabChatModelMetadata;
178
const prodMetadata: Map<string, IChatModelInformation> = await this._prodChatModelMetadata;
179
modelLabMetadata.forEach((modelMetadata) => {
180
modelIDs.add(modelMetadata.id);
181
});
182
prodMetadata.forEach((modelMetadata) => {
183
modelIDs.add(modelMetadata.id);
184
});
185
for (const model of modelIDs) {
186
this._chatEndpoints.set(model, await this.getChatEndpointInfo(model, modelLabMetadata, prodMetadata));
187
}
188
return Array.from(this._chatEndpoints.values());
189
}
190
async getChatEndpoint(requestOrFamilyOrModel: LanguageModelChat | ChatRequest | ChatEndpointFamily): Promise<IChatEndpoint> {
191
if (typeof requestOrFamilyOrModel !== 'string') {
192
requestOrFamilyOrModel = 'copilot-base';
193
}
194
if (requestOrFamilyOrModel === 'copilot-base') {
195
return await this.getChatEndpointInfo(this.gpt4ModelToRunAgainst ?? CHAT_MODEL.GPT41, await this._modelLabChatModelMetadata, await this._prodChatModelMetadata);
196
} else {
197
return await this.getChatEndpointInfo(this.gpt4oMiniModelToRunAgainst ?? CHAT_MODEL.GPT4OMINI, await this._modelLabChatModelMetadata, await this._prodChatModelMetadata);
198
}
199
}
200
async getEmbeddingsEndpoint(family?: EmbeddingsEndpointFamily): Promise<IEmbeddingsEndpoint> {
201
const id = LEGACY_EMBEDDING_MODEL_ID.TEXT3SMALL;
202
const modelInformation: IEmbeddingModelInformation = {
203
id: id,
204
vendor: 'Test Provider',
205
name: id,
206
version: '1.0',
207
model_picker_enabled: false,
208
is_chat_default: false,
209
billing: { is_premium: false, multiplier: 0 },
210
is_chat_fallback: false,
211
capabilities: {
212
type: 'embeddings',
213
tokenizer: TokenizerType.O200K,
214
family: 'test'
215
}
216
};
217
this._testEmbeddingEndpoint ??= this._instantiationService.createInstance(EmbeddingEndpoint, modelInformation);
218
return this._testEmbeddingEndpoint;
219
}
220
}
221
222