Path: blob/main/extensions/copilot/src/platform/embeddings/common/embeddingsGrouper.ts
13401 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/45import { Embedding, EmbeddingVector } from './embeddingsComputer';67export interface Node<T> {8readonly value: T;9readonly embedding: Embedding;10}1112export interface Cluster<T> {13readonly id: string;14readonly nodes: readonly Node<T>[];15readonly centroid: EmbeddingVector;16}1718export interface GroupingOptions {19/** Similarity threshold for clustering (0.0-1.0). Higher values create tighter clusters. Default: 0.9 */20readonly eps?: number;21/** Minimum cluster size. Smaller clusters become singletons. Default: 2 */22readonly minClusterSize?: number;23/** Threshold for inserting new nodes into existing clusters. Default: same as clustering threshold */24readonly insertThreshold?: number;25}2627/**28* Groups embeddings using similarity-based clustering with cosine similarity.29*30* This approach finds cluster seeds (nodes with many similar neighbors) and builds31* clusters around them. It avoids the transitive clustering issues of connected32* components while being more suitable for cosine similarity than DBSCAN.33*/34export class EmbeddingsGrouper<T> {35private nodes: Node<T>[] = [];36private clusters: Cluster<T>[] = [];37private nodeToClusterId = new Map<Node<T>, string>();38private clusterCounter = 0;39private normalizedEmbeddings = new Map<Node<T>, EmbeddingVector>();40private cachedSimilarities: number[] | undefined;41private readonly options: {42eps: number;43minClusterSize: number;44insertThreshold?: number;45};4647constructor(options?: GroupingOptions) {48this.options = {49eps: 0.9, // Higher similarity threshold for cosine similarity50minClusterSize: 2,51...options,52};53}5455/**56* Add a node to the grouper. Will attempt to assign to existing cluster57* or create a new singleton cluster.58*/59addNode(node: Node<T>): void {60this.nodes.push(node);61// Cache normalized embedding for this node62this.normalizedEmbeddings.set(node, this.normalizeVector(node.embedding.value));63// Invalidate cached similarities since we added a node64this.cachedSimilarities = undefined;6566// If we have existing clusters, try to insert into the best matching one67if (this.clusters.length > 0) {68const insertThreshold = this.options.insertThreshold ?? this.lastUsedThreshold;69const bestCluster = this.findBestClusterForNode(node, insertThreshold);7071if (bestCluster) {72this.addNodeToCluster(node, bestCluster);73return;74}75}7677// Create new singleton cluster78this.createSingletonCluster(node);79}8081/**82* Add multiple nodes efficiently in batch. This is much more efficient than83* calling addNode() multiple times as it defers clustering until all nodes are added.84*85* @param nodes Array of nodes to add86* @param reclusterAfter Whether to recluster after adding all nodes. Default: true87*/88addNodes(nodes: Node<T>[], reclusterAfter: boolean = true): void {89if (nodes.length === 0) {90return;91}9293// Batch add all nodes and cache their normalized embeddings94for (const node of nodes) {95this.nodes.push(node);96}97// Invalidate cached similarities since we added nodes98this.cachedSimilarities = undefined;99100if (reclusterAfter) {101// Perform full reclustering which is more efficient for bulk operations102this.recluster();103} else {104// Create singleton clusters for all new nodes (fast path when clustering is deferred)105for (const node of nodes) {106this.createSingletonCluster(node);107}108}109}110111/**112* Remove a node from the grouper. May cause cluster splits or deletions.113*/114removeNode(node: Node<T>): boolean {115const nodeIndex = this.nodes.indexOf(node);116if (nodeIndex === -1) {117return false;118}119120this.nodes.splice(nodeIndex, 1);121// Clean up cached normalized embedding122this.normalizedEmbeddings.delete(node);123// Invalidate cached similarities since we removed a node124this.cachedSimilarities = undefined;125126const clusterId = this.nodeToClusterId.get(node);127if (clusterId) {128this.nodeToClusterId.delete(node);129this.removeNodeFromCluster(node, clusterId);130}131132return true;133}134135/**136* Perform full reclustering of all nodes using similarity-based clustering.137*/138recluster(): void {139if (this.nodes.length === 0) {140this.clusters = [];141this.nodeToClusterId.clear();142return;143}144145// Clear existing clusters146this.clusters = [];147this.nodeToClusterId.clear();148149// Run similarity-based clustering that avoids transitive issues150const clusterAssignments = this.runSimilarityBasedClustering(this.options.eps, this.options.minClusterSize);151152// Create clusters from results153this.createClustersFromAssignments(clusterAssignments);154}155156/**157* Get all current clusters158*/159getClusters(): readonly Cluster<T>[] {160return this.clusters;161}162163/**164* Get the cluster containing a specific node165*/166getClusterForNode(node: Node<T>): Cluster<T> | undefined {167const clusterId = this.nodeToClusterId.get(node);168return clusterId ? this.clusters.find(c => c.id === clusterId) : undefined;169}170171private lastUsedThreshold = 0.9; // Fallback default172173/**174* Compute similarity threshold based on percentile.175* Higher percentiles result in stricter clustering (higher similarity required).176*/177private computeEpsFromPercentile(percentile: number): number {178if (this.nodes.length < 2) {179return 0.9; // High similarity for small datasets180}181182const similarities = this.getSimilarities();183if (similarities.length === 0) {184return 0.9;185}186187// Higher percentiles = higher similarity thresholds for tighter clusters188const index = Math.floor((percentile / 100) * similarities.length);189const threshold = similarities[Math.min(index, similarities.length - 1)];190191this.lastUsedThreshold = threshold;192return threshold;193}194195/**196* Run similarity-based clustering that avoids transitive clustering issues197* @param threshold Minimum similarity for nodes to be clustered together198* @param minClusterSize Minimum size for a valid cluster199* @returns Array where each index corresponds to a node and value is cluster ID (-1 for unassigned)200*/201private runSimilarityBasedClustering(threshold: number, minClusterSize: number): number[] {202const assignments: number[] = new Array(this.nodes.length).fill(-1);203const processed: boolean[] = new Array(this.nodes.length).fill(false);204let clusterId = 0;205206// Find cluster seeds - nodes with high similarity to multiple others207const seeds = this.findClusterSeeds(threshold, minClusterSize);208209// Assign nodes to clusters based on best similarity to seed210for (const seed of seeds) {211if (processed[seed]) {212continue;213}214215const cluster = this.buildClusterAroundSeed(seed, threshold, processed);216if (cluster.length >= minClusterSize) {217for (const nodeIndex of cluster) {218assignments[nodeIndex] = clusterId;219processed[nodeIndex] = true;220}221clusterId++;222}223}224225return assignments;226}227228/**229* Find potential cluster seeds - nodes that are similar to many others230*/231private findClusterSeeds(threshold: number, minClusterSize: number): number[] {232const seeds: number[] = [];233const similarityCounts: number[] = new Array(this.nodes.length).fill(0);234235// Count how many nodes each node is similar to236for (let i = 0; i < this.nodes.length; i++) {237for (let j = i + 1; j < this.nodes.length; j++) {238const similarity = this.cachedCosineSimilarity(this.nodes[i], this.nodes[j]);239if (similarity >= threshold) {240similarityCounts[i]++;241similarityCounts[j]++;242}243}244}245246// Select nodes that could form clusters as seeds247for (let i = 0; i < this.nodes.length; i++) {248if (similarityCounts[i] >= minClusterSize - 1) {249seeds.push(i);250}251}252253// Sort seeds by similarity count (most connected first)254seeds.sort((a, b) => similarityCounts[b] - similarityCounts[a]);255return seeds;256}257258/**259* Build a cluster around a seed node by finding all nodes similar to the seed260*/261private buildClusterAroundSeed(seed: number, threshold: number, processed: boolean[]): number[] {262const cluster = [seed];263264for (let i = 0; i < this.nodes.length; i++) {265if (i === seed || processed[i]) {266continue;267}268269const similarity = this.cachedCosineSimilarity(this.nodes[seed], this.nodes[i]);270if (similarity >= threshold) {271cluster.push(i);272}273}274275return cluster;276}277278/**279* Create clusters from assignment results280*/281private createClustersFromAssignments(clusterAssignments: number[]): void {282const clusterMap = new Map<number, Node<T>[]>();283const unassigned: Node<T>[] = [];284285// Group nodes by cluster ID286for (let i = 0; i < clusterAssignments.length; i++) {287const clusterId = clusterAssignments[i];288const node = this.nodes[i];289290if (clusterId === -1) {291unassigned.push(node);292} else {293if (!clusterMap.has(clusterId)) {294clusterMap.set(clusterId, []);295}296clusterMap.get(clusterId)!.push(node);297}298}299300// Create clusters from grouped nodes301for (const [, nodes] of clusterMap) {302if (nodes.length >= this.options.minClusterSize) {303this.createCluster(nodes);304} else {305// Small clusters become singletons306for (const node of nodes) {307this.createSingletonCluster(node);308}309}310}311312// Handle unassigned points as singletons313for (const node of unassigned) {314this.createSingletonCluster(node);315}316}317318/**319* Find the best existing cluster for a new node320*/321private findBestClusterForNode(node: Node<T>, threshold: number): Cluster<T> | undefined {322let bestCluster: Cluster<T> | undefined;323let bestSimilarity = -1;324325for (const cluster of this.clusters) {326const similarity = this.dotProduct(327this.getNormalizedEmbedding(node),328cluster.centroid329);330if (similarity >= threshold && similarity > bestSimilarity) {331bestSimilarity = similarity;332bestCluster = cluster;333}334}335336return bestCluster;337}338339/**340* Add node to existing cluster and update centroid341*/342private addNodeToCluster(node: Node<T>, cluster: Cluster<T>): void {343const updatedNodes = [...cluster.nodes, node];344const updatedCentroid = this.computeCentroid(updatedNodes.map(n => n.embedding.value));345346const updatedCluster: Cluster<T> = {347...cluster,348nodes: updatedNodes,349centroid: updatedCentroid350};351352// Update clusters array353const clusterIndex = this.clusters.indexOf(cluster);354this.clusters[clusterIndex] = updatedCluster;355356this.nodeToClusterId.set(node, cluster.id);357}358359/**360* Remove node from cluster and handle potential cluster deletion361*/362private removeNodeFromCluster(node: Node<T>, clusterId: string): void {363const clusterIndex = this.clusters.findIndex(c => c.id === clusterId);364if (clusterIndex === -1) {365return;366}367368const cluster = this.clusters[clusterIndex];369const updatedNodes = cluster.nodes.filter(n => n !== node);370371if (updatedNodes.length === 0) {372// Remove empty cluster373this.clusters.splice(clusterIndex, 1);374} else {375// Update cluster with remaining nodes376const updatedCentroid = this.computeCentroid(updatedNodes.map(n => n.embedding.value));377const updatedCluster: Cluster<T> = {378...cluster,379nodes: updatedNodes,380centroid: updatedCentroid381};382this.clusters[clusterIndex] = updatedCluster;383384// Update node mappings for remaining nodes385for (const remainingNode of updatedNodes) {386this.nodeToClusterId.set(remainingNode, clusterId);387}388}389}390391/**392* Create a new cluster from nodes393*/394private createCluster(nodes: Node<T>[]): void {395const id = `cluster_${this.clusterCounter++}`;396const centroid = this.computeCentroid(nodes.map(n => n.embedding.value));397398const cluster: Cluster<T> = {399id,400nodes,401centroid402};403404this.clusters.push(cluster);405406for (const node of nodes) {407this.nodeToClusterId.set(node, id);408}409}410411/**412* Create a singleton cluster for a single node413*/414private createSingletonCluster(node: Node<T>): void {415this.createCluster([node]);416}417418/**419* Compute centroid (mean) of embedding vectors420*/421private computeCentroid(embeddings: EmbeddingVector[]): EmbeddingVector {422if (embeddings.length === 0) {423return [];424}425426if (embeddings.length === 1) {427return [...embeddings[0]]; // Copy to avoid mutations428}429430const dimensions = embeddings[0].length;431const centroid = new Array(dimensions).fill(0);432433// Sum all embeddings434for (const embedding of embeddings) {435for (let i = 0; i < dimensions; i++) {436centroid[i] += embedding[i];437}438}439440// Divide by count to get mean441for (let i = 0; i < dimensions; i++) {442centroid[i] /= embeddings.length;443}444445// L2 normalize the centroid446return this.normalizeVector(centroid);447}448449/**450* Gets the sorted list of pairwise similarities between all nodes.451* The returned list is ordered by similarity, NOT in any particular node order.452*/453private getSimilarities() {454if (this.cachedSimilarities) {455return this.cachedSimilarities;456}457458const similarities: number[] = [];459460// Compute all pairwise similarities (upper triangle only)461for (let i = 0; i < this.nodes.length; i++) {462for (let j = i + 1; j < this.nodes.length; j++) {463const sim = this.cachedCosineSimilarity(this.nodes[i], this.nodes[j]);464similarities.push(sim);465}466}467468// Sort for efficient percentile lookups469similarities.sort((a, b) => a - b);470this.cachedSimilarities = similarities;471return this.cachedSimilarities;472}473474/**475* Optimize clustering by finding the best similarity threshold that results in476* a target number of clusters or fewer, aiming for the highest cluster count477* that doesn't exceed the maximum. Includes a "cliff effect" to avoid over-clustering.478*479* @param maxClusters Maximum desired number of clusters480* @param minThreshold Minimum similarity threshold to try (default: 0.7 - loose clustering)481* @param maxThreshold Maximum similarity threshold to try (default: 0.99 - very strict)482* @param precision How precise the search should be (default: 0.02)483* @param cliffThreshold Fraction of maxClusters that triggers cliff effect (default: 2/3)484* @param cliffGain Minimum additional clusters needed to continue past cliff (default: 20% of maxClusters)485* @returns The optimal threshold found and resulting cluster count486*/487tuneThresholdForTargetClusters(488maxClusters: number,489minThreshold: number = 0.7,490maxThreshold: number = 0.99,491precision: number = 0.02,492cliffThreshold: number = 2 / 3,493cliffGain: number = 0.2494): { percentile: number; clusterCount: number; threshold: number } {495if (this.nodes.length === 0) {496return { percentile: 90, clusterCount: 0, threshold: 0.9 };497}498499const cliffPoint = Math.floor(maxClusters * cliffThreshold);500const minGainAfterCliff = Math.max(1, Math.floor(maxClusters * cliffGain));501502let bestThreshold = maxThreshold;503let bestClusterCount = 1; // Start with worst case (very few clusters)504let cliffReached = false;505506// Binary search for optimal threshold that maximizes clusters while staying under limit507let low = minThreshold;508let high = maxThreshold;509510while (high - low > precision) {511const mid = (low + high) / 2;512const clusterCount = this.countClustersForThreshold(mid, this.options.minClusterSize);513514if (clusterCount <= maxClusters) {515// Check if this is a meaningful improvement516let shouldUpdate = false;517518if (!cliffReached && clusterCount >= cliffPoint) {519// We've reached the cliff point - this is good enough520cliffReached = true;521shouldUpdate = clusterCount > bestClusterCount;522} else if (cliffReached) {523// Past cliff - only update if we get significant additional clusters524shouldUpdate = clusterCount >= bestClusterCount + minGainAfterCliff;525} else {526// Before cliff - any improvement is good527shouldUpdate = clusterCount > bestClusterCount;528}529530if (shouldUpdate) {531bestThreshold = mid;532bestClusterCount = clusterCount;533}534535// Try going to lower threshold for potentially more clusters536low = mid + precision;537} else {538// Too many clusters, need higher threshold (stricter clustering)539high = mid - precision;540}541}542543// Convert threshold to approximate percentile for compatibility544const similarities = this.getSimilarities();545let approximatePercentile = 90;546if (similarities.length > 0) {547const position = similarities.findIndex(s => s >= bestThreshold);548if (position >= 0) {549approximatePercentile = Math.round((position / similarities.length) * 100);550}551}552553return {554percentile: approximatePercentile,555clusterCount: bestClusterCount,556threshold: bestThreshold557};558}559560/**561* Apply a specific similarity threshold and recluster562*563* @param percentile The similarity percentile to convert to threshold564*/565applyPercentileAndRecluster(percentile: number): void {566// Convert percentile to similarity threshold567const eps = this.computeEpsFromPercentile(percentile);568// Temporarily override the eps option569const originalEps = this.options.eps;570(this.options as any).eps = eps;571572try {573this.recluster();574} finally {575// Restore original eps576(this.options as any).eps = originalEps;577}578}579580/**581* Count how many clusters would result from a given similarity threshold without actually clustering582*/583private countClustersForThreshold(threshold: number, minClusterSize: number): number {584if (this.nodes.length === 0) {585return 0;586}587588// Run clustering with given parameters589const clusterAssignments = this.runSimilarityBasedClustering(threshold, minClusterSize);590591// Count unique cluster IDs (excluding -1 which is unassigned)592const clusterIds = new Set<number>();593for (const clusterId of clusterAssignments) {594if (clusterId !== -1) {595clusterIds.add(clusterId);596}597}598599// Add singleton clusters for unassigned points and small clusters600const clusterSizes = new Map<number, number>();601let unassignedCount = 0;602603for (const clusterId of clusterAssignments) {604if (clusterId === -1) {605unassignedCount++;606} else {607clusterSizes.set(clusterId, (clusterSizes.get(clusterId) || 0) + 1);608}609}610611// Count valid clusters (meeting minClusterSize) and singletons612let validClusters = 0;613let singletons = unassignedCount; // Unassigned points become singletons614615for (const [, size] of clusterSizes) {616if (size >= minClusterSize) {617validClusters++;618} else {619singletons += size;620}621}622623return validClusters + singletons;624}625626/**627* Get cached normalized embedding for a node628*/629private getNormalizedEmbedding(node: Node<T>): EmbeddingVector {630let normalized = this.normalizedEmbeddings.get(node);631if (!normalized) {632normalized = this.normalizeVector(node.embedding.value);633this.normalizedEmbeddings.set(node, normalized);634}635return normalized;636}637638/**639* Compute cosine similarity using cached normalized embeddings640*/641private cachedCosineSimilarity(nodeA: Node<T>, nodeB: Node<T>): number {642const normA = this.getNormalizedEmbedding(nodeA);643const normB = this.getNormalizedEmbedding(nodeB);644return this.dotProduct(normA, normB);645}646647/**648* Optimized dot product computation649*/650private dotProduct(a: EmbeddingVector, b: EmbeddingVector): number {651let dotProduct = 0;652const len = Math.min(a.length, b.length);653// Unroll loop for better performance on small vectors654let i = 0;655for (; i < len - 3; i += 4) {656dotProduct += a[i] * b[i] + a[i + 1] * b[i + 1] + a[i + 2] * b[i + 2] + a[i + 3] * b[i + 3];657}658// Handle remaining elements659for (; i < len; i++) {660dotProduct += a[i] * b[i];661}662return dotProduct;663}664665/**666* L2 normalize a vector667*/668private normalizeVector(vector: EmbeddingVector): EmbeddingVector {669const magnitude = Math.sqrt(vector.reduce((sum, val) => sum + val * val, 0));670671if (magnitude === 0) {672return vector.slice(); // Return copy of zero vector673}674675return vector.map(val => val / magnitude);676}677}678679680