github.com/fluhus/gostuff@v0.4.1-0.20240331134726-be71864f2b5d/clustering/kmeans.go (about) 1 // Package clustering provides basic clustering functions. 2 package clustering 3 4 import ( 5 "fmt" 6 "math" 7 "math/rand" 8 9 "github.com/fluhus/gostuff/gnum" 10 ) 11 12 // Kmeans performs k-means clustering on the given data. Each vector is an 13 // element in the clustering. Returns the generated means, and the tag each 14 // element was given. 15 func Kmeans(vecs [][]float64, k int) (means [][]float64, tags []int) { 16 // K must be at least 1. 17 if k < 1 { 18 panic(fmt.Sprint("Bad k:", k)) 19 } 20 21 // Must have at least 1 vector. 22 if len(vecs) == 0 { 23 panic("Cannot cluster 0 vectors.") 24 } 25 26 // If k is too large - that's ok just reduce to avoid out-of-range. 27 if k > len(vecs) { 28 k = len(vecs) 29 } 30 31 // First tagging. 32 means = initialMeans(vecs, k) 33 tags = tag(vecs, means, make([]int, len(vecs))) 34 dist := MeanSquaredError(vecs, means, tags) 35 distOld := 2 * dist 36 37 // Iterate until converged. 38 for dist > distOld || dist/distOld < 0.999 { 39 distOld = dist 40 means = findMeans(vecs, tags, k) 41 tags = tag(vecs, means, tags) 42 dist = MeanSquaredError(vecs, means, tags) 43 } 44 45 return 46 } 47 48 // tag tags each row with the index of its nearest centroid. The old tags are 49 // used for optimization. 50 func tag(vecs, means [][]float64, oldTags []int) []int { 51 if len(means) == 0 { 52 panic("Cannot tag on 0 centroids.") 53 } 54 55 // Create a distance matrix of means from one another. 56 meansd := make([][]float64, len(means)) 57 for i := range meansd { 58 meansd[i] = make([]float64, len(means)) 59 for j := range means { 60 meansd[i][j] = gnum.L2(means[i], means[j]) 61 } 62 } 63 64 tags := make([]int, len(vecs)) 65 66 // Go over vectors. 67 for i := range vecs { 68 // Find nearest centroid. 69 tags[i] = oldTags[i] 70 d := gnum.L2(means[oldTags[i]], vecs[i]) 71 72 for j := 0; j < len(means); j++ { 73 // Use triangle inequality to skip means that are too distant. 74 if j == tags[i] || meansd[j][tags[i]] >= 2*d { 75 continue 76 } 77 78 dj := gnum.L2(means[j], vecs[i]) 79 if dj < d { 80 d = dj 81 tags[i] = j 82 } 83 } 84 } 85 86 return tags 87 } 88 89 // findMeans calculates the new means, according to average of tagged rows in 90 // each group. 91 func findMeans(vecs [][]float64, tags []int, k int) [][]float64 { 92 // Initialize new arrays. 93 means := make([][]float64, k) 94 for i := range means { 95 means[i] = make([]float64, len(vecs[0])) 96 } 97 counts := make([]int, k) 98 99 // Sum all vectors according to tags. 100 for i := range vecs { 101 counts[tags[i]]++ 102 gnum.Add(means[tags[i]], vecs[i]) 103 } 104 105 // Divide by count. 106 for i := range means { 107 if counts[i] != 0 { 108 gnum.Mul1(means[i], 1/float64(counts[i])) 109 } 110 } 111 112 return means 113 } 114 115 // initialMeans picks the initial means with the K-means++ algorithm. 116 func initialMeans(vecs [][]float64, k int) [][]float64 { 117 result := make([][]float64, k) 118 perm := rand.Perm(len(vecs)) 119 numTrials := 2 + int(math.Log(float64(k))) 120 121 probs := make([]float64, len(vecs)) // Probability of each vector. 122 nearest := make([]int, len(vecs)) // Index of nearest mean to each vector. 123 distance := make([]float64, len(vecs)) // Distance to nearest mean. 124 mdistance := make([][]float64, k) // Distance between means. 125 for i := range mdistance { 126 mdistance[i] = make([]float64, k) 127 } 128 129 // Pick each mean. 130 for i := range result { 131 result[i] = make([]float64, len(vecs[0])) 132 133 // First mean is first vector. 134 if i == 0 { 135 copy(result[0], vecs[perm[0]]) 136 for _, j := range perm { 137 distance[j] = gnum.L2(vecs[j], result[0]) 138 } 139 continue 140 } 141 142 // Find next mean. 143 bestCandidate := -1 144 bestImprovement := -math.MaxFloat64 145 146 for t := 0; t < numTrials; t++ { // Make a few attempts. 147 sum := 0.0 148 for _, j := range perm { 149 probs[j] = distance[j] * distance[j] 150 sum += probs[j] 151 } 152 // Pick element with probability relative to d^2. 153 r := rand.Float64() * sum 154 newMean := 0 155 for r > probs[newMean] { 156 r -= probs[newMean] 157 newMean++ 158 } 159 copy(result[i], vecs[newMean]) 160 161 // Update distances from new mean to other means. 162 for j := range mdistance[:i] { 163 mdistance[j][i] = gnum.L2(result[i], result[j]) 164 mdistance[i][j] = mdistance[j][i] 165 } 166 167 // Check improvement. 168 newImprovement := 0.0 169 for j := range vecs { 170 if mdistance[i][nearest[j]] < 2*distance[j] { 171 d := gnum.L2(vecs[j], result[i]) 172 d = math.Min(distance[j], d) 173 newImprovement += distance[j] - d 174 } 175 } 176 if newImprovement > bestImprovement { 177 bestCandidate = newMean 178 bestImprovement = newImprovement 179 } 180 } 181 182 copy(result[i], vecs[bestCandidate]) 183 184 // Update distances. 185 for j := range mdistance[:i] { // From new mean to other means. 186 mdistance[j][i] = gnum.L2(result[i], result[j]) 187 mdistance[i][j] = mdistance[j][i] 188 } 189 for j := range vecs { // From vecs to nearest means. 190 if mdistance[i][nearest[j]] < 2*distance[j] { 191 d := gnum.L2(vecs[j], result[i]) 192 if d < distance[j] { 193 distance[j] = math.Min(distance[j], d) 194 nearest[j] = i 195 } 196 } 197 } 198 } 199 200 return result 201 } 202 203 // MeanSquaredError calculates the average squared-distance of elements from 204 // their assigned means. 205 func MeanSquaredError(vecs, means [][]float64, tags []int) float64 { 206 if len(tags) != len(vecs) { 207 panic(fmt.Sprintf("Non-matching lengths of matrix and tags: %d, %d", 208 len(vecs), len(tags))) 209 } 210 if len(vecs) == 0 { 211 return 0 212 } 213 214 d := 0.0 215 for i := range tags { 216 dist := gnum.L2(means[tags[i]], vecs[i]) 217 d += dist * dist 218 } 219 220 return d / float64(len(vecs)) 221 }