Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/vscode
Path: blob/main/extensions/copilot/src/platform/embeddings/common/embeddingsComputer.ts
13401 views
1
/*---------------------------------------------------------------------------------------------
2
* Copyright (c) Microsoft Corporation. All rights reserved.
3
* Licensed under the MIT License. See License.txt in the project root for license information.
4
*--------------------------------------------------------------------------------------------*/
5
6
import type { CancellationToken } from 'vscode';
7
import { createServiceIdentifier } from '../../../util/common/services';
8
import { TelemetryCorrelationId } from '../../../util/common/telemetryCorrelationId';
9
10
/**
11
* Fully qualified type of the embedding.
12
*
13
* This includes both the model identifier and the dimensions.
14
*/
15
export class EmbeddingType {
16
public static readonly text3small_512 = new EmbeddingType('text-embedding-3-small-512');
17
public static readonly metis_1024_I16_Binary = new EmbeddingType('metis-1024-I16-Binary');
18
19
constructor(
20
public readonly id: string
21
) { }
22
23
public toString(): string {
24
return this.id;
25
}
26
27
public equals(other: EmbeddingType): boolean {
28
return this.id === other.id;
29
}
30
}
31
32
// WARNING
33
// These values are used in the request and are case sensitive. Do not change them unless advised by CAPI.
34
export const enum LEGACY_EMBEDDING_MODEL_ID {
35
TEXT3SMALL = 'text-embedding-3-small',
36
Metis_I16_Binary = 'metis-I16-Binary'
37
}
38
39
type EmbeddingQuantization = 'float32' | 'float16' | 'binary';
40
41
export interface EmbeddingTypeInfo {
42
readonly model: LEGACY_EMBEDDING_MODEL_ID;
43
readonly dimensions: number;
44
readonly quantization: {
45
readonly query: EmbeddingQuantization;
46
readonly document: EmbeddingQuantization;
47
};
48
}
49
50
const wellKnownEmbeddingMetadata = Object.freeze<Record<string, EmbeddingTypeInfo>>({
51
[EmbeddingType.text3small_512.id]: {
52
model: LEGACY_EMBEDDING_MODEL_ID.TEXT3SMALL,
53
dimensions: 512,
54
quantization: {
55
query: 'float32',
56
document: 'float32'
57
},
58
},
59
[EmbeddingType.metis_1024_I16_Binary.id]: {
60
model: LEGACY_EMBEDDING_MODEL_ID.Metis_I16_Binary,
61
dimensions: 1024,
62
quantization: {
63
query: 'float16',
64
document: 'binary'
65
},
66
},
67
});
68
69
export function getWellKnownEmbeddingTypeInfo(type: EmbeddingType): EmbeddingTypeInfo | undefined {
70
return wellKnownEmbeddingMetadata[type.id];
71
}
72
73
export type EmbeddingVector = readonly number[];
74
75
export interface Embedding {
76
readonly type: EmbeddingType;
77
readonly value: EmbeddingVector;
78
}
79
80
export function isValidEmbedding(value: unknown): value is Embedding {
81
if (typeof value !== 'object' || value === null) {
82
return false;
83
}
84
85
const asEmbedding = value as Embedding;
86
if (!asEmbedding.type) {
87
return false;
88
}
89
90
if (!Array.isArray(asEmbedding.value) || asEmbedding.value.length === 0) {
91
return false;
92
}
93
94
return true;
95
}
96
97
export interface Embeddings {
98
readonly type: EmbeddingType;
99
readonly values: readonly Embedding[];
100
}
101
102
export interface EmbeddingDistance {
103
readonly embeddingType: EmbeddingType;
104
readonly value: number;
105
}
106
107
export const IEmbeddingsComputer = createServiceIdentifier<IEmbeddingsComputer>('IEmbeddingsComputer');
108
109
export type EmbeddingInputType = 'document' | 'query';
110
111
export type ComputeEmbeddingsOptions = {
112
readonly inputType?: EmbeddingInputType;
113
};
114
115
export interface IEmbeddingsComputer {
116
117
readonly _serviceBrand: undefined;
118
119
/**
120
* Computes embeddings for the given strings.
121
*
122
* @param inputs The strings to compute embeddings for.
123
*
124
* @returns The embeddings, or if there is a failure/no embeddings, undefined.
125
*/
126
computeEmbeddings(
127
type: EmbeddingType,
128
inputs: readonly string[],
129
options?: ComputeEmbeddingsOptions,
130
telemetryInfo?: TelemetryCorrelationId,
131
token?: CancellationToken,
132
): Promise<Embeddings>;
133
}
134
135
function dotProduct(a: EmbeddingVector, b: EmbeddingVector): number {
136
if (a.length !== b.length) {
137
console.warn('Embeddings do not have same length for computing dot product');
138
}
139
140
let dotProduct = 0;
141
const len = Math.min(a.length, b.length);
142
for (let i = 0; i < len; i++) {
143
dotProduct += a[i] * b[i];
144
}
145
return dotProduct;
146
}
147
148
/**
149
* Gets the similarity score from 0-1 between two embeddings.
150
*/
151
export function distance(queryEmbedding: Embedding, otherEmbedding: Embedding): EmbeddingDistance {
152
if (!queryEmbedding.type.equals(otherEmbedding.type)) {
153
throw new Error(`Embeddings must be of the same type to compute similarity. Got: ${queryEmbedding.type.id} and ${otherEmbedding.type.id}`);
154
}
155
156
return {
157
embeddingType: queryEmbedding.type,
158
value: dotProduct(otherEmbedding.value, queryEmbedding.value),
159
};
160
}
161
162
/**
163
* Rank the embedding items by their cosine similarity to a query
164
*
165
* @returns The top {@linkcode maxResults} items.
166
*/
167
export function rankEmbeddings<T>(
168
queryEmbedding: Embedding,
169
items: ReadonlyArray<readonly [T, Embedding]>,
170
maxResults: number,
171
options?: {
172
readonly minDistance?: number;
173
readonly maxSpread?: number;
174
}
175
): Array<{ readonly value: T; readonly distance: EmbeddingDistance }> {
176
const minThreshold = options?.minDistance ?? 0;
177
178
const results = items
179
.map(([value, embedding]): { readonly distance: EmbeddingDistance; readonly value: T } => {
180
return { distance: distance(embedding, queryEmbedding), value };
181
})
182
.filter(entry => entry.distance.value > minThreshold)
183
.sort((a, b) => b.distance.value - a.distance.value)
184
.slice(0, maxResults)
185
.map(entry => {
186
return {
187
distance: entry.distance,
188
value: entry.value,
189
};
190
});
191
192
if (results.length && typeof options?.maxSpread === 'number') {
193
const minScore = results.at(0)!.distance.value * (1.0 - options.maxSpread);
194
const out = results.filter(x => x.distance.value >= minScore);
195
return out;
196
}
197
198
return results;
199
}
200
201