github.com/fluhus/gostuff@v0.4.1-0.20240331134726-be71864f2b5d/clustering/kmeans.go (about)

     1  // Package clustering provides basic clustering functions.
     2  package clustering
     3  
     4  import (
     5  	"fmt"
     6  	"math"
     7  	"math/rand"
     8  
     9  	"github.com/fluhus/gostuff/gnum"
    10  )
    11  
    12  // Kmeans performs k-means clustering on the given data. Each vector is an
    13  // element in the clustering. Returns the generated means, and the tag each
    14  // element was given.
    15  func Kmeans(vecs [][]float64, k int) (means [][]float64, tags []int) {
    16  	// K must be at least 1.
    17  	if k < 1 {
    18  		panic(fmt.Sprint("Bad k:", k))
    19  	}
    20  
    21  	// Must have at least 1 vector.
    22  	if len(vecs) == 0 {
    23  		panic("Cannot cluster 0 vectors.")
    24  	}
    25  
    26  	// If k is too large - that's ok just reduce to avoid out-of-range.
    27  	if k > len(vecs) {
    28  		k = len(vecs)
    29  	}
    30  
    31  	// First tagging.
    32  	means = initialMeans(vecs, k)
    33  	tags = tag(vecs, means, make([]int, len(vecs)))
    34  	dist := MeanSquaredError(vecs, means, tags)
    35  	distOld := 2 * dist
    36  
    37  	// Iterate until converged.
    38  	for dist > distOld || dist/distOld < 0.999 {
    39  		distOld = dist
    40  		means = findMeans(vecs, tags, k)
    41  		tags = tag(vecs, means, tags)
    42  		dist = MeanSquaredError(vecs, means, tags)
    43  	}
    44  
    45  	return
    46  }
    47  
    48  // tag tags each row with the index of its nearest centroid. The old tags are
    49  // used for optimization.
    50  func tag(vecs, means [][]float64, oldTags []int) []int {
    51  	if len(means) == 0 {
    52  		panic("Cannot tag on 0 centroids.")
    53  	}
    54  
    55  	// Create a distance matrix of means from one another.
    56  	meansd := make([][]float64, len(means))
    57  	for i := range meansd {
    58  		meansd[i] = make([]float64, len(means))
    59  		for j := range means {
    60  			meansd[i][j] = gnum.L2(means[i], means[j])
    61  		}
    62  	}
    63  
    64  	tags := make([]int, len(vecs))
    65  
    66  	// Go over vectors.
    67  	for i := range vecs {
    68  		// Find nearest centroid.
    69  		tags[i] = oldTags[i]
    70  		d := gnum.L2(means[oldTags[i]], vecs[i])
    71  
    72  		for j := 0; j < len(means); j++ {
    73  			// Use triangle inequality to skip means that are too distant.
    74  			if j == tags[i] || meansd[j][tags[i]] >= 2*d {
    75  				continue
    76  			}
    77  
    78  			dj := gnum.L2(means[j], vecs[i])
    79  			if dj < d {
    80  				d = dj
    81  				tags[i] = j
    82  			}
    83  		}
    84  	}
    85  
    86  	return tags
    87  }
    88  
    89  // findMeans calculates the new means, according to average of tagged rows in
    90  // each group.
    91  func findMeans(vecs [][]float64, tags []int, k int) [][]float64 {
    92  	// Initialize new arrays.
    93  	means := make([][]float64, k)
    94  	for i := range means {
    95  		means[i] = make([]float64, len(vecs[0]))
    96  	}
    97  	counts := make([]int, k)
    98  
    99  	// Sum all vectors according to tags.
   100  	for i := range vecs {
   101  		counts[tags[i]]++
   102  		gnum.Add(means[tags[i]], vecs[i])
   103  	}
   104  
   105  	// Divide by count.
   106  	for i := range means {
   107  		if counts[i] != 0 {
   108  			gnum.Mul1(means[i], 1/float64(counts[i]))
   109  		}
   110  	}
   111  
   112  	return means
   113  }
   114  
   115  // initialMeans picks the initial means with the K-means++ algorithm.
   116  func initialMeans(vecs [][]float64, k int) [][]float64 {
   117  	result := make([][]float64, k)
   118  	perm := rand.Perm(len(vecs))
   119  	numTrials := 2 + int(math.Log(float64(k)))
   120  
   121  	probs := make([]float64, len(vecs))    // Probability of each vector.
   122  	nearest := make([]int, len(vecs))      // Index of nearest mean to each vector.
   123  	distance := make([]float64, len(vecs)) // Distance to nearest mean.
   124  	mdistance := make([][]float64, k)      // Distance between means.
   125  	for i := range mdistance {
   126  		mdistance[i] = make([]float64, k)
   127  	}
   128  
   129  	// Pick each mean.
   130  	for i := range result {
   131  		result[i] = make([]float64, len(vecs[0]))
   132  
   133  		// First mean is first vector.
   134  		if i == 0 {
   135  			copy(result[0], vecs[perm[0]])
   136  			for _, j := range perm {
   137  				distance[j] = gnum.L2(vecs[j], result[0])
   138  			}
   139  			continue
   140  		}
   141  
   142  		// Find next mean.
   143  		bestCandidate := -1
   144  		bestImprovement := -math.MaxFloat64
   145  
   146  		for t := 0; t < numTrials; t++ { // Make a few attempts.
   147  			sum := 0.0
   148  			for _, j := range perm {
   149  				probs[j] = distance[j] * distance[j]
   150  				sum += probs[j]
   151  			}
   152  			// Pick element with probability relative to d^2.
   153  			r := rand.Float64() * sum
   154  			newMean := 0
   155  			for r > probs[newMean] {
   156  				r -= probs[newMean]
   157  				newMean++
   158  			}
   159  			copy(result[i], vecs[newMean])
   160  
   161  			// Update distances from new mean to other means.
   162  			for j := range mdistance[:i] {
   163  				mdistance[j][i] = gnum.L2(result[i], result[j])
   164  				mdistance[i][j] = mdistance[j][i]
   165  			}
   166  
   167  			// Check improvement.
   168  			newImprovement := 0.0
   169  			for j := range vecs {
   170  				if mdistance[i][nearest[j]] < 2*distance[j] {
   171  					d := gnum.L2(vecs[j], result[i])
   172  					d = math.Min(distance[j], d)
   173  					newImprovement += distance[j] - d
   174  				}
   175  			}
   176  			if newImprovement > bestImprovement {
   177  				bestCandidate = newMean
   178  				bestImprovement = newImprovement
   179  			}
   180  		}
   181  
   182  		copy(result[i], vecs[bestCandidate])
   183  
   184  		// Update distances.
   185  		for j := range mdistance[:i] { // From new mean to other means.
   186  			mdistance[j][i] = gnum.L2(result[i], result[j])
   187  			mdistance[i][j] = mdistance[j][i]
   188  		}
   189  		for j := range vecs { // From vecs to nearest means.
   190  			if mdistance[i][nearest[j]] < 2*distance[j] {
   191  				d := gnum.L2(vecs[j], result[i])
   192  				if d < distance[j] {
   193  					distance[j] = math.Min(distance[j], d)
   194  					nearest[j] = i
   195  				}
   196  			}
   197  		}
   198  	}
   199  
   200  	return result
   201  }
   202  
   203  // MeanSquaredError calculates the average squared-distance of elements from
   204  // their assigned means.
   205  func MeanSquaredError(vecs, means [][]float64, tags []int) float64 {
   206  	if len(tags) != len(vecs) {
   207  		panic(fmt.Sprintf("Non-matching lengths of matrix and tags: %d, %d",
   208  			len(vecs), len(tags)))
   209  	}
   210  	if len(vecs) == 0 {
   211  		return 0
   212  	}
   213  
   214  	d := 0.0
   215  	for i := range tags {
   216  		dist := gnum.L2(means[tags[i]], vecs[i])
   217  		d += dist * dist
   218  	}
   219  
   220  	return d / float64(len(vecs))
   221  }