Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/vscode
Path: blob/main/extensions/copilot/src/platform/embeddings/common/remoteEmbeddingsComputer.ts
13401 views
1
/*---------------------------------------------------------------------------------------------
2
* Copyright (c) Microsoft Corporation. All rights reserved.
3
* Licensed under the MIT License. See License.txt in the project root for license information.
4
*--------------------------------------------------------------------------------------------*/
5
6
import { RequestType } from '@vscode/copilot-api';
7
import type { CancellationToken } from 'vscode';
8
import { CallTracker, TelemetryCorrelationId } from '../../../util/common/telemetryCorrelationId';
9
import { Limiter } from '../../../util/vs/base/common/async';
10
import { generateUuid } from '../../../util/vs/base/common/uuid';
11
import { IInstantiationService } from '../../../util/vs/platform/instantiation/common/instantiation';
12
import { IAuthenticationService } from '../../authentication/common/authentication';
13
import { IEndpointProvider } from '../../endpoint/common/endpointProvider';
14
import { IEnvService } from '../../env/common/envService';
15
import { getGithubMetadataHeaders } from '../../github/common/githubApiFetcherService';
16
import { logExecTime } from '../../log/common/logExecTime';
17
import { ILogService } from '../../log/common/logService';
18
import { IEmbeddingsEndpoint, postRequest } from '../../networking/common/networking';
19
import { GenAiAttr, GenAiOperationName, GenAiProviderName } from '../../otel/common/genAiAttributes';
20
import { IOTelService, SpanKind, SpanStatusCode } from '../../otel/common/otelService';
21
import { ITelemetryService } from '../../telemetry/common/telemetry';
22
import { ComputeEmbeddingsOptions, Embedding, EmbeddingType, EmbeddingTypeInfo, EmbeddingVector, Embeddings, IEmbeddingsComputer, getWellKnownEmbeddingTypeInfo } from './embeddingsComputer';
23
24
interface CAPIEmbeddingResults {
25
readonly type: 'success';
26
readonly embeddings: EmbeddingVector[];
27
}
28
interface CAPIEmbeddingError {
29
readonly type: 'failed';
30
readonly reason: string;
31
}
32
33
export class RemoteEmbeddingsComputer implements IEmbeddingsComputer {
34
35
declare readonly _serviceBrand: undefined;
36
37
private readonly batchSize = 100;
38
39
constructor(
40
@IAuthenticationService private readonly _authService: IAuthenticationService,
41
@IEnvService private readonly _envService: IEnvService,
42
@ILogService private readonly _logService: ILogService,
43
@ITelemetryService private readonly _telemetryService: ITelemetryService,
44
@IEndpointProvider private readonly _endpointProvider: IEndpointProvider,
45
@IInstantiationService private readonly _instantiationService: IInstantiationService,
46
@IOTelService private readonly _otelService: IOTelService,
47
) { }
48
49
public async computeEmbeddings(
50
embeddingType: EmbeddingType,
51
inputs: readonly string[],
52
options?: ComputeEmbeddingsOptions,
53
telemetryInfo?: TelemetryCorrelationId,
54
cancellationToken?: CancellationToken,
55
): Promise<Embeddings> {
56
const otelSpan = this._otelService.startSpan(`embeddings ${embeddingType.id}`, {
57
kind: SpanKind.CLIENT,
58
attributes: {
59
[GenAiAttr.OPERATION_NAME]: GenAiOperationName.EMBEDDINGS,
60
[GenAiAttr.PROVIDER_NAME]: GenAiProviderName.OPENAI,
61
[GenAiAttr.REQUEST_MODEL]: embeddingType.id,
62
'gen_ai.embeddings.input_count': inputs.length,
63
},
64
});
65
try {
66
return await logExecTime(this._logService, 'RemoteEmbeddingsComputer::computeEmbeddings', async () => {
67
68
// Determine endpoint type: use CAPI for no-auth users, otherwise use GitHub
69
const copilotToken = await this._authService.getCopilotToken();
70
if (copilotToken.isNoAuthUser) {
71
const embeddings = await this.computeCAPIEmbeddings(inputs, options, cancellationToken);
72
return embeddings ?? { type: embeddingType, values: [] };
73
}
74
75
const token = (await this._authService.getGitHubSession('any', { silent: true }))?.accessToken;
76
if (!token) {
77
throw new Error('No authentication token available');
78
}
79
80
const embeddingsOut: Embedding[] = [];
81
for (let i = 0; i < inputs.length; i += this.batchSize) {
82
const batch = inputs.slice(i, i + this.batchSize);
83
if (!batch.length) {
84
break;
85
}
86
87
const body: {
88
inputs: readonly string[];
89
input_type: 'document' | 'query';
90
embedding_model: string;
91
} = {
92
inputs: batch,
93
input_type: options?.inputType ?? 'document',
94
embedding_model: embeddingType.id,
95
};
96
const response = await this._instantiationService.invokeFunction(postRequest, {
97
endpointOrUrl: { type: RequestType.DotcomEmbeddings },
98
secretKey: token,
99
intent: 'copilot-panel',
100
requestId: generateUuid(),
101
body: body as any,
102
additionalHeaders: getGithubMetadataHeaders(telemetryInfo?.callTracker ?? new CallTracker(), this._envService),
103
cancelToken: cancellationToken,
104
});
105
if (!response.ok) {
106
/* __GDPR__
107
"remoteEmbeddingsComputer.computeEmbeddings.error" : {
108
"owner": "mjbvz",
109
"comment": "Total time for searchFileChunks to complete",
110
"source": { "classification": "SystemMetaData", "purpose": "FeatureInsight", "comment": "Caller" },
111
"correlationId": { "classification": "SystemMetaData", "purpose": "FeatureInsight", "comment": "Correlation id" },
112
"embeddingType": { "classification": "SystemMetaData", "purpose": "FeatureInsight", "comment": "Embedding type" },
113
"totalInputLength": { "classification": "SystemMetaData", "purpose": "FeatureInsight", "isMeasurement": true, "comment": "Total length of the input" },
114
"batchInputLength": { "classification": "SystemMetaData", "purpose": "FeatureInsight", "isMeasurement": true, "comment": "Total length of the batch" },
115
"statusCode": { "classification": "SystemMetaData", "purpose": "FeatureInsight", "isMeasurement": true, "comment": "Status code of the response" }
116
}
117
*/
118
this._telemetryService.sendMSFTTelemetryEvent('remoteEmbeddingsComputer.computeEmbeddings.error', {
119
source: telemetryInfo?.callTracker.toString(),
120
correlationId: telemetryInfo?.correlationId,
121
embeddingType: embeddingType.id,
122
}, {
123
totalInputLength: inputs.length,
124
batchInputLength: batch.length,
125
statusCode: response.status,
126
});
127
throw new Error(`Error fetching embeddings: ${response.status}`);
128
}
129
130
type EmbeddingResponse = {
131
embedding_model: string;
132
embeddings: Array<{ embedding: number[] }>;
133
};
134
const jsonResponse: EmbeddingResponse = await response.json();
135
136
const resolvedType = new EmbeddingType(jsonResponse.embedding_model);
137
if (!resolvedType.equals(embeddingType)) {
138
throw new Error(`Unexpected embedding model. Got: ${resolvedType}. Expected: ${embeddingType}`);
139
}
140
141
if (batch.length !== jsonResponse.embeddings.length) {
142
throw new Error(`Mismatched embedding result count. Expected: ${batch.length}. Got: ${jsonResponse.embeddings.length}`);
143
}
144
145
embeddingsOut.push(...jsonResponse.embeddings.map(embedding => ({
146
type: resolvedType,
147
value: embedding.embedding,
148
})));
149
}
150
151
return { type: embeddingType, values: embeddingsOut };
152
});
153
} catch (err) {
154
otelSpan.setStatus(SpanStatusCode.ERROR, err instanceof Error ? err.message : String(err));
155
otelSpan.setAttribute('error.type', err instanceof Error ? err.constructor.name : 'Error');
156
otelSpan.recordException(err);
157
throw err;
158
} finally {
159
otelSpan.end();
160
}
161
}
162
163
private async computeCAPIEmbeddings(
164
inputs: readonly string[],
165
options?: ComputeEmbeddingsOptions,
166
cancellationToken?: CancellationToken,
167
) {
168
const typeInfo = getWellKnownEmbeddingTypeInfo(EmbeddingType.text3small_512);
169
if (!typeInfo) {
170
throw new Error(`Embeddings type info not found: ${EmbeddingType.text3small_512}`);
171
}
172
const endpoint = await this._endpointProvider.getEmbeddingsEndpoint('text3small');
173
const batchSize = endpoint.maxBatchSize;
174
// Open AI seems to allow 1 less than max tokens for the model requests. So if the max tokens is 8192, we can only send 8191 tokens.
175
const maxTokens = endpoint.modelMaxPromptTokens - 1;
176
return this.fetchResponseWithBatches(typeInfo, endpoint, inputs, cancellationToken, maxTokens, batchSize);
177
}
178
179
/**
180
* A recursive helper that drives the public `fetchResponse` function. This allows accepting a batch and supports backing off the endpoint.
181
* @param inputs The inputs to get embeddings for
182
* @param cancellationToken A cancellation token to allow cancelling the requests
183
* @param batchSize The batch size to calculate
184
* @returns The embeddings
185
*/
186
private async fetchResponseWithBatches(
187
type: EmbeddingTypeInfo,
188
endpoint: IEmbeddingsEndpoint,
189
inputs: readonly string[],
190
cancellationToken: CancellationToken | undefined,
191
maxTokens: number,
192
batchSize: number,
193
parallelism = 1,
194
): Promise<Embeddings | undefined> {
195
// First we loop through all inputs and count their token length, if one exceeds max tokens then we fail
196
for (const input of inputs) {
197
const inputTokenLength = await endpoint.acquireTokenizer().tokenLength(input);
198
if (inputTokenLength > maxTokens) {
199
return undefined;
200
}
201
}
202
203
let embeddings: EmbeddingVector[] = [];
204
const promises: Promise<CAPIEmbeddingResults | undefined>[] = [];
205
const limiter = new Limiter<CAPIEmbeddingResults | undefined>(parallelism);
206
try {
207
for (let i = 0; i < inputs.length; i += batchSize) {
208
const currentBatch = inputs.slice(i, i + batchSize);
209
promises.push(limiter.queue(async () => {
210
if (cancellationToken?.isCancellationRequested) {
211
return;
212
}
213
214
const r = await this.rawEmbeddingsFetchWithTelemetry(type, endpoint, generateUuid(), currentBatch, cancellationToken);
215
if (r.type === 'failed') {
216
throw new Error('Embeddings request failed ' + r.reason);
217
}
218
return r;
219
}));
220
}
221
222
embeddings = (await Promise.all(promises)).flatMap(response => response?.embeddings ?? []);
223
} catch (e) {
224
return undefined;
225
} finally {
226
limiter.dispose();
227
}
228
229
if (cancellationToken?.isCancellationRequested) {
230
return undefined;
231
}
232
233
// If there are no embeddings, return undefined
234
if (embeddings.length === 0) {
235
return undefined;
236
}
237
return { type: EmbeddingType.text3small_512, values: embeddings.map((value): Embedding => ({ type: EmbeddingType.text3small_512, value })) };
238
}
239
240
private async rawEmbeddingsFetchWithTelemetry(
241
type: EmbeddingTypeInfo,
242
endpoint: IEmbeddingsEndpoint,
243
requestId: string,
244
inputs: readonly string[],
245
cancellationToken: CancellationToken | undefined
246
) {
247
const startTime = Date.now();
248
const rawRequest = await this.rawEmbeddingsFetch(type, endpoint, requestId, inputs, cancellationToken);
249
if (rawRequest.type === 'failed') {
250
this._telemetryService.sendMSFTTelemetryErrorEvent('embedding.error', {
251
type: rawRequest.type,
252
reason: rawRequest.reason
253
});
254
return rawRequest;
255
}
256
257
const tokenizer = endpoint.acquireTokenizer();
258
const tokenCounts = await Promise.all(inputs.map(input => tokenizer.tokenLength(input)));
259
const inputTokenCount = tokenCounts.reduce((acc, count) => acc + count, 0);
260
this._telemetryService.sendMSFTTelemetryEvent('embedding.success', {}, {
261
batchSize: inputs.length,
262
inputTokenCount,
263
timeToComplete: Date.now() - startTime
264
});
265
return rawRequest;
266
}
267
268
/**
269
* The function which actually makes the request to the API and handles failures.
270
* This is separated out from fetchResponse as fetchResponse does some manipulation to the input and handles errors differently
271
*/
272
public async rawEmbeddingsFetch(
273
type: EmbeddingTypeInfo,
274
endpoint: IEmbeddingsEndpoint,
275
requestId: string,
276
inputs: readonly string[],
277
cancellationToken: CancellationToken | undefined
278
): Promise<CAPIEmbeddingResults | CAPIEmbeddingError> {
279
try {
280
const token = await this._authService.getCopilotToken();
281
282
const body = { input: inputs, model: type.model, dimensions: type.dimensions };
283
endpoint.interceptBody?.(body);
284
const response = await this._instantiationService.invokeFunction(postRequest, {
285
endpointOrUrl: endpoint,
286
secretKey: token.token,
287
intent: 'copilot-panel',
288
requestId,
289
body,
290
cancelToken: cancellationToken,
291
});
292
const jsonResponse = response.status === 200 ? await response.json() : await response.text();
293
294
type EmbeddingResponse = {
295
object: string;
296
index: number;
297
embedding: number[];
298
};
299
if (response.status === 200 && jsonResponse.data) {
300
return { type: 'success', embeddings: jsonResponse.data.map((d: EmbeddingResponse) => d.embedding) };
301
} else {
302
return { type: 'failed', reason: jsonResponse.error };
303
}
304
} catch (e) {
305
let errorMessage = (e as Error)?.message ?? 'Unknown error';
306
// Timeouts = JSON parse errors because the response is incomplete
307
if (errorMessage.match(/Unexpected.*JSON/i)) {
308
errorMessage = 'timeout';
309
}
310
return { type: 'failed', reason: errorMessage };
311
312
}
313
}
314
}
315
316