Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/vscode
Path: blob/main/extensions/copilot/test/base/cachingChatMLFetcher.ts
13388 views
1
/*---------------------------------------------------------------------------------------------
2
* Copyright (c) Microsoft Corporation. All rights reserved.
3
* Licensed under the MIT License. See License.txt in the project root for license information.
4
*--------------------------------------------------------------------------------------------*/
5
import { Raw } from '@vscode/prompt-tsx';
6
import { promises as fs } from 'fs';
7
import { tmpdir } from 'os';
8
import * as path from 'path';
9
import type { CancellationToken } from 'vscode';
10
import { AbstractChatMLFetcher } from '../../src/extension/prompt/node/chatMLFetcher';
11
import { IChatMLFetcher, IFetchMLOptions } from '../../src/platform/chat/common/chatMLFetcher';
12
import { ChatFetchResponseType, ChatResponses } from '../../src/platform/chat/common/commonTypes';
13
import { IConversationOptions } from '../../src/platform/chat/common/conversationOptions';
14
import { getTextPart } from '../../src/platform/chat/common/globalStringUtils';
15
import { LogLevel } from '../../src/platform/log/common/logService';
16
import { FinishedCallback, ICopilotToolCall, IResponseDelta, OptionalChatRequestParams } from '../../src/platform/networking/common/fetch';
17
import { ChoiceLogProbs, rawMessageToCAPI } from '../../src/platform/networking/common/openai';
18
import { LcsDiff, LineSequence } from '../../src/util/common/diff';
19
import { LockMap } from '../../src/util/common/lock';
20
import { BugIndicatingError } from '../../src/util/vs/base/common/errors';
21
import { SyncDescriptor } from '../../src/util/vs/platform/instantiation/common/descriptors';
22
import { IInstantiationService } from '../../src/util/vs/platform/instantiation/common/instantiation';
23
import { CHAT_ML_CACHE_SALT_PER_MODEL } from '../cacheSalt';
24
import { IJSONOutputPrinter } from '../jsonOutputPrinter';
25
import { OutputType } from '../simulation/shared/sharedTypes';
26
import { logger } from '../simulationLogger';
27
import { computeSHA256 } from './hash';
28
import { CacheMode, NoFetchChatMLFetcher } from './simulationContext';
29
import { ISimulationEndpointHealth } from './simulationEndpointHealth';
30
import { SimulationOutcomeImpl } from './simulationOutcome';
31
import { drainStdoutAndExit } from './stdout';
32
import { REPO_ROOT, SimulationTest } from './stest';
33
34
export class CacheableChatRequest {
35
public readonly hash: string;
36
private readonly obj: unknown;
37
38
constructor(
39
messages: Raw.ChatMessage[],
40
model: string,
41
requestOptions: OptionalChatRequestParams,
42
extraCacheProperties: any | undefined
43
) {
44
this.obj = { messages: rawMessageToCAPI(messages), model, requestOptions, extraCacheProperties };
45
const salt = CHAT_ML_CACHE_SALT_PER_MODEL[model] ?? CHAT_ML_CACHE_SALT_PER_MODEL['DEFAULT'];
46
this.hash = computeSHA256(salt + JSON.stringify(this.obj));
47
48
// To aid in reading cache entries, we will write objects to disk splitting each message by new lines
49
// We do this after the sha computation to avoid invalidating all the existing caches
50
(this.obj as any).messages = (this.obj as any).messages.map((m: Raw.ChatMessage) => {
51
return { ...m, content: getTextPart(m.content).split('\n') };
52
});
53
}
54
55
toJSON() {
56
return this.obj;
57
}
58
}
59
60
export interface IChatMLCache {
61
getRequest?(hash: string): Promise<unknown | undefined>;
62
get(req: CacheableChatRequest, cacheSlot: number): Promise<CachedResponse | undefined>;
63
set(req: CacheableChatRequest, cacheSlot: number, cachedResponse: CachedResponse): Promise<void>;
64
}
65
66
export class CachedTestInfo {
67
public get testName() { return this.stest.fullName; }
68
69
constructor(
70
public readonly stest: SimulationTest,
71
public readonly cacheSlot: number = 0
72
) { }
73
}
74
75
export interface CachedResponseMetadata {
76
requestDuration: number;
77
requestTime: string;
78
testName: string;
79
}
80
81
export namespace CachedResponseMetadata {
82
export function isCachedResponseMetadata(obj: any): obj is CachedResponseMetadata {
83
return (
84
typeof obj === 'object' &&
85
obj !== null &&
86
'requestDuration' in obj &&
87
typeof (obj as any).requestDuration === 'number' &&
88
'requestTime' in obj &&
89
typeof (obj as any).requestTime === 'string' &&
90
'testName' in obj &&
91
typeof (obj as any).testName === 'string'
92
);
93
}
94
}
95
96
export type CachedExtraData = { cacheMetadata: CachedResponseMetadata | undefined; copilotFunctionCalls?: ICopilotToolCall[]; logprobs?: ChoiceLogProbs };
97
export type CachedResponse = ChatResponses & CachedExtraData;
98
99
export type ResponseWithMeta = ChatResponses & {
100
isCacheHit?: boolean; // set when the cache was checked
101
cacheKey?: string; // set when the cache was used or updated
102
cacheMetadata?: CachedResponseMetadata; // set when the cache was used or updated
103
};
104
105
106
export class CachingChatMLFetcher extends AbstractChatMLFetcher {
107
108
private static readonly Locks = new LockMap();
109
110
private readonly fetcher: IChatMLFetcher;
111
private isDisposed = false;
112
113
constructor(
114
fetcherOrDescriptor: SyncDescriptor<IChatMLFetcher> | IChatMLFetcher,
115
private readonly cache: IChatMLCache,
116
private readonly testInfo: CachedTestInfo,
117
private readonly extraCacheProperties: any | undefined = undefined,
118
private readonly cacheMode = CacheMode.Default,
119
@IJSONOutputPrinter private readonly jsonOutputPrinter: IJSONOutputPrinter,
120
@ISimulationEndpointHealth private readonly simulationEndpointHealth: ISimulationEndpointHealth,
121
@IInstantiationService private readonly instantiationService: IInstantiationService,
122
@IConversationOptions options: IConversationOptions,
123
) {
124
super(options);
125
126
this.fetcher = (fetcherOrDescriptor instanceof SyncDescriptor ? instantiationService.createInstance(fetcherOrDescriptor) : fetcherOrDescriptor);
127
}
128
129
override dispose() {
130
super.dispose();
131
this.isDisposed = true;
132
}
133
134
override async fetchMany(opts: IFetchMLOptions, token: CancellationToken): Promise<ResponseWithMeta> {
135
136
if (this.isDisposed) {
137
throw new BugIndicatingError('The CachingChatMLFetcher has been disposed and cannot be used anymore.');
138
}
139
140
if (!this.testInfo.testName) {
141
throw new Error(`Illegal usage of the ChatMLFetcher! You should only use the ChatMLFetcher that is passed to your test and not an ambient one!`);
142
}
143
144
if (this.cacheMode === CacheMode.Require) {
145
for (const message of opts.messages) {
146
if (containsRepoPath(getTextPart(message.content))) {
147
const message = `You should not use the repository root (${REPO_ROOT}) in your ChatML messages because this leads to cache misses! This request is generated by test "${this.testInfo.testName}`;
148
console.error(`\n\n${message}\n\n`);
149
this.printTerminatedWithRequireCache(message);
150
await drainStdoutAndExit(1);
151
throw new Error(message);
152
}
153
}
154
}
155
156
const finalReqOptions = this.preparePostOptions(opts.requestOptions);
157
const req = new CacheableChatRequest(opts.messages, opts.endpoint.model, finalReqOptions, this.extraCacheProperties);
158
// console.log(`request with hash: ${req.hash}`);
159
160
return CachingChatMLFetcher.Locks.withLock(req.hash, async () => {
161
let isCacheHit: boolean | undefined = undefined;
162
if (this.cacheMode !== CacheMode.Disable) {
163
const cacheValue = await this.cache.get(req, this.testInfo.cacheSlot);
164
if (cacheValue) {
165
if (cacheValue.type === ChatFetchResponseType.Success) {
166
await opts.finishedCb?.(cacheValue.value[0], 0, { text: cacheValue.value[0], copilotToolCalls: cacheValue.copilotFunctionCalls, logprobs: cacheValue.logprobs });
167
} else if (cacheValue.type === ChatFetchResponseType.Length) {
168
await opts.finishedCb?.(cacheValue.truncatedValue, 0, { text: cacheValue.truncatedValue, copilotToolCalls: cacheValue.copilotFunctionCalls, logprobs: cacheValue.logprobs });
169
}
170
return { ...cacheValue, isCacheHit: true, cacheKey: req.hash };
171
}
172
isCacheHit = false;
173
}
174
175
if (this.cacheMode === CacheMode.Require) {
176
let diff: { newRequest: string; oldRequest: string } | undefined;
177
try {
178
diff = await this.suggestDiffCommandForCacheMiss(req);
179
} catch (err) {
180
console.log(err);
181
}
182
183
console.log(JSON.stringify(opts.messages, (key, value) => {
184
if (typeof value === 'string') {
185
const split = value.split(/\n/g);
186
return split.length > 1 ? split : value;
187
}
188
return value;
189
}, 4));
190
191
let message = `\n✗ Cache entry not found for a request generated by test "${this.testInfo.testName}"!
192
- Valid cache entries are currently required for all requests!
193
- The missing request has the hash: ${req.hash} (cache slot ${this.testInfo.cacheSlot}, make sure to call simulate -- -n=10).
194
`;
195
if (diff) {
196
message += `- Compare with the closest cache entry using \`code-insiders --diff "${diff.oldRequest}" "${diff.newRequest}"\`\n`;
197
}
198
199
console.log(message);
200
this.printTerminatedWithRequireCache(message);
201
await drainStdoutAndExit(1);
202
throw new Error(message);
203
}
204
205
const callbackWrapper = new FinishedCallbackWrapper(opts.finishedCb);
206
const start = Date.now();
207
if (logger.shouldLog(LogLevel.Trace)) {
208
logger.trace(`Making request:\n` + opts.messages.map(m => ` ${m.role}: ${getTextPart(m.content)}`).join('\n'));
209
}
210
const result = await this.fetcher.fetchMany({ ...opts, finishedCb: callbackWrapper.getCb() }, token);
211
const fetchingResponseTimeInMs = Date.now() - start;
212
// Don't cache failed results
213
if (
214
result.type === ChatFetchResponseType.OffTopic
215
|| result.type === ChatFetchResponseType.Filtered
216
|| result.type === ChatFetchResponseType.PromptFiltered
217
|| result.type === ChatFetchResponseType.Length
218
|| result.type === ChatFetchResponseType.Success
219
) {
220
const cacheMetadata: CachedResponseMetadata = {
221
testName: this.testInfo.testName,
222
requestDuration: fetchingResponseTimeInMs,
223
requestTime: new Date().toISOString()
224
};
225
const cachedResponse: CachedResponse = {
226
...result,
227
cacheMetadata,
228
copilotFunctionCalls: callbackWrapper.copilotFunctionCalls,
229
logprobs: callbackWrapper.logprobs,
230
};
231
if (!(this.fetcher instanceof NoFetchChatMLFetcher)) {
232
try {
233
await this.cache.set(req, this.testInfo.cacheSlot, cachedResponse);
234
} catch (err) {
235
if (/Key already exists/.test(err.message)) {
236
console.log(JSON.stringify(opts.messages, (key, value) => {
237
if (typeof value === 'string') {
238
const split = value.split(/\n/g);
239
return split.length > 1 ? split : value;
240
}
241
return value;
242
}, 4));
243
console.log(`\n✗ ${err.message}`);
244
await drainStdoutAndExit(1);
245
}
246
247
throw err;
248
}
249
return { ...result, cacheMetadata, isCacheHit, cacheKey: req.hash };
250
}
251
} else {
252
// A request failed, so we don't want to cache it.
253
// But we should warn the developer that they need to rerun
254
this.simulationEndpointHealth.markFailure(this.testInfo, result);
255
}
256
return { ...result, isCacheHit };
257
});
258
}
259
260
private async suggestDiffCommandForCacheMiss(req: CacheableChatRequest) {
261
const outcome = await this.instantiationService.createInstance(SimulationOutcomeImpl, false).get(this.testInfo.stest);
262
if (!outcome?.requests.length) {
263
return;
264
}
265
266
const newRequest = path.join(tmpdir(), `${req.hash}-new.json`);
267
await fs.writeFile(newRequest, JSON.stringify(req.toJSON(), null, '\t'));
268
269
let best: unknown | undefined;
270
let bestScore = Infinity;
271
for (const requestHash of outcome.requests) {
272
const request = await this.cache.getRequest!(requestHash);
273
if (!request) {
274
continue;
275
}
276
277
const diff = new LcsDiff(
278
new LineSequence(JSON.stringify(request, null, '\t').split('\n')),
279
new LineSequence(JSON.stringify(req.toJSON(), null, '\t').split('\n')),
280
).ComputeDiff();
281
282
let score = 0;
283
for (const d of diff) {
284
score += d.modifiedLength + d.originalLength;
285
}
286
287
if (score < bestScore) {
288
best = request;
289
bestScore = score;
290
}
291
}
292
293
if (!best) {
294
return;
295
}
296
297
const oldRequest = path.join(tmpdir(), `${req.hash}-previous.json`);
298
await fs.writeFile(oldRequest, JSON.stringify(best, null, '\t'));
299
return {
300
newRequest,
301
oldRequest,
302
get isWhitespaceOnly() {
303
let whitespaceOnly = false;
304
if (best) {
305
const bestCast = best as { messages: { content: string[] }[] };
306
const currentCast = req.toJSON() as { messages: { content: string[] }[] };
307
if (bestCast.messages.length === currentCast.messages.length && bestCast.messages.every(
308
(v, i) => v.content.join('').replace(/\n\n+/, '\n').trim() === currentCast.messages[i].content.join('').replace(/\n\n+/, '\n').trim())) {
309
whitespaceOnly = true;
310
}
311
}
312
313
return whitespaceOnly;
314
}
315
};
316
}
317
318
private printTerminatedWithRequireCache(message: string) {
319
return this.jsonOutputPrinter.print({ type: OutputType.terminated, reason: `Terminated because of --require-cache\n${message}` });
320
}
321
}
322
323
const repoRootRegex = new RegExp(REPO_ROOT.replace(/[/\\]/g, '[/\\\\]'), 'i');
324
325
function containsRepoPath(testString: string): boolean {
326
return repoRootRegex.test(testString);
327
}
328
329
class FinishedCallbackWrapper {
330
public readonly copilotFunctionCalls: ICopilotToolCall[] = [];
331
public logprobs: ChoiceLogProbs | undefined;
332
333
constructor(
334
private readonly original: FinishedCallback | undefined) { }
335
336
public getCb(): FinishedCallback {
337
return async (text: string, index: number, delta: IResponseDelta): Promise<number | undefined> => {
338
if (delta.copilotToolCalls) {
339
this.copilotFunctionCalls.push(...delta.copilotToolCalls);
340
}
341
if (delta.logprobs) {
342
if (!this.logprobs) {
343
this.logprobs = { ...delta.logprobs };
344
} else {
345
this.logprobs.content.push(...delta.logprobs.content);
346
}
347
}
348
349
return this.original?.(text, index, delta);
350
};
351
}
352
}
353
354