github.com/fluhus/gostuff@v0.4.1-0.20240331134726-be71864f2b5d/clustering/agglo.go (about)

     1  package clustering
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  	"sort"
     7  )
     8  
     9  // How agglomerative clustering should calculate distance between clusters.
    10  const (
    11  	AggloMin     = iota // Minimal distance between any pair of elements.
    12  	AggloMax            // Maximal distance between any pair of elements.
    13  	AggloAverage        // Average distance between any pair of elements.
    14  )
    15  
    16  // Agglo performs agglomerative clustering on the indexes 0 to n-1. d should
    17  // return the distance between the i'th and j'th element, such that
    18  // d(i,j)=d(j,i) and d(i,i)=0.
    19  //
    20  // clusterDist should be one of AggloMin or AggloMax.
    21  //
    22  // Works in O(n^2) time and makes O(n^2) calls to d.
    23  func Agglo(n int, clusterDist int, d func(int, int) float64) *AggloResult {
    24  	if n <= 0 {
    25  		panic(fmt.Sprintf("Bad n: %d, must be positive", n))
    26  	}
    27  
    28  	switch clusterDist {
    29  	case AggloMin:
    30  		return slink(n, d)
    31  	case AggloMax:
    32  		return clink(n, d)
    33  	case AggloAverage:
    34  		return upgma(n, d)
    35  	default:
    36  		panic(fmt.Sprintf("Unsupported cluster distance: %v, "+
    37  			"want AggloMin or AggloMax", clusterDist))
    38  	}
    39  }
    40  
    41  // slink is an implementation of the SLINK algorithm.
    42  //
    43  // Copied from:
    44  // https://www.cs.ucsb.edu/~veronika/MAE/SLINK_sibson.pdf
    45  func slink(n int, d func(int, int) float64) *AggloResult {
    46  	// Implementation copied from paper, pardon the crap names.
    47  	m := make([]float64, n)      // Distance of i'th element from elements/clusters.
    48  	pi := make([]int, n)         // Index of first merge target of each element.
    49  	lambda := make([]float64, n) // Distance of first merge target of each element.
    50  
    51  	lambda[0] = math.MaxFloat64
    52  
    53  	for i := 1; i < n; i++ {
    54  		pi[i] = i
    55  		lambda[i] = math.MaxFloat64
    56  
    57  		for j := 0; j < i; j++ {
    58  			m[j] = d(i, j)
    59  		}
    60  
    61  		for j := 0; j < i; j++ {
    62  			if m[j] <= lambda[j] {
    63  				m[pi[j]] = math.Min(m[pi[j]], lambda[j])
    64  				lambda[j] = m[j]
    65  				pi[j] = i
    66  			} else {
    67  				m[pi[j]] = math.Min(m[pi[j]], m[j])
    68  			}
    69  		}
    70  
    71  		for j := 0; j < i; j++ {
    72  			if lambda[j] >= lambda[pi[j]] {
    73  				pi[j] = i
    74  			}
    75  		}
    76  	}
    77  
    78  	return newAggloResult(pi, lambda)
    79  }
    80  
    81  // clink is an implementation of the CLINK algorithm.
    82  //
    83  // Copied from:
    84  // https://academic.oup.com/comjnl/article-pdf/20/4/364/1108735/200364.pdf
    85  func clink(n int, d func(int, int) float64) *AggloResult {
    86  	// Implementation copied from paper, pardon the crap names.
    87  	m := make([]float64, n)      // Distance of i'th element from elements/clusters.
    88  	pi := make([]int, n)         // Index of first merge target of each element.
    89  	lambda := make([]float64, n) // Distance of first merge target of each element.
    90  
    91  	lambda[0] = math.MaxFloat64
    92  
    93  	for i := 1; i < n; i++ {
    94  		pi[i] = i
    95  		lambda[i] = math.MaxFloat64
    96  
    97  		for j := 0; j < i; j++ {
    98  			m[j] = d(i, j)
    99  		}
   100  
   101  		for j := 0; j < i; j++ {
   102  			if lambda[j] < m[j] {
   103  				m[pi[j]] = math.Max(m[pi[j]], m[j])
   104  				m[j] = math.MaxFloat64
   105  			}
   106  		}
   107  
   108  		a := i - 1
   109  		for j := 0; j < i; j++ {
   110  			if lambda[i-j-1] >= m[pi[i-j-1]] {
   111  				if m[i-j-1] < m[a] {
   112  					a = i - j - 1
   113  				}
   114  			} else {
   115  				m[i-j-1] = math.MaxFloat64
   116  			}
   117  		}
   118  
   119  		b := pi[a]
   120  		c := lambda[a]
   121  		pi[a] = i
   122  		lambda[a] = m[a]
   123  		for a < i-1 {
   124  			if b < i-1 {
   125  				d := pi[b]
   126  				e := lambda[b]
   127  				pi[b] = i
   128  				lambda[b] = c
   129  				b = d
   130  				c = e
   131  			} else if b == i-1 {
   132  				pi[b] = i
   133  				lambda[b] = c
   134  				break
   135  			}
   136  		}
   137  
   138  		for j := 0; j < i; j++ {
   139  			if pi[pi[j]] == i && lambda[j] >= lambda[pi[j]] {
   140  				pi[j] = i
   141  			}
   142  		}
   143  	}
   144  
   145  	return newAggloResult(pi, lambda)
   146  }
   147  
   148  // AggloResult is an interactive agglomerative-clustering result.
   149  type AggloResult struct {
   150  	pi     []int
   151  	lambda []float64
   152  	perm   []int
   153  	dict   []string
   154  }
   155  
   156  // Dict returns the string representations of elements in the clustering.
   157  func (r *AggloResult) Dict() []string {
   158  	return r.dict
   159  }
   160  
   161  // newAggloResult creates a new result.
   162  func newAggloResult(pi []int, lambda []float64) *AggloResult {
   163  	result := &AggloResult{pi, lambda, make([]int, len(pi)), nil}
   164  	for i := range result.perm {
   165  		result.perm[i] = i
   166  	}
   167  	sort.Sort((*aggloSorter)(result))
   168  	return result
   169  }
   170  
   171  // aggloSorter is a sorting interface for AggloResult, for sorting by distance.
   172  // This actually sorts the agglomerative steps by their order of occurrence.
   173  type aggloSorter AggloResult
   174  
   175  // Len returns the number of elements in the sorter.
   176  func (r *aggloSorter) Len() int {
   177  	return len(r.perm)
   178  }
   179  
   180  // Less compares two steps by their order of occurrence.
   181  func (r *aggloSorter) Less(i, j int) bool {
   182  	return r.lambda[r.perm[i]] < r.lambda[r.perm[j]]
   183  }
   184  
   185  // Swap swaps two steps.
   186  func (r *aggloSorter) Swap(i, j int) {
   187  	r.perm[i], r.perm[j] = r.perm[j], r.perm[i]
   188  }
   189  
   190  // SetDict sets the string representation of each element, for the String()
   191  // function. Returns itself for chaining.
   192  func (r *AggloResult) SetDict(dict []string) *AggloResult {
   193  	if len(dict) != len(r.perm) {
   194  		panic(fmt.Sprintf("Bad dictionary size: %d, expected %d",
   195  			len(dict), len(r.perm)))
   196  	}
   197  	r.dict = dict
   198  	return r
   199  }
   200  
   201  // String returns a representation of the clustering. If SetDict was not
   202  // called, will use element numbers.
   203  func (r *AggloResult) String() string {
   204  	strs := make([]string, len(r.perm))
   205  	for i := range strs {
   206  		if r.dict == nil {
   207  			strs[i] = fmt.Sprint(i)
   208  		} else {
   209  			strs[i] = r.dict[i]
   210  		}
   211  	}
   212  	for _, i := range r.perm {
   213  		j := r.pi[i]
   214  		if i == j { // Reached the end.
   215  			return strs[i]
   216  		}
   217  		strs[j] = fmt.Sprintf("[%s, %s]", strs[i], strs[j])
   218  	}
   219  
   220  	// TODO(amit): Panic if reached here.
   221  	return ""
   222  }
   223  
   224  // Len returns the number of steps in this clustering. Equals the number of
   225  // elements - 1.
   226  func (r *AggloResult) Len() int {
   227  	return len(r.perm) - 1
   228  }
   229  
   230  // An AggloStep is a single step in the clustering process.
   231  // The index of a cluster is the greatest indexed element in it.
   232  // C2 is always greater than C1.
   233  type AggloStep struct {
   234  	C1 int     // Index of the first merged cluster.
   235  	C2 int     // Index of the second merged cluster.
   236  	D  float64 // Distance between the clusters when merging.
   237  }
   238  
   239  // Step returns the i'th step in the clustering.
   240  func (r *AggloResult) Step(i int) AggloStep {
   241  	return AggloStep{r.perm[i], r.pi[r.perm[i]], r.lambda[r.perm[i]]}
   242  }