Path: blob/main/extensions/copilot/src/platform/embeddings/common/remoteEmbeddingsComputer.ts
13401 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/45import { RequestType } from '@vscode/copilot-api';6import type { CancellationToken } from 'vscode';7import { CallTracker, TelemetryCorrelationId } from '../../../util/common/telemetryCorrelationId';8import { Limiter } from '../../../util/vs/base/common/async';9import { generateUuid } from '../../../util/vs/base/common/uuid';10import { IInstantiationService } from '../../../util/vs/platform/instantiation/common/instantiation';11import { IAuthenticationService } from '../../authentication/common/authentication';12import { IEndpointProvider } from '../../endpoint/common/endpointProvider';13import { IEnvService } from '../../env/common/envService';14import { getGithubMetadataHeaders } from '../../github/common/githubApiFetcherService';15import { logExecTime } from '../../log/common/logExecTime';16import { ILogService } from '../../log/common/logService';17import { IEmbeddingsEndpoint, postRequest } from '../../networking/common/networking';18import { GenAiAttr, GenAiOperationName, GenAiProviderName } from '../../otel/common/genAiAttributes';19import { IOTelService, SpanKind, SpanStatusCode } from '../../otel/common/otelService';20import { ITelemetryService } from '../../telemetry/common/telemetry';21import { ComputeEmbeddingsOptions, Embedding, EmbeddingType, EmbeddingTypeInfo, EmbeddingVector, Embeddings, IEmbeddingsComputer, getWellKnownEmbeddingTypeInfo } from './embeddingsComputer';2223interface CAPIEmbeddingResults {24readonly type: 'success';25readonly embeddings: EmbeddingVector[];26}27interface CAPIEmbeddingError {28readonly type: 'failed';29readonly reason: string;30}3132export class RemoteEmbeddingsComputer implements IEmbeddingsComputer {3334declare readonly _serviceBrand: undefined;3536private readonly batchSize = 100;3738constructor(39@IAuthenticationService private readonly _authService: IAuthenticationService,40@IEnvService private readonly _envService: IEnvService,41@ILogService private readonly _logService: ILogService,42@ITelemetryService private readonly _telemetryService: ITelemetryService,43@IEndpointProvider private readonly _endpointProvider: IEndpointProvider,44@IInstantiationService private readonly _instantiationService: IInstantiationService,45@IOTelService private readonly _otelService: IOTelService,46) { }4748public async computeEmbeddings(49embeddingType: EmbeddingType,50inputs: readonly string[],51options?: ComputeEmbeddingsOptions,52telemetryInfo?: TelemetryCorrelationId,53cancellationToken?: CancellationToken,54): Promise<Embeddings> {55const otelSpan = this._otelService.startSpan(`embeddings ${embeddingType.id}`, {56kind: SpanKind.CLIENT,57attributes: {58[GenAiAttr.OPERATION_NAME]: GenAiOperationName.EMBEDDINGS,59[GenAiAttr.PROVIDER_NAME]: GenAiProviderName.OPENAI,60[GenAiAttr.REQUEST_MODEL]: embeddingType.id,61'gen_ai.embeddings.input_count': inputs.length,62},63});64try {65return await logExecTime(this._logService, 'RemoteEmbeddingsComputer::computeEmbeddings', async () => {6667// Determine endpoint type: use CAPI for no-auth users, otherwise use GitHub68const copilotToken = await this._authService.getCopilotToken();69if (copilotToken.isNoAuthUser) {70const embeddings = await this.computeCAPIEmbeddings(inputs, options, cancellationToken);71return embeddings ?? { type: embeddingType, values: [] };72}7374const token = (await this._authService.getGitHubSession('any', { silent: true }))?.accessToken;75if (!token) {76throw new Error('No authentication token available');77}7879const embeddingsOut: Embedding[] = [];80for (let i = 0; i < inputs.length; i += this.batchSize) {81const batch = inputs.slice(i, i + this.batchSize);82if (!batch.length) {83break;84}8586const body: {87inputs: readonly string[];88input_type: 'document' | 'query';89embedding_model: string;90} = {91inputs: batch,92input_type: options?.inputType ?? 'document',93embedding_model: embeddingType.id,94};95const response = await this._instantiationService.invokeFunction(postRequest, {96endpointOrUrl: { type: RequestType.DotcomEmbeddings },97secretKey: token,98intent: 'copilot-panel',99requestId: generateUuid(),100body: body as any,101additionalHeaders: getGithubMetadataHeaders(telemetryInfo?.callTracker ?? new CallTracker(), this._envService),102cancelToken: cancellationToken,103});104if (!response.ok) {105/* __GDPR__106"remoteEmbeddingsComputer.computeEmbeddings.error" : {107"owner": "mjbvz",108"comment": "Total time for searchFileChunks to complete",109"source": { "classification": "SystemMetaData", "purpose": "FeatureInsight", "comment": "Caller" },110"correlationId": { "classification": "SystemMetaData", "purpose": "FeatureInsight", "comment": "Correlation id" },111"embeddingType": { "classification": "SystemMetaData", "purpose": "FeatureInsight", "comment": "Embedding type" },112"totalInputLength": { "classification": "SystemMetaData", "purpose": "FeatureInsight", "isMeasurement": true, "comment": "Total length of the input" },113"batchInputLength": { "classification": "SystemMetaData", "purpose": "FeatureInsight", "isMeasurement": true, "comment": "Total length of the batch" },114"statusCode": { "classification": "SystemMetaData", "purpose": "FeatureInsight", "isMeasurement": true, "comment": "Status code of the response" }115}116*/117this._telemetryService.sendMSFTTelemetryEvent('remoteEmbeddingsComputer.computeEmbeddings.error', {118source: telemetryInfo?.callTracker.toString(),119correlationId: telemetryInfo?.correlationId,120embeddingType: embeddingType.id,121}, {122totalInputLength: inputs.length,123batchInputLength: batch.length,124statusCode: response.status,125});126throw new Error(`Error fetching embeddings: ${response.status}`);127}128129type EmbeddingResponse = {130embedding_model: string;131embeddings: Array<{ embedding: number[] }>;132};133const jsonResponse: EmbeddingResponse = await response.json();134135const resolvedType = new EmbeddingType(jsonResponse.embedding_model);136if (!resolvedType.equals(embeddingType)) {137throw new Error(`Unexpected embedding model. Got: ${resolvedType}. Expected: ${embeddingType}`);138}139140if (batch.length !== jsonResponse.embeddings.length) {141throw new Error(`Mismatched embedding result count. Expected: ${batch.length}. Got: ${jsonResponse.embeddings.length}`);142}143144embeddingsOut.push(...jsonResponse.embeddings.map(embedding => ({145type: resolvedType,146value: embedding.embedding,147})));148}149150return { type: embeddingType, values: embeddingsOut };151});152} catch (err) {153otelSpan.setStatus(SpanStatusCode.ERROR, err instanceof Error ? err.message : String(err));154otelSpan.setAttribute('error.type', err instanceof Error ? err.constructor.name : 'Error');155otelSpan.recordException(err);156throw err;157} finally {158otelSpan.end();159}160}161162private async computeCAPIEmbeddings(163inputs: readonly string[],164options?: ComputeEmbeddingsOptions,165cancellationToken?: CancellationToken,166) {167const typeInfo = getWellKnownEmbeddingTypeInfo(EmbeddingType.text3small_512);168if (!typeInfo) {169throw new Error(`Embeddings type info not found: ${EmbeddingType.text3small_512}`);170}171const endpoint = await this._endpointProvider.getEmbeddingsEndpoint('text3small');172const batchSize = endpoint.maxBatchSize;173// Open AI seems to allow 1 less than max tokens for the model requests. So if the max tokens is 8192, we can only send 8191 tokens.174const maxTokens = endpoint.modelMaxPromptTokens - 1;175return this.fetchResponseWithBatches(typeInfo, endpoint, inputs, cancellationToken, maxTokens, batchSize);176}177178/**179* A recursive helper that drives the public `fetchResponse` function. This allows accepting a batch and supports backing off the endpoint.180* @param inputs The inputs to get embeddings for181* @param cancellationToken A cancellation token to allow cancelling the requests182* @param batchSize The batch size to calculate183* @returns The embeddings184*/185private async fetchResponseWithBatches(186type: EmbeddingTypeInfo,187endpoint: IEmbeddingsEndpoint,188inputs: readonly string[],189cancellationToken: CancellationToken | undefined,190maxTokens: number,191batchSize: number,192parallelism = 1,193): Promise<Embeddings | undefined> {194// First we loop through all inputs and count their token length, if one exceeds max tokens then we fail195for (const input of inputs) {196const inputTokenLength = await endpoint.acquireTokenizer().tokenLength(input);197if (inputTokenLength > maxTokens) {198return undefined;199}200}201202let embeddings: EmbeddingVector[] = [];203const promises: Promise<CAPIEmbeddingResults | undefined>[] = [];204const limiter = new Limiter<CAPIEmbeddingResults | undefined>(parallelism);205try {206for (let i = 0; i < inputs.length; i += batchSize) {207const currentBatch = inputs.slice(i, i + batchSize);208promises.push(limiter.queue(async () => {209if (cancellationToken?.isCancellationRequested) {210return;211}212213const r = await this.rawEmbeddingsFetchWithTelemetry(type, endpoint, generateUuid(), currentBatch, cancellationToken);214if (r.type === 'failed') {215throw new Error('Embeddings request failed ' + r.reason);216}217return r;218}));219}220221embeddings = (await Promise.all(promises)).flatMap(response => response?.embeddings ?? []);222} catch (e) {223return undefined;224} finally {225limiter.dispose();226}227228if (cancellationToken?.isCancellationRequested) {229return undefined;230}231232// If there are no embeddings, return undefined233if (embeddings.length === 0) {234return undefined;235}236return { type: EmbeddingType.text3small_512, values: embeddings.map((value): Embedding => ({ type: EmbeddingType.text3small_512, value })) };237}238239private async rawEmbeddingsFetchWithTelemetry(240type: EmbeddingTypeInfo,241endpoint: IEmbeddingsEndpoint,242requestId: string,243inputs: readonly string[],244cancellationToken: CancellationToken | undefined245) {246const startTime = Date.now();247const rawRequest = await this.rawEmbeddingsFetch(type, endpoint, requestId, inputs, cancellationToken);248if (rawRequest.type === 'failed') {249this._telemetryService.sendMSFTTelemetryErrorEvent('embedding.error', {250type: rawRequest.type,251reason: rawRequest.reason252});253return rawRequest;254}255256const tokenizer = endpoint.acquireTokenizer();257const tokenCounts = await Promise.all(inputs.map(input => tokenizer.tokenLength(input)));258const inputTokenCount = tokenCounts.reduce((acc, count) => acc + count, 0);259this._telemetryService.sendMSFTTelemetryEvent('embedding.success', {}, {260batchSize: inputs.length,261inputTokenCount,262timeToComplete: Date.now() - startTime263});264return rawRequest;265}266267/**268* The function which actually makes the request to the API and handles failures.269* This is separated out from fetchResponse as fetchResponse does some manipulation to the input and handles errors differently270*/271public async rawEmbeddingsFetch(272type: EmbeddingTypeInfo,273endpoint: IEmbeddingsEndpoint,274requestId: string,275inputs: readonly string[],276cancellationToken: CancellationToken | undefined277): Promise<CAPIEmbeddingResults | CAPIEmbeddingError> {278try {279const token = await this._authService.getCopilotToken();280281const body = { input: inputs, model: type.model, dimensions: type.dimensions };282endpoint.interceptBody?.(body);283const response = await this._instantiationService.invokeFunction(postRequest, {284endpointOrUrl: endpoint,285secretKey: token.token,286intent: 'copilot-panel',287requestId,288body,289cancelToken: cancellationToken,290});291const jsonResponse = response.status === 200 ? await response.json() : await response.text();292293type EmbeddingResponse = {294object: string;295index: number;296embedding: number[];297};298if (response.status === 200 && jsonResponse.data) {299return { type: 'success', embeddings: jsonResponse.data.map((d: EmbeddingResponse) => d.embedding) };300} else {301return { type: 'failed', reason: jsonResponse.error };302}303} catch (e) {304let errorMessage = (e as Error)?.message ?? 'Unknown error';305// Timeouts = JSON parse errors because the response is incomplete306if (errorMessage.match(/Unexpected.*JSON/i)) {307errorMessage = 'timeout';308}309return { type: 'failed', reason: errorMessage };310311}312}313}314315316