Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/vscode
Path: blob/main/extensions/copilot/test/base/cachingEmbeddingsFetcher.ts
13388 views
1
/*---------------------------------------------------------------------------------------------
2
* Copyright (c) Microsoft Corporation. All rights reserved.
3
* Licensed under the MIT License. See License.txt in the project root for license information.
4
*--------------------------------------------------------------------------------------------*/
5
import type { CancellationToken } from 'vscode';
6
import { IAuthenticationService } from '../../src/platform/authentication/common/authentication';
7
import { ComputeEmbeddingsOptions, Embedding, EmbeddingType, EmbeddingVector, Embeddings, LEGACY_EMBEDDING_MODEL_ID, getWellKnownEmbeddingTypeInfo } from '../../src/platform/embeddings/common/embeddingsComputer';
8
import { RemoteEmbeddingsComputer } from '../../src/platform/embeddings/common/remoteEmbeddingsComputer';
9
import { IEndpointProvider } from '../../src/platform/endpoint/common/endpointProvider';
10
import { IEnvService } from '../../src/platform/env/common/envService';
11
import { ILogService } from '../../src/platform/log/common/logService';
12
import { IOTelService } from '../../src/platform/otel/common/otelService';
13
import { ITelemetryService } from '../../src/platform/telemetry/common/telemetry';
14
import { TelemetryCorrelationId } from '../../src/util/common/telemetryCorrelationId';
15
import { IInstantiationService } from '../../src/util/vs/platform/instantiation/common/instantiation';
16
import { computeSHA256 } from './hash';
17
18
export class CacheableEmbeddingRequest {
19
public readonly hash: string;
20
public readonly query: string;
21
public readonly model: LEGACY_EMBEDDING_MODEL_ID;
22
23
constructor(
24
embeddingQuery: string,
25
model: LEGACY_EMBEDDING_MODEL_ID
26
) {
27
this.query = embeddingQuery;
28
this.model = model;
29
this.hash = computeSHA256(this.query + model);
30
}
31
32
toJSON() {
33
return {
34
query: this.query,
35
model: this.model,
36
};
37
}
38
}
39
40
export interface IEmbeddingsCache {
41
get(queryHash: CacheableEmbeddingRequest): Promise<EmbeddingVector | undefined>;
42
set(queryHash: CacheableEmbeddingRequest, embedding: EmbeddingVector): Promise<void>;
43
}
44
45
export class CachingEmbeddingsComputer extends RemoteEmbeddingsComputer {
46
constructor(
47
private readonly cache: IEmbeddingsCache,
48
@IAuthenticationService authService: IAuthenticationService,
49
@IEnvService envService: IEnvService,
50
@ILogService logService: ILogService,
51
@ITelemetryService telemetryService: ITelemetryService,
52
@IEndpointProvider endpointProvider: IEndpointProvider,
53
@IInstantiationService instantiationService: IInstantiationService,
54
@IOTelService otelService: IOTelService,
55
) {
56
super(
57
authService,
58
envService,
59
logService,
60
telemetryService,
61
endpointProvider,
62
instantiationService,
63
otelService,
64
);
65
}
66
67
public override async computeEmbeddings(
68
type: EmbeddingType,
69
inputs: string[],
70
options: ComputeEmbeddingsOptions,
71
telemetryInfo?: TelemetryCorrelationId,
72
token?: CancellationToken,
73
): Promise<Embeddings> {
74
const embeddingEntries = new Map<string, Embedding>();
75
const nonCached: string[] = [];
76
77
const model = getWellKnownEmbeddingTypeInfo(type)?.model;
78
if (!model) {
79
throw new Error(`Unknown embedding type: ${type.id}`);
80
}
81
82
for (const input of inputs) {
83
const embeddingRequest = new CacheableEmbeddingRequest(input, model);
84
const cacheEntry = await this.cache.get(embeddingRequest);
85
if (!cacheEntry) {
86
nonCached.push(embeddingRequest.query);
87
} else {
88
embeddingEntries.set(embeddingRequest.query, { type, value: cacheEntry });
89
}
90
}
91
92
if (nonCached.length) {
93
const embeddingsResult = await super.computeEmbeddings(type, nonCached, options, telemetryInfo, token);
94
95
// Update the cache with the newest entries
96
for (let i = 0; i < nonCached.length; i++) {
97
const embeddingRequest = new CacheableEmbeddingRequest(nonCached[i], model);
98
const embedding = embeddingsResult.values[i];
99
embeddingEntries.set(embeddingRequest.query, embedding);
100
await this.cache.set(embeddingRequest, embedding.value);
101
}
102
}
103
104
// This reconstructs the output array such that each embedding is at the right index to match the input array
105
const out: Embedding[] = [];
106
for (const input of inputs) {
107
const embedding = embeddingEntries.get(input);
108
if (embedding) {
109
out.push(embedding);
110
}
111
}
112
return { type, values: out };
113
}
114
}
115
116