Path: blob/main/extensions/copilot/test/base/spyingChatMLFetcher.ts
13389 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/4import { Raw } from '@vscode/prompt-tsx';5import type { CancellationToken } from 'vscode';6import { AbstractChatMLFetcher } from '../../src/extension/prompt/node/chatMLFetcher';7import { IChatMLFetcher, IFetchMLOptions } from '../../src/platform/chat/common/chatMLFetcher';8import { ChatResponses } from '../../src/platform/chat/common/commonTypes';9import { IConversationOptions } from '../../src/platform/chat/common/conversationOptions';10import { roleToString } from '../../src/platform/chat/common/globalStringUtils';11import { FinishedCallback, ICopilotToolCall } from '../../src/platform/networking/common/fetch';12import { APIUsage } from '../../src/platform/networking/common/openai';13import { TaskQueue } from '../../src/util/common/async';14import { coalesce } from '../../src/util/vs/base/common/arrays';15import { isDisposable } from '../../src/util/vs/base/common/lifecycle';16import { StopWatch } from '../../src/util/vs/base/common/stopwatch';17import { SyncDescriptor } from '../../src/util/vs/platform/instantiation/common/descriptors';18import { IInstantiationService } from '../../src/util/vs/platform/instantiation/common/instantiation';19import { InterceptedRequest, ISerialisedChatResponse } from '../simulation/shared/sharedTypes';20import { CacheInfo, TestRunCacheInfo } from '../testExecutor';21import { ResponseWithMeta } from './cachingChatMLFetcher';2223export class FetchRequestCollector {24public readonly _interceptedRequests: InterceptedRequest[] = [];2526public get interceptedRequests(): readonly InterceptedRequest[] {27return this._interceptedRequests;28}2930private readonly _pendingRequests = new TaskQueue();31private readonly _scheduledRequests: Promise<void>[] = [];3233public addInterceptedRequest(requestPromise: Promise<InterceptedRequest>): void {34this._scheduledRequests.push(this._pendingRequests.schedule(async () => {35try {36const request = await requestPromise;37this._interceptedRequests.push(request);38} catch (err) {39// ignore errors here- the error will be thrown out of the ChatMLFetcher and handled40}41}));42}4344/**45* Intercepted requests are async. This method waits for all pending requests to complete.46*/47public async complete(): Promise<void> {48await Promise.all(this._scheduledRequests);49}5051public get contentFilterCount(): number {52return this.interceptedRequests.filter(x => x.response.type === 'filtered').length;53}5455public get usage(): APIUsage {56// Have to extract this to give it an explicit type or TS is confused57const initial: APIUsage = { completion_tokens: 0, prompt_tokens: 0, total_tokens: 0, prompt_tokens_details: { cached_tokens: 0 } };58return this.interceptedRequests.reduce((p, c): APIUsage => {59const initialUsage: APIUsage = { completion_tokens: 0, prompt_tokens: 0, total_tokens: 0, prompt_tokens_details: { cached_tokens: 0 } };60const cUsage = c.response.usage || initialUsage;61return {62completion_tokens: p.completion_tokens + cUsage.completion_tokens,63prompt_tokens: p.prompt_tokens + cUsage.prompt_tokens,64total_tokens: p.total_tokens + cUsage.total_tokens,65prompt_tokens_details: {66cached_tokens: (p.prompt_tokens_details?.cached_tokens ?? 0) + (cUsage.prompt_tokens_details?.cached_tokens ?? 0),67}68};69}, initial);70}7172public get averageRequestDuration(): number {73const requestDurations = coalesce(this.interceptedRequests.map(r => r.response.cacheMetadata?.requestDuration));74return requestDurations.reduce((sum, duration) => sum + duration, 0) / requestDurations.length;75}7677public get hasCacheMiss(): boolean {78return this.interceptedRequests.some(x => x.response.isCacheHit === false);79}8081public get cacheInfo(): TestRunCacheInfo {82return coalesce(this.interceptedRequests.map(r => r.cacheKey)).map(key => ({ type: 'request', key } satisfies CacheInfo));83}84}8586export class SpyingChatMLFetcher extends AbstractChatMLFetcher {8788private readonly fetcher: IChatMLFetcher;8990public get interceptedRequests(): readonly InterceptedRequest[] {91return this.requestCollector.interceptedRequests;92}9394public get contentFilterCount(): number {95return this.requestCollector.contentFilterCount;96}9798constructor(99public readonly requestCollector: FetchRequestCollector,100fetcherDesc: SyncDescriptor<IChatMLFetcher>,101@IInstantiationService instantiationService: IInstantiationService,102@IConversationOptions options: IConversationOptions,103) {104super(options);105this.fetcher = instantiationService.createInstance(fetcherDesc);106}107108public override dispose(): void {109super.dispose();110if (isDisposable(this.fetcher)) {111this.fetcher.dispose();112}113}114115override async fetchMany(opts: IFetchMLOptions, token: CancellationToken): Promise<ChatResponses> {116117const toolCalls: ICopilotToolCall[] = [];118const captureToolCallsCb: FinishedCallback = async (text, idx, delta) => {119if (delta.copilotToolCalls) {120toolCalls.push(...delta.copilotToolCalls);121}122if (opts.finishedCb) {123return opts.finishedCb(text, idx, delta);124}125};126127const respPromise = this.fetcher.fetchMany({ ...opts, finishedCb: captureToolCallsCb }, token);128129const sw = new StopWatch(false);130this.requestCollector.addInterceptedRequest(respPromise.then(resp => {131let cacheKey: string | undefined;132if (typeof (resp as ResponseWithMeta).cacheKey === 'string') {133cacheKey = (resp as ResponseWithMeta).cacheKey;134}135(resp as ISerialisedChatResponse).copilotFunctionCalls = toolCalls;136return new InterceptedRequest(opts.messages.map(message => {137return {138role: roleToString(message.role),139content: message.content,140tool_call_id: message.role === Raw.ChatRole.Tool ? message.toolCallId : undefined,141tool_calls: message.role === Raw.ChatRole.Assistant ? message.toolCalls : undefined,142name: message.name,143};144}), opts.requestOptions, resp, cacheKey, opts.endpoint.model, sw.elapsed());145}));146147return await respPromise;148}149}150151152