Path: blob/main/extensions/copilot/test/base/cachingCompletionsFetchService.ts
13388 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/456import { outdent } from 'outdent';7import * as yaml from 'yaml';8import { IAuthenticationService } from '../../src/platform/authentication/common/authentication';9import * as fetcher from '../../src/platform/nesFetch/common/completionsFetchService';10import { ResponseStream } from '../../src/platform/nesFetch/common/responseStream';11import { CompletionsFetchService, FetchResponse, IFetchRequestParams } from '../../src/platform/nesFetch/node/completionsFetchServiceImpl';12import { getRequestId } from '../../src/platform/networking/common/fetch';13import { IFetcherService } from '../../src/platform/networking/common/fetcherService';14import { IRequestLogger } from '../../src/platform/requestLogger/common/requestLogger';15import { LockMap } from '../../src/util/common/lock';16import { Result } from '../../src/util/common/result';17import { AsyncIterableObject, DeferredPromise, IThrottledWorkerOptions, ThrottledWorker } from '../../src/util/vs/base/common/async';18import { CachedFunction } from '../../src/util/vs/base/common/cache';19import { CancellationToken } from '../../src/util/vs/base/common/cancellation';20import { assertType } from '../../src/util/vs/base/common/types';21import { OPENAI_FETCHER_CACHE_SALT } from '../cacheSalt';22import { IJSONOutputPrinter } from '../jsonOutputPrinter';23import { InterceptedRequest, ISerialisedChatResponse, OutputType } from '../simulation/shared/sharedTypes';24import { CachedResponseMetadata, CachedTestInfo } from './cachingChatMLFetcher';25import { emptyFetcherResponse, ICacheableCompletionsResponse, ICompletionsCache } from './completionsCache';26import { computeSHA256 } from './hash';27import { CacheMode } from './simulationContext';28import { FetchRequestCollector } from './spyingChatMLFetcher';29import { drainStdoutAndExit } from './stdout';3031export class CacheableCompletionRequest {32readonly hash: string;33private readonly obj: unknown;3435constructor(url: string, options: fetcher.Completions.Internal.FetchOptions) {36const cacheSalt = OPENAI_FETCHER_CACHE_SALT.getByUrl(url);37this.obj = { url, body: options.body };38this.hash = computeSHA256(cacheSalt + JSON.stringify(this.obj));39}4041toJSON() {42return this.obj;43}44}4546export class CachingCompletionsFetchService extends CompletionsFetchService {4748private static readonly Locks = new LockMap();4950/** Throttle per URL (currently set to send a request only once a second) */51private static readonly throttlers = new CachedFunction(52function createThrottler(url: string) {53const delayMs = 1000; // milliseconds54const options: IThrottledWorkerOptions = {55maxBufferedWork: undefined, // We want to hold as many requests as possible56maxWorkChunkSize: 1,57waitThrottleDelayBetweenWorkUnits: true,58throttleDelay: delayMs,59};60return new ThrottledWorker<() => Promise<void>>(options, async (tasks) => {61for (const task of tasks) {62task();63}64});65}66);6768private requests: Map<string /* requestId */, { request: CacheableCompletionRequest; hitsCache: boolean }> = new Map(); // this's dirty hack to pass info from lower layer _fetchFromUrl to _fetch -- needs rewriting6970constructor(71private readonly nesCache: ICompletionsCache,72private readonly testInfo: CachedTestInfo,73private readonly cacheMode: CacheMode,74private readonly requestCollector: FetchRequestCollector,75private readonly isNoFetchModeEnabled: boolean,76@IJSONOutputPrinter private readonly jsonOutputPrinter: IJSONOutputPrinter,77@IAuthenticationService authService: IAuthenticationService,78@IFetcherService fetcherService: IFetcherService,79@IRequestLogger requestLogger: IRequestLogger,80) {81super(authService, fetcherService, requestLogger);82}8384public override async fetch(url: string, secretKey: string, params: IFetchRequestParams, requestId: string, ct: CancellationToken, headerOverrides?: Record<string, string>): Promise<Result<ResponseStream, fetcher.Completions.CompletionsFetchFailure>> {85const interceptedRequest = new DeferredPromise<InterceptedRequest>();86this.requestCollector.addInterceptedRequest(interceptedRequest.p);87const r = await super.fetch(url, secretKey, params, requestId, ct, headerOverrides);8889const request = params.prompt;9091const requestOptions = {92...params,93request94};9596const requestCachingInfo = this.requests.get(requestId);97this.requests.delete(requestId);98assertType(requestCachingInfo, 'request must be set');99100const requestHitsCache = requestCachingInfo.hitsCache;101const cacheKey = requestCachingInfo.request.hash;102103const model = inventModelFromURI(url);104105if (r.isOk()) {106const startTime = new Date();107const requestTime = startTime.toISOString();108r.val.response.then(response => {109const elapsedTime = Date.now() - startTime.valueOf();110const cacheMetadata = {111requestDuration: elapsedTime,112requestTime113};114const serializedResponse: ISerialisedChatResponse =115response.isOk()116? {117type: 'success',118cacheKey,119isCacheHit: requestHitsCache,120cacheMetadata,121requestId,122value: [response.val.choices[0].text ?? ''],123}124: {125type: response.err.name,126cacheKey,127isCacheHit: requestHitsCache,128requestId,129value: [response.err.stack ? response.err.stack : response.err.message],130};131interceptedRequest.complete(new InterceptedRequest(request, requestOptions, serializedResponse, cacheKey, model));132});133} else {134const response: ISerialisedChatResponse = {135type: r.err.kind,136cacheKey,137isCacheHit: requestHitsCache,138requestId,139value: [r.err.kind],140};141interceptedRequest.complete(new InterceptedRequest(request, requestOptions, response, cacheKey, model));142}143144return r;145}146147protected override async _fetchFromUrl(148url: string,149options: fetcher.Completions.Internal.FetchOptions,150ct: CancellationToken151): Promise<Result<FetchResponse, fetcher.Completions.CompletionsFetchFailure>> {152153const request = new CacheableCompletionRequest(url, options);154155if (this.cacheMode === CacheMode.Disable) {156this.requests.set(options.requestId, { request, hitsCache: false });157return this._fetchFromUrlAndCache(request, url, options, ct);158}159160return CachingCompletionsFetchService.Locks.withLock(request.hash, async () => {161const cachedValue = await this.nesCache.get(request, this.testInfo.cacheSlot);162if (cachedValue) {163this.requests.set(options.requestId, { request, hitsCache: true });164return Result.ok(ICacheableCompletionsResponse.toFetchResponse(cachedValue));165}166167if (this.cacheMode === CacheMode.Require) {168prettyPrintJsonEncodedObject(options.body);169await this.throwCacheMissing(request);170}171172try {173this.requests.set(options.requestId, { request, hitsCache: false });174} catch (err) {175if (/Key already exists/.test(err.message)) {176prettyPrintJsonEncodedObject(options.body);177console.log(`\nā ${err.message}`);178await drainStdoutAndExit(1);179}180181throw err;182}183return this._fetchFromUrlAndCache(request, url, options, ct);184});185}186187private async _fetchFromUrlAndCache(188request: CacheableCompletionRequest,189url: string,190options: fetcher.Completions.Internal.FetchOptions,191ct: CancellationToken,192): Promise<Result<FetchResponse, fetcher.Completions.CompletionsFetchFailure>> {193194const throttler = CachingCompletionsFetchService.throttlers.get(url);195196let startTime: number | undefined;197const fetchResult: Result<FetchResponse, fetcher.Completions.CompletionsFetchFailure> =198this.isNoFetchModeEnabled199? Result.ok({200requestId: getRequestId(new Headers()),201status: 200,202statusText: '',203headers: new Headers(),204body: AsyncIterableObject.fromArray(['']),205response: emptyFetcherResponse(new Headers()),206} satisfies FetchResponse)207: await new Promise((resolve, reject) => {208throttler.work([209async () => {210try {211startTime = Date.now();212const r = await super._fetchFromUrl(url, options, ct);213resolve(r);214} catch (e) {215reject(e);216}217}218]);219});220221if (fetchResult.isError() || fetchResult.val.status !== 200) { // don't cache a failure222console.log('Fetch failed', JSON.stringify(fetchResult, null, '\t'));223return fetchResult;224}225226const response = fetchResult.val;227const stream = response.body;228229const isCachingEnabled = this.cacheMode !== CacheMode.Disable && !this.isNoFetchModeEnabled;230231let body = '';232const cachingStream = new AsyncIterableObject<string>(async (emitter) => {233// I specifically don't wrap in try-catch to not cache if this throws234for await (const chunk of stream) {235body += chunk.toString();236emitter.emitOne(chunk);237}238if (isCachingEnabled) {239const fetchingResponseTimeInMs = Date.now() - startTime!;240const cacheMetadata: CachedResponseMetadata = {241testName: this.testInfo.testName,242requestDuration: fetchingResponseTimeInMs,243requestTime: new Date().toISOString()244};245this.nesCache246.set(request, this.testInfo.cacheSlot, ICacheableCompletionsResponse.create(options.requestId, cacheMetadata, response.status, response.statusText, body))247.catch(err => {248console.error(err);249console.log('Failed to cache response', JSON.stringify(fetchResult, null, '\t'));250});251}252});253254// Replace response.body with the caching stream255response.body = cachingStream;256257return fetchResult;258}259260private throwCacheMissing(request: CacheableCompletionRequest) {261const message = outdent`262ā Cache entry not found for a request generated by test "${this.testInfo.testName}"!263- Valid cache entries are currently required for all requests!264- The missing request has the hash: ${request.hash} (cache slot ${this.testInfo.cacheSlot}, make sure to call simulate -- -n=10).`;265266console.log(message);267yaml.stringify(request);268269const reason = outdent`270Terminated because of --require-cache271${message}`;272273this.jsonOutputPrinter.print({ type: OutputType.terminated, reason });274275return drainStdoutAndExit(1);276}277}278279function inventModelFromURI(uri: string): string | undefined {280const lastSlash = uri.lastIndexOf('/');281if (lastSlash === -1) {282return uri;283}284const secondLastSlash = uri.lastIndexOf('/', lastSlash - 1);285return uri.substring(secondLastSlash + 1);286}287288function prettyPrintJsonEncodedObject(obj: string) {289console.log(290JSON.stringify(291JSON.parse(obj, (key, value) => {292if (typeof value === 'string') {293const split = value.split(/\n/g);294return split.length > 1 ? split : value;295}296return value;297}),298null,2994300)301);302}303304305