Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/vscode
Path: blob/main/extensions/copilot/src/platform/embeddings/common/embeddingsGrouper.ts
13401 views
1
/*---------------------------------------------------------------------------------------------
2
* Copyright (c) Microsoft Corporation. All rights reserved.
3
* Licensed under the MIT License. See License.txt in the project root for license information.
4
*--------------------------------------------------------------------------------------------*/
5
6
import { Embedding, EmbeddingVector } from './embeddingsComputer';
7
8
export interface Node<T> {
9
readonly value: T;
10
readonly embedding: Embedding;
11
}
12
13
export interface Cluster<T> {
14
readonly id: string;
15
readonly nodes: readonly Node<T>[];
16
readonly centroid: EmbeddingVector;
17
}
18
19
export interface GroupingOptions {
20
/** Similarity threshold for clustering (0.0-1.0). Higher values create tighter clusters. Default: 0.9 */
21
readonly eps?: number;
22
/** Minimum cluster size. Smaller clusters become singletons. Default: 2 */
23
readonly minClusterSize?: number;
24
/** Threshold for inserting new nodes into existing clusters. Default: same as clustering threshold */
25
readonly insertThreshold?: number;
26
}
27
28
/**
29
* Groups embeddings using similarity-based clustering with cosine similarity.
30
*
31
* This approach finds cluster seeds (nodes with many similar neighbors) and builds
32
* clusters around them. It avoids the transitive clustering issues of connected
33
* components while being more suitable for cosine similarity than DBSCAN.
34
*/
35
export class EmbeddingsGrouper<T> {
36
private nodes: Node<T>[] = [];
37
private clusters: Cluster<T>[] = [];
38
private nodeToClusterId = new Map<Node<T>, string>();
39
private clusterCounter = 0;
40
private normalizedEmbeddings = new Map<Node<T>, EmbeddingVector>();
41
private cachedSimilarities: number[] | undefined;
42
private readonly options: {
43
eps: number;
44
minClusterSize: number;
45
insertThreshold?: number;
46
};
47
48
constructor(options?: GroupingOptions) {
49
this.options = {
50
eps: 0.9, // Higher similarity threshold for cosine similarity
51
minClusterSize: 2,
52
...options,
53
};
54
}
55
56
/**
57
* Add a node to the grouper. Will attempt to assign to existing cluster
58
* or create a new singleton cluster.
59
*/
60
addNode(node: Node<T>): void {
61
this.nodes.push(node);
62
// Cache normalized embedding for this node
63
this.normalizedEmbeddings.set(node, this.normalizeVector(node.embedding.value));
64
// Invalidate cached similarities since we added a node
65
this.cachedSimilarities = undefined;
66
67
// If we have existing clusters, try to insert into the best matching one
68
if (this.clusters.length > 0) {
69
const insertThreshold = this.options.insertThreshold ?? this.lastUsedThreshold;
70
const bestCluster = this.findBestClusterForNode(node, insertThreshold);
71
72
if (bestCluster) {
73
this.addNodeToCluster(node, bestCluster);
74
return;
75
}
76
}
77
78
// Create new singleton cluster
79
this.createSingletonCluster(node);
80
}
81
82
/**
83
* Add multiple nodes efficiently in batch. This is much more efficient than
84
* calling addNode() multiple times as it defers clustering until all nodes are added.
85
*
86
* @param nodes Array of nodes to add
87
* @param reclusterAfter Whether to recluster after adding all nodes. Default: true
88
*/
89
addNodes(nodes: Node<T>[], reclusterAfter: boolean = true): void {
90
if (nodes.length === 0) {
91
return;
92
}
93
94
// Batch add all nodes and cache their normalized embeddings
95
for (const node of nodes) {
96
this.nodes.push(node);
97
}
98
// Invalidate cached similarities since we added nodes
99
this.cachedSimilarities = undefined;
100
101
if (reclusterAfter) {
102
// Perform full reclustering which is more efficient for bulk operations
103
this.recluster();
104
} else {
105
// Create singleton clusters for all new nodes (fast path when clustering is deferred)
106
for (const node of nodes) {
107
this.createSingletonCluster(node);
108
}
109
}
110
}
111
112
/**
113
* Remove a node from the grouper. May cause cluster splits or deletions.
114
*/
115
removeNode(node: Node<T>): boolean {
116
const nodeIndex = this.nodes.indexOf(node);
117
if (nodeIndex === -1) {
118
return false;
119
}
120
121
this.nodes.splice(nodeIndex, 1);
122
// Clean up cached normalized embedding
123
this.normalizedEmbeddings.delete(node);
124
// Invalidate cached similarities since we removed a node
125
this.cachedSimilarities = undefined;
126
127
const clusterId = this.nodeToClusterId.get(node);
128
if (clusterId) {
129
this.nodeToClusterId.delete(node);
130
this.removeNodeFromCluster(node, clusterId);
131
}
132
133
return true;
134
}
135
136
/**
137
* Perform full reclustering of all nodes using similarity-based clustering.
138
*/
139
recluster(): void {
140
if (this.nodes.length === 0) {
141
this.clusters = [];
142
this.nodeToClusterId.clear();
143
return;
144
}
145
146
// Clear existing clusters
147
this.clusters = [];
148
this.nodeToClusterId.clear();
149
150
// Run similarity-based clustering that avoids transitive issues
151
const clusterAssignments = this.runSimilarityBasedClustering(this.options.eps, this.options.minClusterSize);
152
153
// Create clusters from results
154
this.createClustersFromAssignments(clusterAssignments);
155
}
156
157
/**
158
* Get all current clusters
159
*/
160
getClusters(): readonly Cluster<T>[] {
161
return this.clusters;
162
}
163
164
/**
165
* Get the cluster containing a specific node
166
*/
167
getClusterForNode(node: Node<T>): Cluster<T> | undefined {
168
const clusterId = this.nodeToClusterId.get(node);
169
return clusterId ? this.clusters.find(c => c.id === clusterId) : undefined;
170
}
171
172
private lastUsedThreshold = 0.9; // Fallback default
173
174
/**
175
* Compute similarity threshold based on percentile.
176
* Higher percentiles result in stricter clustering (higher similarity required).
177
*/
178
private computeEpsFromPercentile(percentile: number): number {
179
if (this.nodes.length < 2) {
180
return 0.9; // High similarity for small datasets
181
}
182
183
const similarities = this.getSimilarities();
184
if (similarities.length === 0) {
185
return 0.9;
186
}
187
188
// Higher percentiles = higher similarity thresholds for tighter clusters
189
const index = Math.floor((percentile / 100) * similarities.length);
190
const threshold = similarities[Math.min(index, similarities.length - 1)];
191
192
this.lastUsedThreshold = threshold;
193
return threshold;
194
}
195
196
/**
197
* Run similarity-based clustering that avoids transitive clustering issues
198
* @param threshold Minimum similarity for nodes to be clustered together
199
* @param minClusterSize Minimum size for a valid cluster
200
* @returns Array where each index corresponds to a node and value is cluster ID (-1 for unassigned)
201
*/
202
private runSimilarityBasedClustering(threshold: number, minClusterSize: number): number[] {
203
const assignments: number[] = new Array(this.nodes.length).fill(-1);
204
const processed: boolean[] = new Array(this.nodes.length).fill(false);
205
let clusterId = 0;
206
207
// Find cluster seeds - nodes with high similarity to multiple others
208
const seeds = this.findClusterSeeds(threshold, minClusterSize);
209
210
// Assign nodes to clusters based on best similarity to seed
211
for (const seed of seeds) {
212
if (processed[seed]) {
213
continue;
214
}
215
216
const cluster = this.buildClusterAroundSeed(seed, threshold, processed);
217
if (cluster.length >= minClusterSize) {
218
for (const nodeIndex of cluster) {
219
assignments[nodeIndex] = clusterId;
220
processed[nodeIndex] = true;
221
}
222
clusterId++;
223
}
224
}
225
226
return assignments;
227
}
228
229
/**
230
* Find potential cluster seeds - nodes that are similar to many others
231
*/
232
private findClusterSeeds(threshold: number, minClusterSize: number): number[] {
233
const seeds: number[] = [];
234
const similarityCounts: number[] = new Array(this.nodes.length).fill(0);
235
236
// Count how many nodes each node is similar to
237
for (let i = 0; i < this.nodes.length; i++) {
238
for (let j = i + 1; j < this.nodes.length; j++) {
239
const similarity = this.cachedCosineSimilarity(this.nodes[i], this.nodes[j]);
240
if (similarity >= threshold) {
241
similarityCounts[i]++;
242
similarityCounts[j]++;
243
}
244
}
245
}
246
247
// Select nodes that could form clusters as seeds
248
for (let i = 0; i < this.nodes.length; i++) {
249
if (similarityCounts[i] >= minClusterSize - 1) {
250
seeds.push(i);
251
}
252
}
253
254
// Sort seeds by similarity count (most connected first)
255
seeds.sort((a, b) => similarityCounts[b] - similarityCounts[a]);
256
return seeds;
257
}
258
259
/**
260
* Build a cluster around a seed node by finding all nodes similar to the seed
261
*/
262
private buildClusterAroundSeed(seed: number, threshold: number, processed: boolean[]): number[] {
263
const cluster = [seed];
264
265
for (let i = 0; i < this.nodes.length; i++) {
266
if (i === seed || processed[i]) {
267
continue;
268
}
269
270
const similarity = this.cachedCosineSimilarity(this.nodes[seed], this.nodes[i]);
271
if (similarity >= threshold) {
272
cluster.push(i);
273
}
274
}
275
276
return cluster;
277
}
278
279
/**
280
* Create clusters from assignment results
281
*/
282
private createClustersFromAssignments(clusterAssignments: number[]): void {
283
const clusterMap = new Map<number, Node<T>[]>();
284
const unassigned: Node<T>[] = [];
285
286
// Group nodes by cluster ID
287
for (let i = 0; i < clusterAssignments.length; i++) {
288
const clusterId = clusterAssignments[i];
289
const node = this.nodes[i];
290
291
if (clusterId === -1) {
292
unassigned.push(node);
293
} else {
294
if (!clusterMap.has(clusterId)) {
295
clusterMap.set(clusterId, []);
296
}
297
clusterMap.get(clusterId)!.push(node);
298
}
299
}
300
301
// Create clusters from grouped nodes
302
for (const [, nodes] of clusterMap) {
303
if (nodes.length >= this.options.minClusterSize) {
304
this.createCluster(nodes);
305
} else {
306
// Small clusters become singletons
307
for (const node of nodes) {
308
this.createSingletonCluster(node);
309
}
310
}
311
}
312
313
// Handle unassigned points as singletons
314
for (const node of unassigned) {
315
this.createSingletonCluster(node);
316
}
317
}
318
319
/**
320
* Find the best existing cluster for a new node
321
*/
322
private findBestClusterForNode(node: Node<T>, threshold: number): Cluster<T> | undefined {
323
let bestCluster: Cluster<T> | undefined;
324
let bestSimilarity = -1;
325
326
for (const cluster of this.clusters) {
327
const similarity = this.dotProduct(
328
this.getNormalizedEmbedding(node),
329
cluster.centroid
330
);
331
if (similarity >= threshold && similarity > bestSimilarity) {
332
bestSimilarity = similarity;
333
bestCluster = cluster;
334
}
335
}
336
337
return bestCluster;
338
}
339
340
/**
341
* Add node to existing cluster and update centroid
342
*/
343
private addNodeToCluster(node: Node<T>, cluster: Cluster<T>): void {
344
const updatedNodes = [...cluster.nodes, node];
345
const updatedCentroid = this.computeCentroid(updatedNodes.map(n => n.embedding.value));
346
347
const updatedCluster: Cluster<T> = {
348
...cluster,
349
nodes: updatedNodes,
350
centroid: updatedCentroid
351
};
352
353
// Update clusters array
354
const clusterIndex = this.clusters.indexOf(cluster);
355
this.clusters[clusterIndex] = updatedCluster;
356
357
this.nodeToClusterId.set(node, cluster.id);
358
}
359
360
/**
361
* Remove node from cluster and handle potential cluster deletion
362
*/
363
private removeNodeFromCluster(node: Node<T>, clusterId: string): void {
364
const clusterIndex = this.clusters.findIndex(c => c.id === clusterId);
365
if (clusterIndex === -1) {
366
return;
367
}
368
369
const cluster = this.clusters[clusterIndex];
370
const updatedNodes = cluster.nodes.filter(n => n !== node);
371
372
if (updatedNodes.length === 0) {
373
// Remove empty cluster
374
this.clusters.splice(clusterIndex, 1);
375
} else {
376
// Update cluster with remaining nodes
377
const updatedCentroid = this.computeCentroid(updatedNodes.map(n => n.embedding.value));
378
const updatedCluster: Cluster<T> = {
379
...cluster,
380
nodes: updatedNodes,
381
centroid: updatedCentroid
382
};
383
this.clusters[clusterIndex] = updatedCluster;
384
385
// Update node mappings for remaining nodes
386
for (const remainingNode of updatedNodes) {
387
this.nodeToClusterId.set(remainingNode, clusterId);
388
}
389
}
390
}
391
392
/**
393
* Create a new cluster from nodes
394
*/
395
private createCluster(nodes: Node<T>[]): void {
396
const id = `cluster_${this.clusterCounter++}`;
397
const centroid = this.computeCentroid(nodes.map(n => n.embedding.value));
398
399
const cluster: Cluster<T> = {
400
id,
401
nodes,
402
centroid
403
};
404
405
this.clusters.push(cluster);
406
407
for (const node of nodes) {
408
this.nodeToClusterId.set(node, id);
409
}
410
}
411
412
/**
413
* Create a singleton cluster for a single node
414
*/
415
private createSingletonCluster(node: Node<T>): void {
416
this.createCluster([node]);
417
}
418
419
/**
420
* Compute centroid (mean) of embedding vectors
421
*/
422
private computeCentroid(embeddings: EmbeddingVector[]): EmbeddingVector {
423
if (embeddings.length === 0) {
424
return [];
425
}
426
427
if (embeddings.length === 1) {
428
return [...embeddings[0]]; // Copy to avoid mutations
429
}
430
431
const dimensions = embeddings[0].length;
432
const centroid = new Array(dimensions).fill(0);
433
434
// Sum all embeddings
435
for (const embedding of embeddings) {
436
for (let i = 0; i < dimensions; i++) {
437
centroid[i] += embedding[i];
438
}
439
}
440
441
// Divide by count to get mean
442
for (let i = 0; i < dimensions; i++) {
443
centroid[i] /= embeddings.length;
444
}
445
446
// L2 normalize the centroid
447
return this.normalizeVector(centroid);
448
}
449
450
/**
451
* Gets the sorted list of pairwise similarities between all nodes.
452
* The returned list is ordered by similarity, NOT in any particular node order.
453
*/
454
private getSimilarities() {
455
if (this.cachedSimilarities) {
456
return this.cachedSimilarities;
457
}
458
459
const similarities: number[] = [];
460
461
// Compute all pairwise similarities (upper triangle only)
462
for (let i = 0; i < this.nodes.length; i++) {
463
for (let j = i + 1; j < this.nodes.length; j++) {
464
const sim = this.cachedCosineSimilarity(this.nodes[i], this.nodes[j]);
465
similarities.push(sim);
466
}
467
}
468
469
// Sort for efficient percentile lookups
470
similarities.sort((a, b) => a - b);
471
this.cachedSimilarities = similarities;
472
return this.cachedSimilarities;
473
}
474
475
/**
476
* Optimize clustering by finding the best similarity threshold that results in
477
* a target number of clusters or fewer, aiming for the highest cluster count
478
* that doesn't exceed the maximum. Includes a "cliff effect" to avoid over-clustering.
479
*
480
* @param maxClusters Maximum desired number of clusters
481
* @param minThreshold Minimum similarity threshold to try (default: 0.7 - loose clustering)
482
* @param maxThreshold Maximum similarity threshold to try (default: 0.99 - very strict)
483
* @param precision How precise the search should be (default: 0.02)
484
* @param cliffThreshold Fraction of maxClusters that triggers cliff effect (default: 2/3)
485
* @param cliffGain Minimum additional clusters needed to continue past cliff (default: 20% of maxClusters)
486
* @returns The optimal threshold found and resulting cluster count
487
*/
488
tuneThresholdForTargetClusters(
489
maxClusters: number,
490
minThreshold: number = 0.7,
491
maxThreshold: number = 0.99,
492
precision: number = 0.02,
493
cliffThreshold: number = 2 / 3,
494
cliffGain: number = 0.2
495
): { percentile: number; clusterCount: number; threshold: number } {
496
if (this.nodes.length === 0) {
497
return { percentile: 90, clusterCount: 0, threshold: 0.9 };
498
}
499
500
const cliffPoint = Math.floor(maxClusters * cliffThreshold);
501
const minGainAfterCliff = Math.max(1, Math.floor(maxClusters * cliffGain));
502
503
let bestThreshold = maxThreshold;
504
let bestClusterCount = 1; // Start with worst case (very few clusters)
505
let cliffReached = false;
506
507
// Binary search for optimal threshold that maximizes clusters while staying under limit
508
let low = minThreshold;
509
let high = maxThreshold;
510
511
while (high - low > precision) {
512
const mid = (low + high) / 2;
513
const clusterCount = this.countClustersForThreshold(mid, this.options.minClusterSize);
514
515
if (clusterCount <= maxClusters) {
516
// Check if this is a meaningful improvement
517
let shouldUpdate = false;
518
519
if (!cliffReached && clusterCount >= cliffPoint) {
520
// We've reached the cliff point - this is good enough
521
cliffReached = true;
522
shouldUpdate = clusterCount > bestClusterCount;
523
} else if (cliffReached) {
524
// Past cliff - only update if we get significant additional clusters
525
shouldUpdate = clusterCount >= bestClusterCount + minGainAfterCliff;
526
} else {
527
// Before cliff - any improvement is good
528
shouldUpdate = clusterCount > bestClusterCount;
529
}
530
531
if (shouldUpdate) {
532
bestThreshold = mid;
533
bestClusterCount = clusterCount;
534
}
535
536
// Try going to lower threshold for potentially more clusters
537
low = mid + precision;
538
} else {
539
// Too many clusters, need higher threshold (stricter clustering)
540
high = mid - precision;
541
}
542
}
543
544
// Convert threshold to approximate percentile for compatibility
545
const similarities = this.getSimilarities();
546
let approximatePercentile = 90;
547
if (similarities.length > 0) {
548
const position = similarities.findIndex(s => s >= bestThreshold);
549
if (position >= 0) {
550
approximatePercentile = Math.round((position / similarities.length) * 100);
551
}
552
}
553
554
return {
555
percentile: approximatePercentile,
556
clusterCount: bestClusterCount,
557
threshold: bestThreshold
558
};
559
}
560
561
/**
562
* Apply a specific similarity threshold and recluster
563
*
564
* @param percentile The similarity percentile to convert to threshold
565
*/
566
applyPercentileAndRecluster(percentile: number): void {
567
// Convert percentile to similarity threshold
568
const eps = this.computeEpsFromPercentile(percentile);
569
// Temporarily override the eps option
570
const originalEps = this.options.eps;
571
(this.options as any).eps = eps;
572
573
try {
574
this.recluster();
575
} finally {
576
// Restore original eps
577
(this.options as any).eps = originalEps;
578
}
579
}
580
581
/**
582
* Count how many clusters would result from a given similarity threshold without actually clustering
583
*/
584
private countClustersForThreshold(threshold: number, minClusterSize: number): number {
585
if (this.nodes.length === 0) {
586
return 0;
587
}
588
589
// Run clustering with given parameters
590
const clusterAssignments = this.runSimilarityBasedClustering(threshold, minClusterSize);
591
592
// Count unique cluster IDs (excluding -1 which is unassigned)
593
const clusterIds = new Set<number>();
594
for (const clusterId of clusterAssignments) {
595
if (clusterId !== -1) {
596
clusterIds.add(clusterId);
597
}
598
}
599
600
// Add singleton clusters for unassigned points and small clusters
601
const clusterSizes = new Map<number, number>();
602
let unassignedCount = 0;
603
604
for (const clusterId of clusterAssignments) {
605
if (clusterId === -1) {
606
unassignedCount++;
607
} else {
608
clusterSizes.set(clusterId, (clusterSizes.get(clusterId) || 0) + 1);
609
}
610
}
611
612
// Count valid clusters (meeting minClusterSize) and singletons
613
let validClusters = 0;
614
let singletons = unassignedCount; // Unassigned points become singletons
615
616
for (const [, size] of clusterSizes) {
617
if (size >= minClusterSize) {
618
validClusters++;
619
} else {
620
singletons += size;
621
}
622
}
623
624
return validClusters + singletons;
625
}
626
627
/**
628
* Get cached normalized embedding for a node
629
*/
630
private getNormalizedEmbedding(node: Node<T>): EmbeddingVector {
631
let normalized = this.normalizedEmbeddings.get(node);
632
if (!normalized) {
633
normalized = this.normalizeVector(node.embedding.value);
634
this.normalizedEmbeddings.set(node, normalized);
635
}
636
return normalized;
637
}
638
639
/**
640
* Compute cosine similarity using cached normalized embeddings
641
*/
642
private cachedCosineSimilarity(nodeA: Node<T>, nodeB: Node<T>): number {
643
const normA = this.getNormalizedEmbedding(nodeA);
644
const normB = this.getNormalizedEmbedding(nodeB);
645
return this.dotProduct(normA, normB);
646
}
647
648
/**
649
* Optimized dot product computation
650
*/
651
private dotProduct(a: EmbeddingVector, b: EmbeddingVector): number {
652
let dotProduct = 0;
653
const len = Math.min(a.length, b.length);
654
// Unroll loop for better performance on small vectors
655
let i = 0;
656
for (; i < len - 3; i += 4) {
657
dotProduct += a[i] * b[i] + a[i + 1] * b[i + 1] + a[i + 2] * b[i + 2] + a[i + 3] * b[i + 3];
658
}
659
// Handle remaining elements
660
for (; i < len; i++) {
661
dotProduct += a[i] * b[i];
662
}
663
return dotProduct;
664
}
665
666
/**
667
* L2 normalize a vector
668
*/
669
private normalizeVector(vector: EmbeddingVector): EmbeddingVector {
670
const magnitude = Math.sqrt(vector.reduce((sum, val) => sum + val * val, 0));
671
672
if (magnitude === 0) {
673
return vector.slice(); // Return copy of zero vector
674
}
675
676
return vector.map(val => val / magnitude);
677
}
678
}
679
680