github.com/grafana/pyroscope@v1.18.0/pkg/util/loser/tree_test.go (about) 1 package loser_test 2 3 import ( 4 "errors" 5 "math" 6 "testing" 7 8 "github.com/stretchr/testify/assert" 9 "github.com/stretchr/testify/require" 10 11 "github.com/grafana/pyroscope/pkg/util/loser" 12 ) 13 14 type List struct { 15 list []uint64 16 cur uint64 17 18 err error 19 20 closed int 21 } 22 23 func NewList(list ...uint64) *List { 24 return &List{list: list} 25 } 26 27 func (it *List) At() uint64 { 28 return it.cur 29 } 30 31 func (it *List) Err() error { return it.err } 32 33 func (it *List) Next() bool { 34 if it.err != nil { 35 return false 36 } 37 if len(it.list) > 0 { 38 it.cur = it.list[0] 39 it.list = it.list[1:] 40 return true 41 } 42 it.cur = 0 43 return false 44 } 45 46 func (it *List) Close() { it.closed += 1 } 47 48 func (it *List) Seek(val uint64) bool { 49 if it.err != nil { 50 return false 51 } 52 for it.cur < val && len(it.list) > 0 { 53 it.cur = it.list[0] 54 it.list = it.list[1:] 55 } 56 return len(it.list) > 0 57 } 58 59 func checkIterablesEqual[E any, S1 loser.Sequence, S2 loser.Sequence](t *testing.T, a S1, b S2, at1 func(S1) E, at2 func(S2) E, less func(E, E) bool) { 60 t.Helper() 61 count := 0 62 for a.Next() { 63 count++ 64 if !b.Next() { 65 t.Fatalf("b ended before a after %d elements", count) 66 } 67 if less(at1(a), at2(b)) || less(at2(b), at1(a)) { 68 t.Fatalf("position %d: %v != %v", count, at1(a), at2(b)) 69 } 70 } 71 if b.Next() { 72 t.Fatalf("a ended before b after %d elements", count) 73 } 74 } 75 76 var testCases = []struct { 77 name string 78 args []*List 79 want *List 80 }{ 81 { 82 name: "empty input", 83 want: NewList(), 84 }, 85 { 86 name: "one list", 87 args: []*List{NewList(1, 2, 3, 4)}, 88 want: NewList(1, 2, 3, 4), 89 }, 90 { 91 name: "two lists", 92 args: []*List{NewList(3, 4, 5), NewList(1, 2)}, 93 want: NewList(1, 2, 3, 4, 5), 94 }, 95 { 96 name: "two lists, first empty", 97 args: []*List{NewList(), NewList(1, 2)}, 98 want: NewList(1, 2), 99 }, 100 { 101 name: "two lists, second empty", 102 args: []*List{NewList(1, 2), NewList()}, 103 want: NewList(1, 2), 104 }, 105 { 106 name: "two lists b", 107 args: []*List{NewList(1, 2), NewList(3, 4, 5)}, 108 want: NewList(1, 2, 3, 4, 5), 109 }, 110 { 111 name: "two lists c", 112 args: []*List{NewList(1, 3), NewList(2, 4, 5)}, 113 want: NewList(1, 2, 3, 4, 5), 114 }, 115 { 116 name: "three lists", 117 args: []*List{NewList(1, 3), NewList(2, 4), NewList(5)}, 118 want: NewList(1, 2, 3, 4, 5), 119 }, 120 } 121 122 func TestMerge(t *testing.T) { 123 at := func(s *List) uint64 { return s.At() } 124 less := func(a, b uint64) bool { return a < b } 125 at2 := func(s *loser.Tree[uint64, *List]) uint64 { return s.Winner().At() } 126 for _, tt := range testCases { 127 t.Run(tt.name, func(t *testing.T) { 128 numCloses := 0 129 close := func(s *List) { 130 numCloses++ 131 } 132 lt := loser.New(tt.args, math.MaxUint64, at, less, close) 133 checkIterablesEqual(t, tt.want, lt, at, at2, less) 134 if numCloses != len(tt.args) { 135 t.Errorf("Expected %d closes, got %d", len(tt.args), numCloses) 136 } 137 }) 138 } 139 } 140 141 func TestPush(t *testing.T) { 142 at := func(s *List) uint64 { return s.At() } 143 less := func(a, b uint64) bool { return a < b } 144 at2 := func(s *loser.Tree[uint64, *List]) uint64 { return s.Winner().At() } 145 for _, tt := range testCases { 146 t.Run(tt.name, func(t *testing.T) { 147 numCloses := 0 148 close := func(s *List) { 149 numCloses++ 150 } 151 lt := loser.New(nil, math.MaxUint64, at, less, close) 152 for _, s := range tt.args { 153 if err := lt.Push(s); err != nil { 154 t.Fatalf("Push failed: %v", err) 155 } 156 } 157 checkIterablesEqual(t, tt.want, lt, at, at2, less) 158 if numCloses != len(tt.args) { 159 t.Errorf("Expected %d closes, got %d", len(tt.args), numCloses) 160 } 161 }) 162 } 163 } 164 165 func TestInitWithErr(t *testing.T) { 166 lists := []*List{ 167 NewList(), 168 NewList(5, 6, 7, 8), 169 } 170 lists[0].err = errTest 171 tree := loser.New(lists, math.MaxUint64, func(s *List) uint64 { return s.At() }, func(a, b uint64) bool { return a < b }, func(s *List) { s.Close() }) 172 if tree.Next() { 173 t.Errorf("Next() should have returned false") 174 } 175 if tree.Err() != errTest { 176 t.Errorf("Err() should have returned %v, got %v", errTest, tree.Err()) 177 } 178 179 tree.Close() 180 for _, l := range lists { 181 assert.Equal(t, l.closed, 1, "list %+#v not closed exactly once", l) 182 } 183 184 } 185 186 var errTest = errors.New("test") 187 188 func TestErrDuringNext(t *testing.T) { 189 lists := []*List{ 190 NewList(5, 6), 191 NewList(11, 12), 192 } 193 tree := loser.New(lists, math.MaxUint64, func(s *List) uint64 { return s.At() }, func(a, b uint64) bool { return a < b }, func(s *List) { s.Close() }) 194 195 // no error for first element 196 if !tree.Next() { 197 t.Errorf("Next() should have returned true") 198 } 199 // now error for second 200 lists[0].err = errTest 201 if tree.Next() { 202 t.Errorf("Next() should have returned false") 203 } 204 if tree.Err() != errTest { 205 t.Errorf("Err() should have returned %v, got %v", errTest, tree.Err()) 206 } 207 if tree.Next() { 208 t.Errorf("Next() should have returned false") 209 } 210 211 tree.Close() 212 for _, l := range lists { 213 assert.Equal(t, l.closed, 1, "list %+#v not closed exactly once", l) 214 } 215 } 216 217 func TestErrInOneIterator(t *testing.T) { 218 l := NewList() 219 l.err = errors.New("test") 220 221 lists := []*List{ 222 NewList(5, 1), 223 l, 224 NewList(2, 4), 225 } 226 tree := loser.New(lists, math.MaxUint64, func(s *List) uint64 { return s.At() }, func(a, b uint64) bool { return a < b }, func(s *List) { s.Close() }) 227 228 // error for first element 229 require.False(t, tree.Next()) 230 assert.Equal(t, l.err, tree.Err()) 231 232 tree.Close() 233 for _, l := range lists { 234 assert.Equal(t, l.closed, 1, "list %+#v not closed exactly once", l) 235 } 236 }