Path: blob/main/extensions/copilot/src/extension/prompt/node/test/defaultIntentRequestHandler.spec.ts
13405 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/456import { Raw, RenderPromptResult } from '@vscode/prompt-tsx';7import { afterEach, beforeEach, expect, suite, test, vi } from 'vitest';8import type { ChatLanguageModelToolReference, ChatPromptReference, ChatRequest, ExtendedChatResponsePart, LanguageModelChat } from 'vscode';9import { IChatMLFetcher } from '../../../../platform/chat/common/chatMLFetcher';10import { toTextPart } from '../../../../platform/chat/common/globalStringUtils';11import { StaticChatMLFetcher } from '../../../../platform/chat/test/common/staticChatMLFetcher';12import { MockEndpoint } from '../../../../platform/endpoint/test/node/mockEndpoint';13import { IResponseDelta } from '../../../../platform/networking/common/fetch';14import { IChatEndpoint } from '../../../../platform/networking/common/networking';15import { ITelemetryService } from '../../../../platform/telemetry/common/telemetry';16import { SpyingTelemetryService } from '../../../../platform/telemetry/node/spyingTelemetryService';17import { ITestingServicesAccessor } from '../../../../platform/test/node/services';18import { NullWorkspaceFileIndex } from '../../../../platform/workspaceChunkSearch/node/nullWorkspaceFileIndex';19import { IWorkspaceFileIndex } from '../../../../platform/workspaceChunkSearch/node/workspaceFileIndex';20import { ChatResponseStreamImpl } from '../../../../util/common/chatResponseStreamImpl';21import { CancellationToken } from '../../../../util/vs/base/common/cancellation';22import { isObject, isUndefinedOrNull } from '../../../../util/vs/base/common/types';23import { generateUuid } from '../../../../util/vs/base/common/uuid';24import { SyncDescriptor } from '../../../../util/vs/platform/instantiation/common/descriptors';25import { IInstantiationService } from '../../../../util/vs/platform/instantiation/common/instantiation';26import { ChatLocation, ChatResponseConfirmationPart, ChatResponseMarkdownPart, LanguageModelTextPart, LanguageModelToolResult, Uri } from '../../../../vscodeTypes';27import { ToolCallingLoop } from '../../../intents/node/toolCallingLoop';28import { ToolResultMetadata } from '../../../prompts/node/panel/toolCalling';29import { createExtensionUnitTestingServices } from '../../../test/node/services';30import { Conversation, Turn } from '../../common/conversation';31import { IBuildPromptContext } from '../../common/intents';32import { ToolCallRound } from '../../common/toolCallRound';33import { ChatTelemetryBuilder } from '../chatParticipantTelemetry';34import { DefaultIntentRequestHandler } from '../defaultIntentRequestHandler';35import { IIntent, IIntentInvocation, nullRenderPromptResult, promptResultMetadata } from '../intents';3637suite('defaultIntentRequestHandler', () => {38let accessor: ITestingServicesAccessor;39let response: ExtendedChatResponsePart[];40let chatResponse: (string | IResponseDelta[])[] = [];41let promptResult: RenderPromptResult | RenderPromptResult[];42let telemetry: SpyingTelemetryService;43let fetcher: StaticChatMLFetcher;44let endpoint: IChatEndpoint;45let turnIdCounter = 0;46let builtPrompts: IBuildPromptContext[] = [];47const sessionId = 'some-session-id';4849const getTurnId = () => `turn-id-${turnIdCounter}`;5051beforeEach(async () => {52const services = createExtensionUnitTestingServices();53telemetry = new SpyingTelemetryService();54chatResponse = [];55fetcher = new StaticChatMLFetcher(chatResponse);56services.define(ITelemetryService, telemetry);57services.define(IChatMLFetcher, fetcher);58services.define(IWorkspaceFileIndex, new SyncDescriptor(NullWorkspaceFileIndex));5960accessor = services.createTestingAccessor();61endpoint = accessor.get(IInstantiationService).createInstance(MockEndpoint, undefined);62builtPrompts = [];63response = [];64promptResult = nullRenderPromptResult();65turnIdCounter = 0;66(ToolCallingLoop as any).NextToolCallId = 0;67(ToolCallRound as any).generateID = () => 'static-id';68vi.spyOn(Date, 'now').mockReturnValue(0);69});7071afterEach(() => {72vi.restoreAllMocks();73accessor.dispose();74});7576const uuidRegex = /[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}/g;7778function getDerandomizedTelemetry() {79const evts = telemetry.getEvents();80return cloneAndChangeWithKey(evts, (e, key) => {81if (typeof e === 'string' && uuidRegex.test(e)) {82return 'some-uuid';83} else if (typeof e === 'number' && typeof key === 'string' && key.startsWith('timeTo')) {84return '<duration>';85}86});87}8889class TestIntent implements IIntent {90id = 'test';91description = 'test intent';92locations = [ChatLocation.Panel];93invoke(): Promise<IIntentInvocation> {94return Promise.resolve(new TestIntentInvocation(this, this.locations[0], endpoint));95}96}9798class TestIntentInvocation implements IIntentInvocation {99public readonly context: IBuildPromptContext[] = [];100101constructor(102readonly intent: IIntent,103readonly location: ChatLocation,104readonly endpoint: IChatEndpoint,105) { }106107async buildPrompt(context: IBuildPromptContext): Promise<RenderPromptResult> {108builtPrompts.push(context);109if (Array.isArray(promptResult)) {110const next = promptResult.shift();111if (!next) {112throw new Error('ran out of prompts');113}114return next;115}116117return promptResult;118}119}120121class TestChatRequest implements ChatRequest {122toolInvocationToken!: never;123acceptedConfirmationData?: any[] | undefined;124rejectedConfirmationData?: any[] | undefined;125attempt = 1;126enableCommandDetection = false;127isParticipantDetected = false;128location = ChatLocation.Panel;129location2 = undefined;130prompt = 'hello world!';131command: string | undefined;132references: readonly ChatPromptReference[] = [];133toolReferences: readonly ChatLanguageModelToolReference[] = [];134model: LanguageModelChat = { family: '' } as any;135tools = new Map();136id = generateUuid();137sessionId = generateUuid();138sessionResource = Uri.parse(`test://session/${this.sessionId}`);139hasHooksEnabled = false;140}141142const responseStream = new ChatResponseStreamImpl(p => response.push(p), () => { }, undefined, undefined, undefined, () => Promise.resolve(undefined));143const maxToolCallIterations = 3;144145const makeHandler = ({146request = new TestChatRequest(),147turns = []148}: { request?: ChatRequest; turns?: Turn[] } = {}) => {149turns.push(new Turn(150getTurnId(),151{ type: 'user', message: request.prompt },152undefined,153));154155const instaService = accessor.get(IInstantiationService);156return instaService.createInstance(157DefaultIntentRequestHandler,158new TestIntent(),159new Conversation(sessionId, turns),160request,161responseStream,162CancellationToken.None,163undefined,164ChatLocation.Panel,165instaService.createInstance(ChatTelemetryBuilder, Date.now(), sessionId, undefined, turns.length > 1, request, undefined),166{ maxToolCallIterations },167undefined,168);169};170171test('avoids requests when handler return is null', async () => {172const handler = makeHandler();173const result = await handler.getResult();174expect(result).to.deep.equal({});175expect(getDerandomizedTelemetry()).toMatchSnapshot();176});177178test('makes a successful request with a single turn', async () => {179const handler = makeHandler();180chatResponse[0] = 'some response here :)';181promptResult = {182...nullRenderPromptResult(),183messages: [{ role: Raw.ChatRole.User, content: [toTextPart('hello world!')] }],184};185186const result = await handler.getResult();187expect(result).toMatchSnapshot();188// Wait for event loop to finish as we often fire off telemetry without properly awaiting it as it doesn't matter when it is sent189await new Promise(setImmediate);190expect(getDerandomizedTelemetry()).toMatchSnapshot();191});192193test('propagates resolvedModel into result metadata from a successful response', async () => {194fetcher.resolvedModel = 'gpt-4o-resolved';195const handler = makeHandler();196chatResponse[0] = 'some response here :)';197promptResult = {198...nullRenderPromptResult(),199messages: [{ role: Raw.ChatRole.User, content: [toTextPart('hello world!')] }],200};201202const result = await handler.getResult();203expect(result.metadata?.resolvedModel).toBe('gpt-4o-resolved');204});205206test('ignores stateful marker when mode instructions changed on responses api requests', async () => {207const request = new TestChatRequest();208(request as any).modeInstructions2 = { name: 'Agent', content: 'agent instructions', isBuiltin: true };209(endpoint as any).apiType = 'responses';210const requestSpy = vi.spyOn(endpoint, 'makeChatRequest2');211const previousTurn = new Turn(generateUuid(), { message: 'previous', type: 'user' }, undefined, [], undefined, undefined, false, { name: 'Plan', content: 'plan instructions', isBuiltin: true } as any);212const handler = makeHandler({ request, turns: [previousTurn] });213chatResponse[0] = 'some response here :)';214promptResult = {215...nullRenderPromptResult(),216messages: [{ role: Raw.ChatRole.User, content: [toTextPart('hello world!')] }],217};218219await handler.getResult();220221expect(requestSpy).toHaveBeenCalledOnce();222expect(requestSpy.mock.calls[0][0].modeChanged).toBe(true);223expect(requestSpy.mock.calls[0][0].ignoreStatefulMarker).toBeUndefined();224});225226test('preserves default stateful marker behavior when mode instructions are unchanged on responses api requests', async () => {227const request = new TestChatRequest();228(request as any).modeInstructions2 = { name: 'Agent', content: 'agent instructions', isBuiltin: true };229(endpoint as any).apiType = 'responses';230const requestSpy = vi.spyOn(endpoint, 'makeChatRequest2');231const previousTurn = new Turn(generateUuid(), { message: 'previous', type: 'user' }, undefined, [], undefined, undefined, false, { name: 'Agent', content: 'agent instructions', isBuiltin: true } as any);232const handler = makeHandler({ request, turns: [previousTurn] });233chatResponse[0] = 'some response here :)';234promptResult = {235...nullRenderPromptResult(),236messages: [{ role: Raw.ChatRole.User, content: [toTextPart('hello world!')] }],237};238239await handler.getResult();240241expect(requestSpy).toHaveBeenCalledOnce();242expect(requestSpy.mock.calls[0][0].modeChanged).toBe(false);243expect(requestSpy.mock.calls[0][0].ignoreStatefulMarker).toBeUndefined();244});245246test('makes a tool call turn', async () => {247const handler = makeHandler();248chatResponse[0] = [{249text: 'some response here :)',250copilotToolCalls: [{251arguments: 'some args here',252name: 'my_func',253id: 'tool_call_id',254}],255}];256chatResponse[1] = 'response to tool call';257258const toolResult = new LanguageModelToolResult([new LanguageModelTextPart('tool-result')]);259260promptResult = {261...nullRenderPromptResult(),262messages: [{ role: Raw.ChatRole.User, content: [toTextPart('hello world!')] }],263metadata: promptResultMetadata([new ToolResultMetadata('tool_call_id__vscode-0', toolResult)])264};265266const result = await handler.getResult();267expect(result).toMatchSnapshot();268// Wait for event loop to finish as we often fire off telemetry without properly awaiting it as it doesn't matter when it is sent269await new Promise(setImmediate);270expect(getDerandomizedTelemetry()).toMatchSnapshot();271272expect(builtPrompts).toHaveLength(2);273expect(builtPrompts[1].toolCallResults).toEqual({ 'tool_call_id__vscode-0': toolResult });274expect(builtPrompts[1].toolCallRounds).toMatchObject([275{276toolCalls: [{ arguments: 'some args here', name: 'my_func', id: 'tool_call_id__vscode-0' }],277toolInputRetry: 0,278response: 'some response here :)',279},280{281toolCalls: [],282toolInputRetry: 0,283response: 'response to tool call',284},285]);286});287288function fillWithToolCalls(insertN = 20) {289promptResult = [];290for (let i = 0; i < insertN; i++) {291chatResponse[i] = [{292text: `response number ${i}`,293copilotToolCalls: [{294arguments: 'some args here',295name: 'my_func',296id: `tool_call_id_${i}`,297}],298}];299const toolResult = new LanguageModelToolResult([new LanguageModelTextPart(`tool-result-${i}`)]);300promptResult[i] = {301...nullRenderPromptResult(),302messages: [{ role: Raw.ChatRole.User, content: [toTextPart('hello world!')] }],303metadata: promptResultMetadata([new ToolResultMetadata(`tool_call_id_${i}__vscode-${i}`, toolResult)])304};305}306}307308function setupMultiturnToolCalls(turns: number, roundsPerTurn: number) {309// Matches the counter in ToolCallingLoop310let toolCallCounter = 0;311promptResult = [];312const setupOneRound = (startIdx: number) => {313const endIdx = startIdx + roundsPerTurn;314for (let i = startIdx; i < endIdx; i++) {315const isLast = i === endIdx - 1;316chatResponse[i] = [{317text: `response number ${i}`,318copilotToolCalls: isLast ?319undefined :320[{321arguments: 'some args here',322name: 'my_func',323id: `tool_call_id_${toolCallCounter++}`,324}],325}];326327// ToolResultMetadata is reported by the prompt for all tool calls, in history or called this round328const promptMetadata: ToolResultMetadata[] = [];329for (let toolResultIdx = 0; toolResultIdx <= toolCallCounter; toolResultIdx++) {330// For each request in a round, all the previous and current ToolResultMetadata are reported331const toolResult = new LanguageModelToolResult([new LanguageModelTextPart(`tool-result-${toolResultIdx}`)]);332promptMetadata.push(new ToolResultMetadata(`tool_call_id_${toolResultIdx}__vscode-${toolResultIdx}`, toolResult));333}334(promptResult as RenderPromptResult[])[i] = {335...nullRenderPromptResult(),336messages: [{ role: Raw.ChatRole.User, content: [toTextPart('hello world!')] }],337metadata: promptResultMetadata(promptMetadata)338};339}340};341342for (let i = 0; i < turns; i++) {343setupOneRound(i * roundsPerTurn);344}345}346347test('confirms on max tool call iterations, and continues to iterate', async () => {348const handler = makeHandler();349fillWithToolCalls();350const result1 = await handler.getResult();351expect(result1).toMatchSnapshot();352353const last = response.at(-1);354expect(last).toBeInstanceOf(ChatResponseConfirmationPart);355356const request = new TestChatRequest();357request.acceptedConfirmationData = [(last as ChatResponseConfirmationPart).data];358const handler2 = makeHandler({ request });359expect(await handler2.getResult()).toMatchSnapshot();360361expect(response).toMatchSnapshot();362// Wait for event loop to finish as we often fire off telemetry without properly awaiting it as it doesn't matter when it is sent363await new Promise(setImmediate);364expect(getDerandomizedTelemetry()).toMatchSnapshot();365});366367test('ChatResult metadata after multiple turns only has tool results from current turn', async () => {368const request = new TestChatRequest();369const handler = makeHandler();370setupMultiturnToolCalls(2, maxToolCallIterations);371const result1 = await handler.getResult();372expect(result1.metadata).toMatchSnapshot();373374const turn1 = new Turn(generateUuid(), { message: request.prompt, type: 'user' }, undefined);375const handler2 = makeHandler({ request, turns: [turn1] });376const result2 = await handler2.getResult();377expect(result2.metadata).toMatchSnapshot();378});379380test('aborts on max tool call iterations', async () => {381fillWithToolCalls();382const handler = makeHandler();383await handler.getResult();384385const last = response.at(-1);386expect(last).toBeInstanceOf(ChatResponseConfirmationPart);387388const request = new TestChatRequest();389request.rejectedConfirmationData = [(last as ChatResponseConfirmationPart).data];390request.prompt = (last as ChatResponseConfirmationPart).buttons![1];391const handler2 = makeHandler({ request });392await handler2.getResult();393394const last2 = response.at(-1);395expect(last2).toBeInstanceOf(ChatResponseMarkdownPart);396expect((last2 as ChatResponseMarkdownPart).value.value).toMatchInlineSnapshot(`"Let me know if there's anything else I can help with!"`);397});398});399400401function cloneAndChangeWithKey(obj: any, changer: (orig: any, key?: string | number) => any): any {402return _cloneAndChangeWithKey(obj, changer, new Set(), undefined);403}404405function _cloneAndChangeWithKey(obj: any, changer: (orig: any, key?: string | number) => any, seen: Set<any>, key: string | number | undefined): any {406if (isUndefinedOrNull(obj)) {407return obj;408}409410const changed = changer(obj, key);411if (typeof changed !== 'undefined') {412return changed;413}414415if (Array.isArray(obj)) {416const r1: any[] = [];417for (const [i, e] of obj.entries()) {418r1.push(_cloneAndChangeWithKey(e, changer, seen, i));419}420return r1;421}422423if (isObject(obj)) {424if (seen.has(obj)) {425throw new Error('Cannot clone recursive data-structure');426}427seen.add(obj);428const r2 = {};429for (const i2 in obj) {430if (Object.prototype.hasOwnProperty.call(obj, i2)) {431(r2 as any)[i2] = _cloneAndChangeWithKey(obj[i2], changer, seen, i2);432}433}434seen.delete(obj);435return r2;436}437438return obj;439}440441442