Path: blob/main/extensions/copilot/src/extension/prompt/vscode-node/endpointProviderImpl.ts
13399 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/45import { LanguageModelChat, type ChatRequest } from 'vscode';6import { IAuthenticationService } from '../../../platform/authentication/common/authentication';7import { IConfigurationService } from '../../../platform/configuration/common/configurationService';8import { ChatEndpointFamily, EmbeddingsEndpointFamily, IChatModelInformation, ICompletionModelInformation, IEmbeddingModelInformation, IEndpointProvider } from '../../../platform/endpoint/common/endpointProvider';9import { AutoChatEndpoint } from '../../../platform/endpoint/node/autoChatEndpoint';10import { IAutomodeService } from '../../../platform/endpoint/node/automodeService';11import { CopilotChatEndpoint } from '../../../platform/endpoint/node/copilotChatEndpoint';12import { EmbeddingEndpoint } from '../../../platform/endpoint/node/embeddingsEndpoint';13import { IModelMetadataFetcher, ModelMetadataFetcher } from '../../../platform/endpoint/node/modelMetadataFetcher';14import { ExtensionContributedChatEndpoint } from '../../../platform/endpoint/vscode-node/extChatEndpoint';15import { ILogService } from '../../../platform/log/common/logService';16import { IChatEndpoint, IEmbeddingsEndpoint } from '../../../platform/networking/common/networking';17import { Emitter, Event } from '../../../util/vs/base/common/event';18import { Disposable } from '../../../util/vs/base/common/lifecycle';19import { IInstantiationService } from '../../../util/vs/platform/instantiation/common/instantiation';202122export class ProductionEndpointProvider extends Disposable implements IEndpointProvider {2324declare readonly _serviceBrand: undefined;2526private readonly _onDidModelsRefresh = this._register(new Emitter<void>());27readonly onDidModelsRefresh: Event<void> = this._onDidModelsRefresh.event;2829private _chatEndpoints: Map<string, IChatEndpoint> = new Map();30private _embeddingEndpoints: Map<string, IEmbeddingsEndpoint> = new Map();31private readonly _modelFetcher: IModelMetadataFetcher;3233constructor(34@IAutomodeService private readonly _autoModeService: IAutomodeService,35@ILogService protected readonly _logService: ILogService,36@IConfigurationService protected readonly _configService: IConfigurationService,37@IInstantiationService protected readonly _instantiationService: IInstantiationService,38@IAuthenticationService protected readonly _authService: IAuthenticationService,39) {40super();4142this._modelFetcher = this._instantiationService.createInstance(ModelMetadataFetcher,43false,44);4546// When new models come in from CAPI we want to clear our local caches and let the endpoints be recreated since there may be new info47this._register(this._modelFetcher.onDidModelsRefresh(() => {48this._chatEndpoints.clear();49this._embeddingEndpoints.clear();50this._onDidModelsRefresh.fire();51}));52}5354private getOrCreateChatEndpointInstance(modelMetadata: IChatModelInformation): IChatEndpoint {55const modelId = modelMetadata.id;56let chatEndpoint = this._chatEndpoints.get(modelId);57if (!chatEndpoint) {58chatEndpoint = this._instantiationService.createInstance(CopilotChatEndpoint, modelMetadata);59this._chatEndpoints.set(modelId, chatEndpoint);60}61return chatEndpoint;62}6364async getChatEndpoint(requestOrFamilyOrModel: LanguageModelChat | ChatRequest | ChatEndpointFamily): Promise<IChatEndpoint> {65this._logService.trace(`Resolving chat model`);6667if (typeof requestOrFamilyOrModel === 'string') {68const modelMetadata = await this._modelFetcher.getChatModelFromFamily(requestOrFamilyOrModel);69return this.getOrCreateChatEndpointInstance(modelMetadata!);70}7172const model = 'model' in requestOrFamilyOrModel ? requestOrFamilyOrModel.model : requestOrFamilyOrModel;7374if (!model) {75return this.getChatEndpoint('copilot-base');76}7778if (model.vendor !== 'copilot') {79return this._instantiationService.createInstance(ExtensionContributedChatEndpoint, model);80}8182if (model.id === AutoChatEndpoint.pseudoModelId) {83try {84const allEndpoints = await this.getAllChatEndpoints();85return this._autoModeService.resolveAutoModeEndpoint(requestOrFamilyOrModel as ChatRequest, allEndpoints);86} catch {87return this.getChatEndpoint('copilot-base');88}89}9091const modelMetadata = await this._modelFetcher.getChatModelFromApiModel(model);92// If we fail to resolve a model since this is panel we give copilot base. This really should never happen as the picker is powered by the same service.93return modelMetadata ? this.getOrCreateChatEndpointInstance(modelMetadata) : this.getChatEndpoint('copilot-base');94}9596async getEmbeddingsEndpoint(family?: EmbeddingsEndpointFamily): Promise<IEmbeddingsEndpoint> {97this._logService.trace(`Resolving embedding model`);98const modelMetadata = await this._modelFetcher.getEmbeddingsModel('text-embedding-3-small');99const model = await this.getOrCreateEmbeddingEndpointInstance(modelMetadata);100this._logService.trace(`Resolved embedding model`);101return model;102}103104private async getOrCreateEmbeddingEndpointInstance(modelMetadata: IEmbeddingModelInformation): Promise<IEmbeddingsEndpoint> {105const modelId = 'text-embedding-3-small';106let embeddingEndpoint = this._embeddingEndpoints.get(modelId);107if (!embeddingEndpoint) {108embeddingEndpoint = this._instantiationService.createInstance(EmbeddingEndpoint, modelMetadata);109this._embeddingEndpoints.set(modelId, embeddingEndpoint);110}111return embeddingEndpoint;112}113114async getAllCompletionModels(forceRefresh?: boolean): Promise<ICompletionModelInformation[]> {115return this._modelFetcher.getAllCompletionModels(forceRefresh ?? false);116}117118async getAllChatEndpoints(): Promise<IChatEndpoint[]> {119const models: IChatModelInformation[] = await this._modelFetcher.getAllChatModels();120return models.map(model => this.getOrCreateChatEndpointInstance(model));121}122}123124125