Path: blob/main/extensions/copilot/src/platform/embeddings/test/node/embeddingsGrouper.spec.ts
13405 views
/*---------------------------------------------------------------------------------------------1* Copyright (c) Microsoft Corporation. All rights reserved.2* Licensed under the MIT License. See License.txt in the project root for license information.3*--------------------------------------------------------------------------------------------*/45import { beforeEach, describe, expect, it } from 'vitest';6import { Embedding, EmbeddingType } from '../../common/embeddingsComputer';7import { EmbeddingsGrouper, GroupingOptions, Node } from '../../common/embeddingsGrouper';89interface TestTool {10name: string;11category: string;12}1314// Helper function to create test embeddings15function createEmbedding(values: number[]): Embedding {16return {17type: EmbeddingType.text3small_512,18value: values19};20}2122// Helper function to create test nodes23function createNode(name: string, category: string, embedding: number[]): Node<TestTool> {24return {25value: { name, category },26embedding: createEmbedding(embedding)27};28}2930// Create embeddings that should cluster together (high cosine similarity)31function createSimilarEmbeddings(): number[][] {32return [33[1, 0.8, 0.2, 0.1], // Similar to group 134[0.9, 0.7, 0.1, 0.15], // Similar to group 135[0.1, 0.2, 1, 0.8], // Similar to group 236[0.15, 0.1, 0.9, 0.7], // Similar to group 237[0.5, 0.5, 0.5, 0.5] // Outlier38];39}4041describe('EmbeddingsGrouper', () => {42let grouper: EmbeddingsGrouper<TestTool>;4344beforeEach(() => {45grouper = new EmbeddingsGrouper<TestTool>();46});4748describe('constructor and initial state', () => {49it('should initialize with empty clusters', () => {50expect(grouper.getClusters()).toHaveLength(0);51});5253it('should accept custom options', () => {54const options: GroupingOptions = {55eps: 0.85, // High similarity threshold56minClusterSize: 3,57insertThreshold: 0.758};59const customGrouper = new EmbeddingsGrouper<TestTool>(options);60expect(customGrouper.getClusters()).toHaveLength(0);61});62});6364describe('addNode', () => {65it('should create singleton cluster for first node', () => {66const node = createNode('tool1', 'category1', [1, 0, 0, 0]);67grouper.addNode(node);6869const clusters = grouper.getClusters();70expect(clusters).toHaveLength(1);71expect(clusters[0].nodes).toHaveLength(1);72expect(clusters[0].nodes[0]).toBe(node);73});7475it('should create separate clusters for dissimilar nodes', () => {76const node1 = createNode('tool1', 'category1', [1, 0, 0, 0]);77const node2 = createNode('tool2', 'category2', [0, 1, 0, 0]);7879grouper.addNode(node1);80grouper.addNode(node2);8182const clusters = grouper.getClusters();83expect(clusters).toHaveLength(2);84expect(clusters.every(c => c.nodes.length === 1)).toBe(true);85});8687it('should add similar nodes to existing clusters when possible', () => {88const embeddings = createSimilarEmbeddings();89const nodes = [90createNode('tool1', 'category1', embeddings[0]),91createNode('tool2', 'category1', embeddings[1])92];9394grouper.addNode(nodes[0]);95grouper.addNode(nodes[1]);9697// After adding similar nodes incrementally, they might not cluster98// Let's force clustering to see the behavior99grouper.recluster();100101const clusters = grouper.getClusters();102expect(clusters.length).toBeGreaterThan(0);103104// Similar embeddings should be in same cluster105const cluster = grouper.getClusterForNode(nodes[0]);106expect(cluster).toBeDefined();107expect(cluster!.nodes.some(n => n === nodes[1])).toBe(true);108});109});110111describe('addNodes (bulk)', () => {112it('should handle empty array efficiently', () => {113grouper.addNodes([]);114expect(grouper.getClusters()).toHaveLength(0);115});116117it('should add multiple nodes and cluster them efficiently', () => {118const embeddings = createSimilarEmbeddings();119const nodes = [120createNode('search', 'lookup', embeddings[0]),121createNode('find', 'lookup', embeddings[1]), // Similar to search122createNode('create', 'generate', embeddings[2]),123createNode('make', 'generate', embeddings[3]), // Similar to create124createNode('random', 'misc', embeddings[4]) // Outlier125];126127grouper.addNodes(nodes);128129const clusters = grouper.getClusters();130expect(clusters.length).toBeGreaterThanOrEqual(2);131132// Verify all nodes are assigned133const totalNodesInClusters = clusters.reduce((sum, cluster) => sum + cluster.nodes.length, 0);134expect(totalNodesInClusters).toBe(nodes.length);135});136137it('should allow deferring clustering with reclusterAfter=false', () => {138const nodes = [139createNode('tool1', 'cat1', [1, 0, 0, 0]),140createNode('tool2', 'cat1', [0.9, 0.1, 0, 0]),141createNode('tool3', 'cat1', [0, 1, 0, 0])142];143144grouper.addNodes(nodes, false);145146// Should create singleton clusters without clustering147const clusters = grouper.getClusters();148expect(clusters).toHaveLength(3);149expect(clusters.every(c => c.nodes.length === 1)).toBe(true);150151// Manual recluster should group similar nodes152grouper.recluster();153const clustersAfterRecluster = grouper.getClusters();154// Should potentially reduce cluster count due to grouping155expect(clustersAfterRecluster.length).toBeLessThanOrEqual(clusters.length);156});157});158159describe('removeNode', () => {160it('should return false for non-existent node', () => {161const node = createNode('tool1', 'category1', [1, 0, 0, 0]);162const result = grouper.removeNode(node);163expect(result).toBe(false);164});165166it('should remove node and return true for existing node', () => {167const node = createNode('tool1', 'category1', [1, 0, 0, 0]);168grouper.addNode(node);169170const result = grouper.removeNode(node);171expect(result).toBe(true);172expect(grouper.getClusters()).toHaveLength(0);173});174175it('should remove empty clusters when last node is removed', () => {176const node = createNode('tool1', 'category1', [1, 0, 0, 0]);177grouper.addNode(node);178expect(grouper.getClusters()).toHaveLength(1);179180grouper.removeNode(node);181expect(grouper.getClusters()).toHaveLength(0);182});183184it('should update cluster when node is removed from multi-node cluster', () => {185const embeddings = createSimilarEmbeddings();186const nodes = [187createNode('tool1', 'category1', embeddings[0]),188createNode('tool2', 'category1', embeddings[1]),189createNode('tool3', 'category1', embeddings[0]) // Very similar to first190];191192nodes.forEach(node => grouper.addNode(node));193grouper.recluster();194195const initialClusters = grouper.getClusters();196const clusterWithMultipleNodes = initialClusters.find(c => c.nodes.length > 1);197198if (clusterWithMultipleNodes && clusterWithMultipleNodes.nodes.length > 1) {199const nodeToRemove = clusterWithMultipleNodes.nodes[0];200grouper.removeNode(nodeToRemove);201202const updatedCluster = grouper.getClusterForNode(clusterWithMultipleNodes.nodes[1]);203expect(updatedCluster).toBeDefined();204expect(updatedCluster!.nodes.some(n => n === nodeToRemove)).toBe(false);205}206});207});208209describe('recluster', () => {210it('should handle empty node list', () => {211grouper.recluster();212expect(grouper.getClusters()).toHaveLength(0);213});214215it('should create single cluster for single node', () => {216const node = createNode('tool1', 'category1', [1, 0, 0, 0]);217grouper.addNode(node);218grouper.recluster();219220const clusters = grouper.getClusters();221expect(clusters).toHaveLength(1);222expect(clusters[0].nodes).toHaveLength(1);223});224225it('should group similar embeddings together', () => {226const embeddings = createSimilarEmbeddings();227const nodes = [228createNode('search', 'lookup', embeddings[0]),229createNode('find', 'lookup', embeddings[1]), // Similar to search230createNode('create', 'generate', embeddings[2]),231createNode('make', 'generate', embeddings[3]), // Similar to create232createNode('random', 'misc', embeddings[4]) // Outlier233];234235nodes.forEach(node => grouper.addNode(node));236grouper.recluster();237238const clusters = grouper.getClusters();239expect(clusters.length).toBeGreaterThanOrEqual(2);240241// Find clusters for similar nodes242const searchCluster = grouper.getClusterForNode(nodes[0]);243const findCluster = grouper.getClusterForNode(nodes[1]);244const createCluster = grouper.getClusterForNode(nodes[2]);245const makeCluster = grouper.getClusterForNode(nodes[3]);246247// Similar nodes should be in same clusters248expect(searchCluster).toBeDefined();249expect(findCluster).toBeDefined();250expect(createCluster).toBeDefined();251expect(makeCluster).toBeDefined();252});253254it('should respect minimum cluster size option', () => {255const grouper = new EmbeddingsGrouper<TestTool>({ minClusterSize: 3 });256const embeddings = createSimilarEmbeddings();257const nodes = [258createNode('tool1', 'cat1', embeddings[0]),259createNode('tool2', 'cat1', embeddings[1])260];261262nodes.forEach(node => grouper.addNode(node));263grouper.recluster();264265// With minClusterSize=3, these 2 similar nodes should be separate singletons266const clusters = grouper.getClusters();267expect(clusters.every(c => c.nodes.length < 3)).toBe(true);268});269});270271describe('getClusterForNode', () => {272it('should return undefined for non-existent node', () => {273const node = createNode('tool1', 'category1', [1, 0, 0, 0]);274const cluster = grouper.getClusterForNode(node);275expect(cluster).toBeUndefined();276});277278it('should return correct cluster for existing node', () => {279const node = createNode('tool1', 'category1', [1, 0, 0, 0]);280grouper.addNode(node);281282const cluster = grouper.getClusterForNode(node);283expect(cluster).toBeDefined();284expect(cluster!.nodes).toContain(node);285});286});287288describe('clustering quality', () => {289it('should handle identical embeddings', () => {290const identicalEmbedding = [1, 0, 0, 0];291const nodes = [292createNode('tool1', 'cat1', identicalEmbedding),293createNode('tool2', 'cat1', identicalEmbedding),294createNode('tool3', 'cat1', identicalEmbedding)295];296297nodes.forEach(node => grouper.addNode(node));298grouper.recluster();299300// All identical embeddings should be in same cluster301const clusters = grouper.getClusters();302expect(clusters).toHaveLength(1);303expect(clusters[0].nodes).toHaveLength(3);304});305306it('should handle zero vectors', () => {307const zeroEmbedding = [0, 0, 0, 0];308const node = createNode('tool1', 'cat1', zeroEmbedding);309grouper.addNode(node);310grouper.recluster();311312const clusters = grouper.getClusters();313expect(clusters).toHaveLength(1);314expect(clusters[0].centroid).toEqual([0, 0, 0, 0]);315});316317it('should handle varied similarity distributions', () => {318// Create embeddings with different similarity levels319const nodes = [320createNode('very_similar_1', 'cat1', [1, 0.1, 0, 0]),321createNode('very_similar_2', 'cat1', [0.9, 0.1, 0, 0]),322createNode('somewhat_similar', 'cat1', [0.7, 0.3, 0.1, 0]),323createNode('different_1', 'cat2', [0, 0, 1, 0.1]),324createNode('different_2', 'cat2', [0.1, 0, 0.9, 0.1]),325createNode('outlier', 'cat3', [0.25, 0.25, 0.25, 0.25])326];327328nodes.forEach(node => grouper.addNode(node));329grouper.recluster();330331const clusters = grouper.getClusters();332expect(clusters.length).toBeGreaterThan(1);333expect(clusters.length).toBeLessThanOrEqual(nodes.length);334335// Verify each node is assigned to exactly one cluster336const allNodesInClusters = clusters.flatMap(c => c.nodes);337expect(allNodesInClusters).toHaveLength(nodes.length);338});339});340341describe('adaptive threshold behavior', () => {342it('should work with different similarity percentiles', () => {343const strictGrouper = new EmbeddingsGrouper<TestTool>({ eps: 0.95 }); // High similarity required344const lenientGrouper = new EmbeddingsGrouper<TestTool>({ eps: 0.8 }); // Lower similarity required345346const embeddings = createSimilarEmbeddings();347const nodes = embeddings.map((emb, i) =>348createNode(`tool${i}`, 'category', emb)349);350351// Add same nodes to both groupers352nodes.forEach(node => {353strictGrouper.addNode({ ...node });354lenientGrouper.addNode({ ...node });355});356357strictGrouper.recluster();358lenientGrouper.recluster();359360const strictClusters = strictGrouper.getClusters();361const lenientClusters = lenientGrouper.getClusters();362363// Stricter threshold should generally create more clusters364expect(strictClusters.length).toBeGreaterThanOrEqual(lenientClusters.length);365});366});367368describe('centroid computation', () => {369it('should compute correct centroids for clusters', () => {370const nodes = [371createNode('tool1', 'cat1', [1, 0, 0, 0]),372createNode('tool2', 'cat1', [0, 1, 0, 0])373];374375nodes.forEach(node => grouper.addNode(node));376grouper.recluster();377378const clusters = grouper.getClusters();379clusters.forEach(cluster => {380expect(cluster.centroid).toBeDefined();381expect(cluster.centroid.length).toBeGreaterThan(0);382383// Centroid should be normalized (magnitude ≈ 1 for non-zero vectors)384const magnitude = Math.sqrt(cluster.centroid.reduce((sum, val) => sum + val * val, 0));385if (magnitude > 0) {386expect(magnitude).toBeCloseTo(1, 5);387}388});389});390});391392describe('threshold tuning optimization', () => {393it('should find optimal percentile for target cluster count', () => {394const embeddings = createSimilarEmbeddings();395const nodes = [396createNode('search', 'lookup', embeddings[0]),397createNode('find', 'lookup', embeddings[1]), // Similar to search398createNode('create', 'generate', embeddings[2]),399createNode('make', 'generate', embeddings[3]), // Similar to create400createNode('random1', 'misc', embeddings[4]), // Outlier401createNode('random2', 'misc', [0.3, 0.3, 0.3, 0.1]), // Another outlier402createNode('random3', 'misc', [0.1, 0.3, 0.3, 0.3]) // Another outlier403];404405grouper.addNodes(nodes, false); // Don't cluster initially406407// Find optimal percentile for max 4 clusters408const result = grouper.tuneThresholdForTargetClusters(4);409410expect(result.clusterCount).toBeLessThanOrEqual(4);411expect(result.percentile).toBeGreaterThanOrEqual(80);412expect(result.percentile).toBeLessThanOrEqual(99);413expect(result.threshold).toBeGreaterThan(0);414});415416it('should apply percentile and recluster efficiently', () => {417const embeddings = createSimilarEmbeddings();418const nodes = embeddings.map((emb, i) =>419createNode(`tool${i}`, 'category', emb)420);421422grouper.addNodes(nodes);423424// Apply a stricter percentile (should create more clusters)425grouper.applyPercentileAndRecluster(98);426const strictClusterCount = grouper.getClusters().length;427428// Apply a more lenient percentile (should create fewer clusters)429grouper.applyPercentileAndRecluster(85);430const lenientClusterCount = grouper.getClusters().length;431432// Stricter threshold should generally create more or equal clusters433expect(strictClusterCount).toBeGreaterThanOrEqual(lenientClusterCount);434435// All nodes should still be assigned436const allClusters = grouper.getClusters();437const totalNodes = allClusters.reduce((sum, cluster) => sum + cluster.nodes.length, 0);438expect(totalNodes).toBe(nodes.length);439});440441it('should cache similarities for efficient repeated tuning', () => {442// Create embeddings with more predictable clustering behavior443const nodes: Node<TestTool>[] = [];444445// Create 3 distinct groups of 10 nodes each using deterministic trigonometric patterns446for (let group = 0; group < 3; group++) {447for (let i = 0; i < 10; i++) {448// Each group occupies a different region of the 4D space using trigonometry449const baseAngle = group * (2 * Math.PI / 3); // 120° separation between groups450const variation = i * 0.1; // Small variation within group451452const embedding = [453Math.cos(baseAngle + variation) * 0.8 + 0.2,454Math.sin(baseAngle + variation) * 0.8 + 0.2,455Math.cos(baseAngle * 2 + variation * 0.5) * 0.3 + 0.1,456Math.sin(baseAngle * 2 + variation * 0.5) * 0.3 + 0.1457];458nodes.push(createNode(`tool${group}_${i}`, `cat${group}`, embedding));459}460}461462grouper.addNodes(nodes, false);463464const result1 = grouper.tuneThresholdForTargetClusters(15); // Very lenient465const result2 = grouper.tuneThresholdForTargetClusters(8); // Moderate466const result3 = grouper.tuneThresholdForTargetClusters(5); // Strict467468// With deterministic trigonometric data, these should be more predictable469expect(result1.clusterCount).toBeLessThanOrEqual(15);470expect(result2.clusterCount).toBeLessThanOrEqual(8);471472// For the strictest case, we have 3 natural groups, so expect something reasonable473expect(result3.clusterCount).toBeLessThanOrEqual(5);474475// Verify the algorithm works in the right direction (more restrictive = fewer clusters)476expect(result1.clusterCount).toBeGreaterThanOrEqual(result2.clusterCount);477expect(result2.clusterCount).toBeGreaterThanOrEqual(result3.clusterCount);478}); it('should handle edge cases in threshold tuning', () => {479// Empty grouper480const emptyResult = grouper.tuneThresholdForTargetClusters(5);481expect(emptyResult.clusterCount).toBe(0);482483// Single node484grouper.addNode(createNode('single', 'cat', [1, 0, 0, 0]));485const singleResult = grouper.tuneThresholdForTargetClusters(5);486expect(singleResult.clusterCount).toBe(1);487488// Target higher than possible clusters489const nodes = [490createNode('tool1', 'cat1', [1, 0, 0, 0]),491createNode('tool2', 'cat1', [0, 1, 0, 0])492];493grouper.addNodes(nodes);494const highTargetResult = grouper.tuneThresholdForTargetClusters(10);495expect(highTargetResult.clusterCount).toBeLessThanOrEqual(3); // At most 3 nodes total496});497});498499describe('edge cases', () => {500it('should handle single dimension embeddings', () => {501const nodes = [502createNode('tool1', 'cat1', [1]),503createNode('tool2', 'cat1', [0.9]),504createNode('tool3', 'cat1', [0.1])505];506507nodes.forEach(node => grouper.addNode(node));508grouper.recluster();509510const clusters = grouper.getClusters();511expect(clusters.length).toBeGreaterThan(0);512});513514it('should handle large number of nodes efficiently', () => {515const nodes: Node<TestTool>[] = [];516for (let i = 0; i < 100; i++) {517// Create diverse embeddings that will form distinct clusters518const groupId = Math.floor(i / 10); // 10 groups of 10 nodes each519const withinGroupVariation = (i % 10) * 0.1;520521// Create embeddings with clear group separation522const embedding = [523groupId === 0 ? 1 - withinGroupVariation * 0.2 : withinGroupVariation * 0.2,524groupId === 1 ? 1 - withinGroupVariation * 0.2 : withinGroupVariation * 0.2,525groupId === 2 ? 1 - withinGroupVariation * 0.2 : withinGroupVariation * 0.2,526groupId >= 3 ? 1 - withinGroupVariation * 0.2 : withinGroupVariation * 0.2527];528nodes.push(createNode(`tool${i}`, `cat${groupId}`, embedding));529}530531nodes.forEach(node => grouper.addNode(node));532grouper.recluster();533534const clusters = grouper.getClusters();535// Should create multiple clusters but not more than the number of nodes536expect(clusters.length).toBeGreaterThanOrEqual(2);537expect(clusters.length).toBeLessThanOrEqual(nodes.length);538539// Verify all nodes are assigned540const totalNodesInClusters = clusters.reduce((sum, cluster) => sum + cluster.nodes.length, 0);541expect(totalNodesInClusters).toBe(nodes.length);542});543});544});545546547548549