github.com/fluhus/gostuff@v0.4.1-0.20240331134726-be71864f2b5d/clustering/agglo.go (about) 1 package clustering 2 3 import ( 4 "fmt" 5 "math" 6 "sort" 7 ) 8 9 // How agglomerative clustering should calculate distance between clusters. 10 const ( 11 AggloMin = iota // Minimal distance between any pair of elements. 12 AggloMax // Maximal distance between any pair of elements. 13 AggloAverage // Average distance between any pair of elements. 14 ) 15 16 // Agglo performs agglomerative clustering on the indexes 0 to n-1. d should 17 // return the distance between the i'th and j'th element, such that 18 // d(i,j)=d(j,i) and d(i,i)=0. 19 // 20 // clusterDist should be one of AggloMin or AggloMax. 21 // 22 // Works in O(n^2) time and makes O(n^2) calls to d. 23 func Agglo(n int, clusterDist int, d func(int, int) float64) *AggloResult { 24 if n <= 0 { 25 panic(fmt.Sprintf("Bad n: %d, must be positive", n)) 26 } 27 28 switch clusterDist { 29 case AggloMin: 30 return slink(n, d) 31 case AggloMax: 32 return clink(n, d) 33 case AggloAverage: 34 return upgma(n, d) 35 default: 36 panic(fmt.Sprintf("Unsupported cluster distance: %v, "+ 37 "want AggloMin or AggloMax", clusterDist)) 38 } 39 } 40 41 // slink is an implementation of the SLINK algorithm. 42 // 43 // Copied from: 44 // https://www.cs.ucsb.edu/~veronika/MAE/SLINK_sibson.pdf 45 func slink(n int, d func(int, int) float64) *AggloResult { 46 // Implementation copied from paper, pardon the crap names. 47 m := make([]float64, n) // Distance of i'th element from elements/clusters. 48 pi := make([]int, n) // Index of first merge target of each element. 49 lambda := make([]float64, n) // Distance of first merge target of each element. 50 51 lambda[0] = math.MaxFloat64 52 53 for i := 1; i < n; i++ { 54 pi[i] = i 55 lambda[i] = math.MaxFloat64 56 57 for j := 0; j < i; j++ { 58 m[j] = d(i, j) 59 } 60 61 for j := 0; j < i; j++ { 62 if m[j] <= lambda[j] { 63 m[pi[j]] = math.Min(m[pi[j]], lambda[j]) 64 lambda[j] = m[j] 65 pi[j] = i 66 } else { 67 m[pi[j]] = math.Min(m[pi[j]], m[j]) 68 } 69 } 70 71 for j := 0; j < i; j++ { 72 if lambda[j] >= lambda[pi[j]] { 73 pi[j] = i 74 } 75 } 76 } 77 78 return newAggloResult(pi, lambda) 79 } 80 81 // clink is an implementation of the CLINK algorithm. 82 // 83 // Copied from: 84 // https://academic.oup.com/comjnl/article-pdf/20/4/364/1108735/200364.pdf 85 func clink(n int, d func(int, int) float64) *AggloResult { 86 // Implementation copied from paper, pardon the crap names. 87 m := make([]float64, n) // Distance of i'th element from elements/clusters. 88 pi := make([]int, n) // Index of first merge target of each element. 89 lambda := make([]float64, n) // Distance of first merge target of each element. 90 91 lambda[0] = math.MaxFloat64 92 93 for i := 1; i < n; i++ { 94 pi[i] = i 95 lambda[i] = math.MaxFloat64 96 97 for j := 0; j < i; j++ { 98 m[j] = d(i, j) 99 } 100 101 for j := 0; j < i; j++ { 102 if lambda[j] < m[j] { 103 m[pi[j]] = math.Max(m[pi[j]], m[j]) 104 m[j] = math.MaxFloat64 105 } 106 } 107 108 a := i - 1 109 for j := 0; j < i; j++ { 110 if lambda[i-j-1] >= m[pi[i-j-1]] { 111 if m[i-j-1] < m[a] { 112 a = i - j - 1 113 } 114 } else { 115 m[i-j-1] = math.MaxFloat64 116 } 117 } 118 119 b := pi[a] 120 c := lambda[a] 121 pi[a] = i 122 lambda[a] = m[a] 123 for a < i-1 { 124 if b < i-1 { 125 d := pi[b] 126 e := lambda[b] 127 pi[b] = i 128 lambda[b] = c 129 b = d 130 c = e 131 } else if b == i-1 { 132 pi[b] = i 133 lambda[b] = c 134 break 135 } 136 } 137 138 for j := 0; j < i; j++ { 139 if pi[pi[j]] == i && lambda[j] >= lambda[pi[j]] { 140 pi[j] = i 141 } 142 } 143 } 144 145 return newAggloResult(pi, lambda) 146 } 147 148 // AggloResult is an interactive agglomerative-clustering result. 149 type AggloResult struct { 150 pi []int 151 lambda []float64 152 perm []int 153 dict []string 154 } 155 156 // Dict returns the string representations of elements in the clustering. 157 func (r *AggloResult) Dict() []string { 158 return r.dict 159 } 160 161 // newAggloResult creates a new result. 162 func newAggloResult(pi []int, lambda []float64) *AggloResult { 163 result := &AggloResult{pi, lambda, make([]int, len(pi)), nil} 164 for i := range result.perm { 165 result.perm[i] = i 166 } 167 sort.Sort((*aggloSorter)(result)) 168 return result 169 } 170 171 // aggloSorter is a sorting interface for AggloResult, for sorting by distance. 172 // This actually sorts the agglomerative steps by their order of occurrence. 173 type aggloSorter AggloResult 174 175 // Len returns the number of elements in the sorter. 176 func (r *aggloSorter) Len() int { 177 return len(r.perm) 178 } 179 180 // Less compares two steps by their order of occurrence. 181 func (r *aggloSorter) Less(i, j int) bool { 182 return r.lambda[r.perm[i]] < r.lambda[r.perm[j]] 183 } 184 185 // Swap swaps two steps. 186 func (r *aggloSorter) Swap(i, j int) { 187 r.perm[i], r.perm[j] = r.perm[j], r.perm[i] 188 } 189 190 // SetDict sets the string representation of each element, for the String() 191 // function. Returns itself for chaining. 192 func (r *AggloResult) SetDict(dict []string) *AggloResult { 193 if len(dict) != len(r.perm) { 194 panic(fmt.Sprintf("Bad dictionary size: %d, expected %d", 195 len(dict), len(r.perm))) 196 } 197 r.dict = dict 198 return r 199 } 200 201 // String returns a representation of the clustering. If SetDict was not 202 // called, will use element numbers. 203 func (r *AggloResult) String() string { 204 strs := make([]string, len(r.perm)) 205 for i := range strs { 206 if r.dict == nil { 207 strs[i] = fmt.Sprint(i) 208 } else { 209 strs[i] = r.dict[i] 210 } 211 } 212 for _, i := range r.perm { 213 j := r.pi[i] 214 if i == j { // Reached the end. 215 return strs[i] 216 } 217 strs[j] = fmt.Sprintf("[%s, %s]", strs[i], strs[j]) 218 } 219 220 // TODO(amit): Panic if reached here. 221 return "" 222 } 223 224 // Len returns the number of steps in this clustering. Equals the number of 225 // elements - 1. 226 func (r *AggloResult) Len() int { 227 return len(r.perm) - 1 228 } 229 230 // An AggloStep is a single step in the clustering process. 231 // The index of a cluster is the greatest indexed element in it. 232 // C2 is always greater than C1. 233 type AggloStep struct { 234 C1 int // Index of the first merged cluster. 235 C2 int // Index of the second merged cluster. 236 D float64 // Distance between the clusters when merging. 237 } 238 239 // Step returns the i'th step in the clustering. 240 func (r *AggloResult) Step(i int) AggloStep { 241 return AggloStep{r.perm[i], r.pi[r.perm[i]], r.lambda[r.perm[i]]} 242 }