github.com/fluhus/gostuff@v0.4.1-0.20240331134726-be71864f2b5d/clustering/upgma.go (about) 1 package clustering 2 3 import ( 4 "fmt" 5 "math" 6 7 "github.com/fluhus/gostuff/gnum" 8 "github.com/fluhus/gostuff/heaps" 9 ) 10 11 // distPyramid is a distance half-matrix. 12 type distPyramid [][]float64 13 14 // dist returns the distance between a and b. 15 func (d distPyramid) dist(a, b int) float64 { 16 if a > b { 17 return d[a][b] 18 } 19 return d[b][a] 20 } 21 22 // makePyramid creates a distance half-matrix. 23 func makePyramid(n int, f func(int, int) float64) distPyramid { 24 nn := n * (n - 1) / 2 25 d := make([]float64, 0, nn) 26 for i := 1; i < n; i++ { 27 for j := 0; j < i; j++ { 28 d = append(d, f(j, i)) 29 } 30 } 31 result := make([][]float64, n) 32 j := 0 33 for i := range result { 34 result[i] = d[j : j+i] 35 j += i 36 } 37 return result 38 } 39 40 // upgma is an implementation of UPGMA clustering. The distance between clusters 41 // is the average distance between pairs of their individual elements. 42 func upgma(n int, f func(int, int) float64) *AggloResult { 43 pi := make([]int, n) // Index of first merge target of each element. 44 lambda := make([]float64, n) // Distance of first merge target of each element. 45 46 // Last cluster does not get matched with anyone -> max distance. 47 lambda[len(lambda)-1] = math.MaxFloat64 48 49 // Calculate raw distances. 50 d := makePyramid(n, f) 51 heapss := make([]*heaps.Heap[upgmaCluster], n) 52 for i := range heapss { 53 heapss[i] = heaps.New(compareUpgmaClusters) 54 } 55 for i := 1; i < n; i++ { 56 for j := 0; j < i; j++ { 57 heapss[i].Push(upgmaCluster{j, d[i][j]}) 58 heapss[j].Push(upgmaCluster{i, d[i][j]}) 59 } 60 } 61 62 // Clustering. 63 sizes := gnum.Ones[[]float64](n) // Cluster sizes 64 // The identifier of each cluster = highest index of an element 65 names := make([]int, n) 66 for i := range names { 67 names[i] = i 68 } 69 for i := 0; i < n-1; i++ { 70 // Find lowest distance. 71 fmin := math.MaxFloat64 72 a, b := -1, -1 73 for hi, h := range heapss { 74 if h == nil { 75 continue 76 } 77 // Clean up removed clusters. 78 if h.Len() == 0 { 79 panic(fmt.Sprintf("heap %d with length 0", hi)) 80 } 81 for heapss[h.Head().i] == nil { 82 h.Pop() 83 } 84 if h.Head().d < fmin { 85 a = hi 86 fmin = h.Head().d 87 b = h.Head().i 88 } 89 } 90 91 // Create agglo step. 92 nmin := min(names[a], names[b]) 93 nmax := max(names[a], names[b]) 94 pi[nmin] = nmax 95 lambda[nmin] = fmin 96 97 // Merge clusters. 98 names = append(names, nmax) 99 sizes = append(sizes, sizes[a]+sizes[b]) 100 heapss[a] = nil 101 heapss[b] = nil 102 var cdist []float64 103 cheap := heaps.New(compareUpgmaClusters) 104 for hi, h := range heapss { 105 if h == nil { 106 cdist = append(cdist, 0) 107 continue 108 } 109 da := d.dist(a, hi) * sizes[a] 110 db := d.dist(b, hi) * sizes[b] 111 dd := (da + db) / (sizes[a] + sizes[b]) 112 cdist = append(cdist, dd) 113 h.Push(upgmaCluster{len(sizes) - 1, dd}) 114 cheap.Push(upgmaCluster{hi, dd}) 115 } 116 d = append(d, cdist) 117 heapss = append(heapss, cheap) 118 } 119 120 return newAggloResult(pi, lambda) 121 } 122 123 // Cluster info in UPGMA. 124 type upgmaCluster struct { 125 i int // Cluster index 126 d float64 // Distance from cluster i 127 } 128 129 func compareUpgmaClusters(a, b upgmaCluster) bool { 130 return a.d < b.d 131 }