gitee.com/quant1x/num@v0.3.2/internal/partial/topk_test.go (about) 1 package partial 2 3 import ( 4 "fmt" 5 "math/rand" 6 "slices" 7 "testing" 8 ) 9 10 func checkTopKInvariants[E any](x []E, k int, less func(E, E) int) bool { 11 sorted := slices.Clone(x) 12 slices.SortFunc(sorted, less) 13 14 if len(x) < 2 { 15 return true 16 } 17 18 // Kth element should be in sorted position 19 if less(x[k-1], sorted[k-1]) < 0 || less(sorted[k-1], x[k-1]) < 0 { // x[k-1] != sorted[k-1] 20 return false 21 } 22 23 // All elements before the kth should be less or equal 24 for _, v := range x[:k-1] { 25 if less(x[k-1], v) < 0 { 26 return false 27 } 28 } 29 30 // All elements following the kth should be greater or equal 31 for _, v := range x[k:] { 32 if less(v, x[k-1]) < 0 { 33 return false 34 } 35 } 36 37 return true 38 } 39 40 type testCase[E any] struct { 41 x []E 42 k int 43 } 44 45 func TestTopK(t *testing.T) { 46 rand.Seed(2) 47 cases := []testCase[int]{ 48 {[]int{}, 1}, 49 {[]int{2}, 1}, 50 {[]int{2, 1}, 1}, 51 {[]int{2, 1}, 2}, 52 {[]int{1, 1, 1}, 2}, 53 {[]int{5, 0, 0, 0, 1}, 2}, 54 {[]int{5, 0, 0, 0, 1}, 5}, 55 } 56 big := make([]int, 100_000) 57 for i := 0; i < 100_000; i++ { 58 big[i] = rand.Intn(10_000) 59 } 60 cases = append(cases, testCase[int]{big, 10_000}) 61 less := func(x, y int) int { return x - y } 62 for _, c := range cases { 63 x := slices.Clone(c.x) 64 TopK(x, c.k) 65 if !checkTopKInvariants(x, c.k, less) { 66 t.Errorf("Invariants failed, in=%v, k=%v, out=%v.", c.x, c.k, x) 67 } 68 } 69 } 70 71 type person struct { 72 name string 73 age int 74 } 75 76 func TestTopKFunc(t *testing.T) { 77 cases := []testCase[person]{ 78 {[]person{{"bob", 45}, {"jane", 31}}, 1}, 79 {[]person{{"bob", 45}, {"jane", 31}}, 2}, 80 {[]person{{"bob", 45}, {"jane", 31}, {"karl", 31}}, 2}, 81 {[]person{{"bob", 45}, {"jane", 31}, {"karl", 31}}, 3}, 82 } 83 less := func(x, y person) int { return x.age - y.age } 84 for _, c := range cases { 85 x := slices.Clone(c.x) 86 TopKFunc(x, c.k, less) 87 if !checkTopKInvariants(x, c.k, less) { 88 t.Errorf("Invariants failed, in=%v, k=%v, out=%v.", c.x, c.k, x) 89 } 90 } 91 } 92 93 func TestTopKOutOfBounds(t *testing.T) { 94 less := func(x, y int) int { return x - y } 95 96 x := []int{9, 2, 5} 97 TopK(x, -1) 98 if !slices.Equal(x, []int{9, 2, 5}) { 99 t.Errorf("Negative k should be treated as zero and sort nothing") 100 } 101 102 y := []int{9, 2, 5} 103 TopK(y, 5) 104 if !checkTopKInvariants(y, 3, less) { 105 t.Errorf("Should take TopK of entire slice when k is greater than len") 106 } 107 } 108 109 func BenchmarkTopK(b *testing.B) { 110 sizes := []int{1_000, 10_000, 100_000} 111 for _, size := range sizes { 112 var x []int 113 for i := 0; i < size; i++ { 114 x = append(x, rand.Intn(size/10)) 115 } 116 k := size / 2 117 b.Run(fmt.Sprintf("slices.Sort_%d", size), func(b *testing.B) { 118 for i := 0; i < b.N; i++ { 119 b.StopTimer() 120 y := slices.Clone(x) 121 b.StartTimer() 122 slices.Sort(y) 123 } 124 }) 125 b.Run(fmt.Sprintf("slices.SortFunc_%d", size), func(b *testing.B) { 126 for i := 0; i < b.N; i++ { 127 b.StopTimer() 128 y := slices.Clone(x) 129 b.StartTimer() 130 slices.SortFunc(y, func(i, j int) int { return i - j }) 131 } 132 }) 133 b.Run(fmt.Sprintf("partial.Sort%d", size), func(b *testing.B) { 134 for i := 0; i < b.N; i++ { 135 b.StopTimer() 136 y := slices.Clone(x) 137 b.StartTimer() 138 Sort(y, k) 139 } 140 }) 141 b.Run(fmt.Sprintf("partial.SortFunc%d", size), func(b *testing.B) { 142 for i := 0; i < b.N; i++ { 143 b.StopTimer() 144 y := slices.Clone(x) 145 b.StartTimer() 146 SortFunc(y, k, func(i, j int) int { return i - j }) 147 } 148 }) 149 b.Run(fmt.Sprintf("partial.TopK_%d", size), func(b *testing.B) { 150 for i := 0; i < b.N; i++ { 151 b.StopTimer() 152 y := slices.Clone(x) 153 b.StartTimer() 154 TopK(y, k) 155 } 156 }) 157 b.Run(fmt.Sprintf("partial.TopKFunc_%d", size), func(b *testing.B) { 158 for i := 0; i < b.N; i++ { 159 b.StopTimer() 160 y := slices.Clone(x) 161 b.StartTimer() 162 TopKFunc(y, k, func(i, j int) int { return i - j }) 163 } 164 }) 165 } 166 }