github.com/grafana/pyroscope@v1.18.0/pkg/util/loser/tree_test.go (about)

     1  package loser_test
     2  
     3  import (
     4  	"errors"
     5  	"math"
     6  	"testing"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  	"github.com/stretchr/testify/require"
    10  
    11  	"github.com/grafana/pyroscope/pkg/util/loser"
    12  )
    13  
    14  type List struct {
    15  	list []uint64
    16  	cur  uint64
    17  
    18  	err error
    19  
    20  	closed int
    21  }
    22  
    23  func NewList(list ...uint64) *List {
    24  	return &List{list: list}
    25  }
    26  
    27  func (it *List) At() uint64 {
    28  	return it.cur
    29  }
    30  
    31  func (it *List) Err() error { return it.err }
    32  
    33  func (it *List) Next() bool {
    34  	if it.err != nil {
    35  		return false
    36  	}
    37  	if len(it.list) > 0 {
    38  		it.cur = it.list[0]
    39  		it.list = it.list[1:]
    40  		return true
    41  	}
    42  	it.cur = 0
    43  	return false
    44  }
    45  
    46  func (it *List) Close() { it.closed += 1 }
    47  
    48  func (it *List) Seek(val uint64) bool {
    49  	if it.err != nil {
    50  		return false
    51  	}
    52  	for it.cur < val && len(it.list) > 0 {
    53  		it.cur = it.list[0]
    54  		it.list = it.list[1:]
    55  	}
    56  	return len(it.list) > 0
    57  }
    58  
    59  func checkIterablesEqual[E any, S1 loser.Sequence, S2 loser.Sequence](t *testing.T, a S1, b S2, at1 func(S1) E, at2 func(S2) E, less func(E, E) bool) {
    60  	t.Helper()
    61  	count := 0
    62  	for a.Next() {
    63  		count++
    64  		if !b.Next() {
    65  			t.Fatalf("b ended before a after %d elements", count)
    66  		}
    67  		if less(at1(a), at2(b)) || less(at2(b), at1(a)) {
    68  			t.Fatalf("position %d: %v != %v", count, at1(a), at2(b))
    69  		}
    70  	}
    71  	if b.Next() {
    72  		t.Fatalf("a ended before b after %d elements", count)
    73  	}
    74  }
    75  
    76  var testCases = []struct {
    77  	name string
    78  	args []*List
    79  	want *List
    80  }{
    81  	{
    82  		name: "empty input",
    83  		want: NewList(),
    84  	},
    85  	{
    86  		name: "one list",
    87  		args: []*List{NewList(1, 2, 3, 4)},
    88  		want: NewList(1, 2, 3, 4),
    89  	},
    90  	{
    91  		name: "two lists",
    92  		args: []*List{NewList(3, 4, 5), NewList(1, 2)},
    93  		want: NewList(1, 2, 3, 4, 5),
    94  	},
    95  	{
    96  		name: "two lists, first empty",
    97  		args: []*List{NewList(), NewList(1, 2)},
    98  		want: NewList(1, 2),
    99  	},
   100  	{
   101  		name: "two lists, second empty",
   102  		args: []*List{NewList(1, 2), NewList()},
   103  		want: NewList(1, 2),
   104  	},
   105  	{
   106  		name: "two lists b",
   107  		args: []*List{NewList(1, 2), NewList(3, 4, 5)},
   108  		want: NewList(1, 2, 3, 4, 5),
   109  	},
   110  	{
   111  		name: "two lists c",
   112  		args: []*List{NewList(1, 3), NewList(2, 4, 5)},
   113  		want: NewList(1, 2, 3, 4, 5),
   114  	},
   115  	{
   116  		name: "three lists",
   117  		args: []*List{NewList(1, 3), NewList(2, 4), NewList(5)},
   118  		want: NewList(1, 2, 3, 4, 5),
   119  	},
   120  }
   121  
   122  func TestMerge(t *testing.T) {
   123  	at := func(s *List) uint64 { return s.At() }
   124  	less := func(a, b uint64) bool { return a < b }
   125  	at2 := func(s *loser.Tree[uint64, *List]) uint64 { return s.Winner().At() }
   126  	for _, tt := range testCases {
   127  		t.Run(tt.name, func(t *testing.T) {
   128  			numCloses := 0
   129  			close := func(s *List) {
   130  				numCloses++
   131  			}
   132  			lt := loser.New(tt.args, math.MaxUint64, at, less, close)
   133  			checkIterablesEqual(t, tt.want, lt, at, at2, less)
   134  			if numCloses != len(tt.args) {
   135  				t.Errorf("Expected %d closes, got %d", len(tt.args), numCloses)
   136  			}
   137  		})
   138  	}
   139  }
   140  
   141  func TestPush(t *testing.T) {
   142  	at := func(s *List) uint64 { return s.At() }
   143  	less := func(a, b uint64) bool { return a < b }
   144  	at2 := func(s *loser.Tree[uint64, *List]) uint64 { return s.Winner().At() }
   145  	for _, tt := range testCases {
   146  		t.Run(tt.name, func(t *testing.T) {
   147  			numCloses := 0
   148  			close := func(s *List) {
   149  				numCloses++
   150  			}
   151  			lt := loser.New(nil, math.MaxUint64, at, less, close)
   152  			for _, s := range tt.args {
   153  				if err := lt.Push(s); err != nil {
   154  					t.Fatalf("Push failed: %v", err)
   155  				}
   156  			}
   157  			checkIterablesEqual(t, tt.want, lt, at, at2, less)
   158  			if numCloses != len(tt.args) {
   159  				t.Errorf("Expected %d closes, got %d", len(tt.args), numCloses)
   160  			}
   161  		})
   162  	}
   163  }
   164  
   165  func TestInitWithErr(t *testing.T) {
   166  	lists := []*List{
   167  		NewList(),
   168  		NewList(5, 6, 7, 8),
   169  	}
   170  	lists[0].err = errTest
   171  	tree := loser.New(lists, math.MaxUint64, func(s *List) uint64 { return s.At() }, func(a, b uint64) bool { return a < b }, func(s *List) { s.Close() })
   172  	if tree.Next() {
   173  		t.Errorf("Next() should have returned false")
   174  	}
   175  	if tree.Err() != errTest {
   176  		t.Errorf("Err() should have returned %v, got %v", errTest, tree.Err())
   177  	}
   178  
   179  	tree.Close()
   180  	for _, l := range lists {
   181  		assert.Equal(t, l.closed, 1, "list %+#v not closed exactly once", l)
   182  	}
   183  
   184  }
   185  
   186  var errTest = errors.New("test")
   187  
   188  func TestErrDuringNext(t *testing.T) {
   189  	lists := []*List{
   190  		NewList(5, 6),
   191  		NewList(11, 12),
   192  	}
   193  	tree := loser.New(lists, math.MaxUint64, func(s *List) uint64 { return s.At() }, func(a, b uint64) bool { return a < b }, func(s *List) { s.Close() })
   194  
   195  	// no error for first element
   196  	if !tree.Next() {
   197  		t.Errorf("Next() should have returned true")
   198  	}
   199  	// now error for second
   200  	lists[0].err = errTest
   201  	if tree.Next() {
   202  		t.Errorf("Next() should have returned false")
   203  	}
   204  	if tree.Err() != errTest {
   205  		t.Errorf("Err() should have returned %v, got %v", errTest, tree.Err())
   206  	}
   207  	if tree.Next() {
   208  		t.Errorf("Next() should have returned false")
   209  	}
   210  
   211  	tree.Close()
   212  	for _, l := range lists {
   213  		assert.Equal(t, l.closed, 1, "list %+#v not closed exactly once", l)
   214  	}
   215  }
   216  
   217  func TestErrInOneIterator(t *testing.T) {
   218  	l := NewList()
   219  	l.err = errors.New("test")
   220  
   221  	lists := []*List{
   222  		NewList(5, 1),
   223  		l,
   224  		NewList(2, 4),
   225  	}
   226  	tree := loser.New(lists, math.MaxUint64, func(s *List) uint64 { return s.At() }, func(a, b uint64) bool { return a < b }, func(s *List) { s.Close() })
   227  
   228  	// error for first element
   229  	require.False(t, tree.Next())
   230  	assert.Equal(t, l.err, tree.Err())
   231  
   232  	tree.Close()
   233  	for _, l := range lists {
   234  		assert.Equal(t, l.closed, 1, "list %+#v not closed exactly once", l)
   235  	}
   236  }