github.com/schollz/clusters@v0.0.0-20221201012527-c6c68863636f/kmeans.go (about) 1 package clusters 2 3 import ( 4 "math" 5 "math/rand" 6 "sync" 7 "time" 8 9 "gonum.org/v1/gonum/floats" 10 ) 11 12 const ( 13 changesThreshold = 2 14 ) 15 16 type kmeansClusterer struct { 17 iterations, number int 18 19 // variables keeping count of changes of points' membership every iteration. User as a stopping condition. 20 changes, oldchanges, counter, threshold int 21 22 // For online learning only 23 alpha float64 24 dimension int 25 26 distance DistanceFunc 27 28 // slices holding the cluster mapping and sizes. Access is synchronized to avoid read during computation. 29 mu sync.RWMutex 30 a, b []int 31 32 // slices holding values of centroids of each clusters 33 m, n [][]float64 34 35 // dataset 36 d [][]float64 37 } 38 39 // Implementation of k-means++ algorithm with online learning 40 func KMeans(iterations, clusters int, distance DistanceFunc) (HardClusterer, error) { 41 if iterations < 1 { 42 return nil, errZeroIterations 43 } 44 45 if clusters < 2 { 46 return nil, errOneCluster 47 } 48 49 var d DistanceFunc 50 { 51 if distance != nil { 52 d = distance 53 } else { 54 d = EuclideanDistance 55 } 56 } 57 58 return &kmeansClusterer{ 59 iterations: iterations, 60 number: clusters, 61 distance: d, 62 }, nil 63 } 64 65 func (c *kmeansClusterer) IsOnline() bool { 66 return true 67 } 68 69 func (c *kmeansClusterer) WithOnline(o Online) HardClusterer { 70 c.alpha = o.Alpha 71 c.dimension = o.Dimension 72 73 c.d = make([][]float64, 0, 100) 74 75 c.initializeMeans() 76 77 return c 78 } 79 80 func (c *kmeansClusterer) Learn(data [][]float64) error { 81 if len(data) == 0 { 82 return errEmptySet 83 } 84 85 c.mu.Lock() 86 87 c.d = data 88 89 c.a = make([]int, len(data)) 90 c.b = make([]int, c.number) 91 92 c.counter = 0 93 c.threshold = changesThreshold 94 c.changes = 0 95 c.oldchanges = 0 96 97 c.initializeMeansWithData() 98 99 for i := 0; i < c.iterations && c.counter != c.threshold; i++ { 100 c.run() 101 c.check() 102 } 103 104 c.n = nil 105 106 c.mu.Unlock() 107 108 return nil 109 } 110 111 func (c *kmeansClusterer) Sizes() []int { 112 c.mu.RLock() 113 defer c.mu.RUnlock() 114 115 return c.b 116 } 117 118 func (c *kmeansClusterer) Guesses() []int { 119 c.mu.RLock() 120 defer c.mu.RUnlock() 121 122 return c.a 123 } 124 125 func (c *kmeansClusterer) Predict(p []float64) int { 126 var ( 127 l int 128 d float64 129 m float64 = c.distance(p, c.m[0]) 130 ) 131 132 for i := 1; i < c.number; i++ { 133 if d = c.distance(p, c.m[i]); d < m { 134 m = d 135 l = i 136 } 137 } 138 139 return l 140 } 141 142 func (c *kmeansClusterer) Online(observations chan []float64, done chan struct{}) chan *HCEvent { 143 c.mu.Lock() 144 145 var ( 146 r chan *HCEvent = make(chan *HCEvent) 147 l, f int = len(c.m), len(c.m[0]) 148 h float64 = 1 - c.alpha 149 ) 150 151 c.b = make([]int, c.number) 152 153 /* The first step of online learning is adjusting the centroids by finding the one closes to new data point 154 * and modifying it's location using given alpha. Once the client quits sending new data, the actual clusters 155 * are computed and the mutex is unlocked. */ 156 157 go func() { 158 for { 159 select { 160 case o := <-observations: 161 var ( 162 k int 163 n float64 164 m float64 = math.Pow(c.distance(o, c.m[0]), 2) 165 ) 166 167 for i := 1; i < l; i++ { 168 if n = math.Pow(c.distance(o, c.m[i]), 2); n < m { 169 m = n 170 k = i 171 } 172 } 173 174 r <- &HCEvent{ 175 Cluster: k, 176 Observation: o, 177 } 178 179 for i := 0; i < f; i++ { 180 c.m[k][i] = c.alpha*o[i] + h*c.m[k][i] 181 } 182 183 c.d = append(c.d, o) 184 case <-done: 185 go func() { 186 var ( 187 n int 188 d, m float64 189 ) 190 191 c.a = make([]int, len(c.d)) 192 193 for i := 0; i < len(c.d); i++ { 194 m = c.distance(c.d[i], c.m[0]) 195 n = 0 196 197 for j := 1; j < c.number; j++ { 198 if d = c.distance(c.d[i], c.m[j]); d < m { 199 m = d 200 n = j 201 } 202 } 203 204 c.a[i] = n + 1 205 c.b[n]++ 206 } 207 208 c.mu.Unlock() 209 }() 210 211 return 212 } 213 } 214 }() 215 216 return r 217 } 218 219 // private 220 func (c *kmeansClusterer) initializeMeansWithData() { 221 c.m = make([][]float64, c.number) 222 c.n = make([][]float64, c.number) 223 224 rand.Seed(time.Now().UTC().Unix()) 225 226 var ( 227 k int 228 s, t, l, f float64 229 d []float64 = make([]float64, len(c.d)) 230 ) 231 232 c.m[0] = c.d[rand.Intn(len(c.d)-1)] 233 234 for i := 1; i < c.number; i++ { 235 s = 0 236 t = 0 237 for j := 0; j < len(c.d); j++ { 238 239 l = c.distance(c.m[0], c.d[j]) 240 for g := 1; g < i; g++ { 241 if f = c.distance(c.m[g], c.d[j]); f < l { 242 l = f 243 } 244 } 245 246 d[j] = math.Pow(l, 2) 247 s += d[j] 248 } 249 250 t = rand.Float64() * s 251 k = 0 252 for s = d[0]; s < t; s += d[k] { 253 k++ 254 } 255 256 c.m[i] = c.d[k] 257 } 258 259 for i := 0; i < c.number; i++ { 260 c.n[i] = make([]float64, len(c.m[0])) 261 } 262 } 263 264 func (c *kmeansClusterer) initializeMeans() { 265 c.m = make([][]float64, c.number) 266 267 rand.Seed(time.Now().UTC().Unix()) 268 269 for i := 0; i < c.number; i++ { 270 c.m[i] = make([]float64, c.dimension) 271 for j := 0; j < c.dimension; j++ { 272 c.m[i][j] = 10 * (rand.Float64() - 0.5) 273 } 274 } 275 } 276 277 func (c *kmeansClusterer) run() { 278 var ( 279 l, k, n int = len(c.m[0]), 0, 0 280 m, d float64 281 ) 282 283 for i := 0; i < c.number; i++ { 284 c.b[i] = 0 285 } 286 287 for i := 0; i < len(c.d); i++ { 288 m = c.distance(c.d[i], c.m[0]) 289 n = 0 290 291 for j := 1; j < c.number; j++ { 292 if d = c.distance(c.d[i], c.m[j]); d < m { 293 m = d 294 n = j 295 } 296 } 297 298 k = n + 1 299 300 if c.a[i] != k { 301 c.changes++ 302 } 303 304 c.a[i] = k 305 c.b[n]++ 306 307 floats.Add(c.n[n], c.d[i]) 308 } 309 310 for i := 0; i < c.number; i++ { 311 floats.Scale(1/float64(c.b[i]), c.n[i]) 312 313 for j := 0; j < l; j++ { 314 c.m[i][j] = c.n[i][j] 315 c.n[i][j] = 0 316 } 317 } 318 } 319 320 func (c *kmeansClusterer) check() { 321 if c.changes == c.oldchanges { 322 c.counter++ 323 } 324 325 c.oldchanges = c.changes 326 }