Path: blob/main/extensions/copilot/src/platform/embeddings/common/embeddingsComputer.ts
13401 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/45import type { CancellationToken } from 'vscode';6import { createServiceIdentifier } from '../../../util/common/services';7import { TelemetryCorrelationId } from '../../../util/common/telemetryCorrelationId';89/**10* Fully qualified type of the embedding.11*12* This includes both the model identifier and the dimensions.13*/14export class EmbeddingType {15public static readonly text3small_512 = new EmbeddingType('text-embedding-3-small-512');16public static readonly metis_1024_I16_Binary = new EmbeddingType('metis-1024-I16-Binary');1718constructor(19public readonly id: string20) { }2122public toString(): string {23return this.id;24}2526public equals(other: EmbeddingType): boolean {27return this.id === other.id;28}29}3031// WARNING32// These values are used in the request and are case sensitive. Do not change them unless advised by CAPI.33export const enum LEGACY_EMBEDDING_MODEL_ID {34TEXT3SMALL = 'text-embedding-3-small',35Metis_I16_Binary = 'metis-I16-Binary'36}3738type EmbeddingQuantization = 'float32' | 'float16' | 'binary';3940export interface EmbeddingTypeInfo {41readonly model: LEGACY_EMBEDDING_MODEL_ID;42readonly dimensions: number;43readonly quantization: {44readonly query: EmbeddingQuantization;45readonly document: EmbeddingQuantization;46};47}4849const wellKnownEmbeddingMetadata = Object.freeze<Record<string, EmbeddingTypeInfo>>({50[EmbeddingType.text3small_512.id]: {51model: LEGACY_EMBEDDING_MODEL_ID.TEXT3SMALL,52dimensions: 512,53quantization: {54query: 'float32',55document: 'float32'56},57},58[EmbeddingType.metis_1024_I16_Binary.id]: {59model: LEGACY_EMBEDDING_MODEL_ID.Metis_I16_Binary,60dimensions: 1024,61quantization: {62query: 'float16',63document: 'binary'64},65},66});6768export function getWellKnownEmbeddingTypeInfo(type: EmbeddingType): EmbeddingTypeInfo | undefined {69return wellKnownEmbeddingMetadata[type.id];70}7172export type EmbeddingVector = readonly number[];7374export interface Embedding {75readonly type: EmbeddingType;76readonly value: EmbeddingVector;77}7879export function isValidEmbedding(value: unknown): value is Embedding {80if (typeof value !== 'object' || value === null) {81return false;82}8384const asEmbedding = value as Embedding;85if (!asEmbedding.type) {86return false;87}8889if (!Array.isArray(asEmbedding.value) || asEmbedding.value.length === 0) {90return false;91}9293return true;94}9596export interface Embeddings {97readonly type: EmbeddingType;98readonly values: readonly Embedding[];99}100101export interface EmbeddingDistance {102readonly embeddingType: EmbeddingType;103readonly value: number;104}105106export const IEmbeddingsComputer = createServiceIdentifier<IEmbeddingsComputer>('IEmbeddingsComputer');107108export type EmbeddingInputType = 'document' | 'query';109110export type ComputeEmbeddingsOptions = {111readonly inputType?: EmbeddingInputType;112};113114export interface IEmbeddingsComputer {115116readonly _serviceBrand: undefined;117118/**119* Computes embeddings for the given strings.120*121* @param inputs The strings to compute embeddings for.122*123* @returns The embeddings, or if there is a failure/no embeddings, undefined.124*/125computeEmbeddings(126type: EmbeddingType,127inputs: readonly string[],128options?: ComputeEmbeddingsOptions,129telemetryInfo?: TelemetryCorrelationId,130token?: CancellationToken,131): Promise<Embeddings>;132}133134function dotProduct(a: EmbeddingVector, b: EmbeddingVector): number {135if (a.length !== b.length) {136console.warn('Embeddings do not have same length for computing dot product');137}138139let dotProduct = 0;140const len = Math.min(a.length, b.length);141for (let i = 0; i < len; i++) {142dotProduct += a[i] * b[i];143}144return dotProduct;145}146147/**148* Gets the similarity score from 0-1 between two embeddings.149*/150export function distance(queryEmbedding: Embedding, otherEmbedding: Embedding): EmbeddingDistance {151if (!queryEmbedding.type.equals(otherEmbedding.type)) {152throw new Error(`Embeddings must be of the same type to compute similarity. Got: ${queryEmbedding.type.id} and ${otherEmbedding.type.id}`);153}154155return {156embeddingType: queryEmbedding.type,157value: dotProduct(otherEmbedding.value, queryEmbedding.value),158};159}160161/**162* Rank the embedding items by their cosine similarity to a query163*164* @returns The top {@linkcode maxResults} items.165*/166export function rankEmbeddings<T>(167queryEmbedding: Embedding,168items: ReadonlyArray<readonly [T, Embedding]>,169maxResults: number,170options?: {171readonly minDistance?: number;172readonly maxSpread?: number;173}174): Array<{ readonly value: T; readonly distance: EmbeddingDistance }> {175const minThreshold = options?.minDistance ?? 0;176177const results = items178.map(([value, embedding]): { readonly distance: EmbeddingDistance; readonly value: T } => {179return { distance: distance(embedding, queryEmbedding), value };180})181.filter(entry => entry.distance.value > minThreshold)182.sort((a, b) => b.distance.value - a.distance.value)183.slice(0, maxResults)184.map(entry => {185return {186distance: entry.distance,187value: entry.value,188};189});190191if (results.length && typeof options?.maxSpread === 'number') {192const minScore = results.at(0)!.distance.value * (1.0 - options.maxSpread);193const out = results.filter(x => x.distance.value >= minScore);194return out;195}196197return results;198}199200201