Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/vscode
Path: blob/main/extensions/copilot/test/base/cachingCompletionsFetchService.ts
13388 views
1
/*---------------------------------------------------------------------------------------------
2
* Copyright (c) Microsoft Corporation. All rights reserved.
3
* Licensed under the MIT License. See License.txt in the project root for license information.
4
*--------------------------------------------------------------------------------------------*/
5
6
7
import { outdent } from 'outdent';
8
import * as yaml from 'yaml';
9
import { IAuthenticationService } from '../../src/platform/authentication/common/authentication';
10
import * as fetcher from '../../src/platform/nesFetch/common/completionsFetchService';
11
import { ResponseStream } from '../../src/platform/nesFetch/common/responseStream';
12
import { CompletionsFetchService, FetchResponse, IFetchRequestParams } from '../../src/platform/nesFetch/node/completionsFetchServiceImpl';
13
import { getRequestId } from '../../src/platform/networking/common/fetch';
14
import { IFetcherService } from '../../src/platform/networking/common/fetcherService';
15
import { IRequestLogger } from '../../src/platform/requestLogger/common/requestLogger';
16
import { LockMap } from '../../src/util/common/lock';
17
import { Result } from '../../src/util/common/result';
18
import { AsyncIterableObject, DeferredPromise, IThrottledWorkerOptions, ThrottledWorker } from '../../src/util/vs/base/common/async';
19
import { CachedFunction } from '../../src/util/vs/base/common/cache';
20
import { CancellationToken } from '../../src/util/vs/base/common/cancellation';
21
import { assertType } from '../../src/util/vs/base/common/types';
22
import { OPENAI_FETCHER_CACHE_SALT } from '../cacheSalt';
23
import { IJSONOutputPrinter } from '../jsonOutputPrinter';
24
import { InterceptedRequest, ISerialisedChatResponse, OutputType } from '../simulation/shared/sharedTypes';
25
import { CachedResponseMetadata, CachedTestInfo } from './cachingChatMLFetcher';
26
import { emptyFetcherResponse, ICacheableCompletionsResponse, ICompletionsCache } from './completionsCache';
27
import { computeSHA256 } from './hash';
28
import { CacheMode } from './simulationContext';
29
import { FetchRequestCollector } from './spyingChatMLFetcher';
30
import { drainStdoutAndExit } from './stdout';
31
32
export class CacheableCompletionRequest {
33
readonly hash: string;
34
private readonly obj: unknown;
35
36
constructor(url: string, options: fetcher.Completions.Internal.FetchOptions) {
37
const cacheSalt = OPENAI_FETCHER_CACHE_SALT.getByUrl(url);
38
this.obj = { url, body: options.body };
39
this.hash = computeSHA256(cacheSalt + JSON.stringify(this.obj));
40
}
41
42
toJSON() {
43
return this.obj;
44
}
45
}
46
47
export class CachingCompletionsFetchService extends CompletionsFetchService {
48
49
private static readonly Locks = new LockMap();
50
51
/** Throttle per URL (currently set to send a request only once a second) */
52
private static readonly throttlers = new CachedFunction(
53
function createThrottler(url: string) {
54
const delayMs = 1000; // milliseconds
55
const options: IThrottledWorkerOptions = {
56
maxBufferedWork: undefined, // We want to hold as many requests as possible
57
maxWorkChunkSize: 1,
58
waitThrottleDelayBetweenWorkUnits: true,
59
throttleDelay: delayMs,
60
};
61
return new ThrottledWorker<() => Promise<void>>(options, async (tasks) => {
62
for (const task of tasks) {
63
task();
64
}
65
});
66
}
67
);
68
69
private requests: Map<string /* requestId */, { request: CacheableCompletionRequest; hitsCache: boolean }> = new Map(); // this's dirty hack to pass info from lower layer _fetchFromUrl to _fetch -- needs rewriting
70
71
constructor(
72
private readonly nesCache: ICompletionsCache,
73
private readonly testInfo: CachedTestInfo,
74
private readonly cacheMode: CacheMode,
75
private readonly requestCollector: FetchRequestCollector,
76
private readonly isNoFetchModeEnabled: boolean,
77
@IJSONOutputPrinter private readonly jsonOutputPrinter: IJSONOutputPrinter,
78
@IAuthenticationService authService: IAuthenticationService,
79
@IFetcherService fetcherService: IFetcherService,
80
@IRequestLogger requestLogger: IRequestLogger,
81
) {
82
super(authService, fetcherService, requestLogger);
83
}
84
85
public override async fetch(url: string, secretKey: string, params: IFetchRequestParams, requestId: string, ct: CancellationToken, headerOverrides?: Record<string, string>): Promise<Result<ResponseStream, fetcher.Completions.CompletionsFetchFailure>> {
86
const interceptedRequest = new DeferredPromise<InterceptedRequest>();
87
this.requestCollector.addInterceptedRequest(interceptedRequest.p);
88
const r = await super.fetch(url, secretKey, params, requestId, ct, headerOverrides);
89
90
const request = params.prompt;
91
92
const requestOptions = {
93
...params,
94
request
95
};
96
97
const requestCachingInfo = this.requests.get(requestId);
98
this.requests.delete(requestId);
99
assertType(requestCachingInfo, 'request must be set');
100
101
const requestHitsCache = requestCachingInfo.hitsCache;
102
const cacheKey = requestCachingInfo.request.hash;
103
104
const model = inventModelFromURI(url);
105
106
if (r.isOk()) {
107
const startTime = new Date();
108
const requestTime = startTime.toISOString();
109
r.val.response.then(response => {
110
const elapsedTime = Date.now() - startTime.valueOf();
111
const cacheMetadata = {
112
requestDuration: elapsedTime,
113
requestTime
114
};
115
const serializedResponse: ISerialisedChatResponse =
116
response.isOk()
117
? {
118
type: 'success',
119
cacheKey,
120
isCacheHit: requestHitsCache,
121
cacheMetadata,
122
requestId,
123
value: [response.val.choices[0].text ?? ''],
124
}
125
: {
126
type: response.err.name,
127
cacheKey,
128
isCacheHit: requestHitsCache,
129
requestId,
130
value: [response.err.stack ? response.err.stack : response.err.message],
131
};
132
interceptedRequest.complete(new InterceptedRequest(request, requestOptions, serializedResponse, cacheKey, model));
133
});
134
} else {
135
const response: ISerialisedChatResponse = {
136
type: r.err.kind,
137
cacheKey,
138
isCacheHit: requestHitsCache,
139
requestId,
140
value: [r.err.kind],
141
};
142
interceptedRequest.complete(new InterceptedRequest(request, requestOptions, response, cacheKey, model));
143
}
144
145
return r;
146
}
147
148
protected override async _fetchFromUrl(
149
url: string,
150
options: fetcher.Completions.Internal.FetchOptions,
151
ct: CancellationToken
152
): Promise<Result<FetchResponse, fetcher.Completions.CompletionsFetchFailure>> {
153
154
const request = new CacheableCompletionRequest(url, options);
155
156
if (this.cacheMode === CacheMode.Disable) {
157
this.requests.set(options.requestId, { request, hitsCache: false });
158
return this._fetchFromUrlAndCache(request, url, options, ct);
159
}
160
161
return CachingCompletionsFetchService.Locks.withLock(request.hash, async () => {
162
const cachedValue = await this.nesCache.get(request, this.testInfo.cacheSlot);
163
if (cachedValue) {
164
this.requests.set(options.requestId, { request, hitsCache: true });
165
return Result.ok(ICacheableCompletionsResponse.toFetchResponse(cachedValue));
166
}
167
168
if (this.cacheMode === CacheMode.Require) {
169
prettyPrintJsonEncodedObject(options.body);
170
await this.throwCacheMissing(request);
171
}
172
173
try {
174
this.requests.set(options.requestId, { request, hitsCache: false });
175
} catch (err) {
176
if (/Key already exists/.test(err.message)) {
177
prettyPrintJsonEncodedObject(options.body);
178
console.log(`\nāœ— ${err.message}`);
179
await drainStdoutAndExit(1);
180
}
181
182
throw err;
183
}
184
return this._fetchFromUrlAndCache(request, url, options, ct);
185
});
186
}
187
188
private async _fetchFromUrlAndCache(
189
request: CacheableCompletionRequest,
190
url: string,
191
options: fetcher.Completions.Internal.FetchOptions,
192
ct: CancellationToken,
193
): Promise<Result<FetchResponse, fetcher.Completions.CompletionsFetchFailure>> {
194
195
const throttler = CachingCompletionsFetchService.throttlers.get(url);
196
197
let startTime: number | undefined;
198
const fetchResult: Result<FetchResponse, fetcher.Completions.CompletionsFetchFailure> =
199
this.isNoFetchModeEnabled
200
? Result.ok({
201
requestId: getRequestId(new Headers()),
202
status: 200,
203
statusText: '',
204
headers: new Headers(),
205
body: AsyncIterableObject.fromArray(['']),
206
response: emptyFetcherResponse(new Headers()),
207
} satisfies FetchResponse)
208
: await new Promise((resolve, reject) => {
209
throttler.work([
210
async () => {
211
try {
212
startTime = Date.now();
213
const r = await super._fetchFromUrl(url, options, ct);
214
resolve(r);
215
} catch (e) {
216
reject(e);
217
}
218
}
219
]);
220
});
221
222
if (fetchResult.isError() || fetchResult.val.status !== 200) { // don't cache a failure
223
console.log('Fetch failed', JSON.stringify(fetchResult, null, '\t'));
224
return fetchResult;
225
}
226
227
const response = fetchResult.val;
228
const stream = response.body;
229
230
const isCachingEnabled = this.cacheMode !== CacheMode.Disable && !this.isNoFetchModeEnabled;
231
232
let body = '';
233
const cachingStream = new AsyncIterableObject<string>(async (emitter) => {
234
// I specifically don't wrap in try-catch to not cache if this throws
235
for await (const chunk of stream) {
236
body += chunk.toString();
237
emitter.emitOne(chunk);
238
}
239
if (isCachingEnabled) {
240
const fetchingResponseTimeInMs = Date.now() - startTime!;
241
const cacheMetadata: CachedResponseMetadata = {
242
testName: this.testInfo.testName,
243
requestDuration: fetchingResponseTimeInMs,
244
requestTime: new Date().toISOString()
245
};
246
this.nesCache
247
.set(request, this.testInfo.cacheSlot, ICacheableCompletionsResponse.create(options.requestId, cacheMetadata, response.status, response.statusText, body))
248
.catch(err => {
249
console.error(err);
250
console.log('Failed to cache response', JSON.stringify(fetchResult, null, '\t'));
251
});
252
}
253
});
254
255
// Replace response.body with the caching stream
256
response.body = cachingStream;
257
258
return fetchResult;
259
}
260
261
private throwCacheMissing(request: CacheableCompletionRequest) {
262
const message = outdent`
263
āœ— Cache entry not found for a request generated by test "${this.testInfo.testName}"!
264
- Valid cache entries are currently required for all requests!
265
- The missing request has the hash: ${request.hash} (cache slot ${this.testInfo.cacheSlot}, make sure to call simulate -- -n=10).`;
266
267
console.log(message);
268
yaml.stringify(request);
269
270
const reason = outdent`
271
Terminated because of --require-cache
272
${message}`;
273
274
this.jsonOutputPrinter.print({ type: OutputType.terminated, reason });
275
276
return drainStdoutAndExit(1);
277
}
278
}
279
280
function inventModelFromURI(uri: string): string | undefined {
281
const lastSlash = uri.lastIndexOf('/');
282
if (lastSlash === -1) {
283
return uri;
284
}
285
const secondLastSlash = uri.lastIndexOf('/', lastSlash - 1);
286
return uri.substring(secondLastSlash + 1);
287
}
288
289
function prettyPrintJsonEncodedObject(obj: string) {
290
console.log(
291
JSON.stringify(
292
JSON.parse(obj, (key, value) => {
293
if (typeof value === 'string') {
294
const split = value.split(/\n/g);
295
return split.length > 1 ? split : value;
296
}
297
return value;
298
}),
299
null,
300
4
301
)
302
);
303
}
304
305