Path: blob/main/extensions/copilot/test/base/cachingEmbeddingsFetcher.ts
13388 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/4import type { CancellationToken } from 'vscode';5import { IAuthenticationService } from '../../src/platform/authentication/common/authentication';6import { ComputeEmbeddingsOptions, Embedding, EmbeddingType, EmbeddingVector, Embeddings, LEGACY_EMBEDDING_MODEL_ID, getWellKnownEmbeddingTypeInfo } from '../../src/platform/embeddings/common/embeddingsComputer';7import { RemoteEmbeddingsComputer } from '../../src/platform/embeddings/common/remoteEmbeddingsComputer';8import { IEndpointProvider } from '../../src/platform/endpoint/common/endpointProvider';9import { IEnvService } from '../../src/platform/env/common/envService';10import { ILogService } from '../../src/platform/log/common/logService';11import { IOTelService } from '../../src/platform/otel/common/otelService';12import { ITelemetryService } from '../../src/platform/telemetry/common/telemetry';13import { TelemetryCorrelationId } from '../../src/util/common/telemetryCorrelationId';14import { IInstantiationService } from '../../src/util/vs/platform/instantiation/common/instantiation';15import { computeSHA256 } from './hash';1617export class CacheableEmbeddingRequest {18public readonly hash: string;19public readonly query: string;20public readonly model: LEGACY_EMBEDDING_MODEL_ID;2122constructor(23embeddingQuery: string,24model: LEGACY_EMBEDDING_MODEL_ID25) {26this.query = embeddingQuery;27this.model = model;28this.hash = computeSHA256(this.query + model);29}3031toJSON() {32return {33query: this.query,34model: this.model,35};36}37}3839export interface IEmbeddingsCache {40get(queryHash: CacheableEmbeddingRequest): Promise<EmbeddingVector | undefined>;41set(queryHash: CacheableEmbeddingRequest, embedding: EmbeddingVector): Promise<void>;42}4344export class CachingEmbeddingsComputer extends RemoteEmbeddingsComputer {45constructor(46private readonly cache: IEmbeddingsCache,47@IAuthenticationService authService: IAuthenticationService,48@IEnvService envService: IEnvService,49@ILogService logService: ILogService,50@ITelemetryService telemetryService: ITelemetryService,51@IEndpointProvider endpointProvider: IEndpointProvider,52@IInstantiationService instantiationService: IInstantiationService,53@IOTelService otelService: IOTelService,54) {55super(56authService,57envService,58logService,59telemetryService,60endpointProvider,61instantiationService,62otelService,63);64}6566public override async computeEmbeddings(67type: EmbeddingType,68inputs: string[],69options: ComputeEmbeddingsOptions,70telemetryInfo?: TelemetryCorrelationId,71token?: CancellationToken,72): Promise<Embeddings> {73const embeddingEntries = new Map<string, Embedding>();74const nonCached: string[] = [];7576const model = getWellKnownEmbeddingTypeInfo(type)?.model;77if (!model) {78throw new Error(`Unknown embedding type: ${type.id}`);79}8081for (const input of inputs) {82const embeddingRequest = new CacheableEmbeddingRequest(input, model);83const cacheEntry = await this.cache.get(embeddingRequest);84if (!cacheEntry) {85nonCached.push(embeddingRequest.query);86} else {87embeddingEntries.set(embeddingRequest.query, { type, value: cacheEntry });88}89}9091if (nonCached.length) {92const embeddingsResult = await super.computeEmbeddings(type, nonCached, options, telemetryInfo, token);9394// Update the cache with the newest entries95for (let i = 0; i < nonCached.length; i++) {96const embeddingRequest = new CacheableEmbeddingRequest(nonCached[i], model);97const embedding = embeddingsResult.values[i];98embeddingEntries.set(embeddingRequest.query, embedding);99await this.cache.set(embeddingRequest, embedding.value);100}101}102103// This reconstructs the output array such that each embedding is at the right index to match the input array104const out: Embedding[] = [];105for (const input of inputs) {106const embedding = embeddingEntries.get(input);107if (embedding) {108out.push(embedding);109}110}111return { type, values: out };112}113}114115116