github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/store/prolly/tree/node_cursor_test.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package tree
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"testing"
    21  
    22  	"github.com/stretchr/testify/assert"
    23  	"github.com/stretchr/testify/require"
    24  
    25  	"github.com/dolthub/dolt/go/store/prolly/message"
    26  	"github.com/dolthub/dolt/go/store/val"
    27  )
    28  
    29  func TestNodeCursor(t *testing.T) {
    30  	t.Run("new cursor at item", func(t *testing.T) {
    31  		testNewCursorAtItem(t, 10)
    32  		testNewCursorAtItem(t, 100)
    33  		testNewCursorAtItem(t, 1000)
    34  		testNewCursorAtItem(t, 10_000)
    35  	})
    36  
    37  	t.Run("get ordinal at item", func(t *testing.T) {
    38  		counts := []int{10, 100, 1000, 10_000}
    39  		for _, c := range counts {
    40  			t.Run(fmt.Sprintf("%d", c), func(t *testing.T) {
    41  				testGetOrdinalOfCursor(t, c)
    42  			})
    43  		}
    44  	})
    45  
    46  	t.Run("retreat past beginning", func(t *testing.T) {
    47  		ctx := context.Background()
    48  		root, _, ns := randomTree(t, 10_000)
    49  		assert.NotNil(t, root)
    50  		before, err := newCursorAtStart(ctx, ns, root)
    51  		assert.NoError(t, err)
    52  		err = before.retreat(ctx)
    53  		assert.NoError(t, err)
    54  		assert.False(t, before.Valid())
    55  
    56  		start, err := newCursorAtStart(ctx, ns, root)
    57  		assert.NoError(t, err)
    58  		assert.True(t, start.compare(before) > 0, "start is after before")
    59  		assert.True(t, before.compare(start) < 0, "before is before start")
    60  
    61  		// Backwards iteration...
    62  		end, err := newCursorAtEnd(ctx, ns, root)
    63  		assert.NoError(t, err)
    64  		i := 0
    65  		for end.compare(before) > 0 {
    66  			i++
    67  			err = end.retreat(ctx)
    68  			assert.NoError(t, err)
    69  		}
    70  		assert.Equal(t, 10_000/2, i)
    71  	})
    72  }
    73  
    74  func testNewCursorAtItem(t *testing.T, count int) {
    75  	root, items, ns := randomTree(t, count)
    76  	assert.NotNil(t, root)
    77  
    78  	ctx := context.Background()
    79  	for i := range items {
    80  		key, value := items[i][0], items[i][1]
    81  		cur, err := newCursorAtKey(ctx, ns, root, val.Tuple(key), keyDesc)
    82  		require.NoError(t, err)
    83  		assert.Equal(t, key, cur.CurrentKey())
    84  		assert.Equal(t, value, cur.currentValue())
    85  	}
    86  
    87  	validateTreeItems(t, ns, root, items)
    88  }
    89  
    90  func testGetOrdinalOfCursor(t *testing.T, count int) {
    91  	tuples, desc := AscendingUintTuples(count)
    92  
    93  	ctx := context.Background()
    94  	ns := NewTestNodeStore()
    95  	serializer := message.NewProllyMapSerializer(desc, ns.Pool())
    96  	chkr, err := newEmptyChunker(ctx, ns, serializer)
    97  	require.NoError(t, err)
    98  
    99  	for _, item := range tuples {
   100  		err = chkr.AddPair(ctx, Item(item[0]), Item(item[1]))
   101  		assert.NoError(t, err)
   102  	}
   103  	nd, err := chkr.Done(ctx)
   104  	assert.NoError(t, err)
   105  
   106  	for i := 0; i < len(tuples); i++ {
   107  		curr, err := newCursorAtKey(ctx, ns, nd, tuples[i][0], desc)
   108  		require.NoError(t, err)
   109  
   110  		ord, err := getOrdinalOfCursor(curr)
   111  		require.NoError(t, err)
   112  
   113  		assert.Equal(t, uint64(i), ord)
   114  	}
   115  
   116  	b := val.NewTupleBuilder(desc)
   117  	b.PutUint32(0, uint32(len(tuples)))
   118  	aboveItem := b.Build(sharedPool)
   119  
   120  	curr, err := newCursorAtKey(ctx, ns, nd, aboveItem, desc)
   121  	require.NoError(t, err)
   122  
   123  	ord, err := getOrdinalOfCursor(curr)
   124  	require.NoError(t, err)
   125  
   126  	require.Equal(t, uint64(len(tuples)), ord)
   127  
   128  	// A cursor past the end should return an ordinal count equal to number of
   129  	// nodes.
   130  	curr, err = newCursorPastEnd(ctx, ns, nd)
   131  	require.NoError(t, err)
   132  
   133  	ord, err = getOrdinalOfCursor(curr)
   134  	require.NoError(t, err)
   135  
   136  	require.Equal(t, uint64(len(tuples)), ord)
   137  }
   138  
   139  func randomTree(t *testing.T, count int) (Node, [][2]Item, NodeStore) {
   140  	ctx := context.Background()
   141  	ns := NewTestNodeStore()
   142  	serializer := message.NewProllyMapSerializer(valDesc, ns.Pool())
   143  	chkr, err := newEmptyChunker(ctx, ns, serializer)
   144  	require.NoError(t, err)
   145  
   146  	items := randomTupleItemPairs(count/2, ns)
   147  	for _, item := range items {
   148  		err = chkr.AddPair(ctx, Item(item[0]), Item(item[1]))
   149  		assert.NoError(t, err)
   150  	}
   151  	nd, err := chkr.Done(ctx)
   152  	assert.NoError(t, err)
   153  	return nd, items, ns
   154  }
   155  
   156  var keyDesc = val.NewTupleDescriptor(
   157  	val.Type{Enc: val.Int64Enc, Nullable: false},
   158  )
   159  var valDesc = val.NewTupleDescriptor(
   160  	val.Type{Enc: val.Int64Enc, Nullable: true},
   161  	val.Type{Enc: val.Int64Enc, Nullable: true},
   162  	val.Type{Enc: val.Int64Enc, Nullable: true},
   163  	val.Type{Enc: val.Int64Enc, Nullable: true},
   164  )
   165  
   166  func randomTupleItemPairs(count int, ns NodeStore) (items [][2]Item) {
   167  	tups := RandomTuplePairs(count, keyDesc, valDesc, ns)
   168  	items = make([][2]Item, count)
   169  	if len(tups) != len(items) {
   170  		panic("mismatch")
   171  	}
   172  
   173  	for i := range items {
   174  		items[i][0] = Item(tups[i][0])
   175  		items[i][1] = Item(tups[i][1])
   176  	}
   177  	return
   178  }