Path: blob/main/extensions/copilot/src/extension/byok/node/openAIEndpoint.ts
13399 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/4import type { CancellationToken } from 'vscode';5import { IChatMLFetcher } from '../../../platform/chat/common/chatMLFetcher';6import { ChatFetchResponseType, ChatResponse } from '../../../platform/chat/common/commonTypes';7import { IConfigurationService } from '../../../platform/configuration/common/configurationService';8import { IDomainService } from '../../../platform/endpoint/common/domainService';9import { IChatModelInformation } from '../../../platform/endpoint/common/endpointProvider';10import { ChatEndpoint } from '../../../platform/endpoint/node/chatEndpoint';11import { ILogService } from '../../../platform/log/common/logService';12import { isOpenAiFunctionTool } from '../../../platform/networking/common/fetch';13import { createCapiRequestBody, IChatEndpoint, ICreateEndpointBodyOptions, IEndpointBody, IMakeChatRequestOptions } from '../../../platform/networking/common/networking';14import { RawMessageConversionCallback } from '../../../platform/networking/common/openai';15import { IChatWebSocketManager } from '../../../platform/networking/node/chatWebSocketManager';16import { IExperimentationService } from '../../../platform/telemetry/common/nullExperimentationService';17import { ITokenizerProvider } from '../../../platform/tokenizer/node/tokenizer';18import { IInstantiationService } from '../../../util/vs/platform/instantiation/common/instantiation';1920function hydrateBYOKErrorMessages(response: ChatResponse): ChatResponse {21if (response.type === ChatFetchResponseType.Failed && response.streamError) {22return {23type: response.type,24requestId: response.requestId,25serverRequestId: response.serverRequestId,26reason: JSON.stringify(response.streamError),27};28} else if (response.type === ChatFetchResponseType.RateLimited) {29return {30type: response.type,31requestId: response.requestId,32serverRequestId: response.serverRequestId,33reason: response.capiError ? 'Rate limit exceeded\n\n' + JSON.stringify(response.capiError) : 'Rate limit exceeded',34rateLimitKey: '',35retryAfter: undefined,36isAuto: false,37capiError: response.capiError38};39}40return response;41}4243/**44* Checks to see if a given endpoint is a BYOK model.45* @param endpoint The endpoint to check if it's a BYOK model46* @returns 1 if client side byok, 2 if server side byok, -1 if not a byok model47*/48export function isBYOKModel(endpoint: IChatEndpoint | undefined): number {49if (!endpoint) {50return -1;51}52return (endpoint instanceof OpenAIEndpoint || endpoint.isExtensionContributed) ? 1 : (endpoint.customModel ? 2 : -1);53}5455export class OpenAIEndpoint extends ChatEndpoint {56// Reserved headers that cannot be overridden for security and functionality reasons57// Including forbidden request headers: https://developer.mozilla.org/en-US/docs/Glossary/Forbidden_request_header58private static readonly _reservedHeaders: ReadonlySet<string> = new Set([59// Forbidden Request Headers60'accept-charset',61'accept-encoding',62'access-control-request-headers',63'access-control-request-method',64'connection',65'content-length',66'cookie',67'date',68'dnt',69'expect',70'host',71'keep-alive',72'origin',73'permissions-policy',74'referer',75'te',76'trailer',77'transfer-encoding',78'upgrade',79'user-agent',80'via',81// Forwarding & Routing82'forwarded',83'x-forwarded-for',84'x-forwarded-host',85'x-forwarded-proto',86// Others87'api-key',88'authorization',89'content-type',90'openai-intent',91'x-github-api-version',92'x-initiator',93'x-interaction-id',94'x-interaction-type',95'x-onbehalf-extension-id',96'x-request-id',97'x-vscode-user-agent-library-version',98// Pattern-based forbidden headers are checked separately:99// - 'proxy-*' headers (handled in sanitization logic)100// - 'sec-*' headers (handled in sanitization logic)101// - 'x-http-method*' with forbidden methods CONNECT, TRACE, TRACK (handled in sanitization logic)102]);103104// RFC 7230 compliant header name pattern: token characters only105private static readonly _validHeaderNamePattern = /^[!#$%&'*+\-.0-9A-Z^_`a-z|~]+$/;106107// Maximum limits to prevent abuse108private static readonly _maxHeaderNameLength = 256;109private static readonly _maxHeaderValueLength = 8192;110private static readonly _maxCustomHeaderCount = 20;111112private readonly _customHeaders: Record<string, string>;113constructor(114_modelMetadata: IChatModelInformation,115protected readonly _apiKey: string,116protected readonly _modelUrl: string,117@IDomainService domainService: IDomainService,118@IChatMLFetcher chatMLFetcher: IChatMLFetcher,119@ITokenizerProvider tokenizerProvider: ITokenizerProvider,120@IInstantiationService protected instantiationService: IInstantiationService,121@IConfigurationService configurationService: IConfigurationService,122@IExperimentationService expService: IExperimentationService,123@IChatWebSocketManager chatWebSocketService: IChatWebSocketManager,124@ILogService protected logService: ILogService125) {126super(127_modelMetadata,128domainService,129chatMLFetcher,130tokenizerProvider,131instantiationService,132configurationService,133expService,134chatWebSocketService,135logService136);137this._customHeaders = this._sanitizeCustomHeaders(_modelMetadata.requestHeaders);138}139140private _sanitizeCustomHeaders(headers: Readonly<Record<string, string>> | undefined): Record<string, string> {141if (!headers) {142return {};143}144145const entries = Object.entries(headers);146147if (entries.length > OpenAIEndpoint._maxCustomHeaderCount) {148this.logService.warn(`[OpenAIEndpoint] Model '${this.modelMetadata.id}' has ${entries.length} custom headers, exceeding limit of ${OpenAIEndpoint._maxCustomHeaderCount}. Only first ${OpenAIEndpoint._maxCustomHeaderCount} will be processed.`);149}150151const sanitized: Record<string, string> = {};152let processedCount = 0;153154for (const [rawKey, rawValue] of entries) {155if (processedCount >= OpenAIEndpoint._maxCustomHeaderCount) {156break;157}158159const key = rawKey.trim();160if (!key) {161this.logService.warn(`[OpenAIEndpoint] Model '${this.modelMetadata.id}' has empty header name, skipping.`);162continue;163}164165if (key.length > OpenAIEndpoint._maxHeaderNameLength) {166this.logService.warn(`[OpenAIEndpoint] Model '${this.modelMetadata.id}' has header name exceeding ${OpenAIEndpoint._maxHeaderNameLength} characters, skipping.`);167continue;168}169170if (!OpenAIEndpoint._validHeaderNamePattern.test(key)) {171this.logService.warn(`[OpenAIEndpoint] Model '${this.modelMetadata.id}' has invalid header name format: '${key}', Skipping.`);172continue;173}174175const lowerKey = key.toLowerCase();176if (OpenAIEndpoint._reservedHeaders.has(lowerKey)) {177this.logService.warn(`[OpenAIEndpoint] Model '${this.modelMetadata.id}' attempted to override reserved header '${key}', skipping.`);178continue;179}180181// Check for pattern-based forbidden headers182if (lowerKey.startsWith('proxy-') || lowerKey.startsWith('sec-')) {183this.logService.warn(`[OpenAIEndpoint] Model '${this.modelMetadata.id}' attempted to set forbidden header pattern '${key}', skipping.`);184continue;185}186187// Check for X-HTTP-Method* headers with forbidden methods188if ((lowerKey === 'x-http-method' || lowerKey === 'x-http-method-override' || lowerKey === 'x-method-override')) {189const forbiddenMethods = ['connect', 'trace', 'track'];190const methodValue = String(rawValue).toLowerCase().trim();191if (forbiddenMethods.includes(methodValue)) {192this.logService.warn(`[OpenAIEndpoint] Model '${this.modelMetadata.id}' attempted to set forbidden method '${methodValue}' in header '${key}', skipping.`);193continue;194}195}196197const sanitizedValue = this._sanitizeHeaderValue(rawValue);198if (sanitizedValue === undefined) {199this.logService.warn(`[OpenAIEndpoint] Model '${this.modelMetadata.id}' has invalid value for header '${key}': '${rawValue}', skipping.`);200continue;201}202203sanitized[key] = sanitizedValue;204processedCount++;205}206207return sanitized;208}209210private _sanitizeHeaderValue(value: unknown): string | undefined {211if (typeof value !== 'string') {212return undefined;213}214215const trimmed = value.trim();216217if (trimmed.length > OpenAIEndpoint._maxHeaderValueLength) {218return undefined;219}220221// Disallow control characters including CR, LF, and others (0x00-0x1F, 0x7F)222// This prevents HTTP header injection and response splitting attacks223if (/[\x00-\x1F\x7F]/.test(trimmed)) {224return undefined;225}226227// Additional check for potential Unicode issues228// Reject headers with bidirectional override characters or zero-width characters229if (/[\u200B-\u200D\u202A-\u202E\uFEFF]/.test(trimmed)) {230return undefined;231}232233return trimmed;234}235236override createRequestBody(options: ICreateEndpointBodyOptions): IEndpointBody {237if (this.useResponsesApi) {238// Handle Responses API: customize the body directly239options.ignoreStatefulMarker = false;240const body = super.createRequestBody(options);241body.store = true;242body.n = undefined;243body.stream_options = undefined;244if (!this.modelMetadata.capabilities.supports.thinking) {245body.reasoning = undefined;246body.include = undefined;247}248if (body.previous_response_id && (!body.previous_response_id.startsWith('resp_') || this.modelMetadata.zeroDataRetentionEnabled)) {249// Don't use a response ID from CAPI or when zero data retention is enabled250body.previous_response_id = undefined;251}252return body;253} else {254// Handle CAPI: provide callback for thinking data processing255const callback: RawMessageConversionCallback = (out, data) => {256if (data && data.id) {257out.cot_id = data.id;258out.cot_summary = Array.isArray(data.text) ? data.text.join('') : data.text;259}260};261const body = createCapiRequestBody(options, this.model, callback);262return body;263}264}265266override interceptBody(body: IEndpointBody | undefined): void {267super.interceptBody(body);268// TODO @lramos15 - We should do this for all models and not just here269if (body?.tools?.length === 0) {270delete body.tools;271}272273if (body?.tools) {274body.tools = body.tools.map(tool => {275if (isOpenAiFunctionTool(tool) && tool.function.parameters === undefined) {276tool.function.parameters = { type: 'object', properties: {} };277}278return tool;279});280}281282if (body) {283if (this.modelMetadata.capabilities.supports.thinking) {284delete body.temperature;285body['max_completion_tokens'] = body.max_tokens;286delete body.max_tokens;287}288// Removing max tokens defaults to the maximum which is what we want for BYOK289delete body.max_tokens;290if (!this.useResponsesApi && body.stream) {291body['stream_options'] = { 'include_usage': true };292}293}294}295296override get urlOrRequestMetadata(): string {297return this._modelUrl;298}299300public override getExtraHeaders(): Record<string, string> {301const headers: Record<string, string> = {302'Content-Type': 'application/json'303};304if (this._modelUrl.includes('openai.azure')) {305headers['api-key'] = this._apiKey;306} else {307headers['Authorization'] = `Bearer ${this._apiKey}`;308}309for (const [key, value] of Object.entries(this._customHeaders)) {310headers[key] = value;311}312return headers;313}314315override cloneWithTokenOverride(modelMaxPromptTokens: number): IChatEndpoint {316const newModelInfo = { ...this.modelMetadata, maxInputTokens: modelMaxPromptTokens };317return this.instantiationService.createInstance(OpenAIEndpoint, newModelInfo, this._apiKey, this._modelUrl);318}319320public override async makeChatRequest2(options: IMakeChatRequestOptions, token: CancellationToken): Promise<ChatResponse> {321// Apply ignoreStatefulMarker: false for initial request322const modifiedOptions: IMakeChatRequestOptions = { ...options, ignoreStatefulMarker: false };323const response = await super.makeChatRequest2(modifiedOptions, token);324return hydrateBYOKErrorMessages(response);325}326}327328329