Path: blob/main/extensions/copilot/test/base/cachingChatMLFetcher.ts
13388 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/4import { Raw } from '@vscode/prompt-tsx';5import { promises as fs } from 'fs';6import { tmpdir } from 'os';7import * as path from 'path';8import type { CancellationToken } from 'vscode';9import { AbstractChatMLFetcher } from '../../src/extension/prompt/node/chatMLFetcher';10import { IChatMLFetcher, IFetchMLOptions } from '../../src/platform/chat/common/chatMLFetcher';11import { ChatFetchResponseType, ChatResponses } from '../../src/platform/chat/common/commonTypes';12import { IConversationOptions } from '../../src/platform/chat/common/conversationOptions';13import { getTextPart } from '../../src/platform/chat/common/globalStringUtils';14import { LogLevel } from '../../src/platform/log/common/logService';15import { FinishedCallback, ICopilotToolCall, IResponseDelta, OptionalChatRequestParams } from '../../src/platform/networking/common/fetch';16import { ChoiceLogProbs, rawMessageToCAPI } from '../../src/platform/networking/common/openai';17import { LcsDiff, LineSequence } from '../../src/util/common/diff';18import { LockMap } from '../../src/util/common/lock';19import { BugIndicatingError } from '../../src/util/vs/base/common/errors';20import { SyncDescriptor } from '../../src/util/vs/platform/instantiation/common/descriptors';21import { IInstantiationService } from '../../src/util/vs/platform/instantiation/common/instantiation';22import { CHAT_ML_CACHE_SALT_PER_MODEL } from '../cacheSalt';23import { IJSONOutputPrinter } from '../jsonOutputPrinter';24import { OutputType } from '../simulation/shared/sharedTypes';25import { logger } from '../simulationLogger';26import { computeSHA256 } from './hash';27import { CacheMode, NoFetchChatMLFetcher } from './simulationContext';28import { ISimulationEndpointHealth } from './simulationEndpointHealth';29import { SimulationOutcomeImpl } from './simulationOutcome';30import { drainStdoutAndExit } from './stdout';31import { REPO_ROOT, SimulationTest } from './stest';3233export class CacheableChatRequest {34public readonly hash: string;35private readonly obj: unknown;3637constructor(38messages: Raw.ChatMessage[],39model: string,40requestOptions: OptionalChatRequestParams,41extraCacheProperties: any | undefined42) {43this.obj = { messages: rawMessageToCAPI(messages), model, requestOptions, extraCacheProperties };44const salt = CHAT_ML_CACHE_SALT_PER_MODEL[model] ?? CHAT_ML_CACHE_SALT_PER_MODEL['DEFAULT'];45this.hash = computeSHA256(salt + JSON.stringify(this.obj));4647// To aid in reading cache entries, we will write objects to disk splitting each message by new lines48// We do this after the sha computation to avoid invalidating all the existing caches49(this.obj as any).messages = (this.obj as any).messages.map((m: Raw.ChatMessage) => {50return { ...m, content: getTextPart(m.content).split('\n') };51});52}5354toJSON() {55return this.obj;56}57}5859export interface IChatMLCache {60getRequest?(hash: string): Promise<unknown | undefined>;61get(req: CacheableChatRequest, cacheSlot: number): Promise<CachedResponse | undefined>;62set(req: CacheableChatRequest, cacheSlot: number, cachedResponse: CachedResponse): Promise<void>;63}6465export class CachedTestInfo {66public get testName() { return this.stest.fullName; }6768constructor(69public readonly stest: SimulationTest,70public readonly cacheSlot: number = 071) { }72}7374export interface CachedResponseMetadata {75requestDuration: number;76requestTime: string;77testName: string;78}7980export namespace CachedResponseMetadata {81export function isCachedResponseMetadata(obj: any): obj is CachedResponseMetadata {82return (83typeof obj === 'object' &&84obj !== null &&85'requestDuration' in obj &&86typeof (obj as any).requestDuration === 'number' &&87'requestTime' in obj &&88typeof (obj as any).requestTime === 'string' &&89'testName' in obj &&90typeof (obj as any).testName === 'string'91);92}93}9495export type CachedExtraData = { cacheMetadata: CachedResponseMetadata | undefined; copilotFunctionCalls?: ICopilotToolCall[]; logprobs?: ChoiceLogProbs };96export type CachedResponse = ChatResponses & CachedExtraData;9798export type ResponseWithMeta = ChatResponses & {99isCacheHit?: boolean; // set when the cache was checked100cacheKey?: string; // set when the cache was used or updated101cacheMetadata?: CachedResponseMetadata; // set when the cache was used or updated102};103104105export class CachingChatMLFetcher extends AbstractChatMLFetcher {106107private static readonly Locks = new LockMap();108109private readonly fetcher: IChatMLFetcher;110private isDisposed = false;111112constructor(113fetcherOrDescriptor: SyncDescriptor<IChatMLFetcher> | IChatMLFetcher,114private readonly cache: IChatMLCache,115private readonly testInfo: CachedTestInfo,116private readonly extraCacheProperties: any | undefined = undefined,117private readonly cacheMode = CacheMode.Default,118@IJSONOutputPrinter private readonly jsonOutputPrinter: IJSONOutputPrinter,119@ISimulationEndpointHealth private readonly simulationEndpointHealth: ISimulationEndpointHealth,120@IInstantiationService private readonly instantiationService: IInstantiationService,121@IConversationOptions options: IConversationOptions,122) {123super(options);124125this.fetcher = (fetcherOrDescriptor instanceof SyncDescriptor ? instantiationService.createInstance(fetcherOrDescriptor) : fetcherOrDescriptor);126}127128override dispose() {129super.dispose();130this.isDisposed = true;131}132133override async fetchMany(opts: IFetchMLOptions, token: CancellationToken): Promise<ResponseWithMeta> {134135if (this.isDisposed) {136throw new BugIndicatingError('The CachingChatMLFetcher has been disposed and cannot be used anymore.');137}138139if (!this.testInfo.testName) {140throw new Error(`Illegal usage of the ChatMLFetcher! You should only use the ChatMLFetcher that is passed to your test and not an ambient one!`);141}142143if (this.cacheMode === CacheMode.Require) {144for (const message of opts.messages) {145if (containsRepoPath(getTextPart(message.content))) {146const message = `You should not use the repository root (${REPO_ROOT}) in your ChatML messages because this leads to cache misses! This request is generated by test "${this.testInfo.testName}`;147console.error(`\n\n${message}\n\n`);148this.printTerminatedWithRequireCache(message);149await drainStdoutAndExit(1);150throw new Error(message);151}152}153}154155const finalReqOptions = this.preparePostOptions(opts.requestOptions);156const req = new CacheableChatRequest(opts.messages, opts.endpoint.model, finalReqOptions, this.extraCacheProperties);157// console.log(`request with hash: ${req.hash}`);158159return CachingChatMLFetcher.Locks.withLock(req.hash, async () => {160let isCacheHit: boolean | undefined = undefined;161if (this.cacheMode !== CacheMode.Disable) {162const cacheValue = await this.cache.get(req, this.testInfo.cacheSlot);163if (cacheValue) {164if (cacheValue.type === ChatFetchResponseType.Success) {165await opts.finishedCb?.(cacheValue.value[0], 0, { text: cacheValue.value[0], copilotToolCalls: cacheValue.copilotFunctionCalls, logprobs: cacheValue.logprobs });166} else if (cacheValue.type === ChatFetchResponseType.Length) {167await opts.finishedCb?.(cacheValue.truncatedValue, 0, { text: cacheValue.truncatedValue, copilotToolCalls: cacheValue.copilotFunctionCalls, logprobs: cacheValue.logprobs });168}169return { ...cacheValue, isCacheHit: true, cacheKey: req.hash };170}171isCacheHit = false;172}173174if (this.cacheMode === CacheMode.Require) {175let diff: { newRequest: string; oldRequest: string } | undefined;176try {177diff = await this.suggestDiffCommandForCacheMiss(req);178} catch (err) {179console.log(err);180}181182console.log(JSON.stringify(opts.messages, (key, value) => {183if (typeof value === 'string') {184const split = value.split(/\n/g);185return split.length > 1 ? split : value;186}187return value;188}, 4));189190let message = `\n✗ Cache entry not found for a request generated by test "${this.testInfo.testName}"!191- Valid cache entries are currently required for all requests!192- The missing request has the hash: ${req.hash} (cache slot ${this.testInfo.cacheSlot}, make sure to call simulate -- -n=10).193`;194if (diff) {195message += `- Compare with the closest cache entry using \`code-insiders --diff "${diff.oldRequest}" "${diff.newRequest}"\`\n`;196}197198console.log(message);199this.printTerminatedWithRequireCache(message);200await drainStdoutAndExit(1);201throw new Error(message);202}203204const callbackWrapper = new FinishedCallbackWrapper(opts.finishedCb);205const start = Date.now();206if (logger.shouldLog(LogLevel.Trace)) {207logger.trace(`Making request:\n` + opts.messages.map(m => ` ${m.role}: ${getTextPart(m.content)}`).join('\n'));208}209const result = await this.fetcher.fetchMany({ ...opts, finishedCb: callbackWrapper.getCb() }, token);210const fetchingResponseTimeInMs = Date.now() - start;211// Don't cache failed results212if (213result.type === ChatFetchResponseType.OffTopic214|| result.type === ChatFetchResponseType.Filtered215|| result.type === ChatFetchResponseType.PromptFiltered216|| result.type === ChatFetchResponseType.Length217|| result.type === ChatFetchResponseType.Success218) {219const cacheMetadata: CachedResponseMetadata = {220testName: this.testInfo.testName,221requestDuration: fetchingResponseTimeInMs,222requestTime: new Date().toISOString()223};224const cachedResponse: CachedResponse = {225...result,226cacheMetadata,227copilotFunctionCalls: callbackWrapper.copilotFunctionCalls,228logprobs: callbackWrapper.logprobs,229};230if (!(this.fetcher instanceof NoFetchChatMLFetcher)) {231try {232await this.cache.set(req, this.testInfo.cacheSlot, cachedResponse);233} catch (err) {234if (/Key already exists/.test(err.message)) {235console.log(JSON.stringify(opts.messages, (key, value) => {236if (typeof value === 'string') {237const split = value.split(/\n/g);238return split.length > 1 ? split : value;239}240return value;241}, 4));242console.log(`\n✗ ${err.message}`);243await drainStdoutAndExit(1);244}245246throw err;247}248return { ...result, cacheMetadata, isCacheHit, cacheKey: req.hash };249}250} else {251// A request failed, so we don't want to cache it.252// But we should warn the developer that they need to rerun253this.simulationEndpointHealth.markFailure(this.testInfo, result);254}255return { ...result, isCacheHit };256});257}258259private async suggestDiffCommandForCacheMiss(req: CacheableChatRequest) {260const outcome = await this.instantiationService.createInstance(SimulationOutcomeImpl, false).get(this.testInfo.stest);261if (!outcome?.requests.length) {262return;263}264265const newRequest = path.join(tmpdir(), `${req.hash}-new.json`);266await fs.writeFile(newRequest, JSON.stringify(req.toJSON(), null, '\t'));267268let best: unknown | undefined;269let bestScore = Infinity;270for (const requestHash of outcome.requests) {271const request = await this.cache.getRequest!(requestHash);272if (!request) {273continue;274}275276const diff = new LcsDiff(277new LineSequence(JSON.stringify(request, null, '\t').split('\n')),278new LineSequence(JSON.stringify(req.toJSON(), null, '\t').split('\n')),279).ComputeDiff();280281let score = 0;282for (const d of diff) {283score += d.modifiedLength + d.originalLength;284}285286if (score < bestScore) {287best = request;288bestScore = score;289}290}291292if (!best) {293return;294}295296const oldRequest = path.join(tmpdir(), `${req.hash}-previous.json`);297await fs.writeFile(oldRequest, JSON.stringify(best, null, '\t'));298return {299newRequest,300oldRequest,301get isWhitespaceOnly() {302let whitespaceOnly = false;303if (best) {304const bestCast = best as { messages: { content: string[] }[] };305const currentCast = req.toJSON() as { messages: { content: string[] }[] };306if (bestCast.messages.length === currentCast.messages.length && bestCast.messages.every(307(v, i) => v.content.join('').replace(/\n\n+/, '\n').trim() === currentCast.messages[i].content.join('').replace(/\n\n+/, '\n').trim())) {308whitespaceOnly = true;309}310}311312return whitespaceOnly;313}314};315}316317private printTerminatedWithRequireCache(message: string) {318return this.jsonOutputPrinter.print({ type: OutputType.terminated, reason: `Terminated because of --require-cache\n${message}` });319}320}321322const repoRootRegex = new RegExp(REPO_ROOT.replace(/[/\\]/g, '[/\\\\]'), 'i');323324function containsRepoPath(testString: string): boolean {325return repoRootRegex.test(testString);326}327328class FinishedCallbackWrapper {329public readonly copilotFunctionCalls: ICopilotToolCall[] = [];330public logprobs: ChoiceLogProbs | undefined;331332constructor(333private readonly original: FinishedCallback | undefined) { }334335public getCb(): FinishedCallback {336return async (text: string, index: number, delta: IResponseDelta): Promise<number | undefined> => {337if (delta.copilotToolCalls) {338this.copilotFunctionCalls.push(...delta.copilotToolCalls);339}340if (delta.logprobs) {341if (!this.logprobs) {342this.logprobs = { ...delta.logprobs };343} else {344this.logprobs.content.push(...delta.logprobs.content);345}346}347348return this.original?.(text, index, delta);349};350}351}352353354