gitee.com/quant1x/num@v0.3.2/internal/partial/topk_test.go (about)

     1  package partial
     2  
     3  import (
     4  	"fmt"
     5  	"math/rand"
     6  	"slices"
     7  	"testing"
     8  )
     9  
    10  func checkTopKInvariants[E any](x []E, k int, less func(E, E) int) bool {
    11  	sorted := slices.Clone(x)
    12  	slices.SortFunc(sorted, less)
    13  
    14  	if len(x) < 2 {
    15  		return true
    16  	}
    17  
    18  	// Kth element should be in sorted position
    19  	if less(x[k-1], sorted[k-1]) < 0 || less(sorted[k-1], x[k-1]) < 0 { // x[k-1] != sorted[k-1]
    20  		return false
    21  	}
    22  
    23  	// All elements before the kth should be less or equal
    24  	for _, v := range x[:k-1] {
    25  		if less(x[k-1], v) < 0 {
    26  			return false
    27  		}
    28  	}
    29  
    30  	// All elements following the kth should be greater or equal
    31  	for _, v := range x[k:] {
    32  		if less(v, x[k-1]) < 0 {
    33  			return false
    34  		}
    35  	}
    36  
    37  	return true
    38  }
    39  
    40  type testCase[E any] struct {
    41  	x []E
    42  	k int
    43  }
    44  
    45  func TestTopK(t *testing.T) {
    46  	rand.Seed(2)
    47  	cases := []testCase[int]{
    48  		{[]int{}, 1},
    49  		{[]int{2}, 1},
    50  		{[]int{2, 1}, 1},
    51  		{[]int{2, 1}, 2},
    52  		{[]int{1, 1, 1}, 2},
    53  		{[]int{5, 0, 0, 0, 1}, 2},
    54  		{[]int{5, 0, 0, 0, 1}, 5},
    55  	}
    56  	big := make([]int, 100_000)
    57  	for i := 0; i < 100_000; i++ {
    58  		big[i] = rand.Intn(10_000)
    59  	}
    60  	cases = append(cases, testCase[int]{big, 10_000})
    61  	less := func(x, y int) int { return x - y }
    62  	for _, c := range cases {
    63  		x := slices.Clone(c.x)
    64  		TopK(x, c.k)
    65  		if !checkTopKInvariants(x, c.k, less) {
    66  			t.Errorf("Invariants failed, in=%v, k=%v, out=%v.", c.x, c.k, x)
    67  		}
    68  	}
    69  }
    70  
    71  type person struct {
    72  	name string
    73  	age  int
    74  }
    75  
    76  func TestTopKFunc(t *testing.T) {
    77  	cases := []testCase[person]{
    78  		{[]person{{"bob", 45}, {"jane", 31}}, 1},
    79  		{[]person{{"bob", 45}, {"jane", 31}}, 2},
    80  		{[]person{{"bob", 45}, {"jane", 31}, {"karl", 31}}, 2},
    81  		{[]person{{"bob", 45}, {"jane", 31}, {"karl", 31}}, 3},
    82  	}
    83  	less := func(x, y person) int { return x.age - y.age }
    84  	for _, c := range cases {
    85  		x := slices.Clone(c.x)
    86  		TopKFunc(x, c.k, less)
    87  		if !checkTopKInvariants(x, c.k, less) {
    88  			t.Errorf("Invariants failed, in=%v, k=%v, out=%v.", c.x, c.k, x)
    89  		}
    90  	}
    91  }
    92  
    93  func TestTopKOutOfBounds(t *testing.T) {
    94  	less := func(x, y int) int { return x - y }
    95  
    96  	x := []int{9, 2, 5}
    97  	TopK(x, -1)
    98  	if !slices.Equal(x, []int{9, 2, 5}) {
    99  		t.Errorf("Negative k should be treated as zero and sort nothing")
   100  	}
   101  
   102  	y := []int{9, 2, 5}
   103  	TopK(y, 5)
   104  	if !checkTopKInvariants(y, 3, less) {
   105  		t.Errorf("Should take TopK of entire slice when k is greater than len")
   106  	}
   107  }
   108  
   109  func BenchmarkTopK(b *testing.B) {
   110  	sizes := []int{1_000, 10_000, 100_000}
   111  	for _, size := range sizes {
   112  		var x []int
   113  		for i := 0; i < size; i++ {
   114  			x = append(x, rand.Intn(size/10))
   115  		}
   116  		k := size / 2
   117  		b.Run(fmt.Sprintf("slices.Sort_%d", size), func(b *testing.B) {
   118  			for i := 0; i < b.N; i++ {
   119  				b.StopTimer()
   120  				y := slices.Clone(x)
   121  				b.StartTimer()
   122  				slices.Sort(y)
   123  			}
   124  		})
   125  		b.Run(fmt.Sprintf("slices.SortFunc_%d", size), func(b *testing.B) {
   126  			for i := 0; i < b.N; i++ {
   127  				b.StopTimer()
   128  				y := slices.Clone(x)
   129  				b.StartTimer()
   130  				slices.SortFunc(y, func(i, j int) int { return i - j })
   131  			}
   132  		})
   133  		b.Run(fmt.Sprintf("partial.Sort%d", size), func(b *testing.B) {
   134  			for i := 0; i < b.N; i++ {
   135  				b.StopTimer()
   136  				y := slices.Clone(x)
   137  				b.StartTimer()
   138  				Sort(y, k)
   139  			}
   140  		})
   141  		b.Run(fmt.Sprintf("partial.SortFunc%d", size), func(b *testing.B) {
   142  			for i := 0; i < b.N; i++ {
   143  				b.StopTimer()
   144  				y := slices.Clone(x)
   145  				b.StartTimer()
   146  				SortFunc(y, k, func(i, j int) int { return i - j })
   147  			}
   148  		})
   149  		b.Run(fmt.Sprintf("partial.TopK_%d", size), func(b *testing.B) {
   150  			for i := 0; i < b.N; i++ {
   151  				b.StopTimer()
   152  				y := slices.Clone(x)
   153  				b.StartTimer()
   154  				TopK(y, k)
   155  			}
   156  		})
   157  		b.Run(fmt.Sprintf("partial.TopKFunc_%d", size), func(b *testing.B) {
   158  			for i := 0; i < b.N; i++ {
   159  				b.StopTimer()
   160  				y := slices.Clone(x)
   161  				b.StartTimer()
   162  				TopKFunc(y, k, func(i, j int) int { return i - j })
   163  			}
   164  		})
   165  	}
   166  }