
     1  package clustering
     3  import (
     4  	"fmt"
     5  )
     7  // AdjustedRandIndex compares 2 taggings of the data for similarity. A score of
     8  // 1 means identical, a score of 0 means as good as random, and a negative
     9  // score means worse than random.
    10  func AdjustedRandIndex(tags1, tags2 []int) float64 {
    11  	// Check input.
    12  	if len(tags1) != len(tags2) {
    13  		panic(fmt.Sprintf("Mismatching lengths: %d, %d",
    14  			len(tags1), len(tags2)))
    15  	}
    17  	sets1 := tagsToSets(tags1)
    18  	sets2 := tagsToSets(tags2)
    20  	r := randIndex(sets1, sets2)
    21  	e := expectedRandIndex(sets1, sets2)
    22  	m := maxRandIndex(sets1, sets2)
    23  	return (r - e) / (m - e)
    24  }
    26  // randIndex returns the RI part of the adjusted index.
    27  func randIndex(tags1, tags2 []intSet) float64 {
    28  	r := 0
    29  	for _, t1 := range tags1 {
    30  		for _, t2 := range tags2 {
    31  			r += choose2(t1.intersect(t2))
    32  		}
    33  	}
    34  	return float64(r)
    35  }
    37  // expectedRandIndex returns the expected index according to hypergeometrical
    38  // distribution.
    39  func expectedRandIndex(tags1, tags2 []intSet) float64 {
    40  	p1 := 0
    41  	n := 0
    42  	for _, tags := range tags1 {
    43  		n += len(tags)
    44  		p1 += choose2(len(tags))
    45  	}
    46  	p2 := 0
    47  	for _, tags := range tags2 {
    48  		p2 += choose2(len(tags))
    49  	}
    50  	p := float64(choose2(n))
    51  	return float64(p1) * float64(p2) / p
    52  }
    54  // maxRandIndex returns the maximal possible index.
    55  func maxRandIndex(tags1, tags2 []intSet) float64 {
    56  	p := 0
    57  	for _, tags := range tags1 {
    58  		p += choose2(len(tags))
    59  	}
    60  	for _, tags := range tags2 {
    61  		p += choose2(len(tags))
    62  	}
    63  	return float64(p) / 2
    64  }
    66  func choose2(n int) int {
    67  	return n * (n - 1) / 2
    68  }
    70  // ----- INT SET --------------------------------------------------------------
    72  // intSet is a set of integers.
    73  type intSet map[int]struct{}
    75  // tagsToSets converts a list of tags to a list of sets of indexes, one list
    76  // for each tag.
    77  func tagsToSets(tags []int) []intSet {
    78  	// Make map from tag to its set.
    79  	sets := map[int]intSet{}
    80  	for i, tag := range tags {
    81  		if sets[tag] == nil {
    82  			sets[tag] = intSet{}
    83  		}
    84  		sets[tag].add(i)
    85  	}
    87  	// Convert map to slice.
    88  	result := make([]intSet, 0, len(sets))
    89  	for _, set := range sets {
    90  		result = append(result, set)
    91  	}
    93  	return result
    94  }
    96  // add adds a number to the set.
    97  func (is intSet) add(i int) {
    98  	is[i] = struct{}{}
    99  }
   101  // contains checks if a set contains the given element.
   102  func (is intSet) contains(i int) bool {
   103  	_, ok := is[i]
   104  	return ok
   105  }
   107  // intersect returns the size of the intersection of the 2 sets.
   108  func (is intSet) intersect(other intSet) int {
   109  	result := 0
   110  	for i := range is {
   111  		if other.contains(i) {
   112  			result++
   113  		}
   114  	}
   115  	return result
   116  }