github.com/fluhus/gostuff@v0.4.1-0.20240331134726-be71864f2b5d/clustering/upgma.go (about)

     1  package clustering
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  
     7  	"github.com/fluhus/gostuff/gnum"
     8  	"github.com/fluhus/gostuff/heaps"
     9  )
    10  
    11  // distPyramid is a distance half-matrix.
    12  type distPyramid [][]float64
    13  
    14  // dist returns the distance between a and b.
    15  func (d distPyramid) dist(a, b int) float64 {
    16  	if a > b {
    17  		return d[a][b]
    18  	}
    19  	return d[b][a]
    20  }
    21  
    22  // makePyramid creates a distance half-matrix.
    23  func makePyramid(n int, f func(int, int) float64) distPyramid {
    24  	nn := n * (n - 1) / 2
    25  	d := make([]float64, 0, nn)
    26  	for i := 1; i < n; i++ {
    27  		for j := 0; j < i; j++ {
    28  			d = append(d, f(j, i))
    29  		}
    30  	}
    31  	result := make([][]float64, n)
    32  	j := 0
    33  	for i := range result {
    34  		result[i] = d[j : j+i]
    35  		j += i
    36  	}
    37  	return result
    38  }
    39  
    40  // upgma is an implementation of UPGMA clustering. The distance between clusters
    41  // is the average distance between pairs of their individual elements.
    42  func upgma(n int, f func(int, int) float64) *AggloResult {
    43  	pi := make([]int, n)         // Index of first merge target of each element.
    44  	lambda := make([]float64, n) // Distance of first merge target of each element.
    45  
    46  	// Last cluster does not get matched with anyone -> max distance.
    47  	lambda[len(lambda)-1] = math.MaxFloat64
    48  
    49  	// Calculate raw distances.
    50  	d := makePyramid(n, f)
    51  	heapss := make([]*heaps.Heap[upgmaCluster], n)
    52  	for i := range heapss {
    53  		heapss[i] = heaps.New(compareUpgmaClusters)
    54  	}
    55  	for i := 1; i < n; i++ {
    56  		for j := 0; j < i; j++ {
    57  			heapss[i].Push(upgmaCluster{j, d[i][j]})
    58  			heapss[j].Push(upgmaCluster{i, d[i][j]})
    59  		}
    60  	}
    61  
    62  	// Clustering.
    63  	sizes := gnum.Ones[[]float64](n) // Cluster sizes
    64  	// The identifier of each cluster = highest index of an element
    65  	names := make([]int, n)
    66  	for i := range names {
    67  		names[i] = i
    68  	}
    69  	for i := 0; i < n-1; i++ {
    70  		// Find lowest distance.
    71  		fmin := math.MaxFloat64
    72  		a, b := -1, -1
    73  		for hi, h := range heapss {
    74  			if h == nil {
    75  				continue
    76  			}
    77  			// Clean up removed clusters.
    78  			if h.Len() == 0 {
    79  				panic(fmt.Sprintf("heap %d with length 0", hi))
    80  			}
    81  			for heapss[h.Head().i] == nil {
    82  				h.Pop()
    83  			}
    84  			if h.Head().d < fmin {
    85  				a = hi
    86  				fmin = h.Head().d
    87  				b = h.Head().i
    88  			}
    89  		}
    90  
    91  		// Create agglo step.
    92  		nmin := min(names[a], names[b])
    93  		nmax := max(names[a], names[b])
    94  		pi[nmin] = nmax
    95  		lambda[nmin] = fmin
    96  
    97  		// Merge clusters.
    98  		names = append(names, nmax)
    99  		sizes = append(sizes, sizes[a]+sizes[b])
   100  		heapss[a] = nil
   101  		heapss[b] = nil
   102  		var cdist []float64
   103  		cheap := heaps.New(compareUpgmaClusters)
   104  		for hi, h := range heapss {
   105  			if h == nil {
   106  				cdist = append(cdist, 0)
   107  				continue
   108  			}
   109  			da := d.dist(a, hi) * sizes[a]
   110  			db := d.dist(b, hi) * sizes[b]
   111  			dd := (da + db) / (sizes[a] + sizes[b])
   112  			cdist = append(cdist, dd)
   113  			h.Push(upgmaCluster{len(sizes) - 1, dd})
   114  			cheap.Push(upgmaCluster{hi, dd})
   115  		}
   116  		d = append(d, cdist)
   117  		heapss = append(heapss, cheap)
   118  	}
   119  
   120  	return newAggloResult(pi, lambda)
   121  }
   122  
   123  // Cluster info in UPGMA.
   124  type upgmaCluster struct {
   125  	i int     // Cluster index
   126  	d float64 // Distance from cluster i
   127  }
   128  
   129  func compareUpgmaClusters(a, b upgmaCluster) bool {
   130  	return a.d < b.d
   131  }