github.com/schollz/clusters@v0.0.0-20221201012527-c6c68863636f/kmeans.go (about)

     1  package clusters
     2  
     3  import (
     4  	"math"
     5  	"math/rand"
     6  	"sync"
     7  	"time"
     8  
     9  	"gonum.org/v1/gonum/floats"
    10  )
    11  
    12  const (
    13  	changesThreshold = 2
    14  )
    15  
    16  type kmeansClusterer struct {
    17  	iterations, number int
    18  
    19  	// variables keeping count of changes of points' membership every iteration. User as a stopping condition.
    20  	changes, oldchanges, counter, threshold int
    21  
    22  	// For online learning only
    23  	alpha     float64
    24  	dimension int
    25  
    26  	distance DistanceFunc
    27  
    28  	// slices holding the cluster mapping and sizes. Access is synchronized to avoid read during computation.
    29  	mu   sync.RWMutex
    30  	a, b []int
    31  
    32  	// slices holding values of centroids of each clusters
    33  	m, n [][]float64
    34  
    35  	// dataset
    36  	d [][]float64
    37  }
    38  
    39  // Implementation of k-means++ algorithm with online learning
    40  func KMeans(iterations, clusters int, distance DistanceFunc) (HardClusterer, error) {
    41  	if iterations < 1 {
    42  		return nil, errZeroIterations
    43  	}
    44  
    45  	if clusters < 2 {
    46  		return nil, errOneCluster
    47  	}
    48  
    49  	var d DistanceFunc
    50  	{
    51  		if distance != nil {
    52  			d = distance
    53  		} else {
    54  			d = EuclideanDistance
    55  		}
    56  	}
    57  
    58  	return &kmeansClusterer{
    59  		iterations: iterations,
    60  		number:     clusters,
    61  		distance:   d,
    62  	}, nil
    63  }
    64  
    65  func (c *kmeansClusterer) IsOnline() bool {
    66  	return true
    67  }
    68  
    69  func (c *kmeansClusterer) WithOnline(o Online) HardClusterer {
    70  	c.alpha = o.Alpha
    71  	c.dimension = o.Dimension
    72  
    73  	c.d = make([][]float64, 0, 100)
    74  
    75  	c.initializeMeans()
    76  
    77  	return c
    78  }
    79  
    80  func (c *kmeansClusterer) Learn(data [][]float64) error {
    81  	if len(data) == 0 {
    82  		return errEmptySet
    83  	}
    84  
    85  	c.mu.Lock()
    86  
    87  	c.d = data
    88  
    89  	c.a = make([]int, len(data))
    90  	c.b = make([]int, c.number)
    91  
    92  	c.counter = 0
    93  	c.threshold = changesThreshold
    94  	c.changes = 0
    95  	c.oldchanges = 0
    96  
    97  	c.initializeMeansWithData()
    98  
    99  	for i := 0; i < c.iterations && c.counter != c.threshold; i++ {
   100  		c.run()
   101  		c.check()
   102  	}
   103  
   104  	c.n = nil
   105  
   106  	c.mu.Unlock()
   107  
   108  	return nil
   109  }
   110  
   111  func (c *kmeansClusterer) Sizes() []int {
   112  	c.mu.RLock()
   113  	defer c.mu.RUnlock()
   114  
   115  	return c.b
   116  }
   117  
   118  func (c *kmeansClusterer) Guesses() []int {
   119  	c.mu.RLock()
   120  	defer c.mu.RUnlock()
   121  
   122  	return c.a
   123  }
   124  
   125  func (c *kmeansClusterer) Predict(p []float64) int {
   126  	var (
   127  		l int
   128  		d float64
   129  		m float64 = c.distance(p, c.m[0])
   130  	)
   131  
   132  	for i := 1; i < c.number; i++ {
   133  		if d = c.distance(p, c.m[i]); d < m {
   134  			m = d
   135  			l = i
   136  		}
   137  	}
   138  
   139  	return l
   140  }
   141  
   142  func (c *kmeansClusterer) Online(observations chan []float64, done chan struct{}) chan *HCEvent {
   143  	c.mu.Lock()
   144  
   145  	var (
   146  		r    chan *HCEvent = make(chan *HCEvent)
   147  		l, f int           = len(c.m), len(c.m[0])
   148  		h    float64       = 1 - c.alpha
   149  	)
   150  
   151  	c.b = make([]int, c.number)
   152  
   153  	/* The first step of online learning is adjusting the centroids by finding the one closes to new data point
   154  	 * and modifying it's location using given alpha. Once the client quits sending new data, the actual clusters
   155  	 * are computed and the mutex is unlocked. */
   156  
   157  	go func() {
   158  		for {
   159  			select {
   160  			case o := <-observations:
   161  				var (
   162  					k int
   163  					n float64
   164  					m float64 = math.Pow(c.distance(o, c.m[0]), 2)
   165  				)
   166  
   167  				for i := 1; i < l; i++ {
   168  					if n = math.Pow(c.distance(o, c.m[i]), 2); n < m {
   169  						m = n
   170  						k = i
   171  					}
   172  				}
   173  
   174  				r <- &HCEvent{
   175  					Cluster:     k,
   176  					Observation: o,
   177  				}
   178  
   179  				for i := 0; i < f; i++ {
   180  					c.m[k][i] = c.alpha*o[i] + h*c.m[k][i]
   181  				}
   182  
   183  				c.d = append(c.d, o)
   184  			case <-done:
   185  				go func() {
   186  					var (
   187  						n    int
   188  						d, m float64
   189  					)
   190  
   191  					c.a = make([]int, len(c.d))
   192  
   193  					for i := 0; i < len(c.d); i++ {
   194  						m = c.distance(c.d[i], c.m[0])
   195  						n = 0
   196  
   197  						for j := 1; j < c.number; j++ {
   198  							if d = c.distance(c.d[i], c.m[j]); d < m {
   199  								m = d
   200  								n = j
   201  							}
   202  						}
   203  
   204  						c.a[i] = n + 1
   205  						c.b[n]++
   206  					}
   207  
   208  					c.mu.Unlock()
   209  				}()
   210  
   211  				return
   212  			}
   213  		}
   214  	}()
   215  
   216  	return r
   217  }
   218  
   219  // private
   220  func (c *kmeansClusterer) initializeMeansWithData() {
   221  	c.m = make([][]float64, c.number)
   222  	c.n = make([][]float64, c.number)
   223  
   224  	rand.Seed(time.Now().UTC().Unix())
   225  
   226  	var (
   227  		k          int
   228  		s, t, l, f float64
   229  		d          []float64 = make([]float64, len(c.d))
   230  	)
   231  
   232  	c.m[0] = c.d[rand.Intn(len(c.d)-1)]
   233  
   234  	for i := 1; i < c.number; i++ {
   235  		s = 0
   236  		t = 0
   237  		for j := 0; j < len(c.d); j++ {
   238  
   239  			l = c.distance(c.m[0], c.d[j])
   240  			for g := 1; g < i; g++ {
   241  				if f = c.distance(c.m[g], c.d[j]); f < l {
   242  					l = f
   243  				}
   244  			}
   245  
   246  			d[j] = math.Pow(l, 2)
   247  			s += d[j]
   248  		}
   249  
   250  		t = rand.Float64() * s
   251  		k = 0
   252  		for s = d[0]; s < t; s += d[k] {
   253  			k++
   254  		}
   255  
   256  		c.m[i] = c.d[k]
   257  	}
   258  
   259  	for i := 0; i < c.number; i++ {
   260  		c.n[i] = make([]float64, len(c.m[0]))
   261  	}
   262  }
   263  
   264  func (c *kmeansClusterer) initializeMeans() {
   265  	c.m = make([][]float64, c.number)
   266  
   267  	rand.Seed(time.Now().UTC().Unix())
   268  
   269  	for i := 0; i < c.number; i++ {
   270  		c.m[i] = make([]float64, c.dimension)
   271  		for j := 0; j < c.dimension; j++ {
   272  			c.m[i][j] = 10 * (rand.Float64() - 0.5)
   273  		}
   274  	}
   275  }
   276  
   277  func (c *kmeansClusterer) run() {
   278  	var (
   279  		l, k, n int = len(c.m[0]), 0, 0
   280  		m, d    float64
   281  	)
   282  
   283  	for i := 0; i < c.number; i++ {
   284  		c.b[i] = 0
   285  	}
   286  
   287  	for i := 0; i < len(c.d); i++ {
   288  		m = c.distance(c.d[i], c.m[0])
   289  		n = 0
   290  
   291  		for j := 1; j < c.number; j++ {
   292  			if d = c.distance(c.d[i], c.m[j]); d < m {
   293  				m = d
   294  				n = j
   295  			}
   296  		}
   297  
   298  		k = n + 1
   299  
   300  		if c.a[i] != k {
   301  			c.changes++
   302  		}
   303  
   304  		c.a[i] = k
   305  		c.b[n]++
   306  
   307  		floats.Add(c.n[n], c.d[i])
   308  	}
   309  
   310  	for i := 0; i < c.number; i++ {
   311  		floats.Scale(1/float64(c.b[i]), c.n[i])
   312  
   313  		for j := 0; j < l; j++ {
   314  			c.m[i][j] = c.n[i][j]
   315  			c.n[i][j] = 0
   316  		}
   317  	}
   318  }
   319  
   320  func (c *kmeansClusterer) check() {
   321  	if c.changes == c.oldchanges {
   322  		c.counter++
   323  	}
   324  
   325  	c.oldchanges = c.changes
   326  }