Path: blob/main/src/vs/workbench/contrib/chat/test/common/chatImageExtraction.test.ts
13406 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/45import assert from 'assert';6import { VSBuffer } from '../../../../../base/common/buffer.js';7import { URI } from '../../../../../base/common/uri.js';8import { ensureNoDisposablesAreLeakedInTestSuite } from '../../../../../base/test/common/utils.js';9import { IImageVariableEntry } from '../../common/attachments/chatVariableEntries.js';10import { IChatProgressResponseContent } from '../../common/model/chatModel.js';11import { IChatRequestViewModel, IChatResponseViewModel } from '../../common/model/chatViewModel.js';12import { IChatContentInlineReference, IChatToolInvocationSerialized, IToolResultOutputDetailsSerialized } from '../../common/chatService/chatService.js';13import { IToolResultInputOutputDetails } from '../../common/tools/languageModelToolsService.js';14import { extractImagesFromChatRequest, extractImagesFromChatResponse, extractImagesFromToolInvocationMessages } from '../../common/chatImageExtraction.js';1516function makeToolInvocation(overrides: Partial<IChatToolInvocationSerialized> = {}): IChatToolInvocationSerialized {17return {18kind: 'toolInvocationSerialized',19toolCallId: 'call_1',20toolId: 'test-tool',21invocationMessage: 'Running tool',22originMessage: undefined,23pastTenseMessage: 'Ran tool',24isConfirmed: true,25isComplete: true,26source: undefined,27presentation: undefined,28resultDetails: undefined,29...overrides,30};31}3233function makeInlineReference(uri: URI, name?: string): IChatContentInlineReference {34return {35kind: 'inlineReference',36inlineReference: uri,37name,38};39}4041function makeResponse(items: ReadonlyArray<IChatProgressResponseContent>, opts: {42sessionResource?: URI;43requestId?: string;44id?: string;45requestMessageText?: string;46noMatchingRequest?: boolean;47} = {}): IChatResponseViewModel {48const sessionResource = opts.sessionResource ?? URI.parse('chat-session://test/session');49const requestId = opts.requestId ?? 'req-1';50const responseId = opts.id ?? 'resp-1';51const requestMessageText = opts.requestMessageText ?? 'Show me images';5253return {54id: responseId,55requestId,56sessionResource,57response: { value: items },58session: {59getItems: () => opts.noMatchingRequest ? [] : [{60id: requestId,61messageText: requestMessageText,62message: { parts: [], text: requestMessageText },63}],64},65} as unknown as IChatResponseViewModel;66}6768const fakeReadFile = (uri: URI) => Promise.resolve(VSBuffer.fromString(`data-for-${uri.path}`));6970function makeRequest(variables: IChatRequestViewModel['variables'], opts: { id?: string; messageText?: string } = {}): IChatRequestViewModel {71return {72id: opts.id ?? 'req-1',73sessionResource: URI.parse('chat-session://test/session'),74dataId: 'data-1',75username: 'test-user',76message: { text: opts.messageText ?? 'Show me images', parts: [] },77messageText: opts.messageText ?? 'Show me images',78attempt: 0,79variables,80currentRenderedHeight: undefined,81shouldBeRemovedOnSend: undefined,82isComplete: true,83isCompleteAddedRequest: true,84slashCommand: undefined,85agentOrSlashCommandDetected: false,86shouldBeBlocked: undefined!,87timestamp: 0,88} as unknown as IChatRequestViewModel;89}9091function makeImageVariableEntry(overrides: Partial<IImageVariableEntry> & Pick<IImageVariableEntry, 'value'>): IImageVariableEntry {92const { value, ...rest } = overrides;93return {94id: 'img-1',95kind: 'image',96name: 'cat.png',97value,98mimeType: 'image/png',99...rest,100};101}102103suite('extractImagesFromChatResponse', () => {104ensureNoDisposablesAreLeakedInTestSuite();105106test('returns empty images when response has no items', async () => {107const response = makeResponse([]);108const result = await extractImagesFromChatResponse(response, fakeReadFile);109assert.deepStrictEqual(result, {110id: response.sessionResource.toString() + '_' + response.id,111title: 'Show me images',112images: [],113});114});115116test('uses default title when no matching request is found', async () => {117const response = makeResponse([], { noMatchingRequest: true });118const result = await extractImagesFromChatResponse(response, fakeReadFile);119assert.strictEqual(result.title, 'Images');120});121122test('extracts image from tool invocation with IToolResultOutputDetails', async () => {123const resultDetails: IToolResultOutputDetailsSerialized = {124output: { type: 'data', mimeType: 'image/png', base64Data: 'AQID' },125};126const toolInvocation = makeToolInvocation({127toolCallId: 'call_img',128toolId: 'screenshot-tool',129pastTenseMessage: 'Took a screenshot',130resultDetails,131});132133const response = makeResponse([toolInvocation]);134const result = await extractImagesFromChatResponse(response, fakeReadFile);135136assert.strictEqual(result.images.length, 1);137assert.strictEqual(result.images[0].id, 'call_img_0');138assert.strictEqual(result.images[0].mimeType, 'image/png');139assert.ok(result.images[0].source.includes('screenshot-tool'));140assert.strictEqual(result.images[0].caption, 'Took a screenshot');141});142143test('extracts multiple images from tool invocation with IToolResultInputOutputDetails', async () => {144const resultDetails: IToolResultInputOutputDetails = {145input: '',146output: [147{ type: 'embed', mimeType: 'image/png', value: 'AQID', isText: false },148{ type: 'embed', mimeType: 'text/plain', value: 'text', isText: true },149{ type: 'embed', mimeType: 'image/jpeg', value: 'BAUG', isText: false },150],151};152const toolInvocation = makeToolInvocation({153toolCallId: 'call_multi',154toolId: 'multi-tool',155pastTenseMessage: 'Generated images',156resultDetails,157});158159const response = makeResponse([toolInvocation]);160const result = await extractImagesFromChatResponse(response, fakeReadFile);161162assert.strictEqual(result.images.length, 2);163assert.strictEqual(result.images[0].id, 'call_multi_0');164assert.strictEqual(result.images[0].mimeType, 'image/png');165assert.strictEqual(result.images[1].id, 'call_multi_2');166assert.strictEqual(result.images[1].mimeType, 'image/jpeg');167});168169test('skips tool invocations without image results', async () => {170const resultDetails: IToolResultOutputDetailsSerialized = {171output: { type: 'data', mimeType: 'text/plain', base64Data: 'aGVsbG8=' },172};173const toolInvocation = makeToolInvocation({ resultDetails });174175const response = makeResponse([toolInvocation]);176const result = await extractImagesFromChatResponse(response, fakeReadFile);177assert.strictEqual(result.images.length, 0);178});179180test('extracts image from inline reference URI when readFile is provided', async () => {181const imageUri = URI.file('/photos/cat.png');182const inlineRef = makeInlineReference(imageUri, 'cat.png');183184const response = makeResponse([inlineRef]);185const result = await extractImagesFromChatResponse(response, fakeReadFile);186187assert.strictEqual(result.images.length, 1);188assert.strictEqual(result.images[0].uri.toString(), imageUri.toString());189assert.strictEqual(result.images[0].name, 'cat.png');190assert.strictEqual(result.images[0].mimeType, 'image/png');191assert.strictEqual(result.images[0].source, 'File');192});193194test('extracts image from inline reference Location', async () => {195const imageUri = URI.file('/photos/dog.jpg');196const inlineRef: IChatContentInlineReference = {197kind: 'inlineReference',198inlineReference: { uri: imageUri, range: { startLineNumber: 1, startColumn: 1, endLineNumber: 1, endColumn: 1 } },199};200201const response = makeResponse([inlineRef]);202const result = await extractImagesFromChatResponse(response, fakeReadFile);203204assert.strictEqual(result.images.length, 1);205assert.strictEqual(result.images[0].uri.toString(), imageUri.toString());206});207208test('skips non-image inline references', async () => {209const codeUri = URI.file('/src/main.ts');210const inlineRef = makeInlineReference(codeUri);211212const response = makeResponse([inlineRef]);213const result = await extractImagesFromChatResponse(response, fakeReadFile);214assert.strictEqual(result.images.length, 0);215});216217test('uses filename from URI path when name is not provided', async () => {218const imageUri = URI.file('/assets/banner.gif');219const inlineRef = makeInlineReference(imageUri);220221const response = makeResponse([inlineRef]);222const result = await extractImagesFromChatResponse(response, fakeReadFile);223224assert.strictEqual(result.images.length, 1);225assert.strictEqual(result.images[0].name, 'banner.gif');226});227228test('preserves interleaved order of tool and inline reference images', async () => {229const toolInvocation = makeToolInvocation({230toolCallId: 'call_first',231toolId: 'tool-1',232resultDetails: {233output: { type: 'data', mimeType: 'image/png', base64Data: 'AQID' },234} satisfies IToolResultOutputDetailsSerialized,235});236237const inlineRef = makeInlineReference(URI.file('/middle.png'), 'middle.png');238239const toolInvocation2 = makeToolInvocation({240toolCallId: 'call_last',241toolId: 'tool-2',242resultDetails: {243output: { type: 'data', mimeType: 'image/jpeg', base64Data: 'BAUG' },244} satisfies IToolResultOutputDetailsSerialized,245});246247const response = makeResponse([toolInvocation, inlineRef, toolInvocation2]);248const result = await extractImagesFromChatResponse(response, fakeReadFile);249250assert.strictEqual(result.images.length, 3);251assert.strictEqual(result.images[0].id, 'call_first_0');252assert.strictEqual(result.images[1].name, 'middle.png');253assert.strictEqual(result.images[2].id, 'call_last_0');254});255256test('collection id combines sessionResource and response id', async () => {257const sessionResource = URI.parse('chat-session://test/my-session');258const response = makeResponse([], { sessionResource, id: 'response-42' });259const result = await extractImagesFromChatResponse(response, fakeReadFile);260assert.strictEqual(result.id, sessionResource.toString() + '_response-42');261});262263test('skips inline reference when readFile fails', async () => {264const imageUri = URI.file('/photos/missing.png');265const inlineRef = makeInlineReference(imageUri, 'missing.png');266const failingReadFile = (_uri: URI) => Promise.reject(new Error('File not found'));267268const response = makeResponse([inlineRef]);269const result = await extractImagesFromChatResponse(response, failingReadFile);270assert.strictEqual(result.images.length, 0);271});272273test('extracts images from tool invocation message URIs', async () => {274const imageUri = URI.file('/screenshots/result.png');275const toolInvocation = makeToolInvocation({276toolCallId: 'call_msg',277toolId: 'screenshot-tool',278pastTenseMessage: { value: 'Took a screenshot', isTrusted: false, uris: { '0': imageUri.toJSON() } },279});280281const response = makeResponse([toolInvocation]);282const result = await extractImagesFromChatResponse(response, fakeReadFile);283284assert.strictEqual(result.images.length, 1);285assert.strictEqual(result.images[0].uri.toString(), imageUri.toString());286assert.strictEqual(result.images[0].name, 'result.png');287assert.strictEqual(result.images[0].mimeType, 'image/png');288assert.strictEqual(result.images[0].caption, 'Took a screenshot');289});290291test('combines output details images and message URI images', async () => {292const imageUri = URI.file('/screenshots/msg-image.jpg');293const resultDetails: IToolResultOutputDetailsSerialized = {294output: { type: 'data', mimeType: 'image/png', base64Data: 'AQID' },295};296const toolInvocation = makeToolInvocation({297toolCallId: 'call_both',298toolId: 'combo-tool',299pastTenseMessage: { value: 'Ran combo tool', isTrusted: false, uris: { '0': imageUri.toJSON() } },300resultDetails,301});302303const response = makeResponse([toolInvocation]);304const result = await extractImagesFromChatResponse(response, fakeReadFile);305306assert.strictEqual(result.images.length, 2);307assert.strictEqual(result.images[0].id, 'call_both_0');308assert.strictEqual(result.images[1].uri.toString(), imageUri.toString());309});310});311312suite('extractImagesFromToolInvocationMessages', () => {313ensureNoDisposablesAreLeakedInTestSuite();314315test('returns empty when message is undefined', async () => {316const toolInvocation = makeToolInvocation({317pastTenseMessage: undefined,318invocationMessage: undefined,319});320const result = await extractImagesFromToolInvocationMessages(toolInvocation, fakeReadFile);321assert.deepStrictEqual(result, []);322});323324test('returns empty when message is a string', async () => {325const toolInvocation = makeToolInvocation({326pastTenseMessage: 'some string message',327});328const result = await extractImagesFromToolInvocationMessages(toolInvocation, fakeReadFile);329assert.deepStrictEqual(result, []);330});331332test('returns empty when message has no uris', async () => {333const toolInvocation = makeToolInvocation({334pastTenseMessage: { value: 'No URIs here', isTrusted: false },335});336const result = await extractImagesFromToolInvocationMessages(toolInvocation, fakeReadFile);337assert.deepStrictEqual(result, []);338});339340test('returns empty when message uris are empty', async () => {341const toolInvocation = makeToolInvocation({342pastTenseMessage: { value: 'Empty URIs', isTrusted: false, uris: {} },343});344const result = await extractImagesFromToolInvocationMessages(toolInvocation, fakeReadFile);345assert.deepStrictEqual(result, []);346});347348test('skips non-image URIs', async () => {349const toolInvocation = makeToolInvocation({350pastTenseMessage: { value: 'Code file', isTrusted: false, uris: { '0': URI.file('/src/main.ts').toJSON() } },351});352const result = await extractImagesFromToolInvocationMessages(toolInvocation, fakeReadFile);353assert.deepStrictEqual(result, []);354});355356test('extracts image from message URI', async () => {357const imageUri = URI.file('/screenshots/capture.png');358const toolInvocation = makeToolInvocation({359toolCallId: 'call_uri',360toolId: 'screenshot-tool',361pastTenseMessage: { value: 'Captured screenshot', isTrusted: false, uris: { '0': imageUri.toJSON() } },362});363364const result = await extractImagesFromToolInvocationMessages(toolInvocation, fakeReadFile);365366assert.strictEqual(result.length, 1);367assert.strictEqual(result[0].uri.toString(), imageUri.toString());368assert.strictEqual(result[0].name, 'capture.png');369assert.strictEqual(result[0].mimeType, 'image/png');370assert.strictEqual(result[0].caption, 'Captured screenshot');371assert.ok(result[0].source.includes('screenshot-tool'));372});373374test('extracts multiple images from message URIs', async () => {375const uri1 = URI.file('/img/a.png');376const uri2 = URI.file('/img/b.jpg');377const toolInvocation = makeToolInvocation({378pastTenseMessage: {379value: 'Generated images',380isTrusted: false,381uris: { '0': uri1.toJSON(), '1': uri2.toJSON() },382},383});384385const result = await extractImagesFromToolInvocationMessages(toolInvocation, fakeReadFile);386387assert.strictEqual(result.length, 2);388assert.strictEqual(result[0].mimeType, 'image/png');389assert.strictEqual(result[1].mimeType, 'image/jpg');390});391392test('continues when readFile fails for one URI', async () => {393const goodUri = URI.file('/img/good.png');394const badUri = URI.file('/img/bad.png');395const failingReadFile = (uri: URI) => {396if (uri.path.includes('bad')) {397return Promise.reject(new Error('File not found'));398}399return Promise.resolve(VSBuffer.fromString('image-data'));400};401const toolInvocation = makeToolInvocation({402pastTenseMessage: {403value: 'Mixed results',404isTrusted: false,405uris: { '0': badUri.toJSON(), '1': goodUri.toJSON() },406},407});408409const result = await extractImagesFromToolInvocationMessages(toolInvocation, failingReadFile);410411assert.strictEqual(result.length, 1);412assert.strictEqual(result[0].uri.toString(), goodUri.toString());413});414415test('falls back to invocationMessage when pastTenseMessage is undefined', async () => {416const imageUri = URI.file('/img/fallback.png');417const toolInvocation = makeToolInvocation({418pastTenseMessage: undefined,419invocationMessage: { value: 'Running tool', isTrusted: false, uris: { '0': imageUri.toJSON() } },420});421422const result = await extractImagesFromToolInvocationMessages(toolInvocation, fakeReadFile);423424assert.strictEqual(result.length, 1);425assert.strictEqual(result[0].caption, 'Running tool');426});427});428429suite('extractImagesFromChatRequest', () => {430ensureNoDisposablesAreLeakedInTestSuite();431432test('extracts image attachment from Uint8Array', () => {433const request = makeRequest([434makeImageVariableEntry({ value: new Uint8Array([1, 2, 3]) }),435]);436437const result = extractImagesFromChatRequest(request);438439assert.strictEqual(result.length, 1);440assert.strictEqual(result[0].name, 'cat.png');441assert.strictEqual(result[0].mimeType, 'image/png');442assert.deepStrictEqual([...result[0].data.buffer], [1, 2, 3]);443});444445test('extracts image attachment from ArrayBuffer', () => {446const request = makeRequest([447makeImageVariableEntry({ value: new Uint8Array([4, 5, 6]).buffer }),448]);449450const result = extractImagesFromChatRequest(request);451452assert.strictEqual(result.length, 1);453assert.deepStrictEqual([...result[0].data.buffer], [4, 5, 6]);454});455456test('extracts restored image attachment from plain object bytes', () => {457const request = makeRequest([458makeImageVariableEntry({ value: { 0: 7, 1: 8, 2: 9 } }),459]);460461const result = extractImagesFromChatRequest(request);462463assert.strictEqual(result.length, 1);464assert.deepStrictEqual([...result[0].data.buffer], [7, 8, 9]);465});466467test('extracts restored image attachment from reordered plain object bytes', () => {468const request = makeRequest([469makeImageVariableEntry({ value: { 2: 9, 0: 7, 1: 8 } }),470]);471472const result = extractImagesFromChatRequest(request);473474assert.strictEqual(result.length, 1);475assert.deepStrictEqual([...result[0].data.buffer], [7, 8, 9]);476});477478test('uses attachment resource URI when available', () => {479const uri = URI.file('/tmp/cat.png');480const request = makeRequest([481makeImageVariableEntry({ value: new Uint8Array([1]), references: [{ kind: 'reference', reference: uri }] }),482]);483484const result = extractImagesFromChatRequest(request);485486assert.strictEqual(result.length, 1);487assert.strictEqual(result[0].uri.toString(), uri.toString());488});489});490491492