Path: blob/main/extensions/copilot/src/platform/endpoint/test/node/openaiCompatibleEndpoint.ts
13405 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/45import { OpenAI } from '@vscode/prompt-tsx';6import { TokenizerType } from '../../../../util/common/tokenizer';7import { IInstantiationService } from '../../../../util/vs/platform/instantiation/common/instantiation';8import { IAuthenticationService } from '../../../authentication/common/authentication';9import { IChatMLFetcher } from '../../../chat/common/chatMLFetcher';10import { IConfigurationService } from '../../../configuration/common/configurationService';11import { IEnvService } from '../../../env/common/envService';12import { ILogService } from '../../../log/common/logService';13import { isOpenAiFunctionTool } from '../../../networking/common/fetch';14import { IFetcherService } from '../../../networking/common/fetcherService';15import { IChatEndpoint, ICreateEndpointBodyOptions, IEndpointBody } from '../../../networking/common/networking';16import { CAPIChatMessage, RawMessageConversionCallback } from '../../../networking/common/openai';17import { IChatWebSocketManager } from '../../../networking/node/chatWebSocketManager';18import { IExperimentationService } from '../../../telemetry/common/nullExperimentationService';19import { ITelemetryService } from '../../../telemetry/common/telemetry';20import { ITokenizerProvider } from '../../../tokenizer/node/tokenizer';21import { ICAPIClientService } from '../../common/capiClient';22import { IDomainService } from '../../common/domainService';23import { IChatModelInformation, ModelSupportedEndpoint } from '../../common/endpointProvider';24import { ChatEndpoint } from '../../node/chatEndpoint';2526export type IModelConfig = {27id: string;28name: string;29version: string;30useDeveloperRole: boolean;31type: 'openai' | 'azureOpenai';32capabilities: {33supports: {34parallel_tool_calls: boolean;35streaming: boolean;36tool_calls: boolean;37vision: boolean;38prediction: boolean;39thinking: boolean;40};41limits: {42max_prompt_tokens: number;43max_output_tokens: number;44max_context_window_tokens?: number;45};46};47supported_endpoints: readonly ModelSupportedEndpoint[];48url: string;49auth: {50/**51* Use Bearer token for authentication52*/53useBearerHeader: boolean;54/**55* Use API key for authentication56*/57useApiKeyHeader: boolean;58/**59* The environment variable name for the API key60*/61apiKeyEnvName?: string;62};63overrides: {64requestHeaders: Record<string, string>;65// If any value is set to null, it will be deleted from the request body66// if the value is undefined, it will not override any existing value in the request body67// if the value is set, it will override the existing value in the request body68temperature?: number | null;69top_p?: number | null;70snippy?: boolean | null;71max_tokens?: number | null;72max_completion_tokens?: number | null;73intent?: boolean | null;74};75};7677export class OpenAICompatibleTestEndpoint extends ChatEndpoint {78constructor(79private readonly modelConfig: IModelConfig,80@IDomainService domainService: IDomainService,81@ICAPIClientService capiClientService: ICAPIClientService,82@IFetcherService fetcherService: IFetcherService,83@IEnvService envService: IEnvService,84@ITelemetryService telemetryService: ITelemetryService,85@IAuthenticationService authService: IAuthenticationService,86@IChatMLFetcher chatMLFetcher: IChatMLFetcher,87@ITokenizerProvider tokenizerProvider: ITokenizerProvider,88@IInstantiationService private instantiationService: IInstantiationService,89@IConfigurationService configurationService: IConfigurationService,90@IExperimentationService experimentationService: IExperimentationService,91@IChatWebSocketManager chatWebSocketService: IChatWebSocketManager,92@ILogService logService: ILogService93) {94const modelInfo: IChatModelInformation = {95id: modelConfig.id,96vendor: 'OpenAI Compatible',97name: modelConfig.name,98version: modelConfig.version,99model_picker_enabled: false,100is_chat_default: false,101is_chat_fallback: false,102capabilities: {103type: 'chat',104family: modelConfig.type === 'azureOpenai' ? 'azure' : 'openai',105tokenizer: TokenizerType.O200K,106supports: {107parallel_tool_calls: modelConfig.capabilities.supports.parallel_tool_calls,108streaming: modelConfig.capabilities.supports.streaming,109tool_calls: modelConfig.capabilities.supports.tool_calls,110vision: modelConfig.capabilities.supports.vision,111prediction: modelConfig.capabilities.supports.prediction,112thinking: modelConfig.capabilities.supports.thinking ?? false113},114limits: {115max_prompt_tokens: modelConfig.capabilities.limits.max_prompt_tokens,116max_output_tokens: modelConfig.capabilities.limits.max_output_tokens,117max_context_window_tokens: modelConfig.capabilities.limits.max_context_window_tokens118}119},120supported_endpoints: Array.isArray(modelConfig.supported_endpoints) && modelConfig.supported_endpoints.length > 0121? modelConfig.supported_endpoints122: [ModelSupportedEndpoint.ChatCompletions]123};124125super(126modelInfo,127domainService,128chatMLFetcher,129tokenizerProvider,130instantiationService,131configurationService,132experimentationService,133chatWebSocketService,134logService135);136}137138override get urlOrRequestMetadata(): string {139return this.modelConfig.version ? this.modelConfig.url + '?api-version=' + this.modelConfig.version : this.modelConfig.url;140}141142public override getExtraHeaders(): Record<string, string> {143const headers: Record<string, string> = {144'Content-Type': 'application/json'145};146147if (this.modelConfig.auth.useBearerHeader || this.modelConfig.auth.useApiKeyHeader) {148if (!this.modelConfig.auth.apiKeyEnvName) {149throw new Error('API key environment variable name is not set in the model configuration');150}151const apiKey = process.env[this.modelConfig.auth.apiKeyEnvName];152if (!apiKey) {153throw new Error(`API key environment variable ${this.modelConfig.auth.apiKeyEnvName} is not set`);154}155156if (this.modelConfig.auth.useBearerHeader) {157headers['Authorization'] = `Bearer ${apiKey}`;158}159160if (this.modelConfig.auth.useApiKeyHeader) {161headers['api-key'] = apiKey;162}163}164165if (this.modelConfig.overrides.requestHeaders) {166Object.entries(this.modelConfig.overrides.requestHeaders).forEach(([key, value]) => {167headers[key] = value;168});169}170171return headers;172}173174override createRequestBody(options: ICreateEndpointBodyOptions): IEndpointBody {175if (this.useResponsesApi) {176// Handle Responses API: customize the body directly177options.ignoreStatefulMarker = false;178const body = super.createRequestBody(options);179body.store = true;180body.n = undefined;181body.stream_options = undefined;182if (!this.modelConfig.capabilities.supports.thinking) {183body.reasoning = undefined;184}185return body;186}187const body = super.createRequestBody(options);188return body;189}190191override interceptBody(body: IEndpointBody | undefined): void {192super.interceptBody(body);193194if (body?.tools?.length === 0) {195delete body.tools;196}197198if (body?.messages) {199body.messages.forEach((message: any) => {200if (message.copilot_cache_control) {201delete message.copilot_cache_control;202}203});204}205206if (body) {207if (this.modelConfig.overrides.snippy === null) {208delete body.snippy;209} else if (this.modelConfig.overrides.snippy) {210body.snippy = { enabled: this.modelConfig.overrides.snippy };211}212213if (this.modelConfig.overrides.intent === null) {214delete body.intent;215} else if (this.modelConfig.overrides.intent) {216body.intent = this.modelConfig.overrides.intent;217}218219if (this.modelConfig.overrides.temperature === null) {220delete body.temperature;221} else if (this.modelConfig.overrides.temperature) {222body.temperature = this.modelConfig.overrides.temperature;223}224225if (this.modelConfig.overrides.top_p === null) {226delete body.top_p;227} else if (this.modelConfig.overrides.top_p) {228body.top_p = this.modelConfig.overrides.top_p;229}230231if (this.modelConfig.overrides.max_tokens === null) {232delete body.max_tokens;233} else if (this.modelConfig.overrides.max_tokens) {234body.max_tokens = this.modelConfig.overrides.max_tokens;235}236}237238if (body?.tools) {239body.tools = body.tools.map(tool => {240if (isOpenAiFunctionTool(tool) && tool.function.parameters === undefined) {241tool.function.parameters = { type: 'object', properties: {} };242}243return tool;244});245}246247if (this.modelConfig.type === 'openai') {248if (body) {249if (!this.useResponsesApi) {250// we need to set this to unsure usage stats are logged251body['stream_options'] = { 'include_usage': true };252}253// OpenAI requires the model name to be set in the body254body.model = this.modelConfig.name;255256// Handle messages reformatting if messages exist257if (body.messages) {258const newMessages: CAPIChatMessage[] = body.messages.map((message: CAPIChatMessage): CAPIChatMessage => {259if (message.role === OpenAI.ChatRole.System) {260return {261role: OpenAI.ChatRole.User,262content: message.content,263};264} else {265return message;266}267});268body['messages'] = newMessages;269}270}271}272273if (this.modelConfig.useDeveloperRole && body) {274const newMessages = body.messages!.map((message: CAPIChatMessage) => {275if (message.role === OpenAI.ChatRole.System) {276return { role: 'developer' as OpenAI.ChatRole.System, content: message.content };277}278return message;279});280Object.keys(body).forEach(key => delete (body as any)[key]);281body.messages = newMessages;282}283}284285override cloneWithTokenOverride(_modelMaxPromptTokens: number): IChatEndpoint {286return this.instantiationService.createInstance(OpenAICompatibleTestEndpoint, this.modelConfig);287}288289protected override getCompletionsCallback(): RawMessageConversionCallback | undefined {290return (out, data) => {291if (data && data.id) {292out.cot_id = data.id;293out.cot_summary = Array.isArray(data.text) ? data.text.join('') : data.text;294}295};296}297}298299300