Path: blob/main/extensions/copilot/src/extension/byok/vscode-node/test/geminiNativeProvider.spec.ts
13405 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/45import { beforeEach, describe, expect, it, vi } from 'vitest';6import * as vscode from 'vscode';7import { NoopOTelService, resolveOTelConfig } from '../../../../platform/otel/common/index';8import type { CapturingToken } from '../../../../platform/requestLogger/common/capturingToken';9import type { IRequestLogger } from '../../../../platform/requestLogger/common/requestLogger';10import { NullTelemetryService } from '../../../../platform/telemetry/common/nullTelemetryService';11import { TestLogService } from '../../../../platform/testing/common/testLogService';12import type { IBYOKStorageService } from '../byokStorageService';1314const mockHandleAPIKeyUpdate = vi.fn();1516vi.mock('@google/genai', () => {17class MockGoogleGenAI {18public static createdWithApiKeys: string[] = [];19public static streamChunks: any[] = [];20public static listModelsResult: AsyncIterable<any> = (async function* () { })();2122public readonly apiKey: string;23public readonly models: {24list: () => Promise<AsyncIterable<any>>;25generateContentStream: (params: unknown) => Promise<AsyncIterable<any>>;26};2728constructor(opts: { apiKey: string }) {29this.apiKey = opts.apiKey;30MockGoogleGenAI.createdWithApiKeys.push(opts.apiKey);31this.models = {32list: async () => MockGoogleGenAI.listModelsResult,33generateContentStream: async () => (async function* () {34for (const c of MockGoogleGenAI.streamChunks) {35yield c;36}37})()38};39}40}4142return {43GoogleGenAI: MockGoogleGenAI,44Type: { OBJECT: 'object' },45};46});4748vi.mock('../../common/byokProvider', async (importOriginal) => {49const actual = await importOriginal<typeof import('../../common/byokProvider')>();50return {51...actual,52handleAPIKeyUpdate: mockHandleAPIKeyUpdate,53};54});5556type ProgressItem = vscode.LanguageModelResponsePart2;5758class TestProgress implements vscode.Progress<ProgressItem> {59public readonly items: ProgressItem[] = [];60report(value: ProgressItem): void {61this.items.push(value);62}63}6465function createStorageService(overrides?: Partial<IBYOKStorageService>): IBYOKStorageService {66return {67getAPIKey: vi.fn().mockResolvedValue(undefined),68storeAPIKey: vi.fn().mockResolvedValue(undefined),69deleteAPIKey: vi.fn().mockResolvedValue(undefined),70getStoredModelConfigs: vi.fn().mockResolvedValue({}),71saveModelConfig: vi.fn().mockResolvedValue(undefined),72removeModelConfig: vi.fn().mockResolvedValue(undefined),73...overrides,74};75}7677function createRequestLogger(): IRequestLogger {78const didChangeEmitter = new vscode.EventEmitter<void>();79return {80_serviceBrand: undefined,81promptRendererTracing: false,82captureInvocation: async <T>(_request: CapturingToken, fn: () => Promise<T>) => fn(),83logToolCall: () => undefined,84logModelListCall: () => undefined,85logChatRequest: () => ({86markTimeToFirstToken: () => undefined,87resolveWithCancelation: () => undefined,88resolve: () => undefined,89}),90addPromptTrace: () => undefined,91addEntry: () => undefined,92onDidChangeRequests: didChangeEmitter.event,93getRequests: () => [],94enableWorkspaceEditTracing: () => undefined,95disableWorkspaceEditTracing: () => undefined,96} as unknown as IRequestLogger;97}9899describe('GeminiNativeBYOKLMProvider', () => {100beforeEach(() => {101vi.clearAllMocks();102});103104it.skip('throws a clear error when no API key is configured (no silent return)', async () => {105const { GeminiNativeBYOKLMProvider } = await import('../geminiNativeProvider');106const storage = createStorageService({ getAPIKey: vi.fn().mockResolvedValue(undefined) });107const provider = new GeminiNativeBYOKLMProvider(undefined, storage, new TestLogService(), createRequestLogger(), new NullTelemetryService(), new NoopOTelService(resolveOTelConfig({ env: {}, extensionVersion: '1.0.0', sessionId: 'test' })));108109const model: vscode.LanguageModelChatInformation = {110id: 'gemini-2.0-flash',111name: 'Gemini 2.0 Flash',112family: 'Gemini',113version: '1.0.0',114maxInputTokens: 1000,115maxOutputTokens: 1000,116capabilities: { toolCalling: false, imageInput: false }117};118const messages: vscode.LanguageModelChatMessage[] = [119new vscode.LanguageModelChatMessage(vscode.LanguageModelChatMessageRole.User, 'hello')120];121122const tokenSource = new vscode.CancellationTokenSource();123const progress = new TestProgress();124await expect(provider.provideLanguageModelChatResponse(125model,126messages,127{ requestInitiator: 'test', tools: [], toolMode: vscode.LanguageModelChatToolMode.Auto },128progress,129tokenSource.token130)).rejects.toThrow(/No API key configured/i);131});132133// it.skip('initializes the Gemini client on API key update and can stream a response', async () => {134// const { GeminiNativeBYOKLMProvider } = await import('../geminiNativeProvider');135// const genai = await import('@google/genai');136// const MockGoogleGenAI = genai.GoogleGenAI as unknown as { createdWithApiKeys: string[]; streamChunks: any[] };137// MockGoogleGenAI.createdWithApiKeys.length = 0;138// MockGoogleGenAI.streamChunks.length = 0;139// MockGoogleGenAI.streamChunks.push({140// candidates: [{141// content: { parts: [{ text: 'Hello from Gemini' }] }142// }]143// });144145// mockHandleAPIKeyUpdate.mockResolvedValue({ apiKey: 'k_test', deleted: false, cancelled: false });146147// const storage = createStorageService({ getAPIKey: vi.fn().mockResolvedValue('k_test') });148// const provider = new GeminiNativeBYOKLMProvider(undefined, storage, new TestLogService(), createRequestLogger());149150// await provider.updateAPIKey();151// expect(MockGoogleGenAI.createdWithApiKeys).toEqual(['k_test']);152153// const model: vscode.LanguageModelChatInformation = {154// id: 'gemini-2.0-flash',155// name: 'Gemini 2.0 Flash',156// family: 'Gemini',157// version: '1.0.0',158// maxInputTokens: 1000,159// maxOutputTokens: 1000,160// capabilities: { toolCalling: false, imageInput: false }161// };162// const messages: vscode.LanguageModelChatMessage[] = [163// new vscode.LanguageModelChatMessage(vscode.LanguageModelChatMessageRole.User, 'hello')164// ];165166// const tokenSource = new vscode.CancellationTokenSource();167// const progress = new TestProgress();168// await provider.provideLanguageModelChatResponse(169// model,170// messages,171// { requestInitiator: 'test', tools: [], toolMode: vscode.LanguageModelChatToolMode.Auto },172// progress,173// tokenSource.token174// );175176// expect(progress.items.some(p => p instanceof vscode.LanguageModelTextPart && p.value.includes('Hello from Gemini'))).toBe(true);177// });178179// it.skip('clears the client when API key is deleted via update flow', async () => {180// const { GeminiNativeBYOKLMProvider } = await import('../geminiNativeProvider');181// const genai = await import('@google/genai');182// const MockGoogleGenAI = genai.GoogleGenAI as unknown as { createdWithApiKeys: string[]; streamChunks: any[] };183// MockGoogleGenAI.createdWithApiKeys.length = 0;184// MockGoogleGenAI.streamChunks.length = 0;185186// const storage = createStorageService({ getAPIKey: vi.fn().mockResolvedValue(undefined) });187// const provider = new GeminiNativeBYOKLMProvider(undefined, storage, new TestLogService(), createRequestLogger());188189// // First set a key190// mockHandleAPIKeyUpdate.mockResolvedValueOnce({ apiKey: 'k_initial', deleted: false, cancelled: false });191// await provider.updateAPIKey();192// expect(MockGoogleGenAI.createdWithApiKeys).toEqual(['k_initial']);193194// // Then delete it195// mockHandleAPIKeyUpdate.mockResolvedValueOnce({ apiKey: undefined, deleted: true, cancelled: false });196// await provider.updateAPIKey();197198// const model: vscode.LanguageModelChatInformation = {199// id: 'gemini-2.0-flash',200// name: 'Gemini 2.0 Flash',201// family: 'Gemini',202// version: '1.0.0',203// maxInputTokens: 1000,204// maxOutputTokens: 1000,205// capabilities: { toolCalling: false, imageInput: false }206// };207// const messages: vscode.LanguageModelChatMessage[] = [208// new vscode.LanguageModelChatMessage(vscode.LanguageModelChatMessageRole.User, 'hello')209// ];210211// const tokenSource = new vscode.CancellationTokenSource();212// const progress = new TestProgress();213// await expect(provider.provideLanguageModelChatResponse(214// model,215// messages,216// { requestInitiator: 'test', tools: [], toolMode: vscode.LanguageModelChatToolMode.Auto },217// progress,218// tokenSource.token219// )).rejects.toThrow(/No API key configured/i);220// });221222it.skip('prompts for a new API key when listing models fails with an invalid key', async () => {223const { GeminiNativeBYOKLMProvider } = await import('../geminiNativeProvider');224const genai = await import('@google/genai');225const MockGoogleGenAI = genai.GoogleGenAI as unknown as { listModelsResult: AsyncIterable<any> };226// Simulate the models.list() call throwing an invalid API key error when iterated227MockGoogleGenAI.listModelsResult = (async function* () {228throw new Error('ApiError: {"error":{"message":"API key not valid. Please pass a valid API key.","details":[{"reason":"API_KEY_INVALID"}]}}');229})();230231const storage = createStorageService({232getAPIKey: vi.fn().mockResolvedValue('bad_key'),233});234235mockHandleAPIKeyUpdate.mockResolvedValue({ apiKey: undefined, deleted: false, cancelled: true });236237const provider = new GeminiNativeBYOKLMProvider(undefined, storage, new TestLogService(), createRequestLogger(), new NullTelemetryService(), new NoopOTelService(resolveOTelConfig({ env: {}, extensionVersion: '1.0.0', sessionId: 'test' })));238const tokenSource = new vscode.CancellationTokenSource();239const models = await provider.provideLanguageModelChatInformation({ silent: false }, tokenSource.token);240241// When the key is invalid, we should re-prompt for a new one242// and handle the failure gracefully by returning an empty list.243expect(models).toEqual([]);244expect(mockHandleAPIKeyUpdate).toHaveBeenCalled();245});246247it.skip('retries listing models after re-prompting with a valid API key', async () => {248const { GeminiNativeBYOKLMProvider } = await import('../geminiNativeProvider');249const genai = await import('@google/genai');250const MockGoogleGenAI = genai.GoogleGenAI as unknown as { listModelsResult: AsyncIterable<any> };251252let iterationCount = 0;253let hasThrown = false;254const modelId = 'test-model';255256MockGoogleGenAI.listModelsResult = {257async *[Symbol.asyncIterator]() {258iterationCount++;259if (!hasThrown) {260hasThrown = true;261throw new Error('ApiError: {"error":{"message":"API key not valid. Please pass a valid API key.","details":[{"reason":"API_KEY_INVALID"}]}}');262}263yield { name: modelId };264}265};266267const storage = createStorageService({268getAPIKey: vi.fn().mockResolvedValue('bad_key'),269});270271mockHandleAPIKeyUpdate.mockResolvedValue({ apiKey: 'k_new', deleted: false, cancelled: false });272273const knownModels = {274[modelId]: {275name: 'Test Model',276maxInputTokens: 1000,277maxOutputTokens: 1000,278toolCalling: false,279vision: false280}281};282283const provider = new GeminiNativeBYOKLMProvider(knownModels, storage, new TestLogService(), createRequestLogger(), new NullTelemetryService(), new NoopOTelService(resolveOTelConfig({ env: {}, extensionVersion: '1.0.0', sessionId: 'test' })));284const tokenSource = new vscode.CancellationTokenSource();285const models = await provider.provideLanguageModelChatInformation({ silent: false }, tokenSource.token);286287// First attempt should fail with invalid key, then after re-prompting288// we should retry listing models and succeed with the new key.289expect(models.map(m => m.id)).toEqual([modelId]);290expect(iterationCount).toBe(2);291expect(mockHandleAPIKeyUpdate).toHaveBeenCalled();292});293});294295296