github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/store/prolly/tree/node_cursor.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  // This file incorporates work covered by the following copyright and
    16  // permission notice:
    17  //
    18  // Copyright 2016 Attic Labs, Inc. All rights reserved.
    19  // Licensed under the Apache License, version 2.0:
    20  // http://www.apache.org/licenses/LICENSE-2.0
    21  
    22  package tree
    23  
    24  import (
    25  	"context"
    26  	"errors"
    27  	"fmt"
    28  
    29  	"github.com/dolthub/dolt/go/store/hash"
    30  )
    31  
    32  // cursor explores a tree of Nodes.
    33  type cursor struct {
    34  	nd     Node
    35  	idx    int
    36  	parent *cursor
    37  	nrw    NodeStore
    38  }
    39  
    40  type SearchFn func(nd Node) (idx int)
    41  
    42  type Ordering[K ~[]byte] interface {
    43  	Compare(left, right K) int
    44  }
    45  
    46  func newCursorAtStart(ctx context.Context, ns NodeStore, nd Node) (cur *cursor, err error) {
    47  	cur = &cursor{nd: nd, nrw: ns}
    48  	for !cur.isLeaf() {
    49  		nd, err = fetchChild(ctx, ns, cur.currentRef())
    50  		if err != nil {
    51  			return nil, err
    52  		}
    53  
    54  		parent := cur
    55  		cur = &cursor{nd: nd, parent: parent, nrw: ns}
    56  	}
    57  	return
    58  }
    59  
    60  func newCursorAtEnd(ctx context.Context, ns NodeStore, nd Node) (cur *cursor, err error) {
    61  	cur = &cursor{nd: nd, nrw: ns}
    62  	cur.skipToNodeEnd()
    63  
    64  	for !cur.isLeaf() {
    65  		nd, err = fetchChild(ctx, ns, cur.currentRef())
    66  		if err != nil {
    67  			return nil, err
    68  		}
    69  
    70  		parent := cur
    71  		cur = &cursor{nd: nd, parent: parent, nrw: ns}
    72  		cur.skipToNodeEnd()
    73  	}
    74  	return
    75  }
    76  
    77  func newCursorPastEnd(ctx context.Context, ns NodeStore, nd Node) (cur *cursor, err error) {
    78  	cur, err = newCursorAtEnd(ctx, ns, nd)
    79  	if err != nil {
    80  		return nil, err
    81  	}
    82  
    83  	// Advance |cur| past the end
    84  	err = cur.advance(ctx)
    85  	if err != nil {
    86  		return nil, err
    87  	}
    88  	if cur.idx != int(cur.nd.count) {
    89  		panic("expected |ok| to be  false")
    90  	}
    91  
    92  	return
    93  }
    94  
    95  func newCursorAtOrdinal(ctx context.Context, ns NodeStore, nd Node, ord uint64) (cur *cursor, err error) {
    96  	cnt, err := nd.TreeCount()
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  	if ord >= uint64(cnt) {
   101  		return newCursorPastEnd(ctx, ns, nd)
   102  	}
   103  
   104  	distance := int64(ord)
   105  	return newCursorFromSearchFn(ctx, ns, nd, func(nd Node) (idx int) {
   106  		if nd.IsLeaf() {
   107  			return int(distance)
   108  		}
   109  		nd, _ = nd.loadSubtrees()
   110  
   111  		for idx = 0; idx < nd.Count(); idx++ {
   112  			cnt, _ := nd.getSubtreeCount(idx)
   113  			card := int64(cnt)
   114  			if (distance - card) < 0 {
   115  				break
   116  			}
   117  			distance -= card
   118  		}
   119  		return
   120  	})
   121  }
   122  
   123  // GetOrdinalOfCursor returns the ordinal position of a cursor.
   124  func getOrdinalOfCursor(curr *cursor) (ord uint64, err error) {
   125  	if !curr.isLeaf() {
   126  		return 0, fmt.Errorf("|cur| must be at a leaf")
   127  	}
   128  
   129  	ord += uint64(curr.idx)
   130  
   131  	for curr.parent != nil {
   132  		curr = curr.parent
   133  
   134  		// If a parent has been invalidated past end, act like we were at the
   135  		// last subtree.
   136  		if curr.idx >= curr.nd.Count() {
   137  			curr.skipToNodeEnd()
   138  		} else if curr.idx < 0 {
   139  			return 0, fmt.Errorf("found invalid parent cursor behind node start")
   140  		}
   141  
   142  		curr.nd, err = curr.nd.loadSubtrees()
   143  		if err != nil {
   144  			return 0, err
   145  		}
   146  
   147  		for idx := curr.idx - 1; idx >= 0; idx-- {
   148  			cnt, err := curr.nd.getSubtreeCount(idx)
   149  			if err != nil {
   150  				return 0, err
   151  			}
   152  			ord += cnt
   153  		}
   154  	}
   155  
   156  	return ord, nil
   157  }
   158  
   159  func newCursorAtKey[K ~[]byte, O Ordering[K]](ctx context.Context, ns NodeStore, nd Node, key K, order O) (cur *cursor, err error) {
   160  	return newCursorFromSearchFn(ctx, ns, nd, searchForKey(key, order))
   161  }
   162  
   163  func newCursorFromSearchFn(ctx context.Context, ns NodeStore, nd Node, search SearchFn) (cur *cursor, err error) {
   164  	cur = &cursor{nd: nd, nrw: ns}
   165  
   166  	cur.idx = search(cur.nd)
   167  	for !cur.isLeaf() {
   168  		// stay in bounds for internal nodes
   169  		cur.keepInBounds()
   170  
   171  		nd, err = fetchChild(ctx, ns, cur.currentRef())
   172  		if err != nil {
   173  			return cur, err
   174  		}
   175  
   176  		parent := cur
   177  		cur = &cursor{nd: nd, parent: parent, nrw: ns}
   178  
   179  		cur.idx = search(cur.nd)
   180  	}
   181  	return
   182  }
   183  
   184  func newLeafCursorAtKey[K ~[]byte, O Ordering[K]](ctx context.Context, ns NodeStore, nd Node, key K, order O) (cursor, error) {
   185  	var err error
   186  	cur := cursor{nd: nd, nrw: ns}
   187  	for {
   188  		// binary search |cur.nd| for |key|
   189  		i, j := 0, cur.nd.Count()
   190  		for i < j {
   191  			h := int(uint(i+j) >> 1)
   192  			cmp := order.Compare(key, K(cur.nd.GetKey(h)))
   193  			if cmp > 0 {
   194  				i = h + 1
   195  			} else {
   196  				j = h
   197  			}
   198  		}
   199  		cur.idx = i
   200  
   201  		if cur.isLeaf() {
   202  			break // done
   203  		}
   204  
   205  		// stay in bounds for internal nodes
   206  		cur.keepInBounds()
   207  
   208  		// reuse |cur| object to keep stack alloc'd
   209  		cur.nd, err = fetchChild(ctx, ns, cur.currentRef())
   210  		if err != nil {
   211  			return cur, err
   212  		}
   213  	}
   214  	return cur, nil
   215  }
   216  
   217  // searchForKey returns a SearchFn for |key|.
   218  func searchForKey[K ~[]byte, O Ordering[K]](key K, order O) SearchFn {
   219  	return func(nd Node) (idx int) {
   220  		n := int(nd.Count())
   221  		// Define f(-1) == false and f(n) == true.
   222  		// Invariant: f(i-1) == false, f(j) == true.
   223  		i, j := 0, n
   224  		for i < j {
   225  			h := int(uint(i+j) >> 1) // avoid overflow when computing h
   226  			less := order.Compare(key, K(nd.GetKey(h))) <= 0
   227  			// i ≤ h < j
   228  			if !less {
   229  				i = h + 1 // preserves f(i-1) == false
   230  			} else {
   231  				j = h // preserves f(j) == true
   232  			}
   233  		}
   234  		// i == j, f(i-1) == false, and
   235  		// f(j) (= f(i)) == true  =>  answer is i.
   236  		return i
   237  	}
   238  }
   239  
   240  type LeafSpan struct {
   241  	Leaves     []Node
   242  	LocalStart int
   243  	LocalStop  int
   244  }
   245  
   246  // FetchLeafNodeSpan returns the leaf Node span for the ordinal range [start, stop). It fetches the span using
   247  // an eager breadth-first search and makes batch read calls to the persistence layer via NodeStore.ReadMany.
   248  func fetchLeafNodeSpan(ctx context.Context, ns NodeStore, root Node, start, stop uint64) (LeafSpan, error) {
   249  	leaves, localStart, err := recursiveFetchLeafNodeSpan(ctx, ns, []Node{root}, start, stop)
   250  	if err != nil {
   251  		return LeafSpan{}, err
   252  	}
   253  
   254  	localStop := (stop - start) + localStart
   255  	for i := 0; i < len(leaves)-1; i++ {
   256  		localStop -= uint64(leaves[i].Count())
   257  	}
   258  
   259  	return LeafSpan{
   260  		Leaves:     leaves,
   261  		LocalStart: int(localStart),
   262  		LocalStop:  int(localStop),
   263  	}, nil
   264  }
   265  
   266  func recursiveFetchLeafNodeSpan(ctx context.Context, ns NodeStore, nodes []Node, start, stop uint64) ([]Node, uint64, error) {
   267  	if nodes[0].IsLeaf() {
   268  		// verify leaf homogeneity
   269  		for i := range nodes {
   270  			if !nodes[i].IsLeaf() {
   271  				return nil, 0, errors.New("mixed leaf/non-leaf set")
   272  			}
   273  		}
   274  		return nodes, start, nil
   275  	}
   276  
   277  	gets := make(hash.HashSlice, 0, len(nodes)*nodes[0].Count())
   278  	acc := uint64(0)
   279  
   280  	var err error
   281  	for _, nd := range nodes {
   282  		if nd, err = nd.loadSubtrees(); err != nil {
   283  			return nil, 0, err
   284  		}
   285  
   286  		for i := 0; i < nd.Count(); i++ {
   287  			card, err := nd.getSubtreeCount(i)
   288  			if err != nil {
   289  				return nil, 0, err
   290  			}
   291  
   292  			if acc == 0 && card < start {
   293  				start -= card
   294  				stop -= card
   295  				continue
   296  			}
   297  
   298  			gets = append(gets, hash.New(nd.GetValue(i)))
   299  			acc += card
   300  			if acc >= stop {
   301  				break
   302  			}
   303  		}
   304  	}
   305  
   306  	children, err := ns.ReadMany(ctx, gets)
   307  	if err != nil {
   308  		return nil, 0, err
   309  	}
   310  	return recursiveFetchLeafNodeSpan(ctx, ns, children, start, stop)
   311  }
   312  
   313  func currentCursorItems(cur *cursor) (key, value Item) {
   314  	key = cur.nd.keys.GetItem(cur.idx, cur.nd.msg)
   315  	value = cur.nd.values.GetItem(cur.idx, cur.nd.msg)
   316  	return
   317  }
   318  
   319  // Seek updates the cursor's node to one whose range spans the key's value, or the last
   320  // node if the key is greater than all existing keys.
   321  // If a node does not contain the key, we recurse upwards to the parent cursor. If the
   322  // node contains a key, we recurse downwards into child nodes.
   323  func Seek[K ~[]byte, O Ordering[K]](ctx context.Context, cur *cursor, key K, order O) (err error) {
   324  	inBounds := true
   325  	if cur.parent != nil {
   326  		inBounds = inBounds && order.Compare(key, K(cur.firstKey())) >= 0
   327  		inBounds = inBounds && order.Compare(key, K(cur.lastKey())) <= 0
   328  	}
   329  
   330  	if !inBounds {
   331  		// |item| is outside the bounds of |cur.nd|, search up the tree
   332  		err = Seek(ctx, cur.parent, key, order)
   333  		if err != nil {
   334  			return err
   335  		}
   336  		// stay in bounds for internal nodes
   337  		cur.parent.keepInBounds()
   338  
   339  		cur.nd, err = fetchChild(ctx, cur.nrw, cur.parent.currentRef())
   340  		if err != nil {
   341  			return err
   342  		}
   343  	}
   344  
   345  	cur.idx = searchForKey(key, order)(cur.nd)
   346  
   347  	return
   348  }
   349  
   350  func (cur *cursor) Valid() bool {
   351  	return cur.nd.count != 0 &&
   352  		cur.nd.bytes() != nil &&
   353  		cur.idx >= 0 &&
   354  		cur.idx < int(cur.nd.count)
   355  }
   356  
   357  func (cur *cursor) CurrentKey() Item {
   358  	return cur.nd.GetKey(cur.idx)
   359  }
   360  
   361  func (cur *cursor) currentValue() Item {
   362  	return cur.nd.GetValue(cur.idx)
   363  }
   364  
   365  func (cur *cursor) currentRef() hash.Hash {
   366  	return cur.nd.getAddress(cur.idx)
   367  }
   368  
   369  func (cur *cursor) currentSubtreeSize() (uint64, error) {
   370  	if cur.isLeaf() {
   371  		return 1, nil
   372  	}
   373  	var err error
   374  	cur.nd, err = cur.nd.loadSubtrees()
   375  	if err != nil {
   376  		return 0, err
   377  	}
   378  	return cur.nd.getSubtreeCount(cur.idx)
   379  }
   380  
   381  func (cur *cursor) firstKey() Item {
   382  	return cur.nd.GetKey(0)
   383  }
   384  
   385  func (cur *cursor) lastKey() Item {
   386  	lastKeyIdx := int(cur.nd.count) - 1
   387  	return cur.nd.GetKey(lastKeyIdx)
   388  }
   389  
   390  func (cur *cursor) skipToNodeStart() {
   391  	cur.idx = 0
   392  }
   393  
   394  func (cur *cursor) skipToNodeEnd() {
   395  	lastKeyIdx := int(cur.nd.count) - 1
   396  	cur.idx = lastKeyIdx
   397  }
   398  
   399  func (cur *cursor) keepInBounds() {
   400  	if cur.idx < 0 {
   401  		cur.skipToNodeStart()
   402  	}
   403  	lastKeyIdx := int(cur.nd.count) - 1
   404  	if cur.idx > lastKeyIdx {
   405  		cur.skipToNodeEnd()
   406  	}
   407  }
   408  
   409  func (cur *cursor) atNodeStart() bool {
   410  	return cur.idx == 0
   411  }
   412  
   413  // atNodeEnd returns true if the cursor's current |idx|
   414  // points to the last node item
   415  func (cur *cursor) atNodeEnd() bool {
   416  	lastKeyIdx := int(cur.nd.count) - 1
   417  	return cur.idx == lastKeyIdx
   418  }
   419  
   420  func (cur *cursor) isLeaf() bool {
   421  	return cur.nd.level == 0
   422  }
   423  
   424  func (cur *cursor) level() (uint64, error) {
   425  	return uint64(cur.nd.level), nil
   426  }
   427  
   428  // invalidateAtEnd sets the cursor's index to the node count.
   429  func (cur *cursor) invalidateAtEnd() {
   430  	cur.idx = int(cur.nd.count)
   431  }
   432  
   433  // invalidateAtStart sets the cursor's index to -1.
   434  func (cur *cursor) invalidateAtStart() {
   435  	cur.idx = -1
   436  }
   437  
   438  // hasNext returns true if we do not need to recursively
   439  // check the parent to know that the current cursor
   440  // has more keys. hasNext can be false even if parent
   441  // cursors are not exhausted.
   442  func (cur *cursor) hasNext() bool {
   443  	return cur.idx < int(cur.nd.count)-1
   444  }
   445  
   446  // hasPrev returns true if the current node has preceding
   447  // keys. hasPrev can be false even in a parent node has
   448  // preceding keys.
   449  func (cur *cursor) hasPrev() bool {
   450  	return cur.idx > 0
   451  }
   452  
   453  // outOfBounds returns true if the current cursor and
   454  // all parents are exhausted.
   455  func (cur *cursor) outOfBounds() bool {
   456  	return cur.idx < 0 || cur.idx >= int(cur.nd.count)
   457  }
   458  
   459  // advance either increments the current key index by one,
   460  // or has reached the end of the current node and skips to the next
   461  // child of the parent cursor, recursively if necessary, returning
   462  // either an error or nil.
   463  //
   464  // More specifically, one of three things happens:
   465  //
   466  // 1) The current chunk still has keys, iterate to
   467  // the next |idx|;
   468  //
   469  // 2) We've exhausted the current cursor, but there is at least
   470  // one |parent| cursor with more keys. We find that |parent| recursively,
   471  // perform step (1), and then have every child initialize itself
   472  // using the new |parent|.
   473  //
   474  // 3) We've exhausted the current cursor and every |parent|. Jump
   475  // to an end state (idx = node.count).
   476  func (cur *cursor) advance(ctx context.Context) error {
   477  	if cur.hasNext() {
   478  		cur.idx++
   479  		return nil
   480  	}
   481  
   482  	if cur.parent == nil {
   483  		cur.invalidateAtEnd()
   484  		return nil
   485  	}
   486  
   487  	// recursively increment the parent
   488  	err := cur.parent.advance(ctx)
   489  	if err != nil {
   490  		return err
   491  	}
   492  
   493  	if cur.parent.outOfBounds() {
   494  		// exhausted every parent cursor
   495  		cur.invalidateAtEnd()
   496  		return nil
   497  	}
   498  
   499  	// new parent cursor points to new cur node
   500  	err = cur.fetchNode(ctx)
   501  	if err != nil {
   502  		return err
   503  	}
   504  
   505  	cur.skipToNodeStart()
   506  	return nil
   507  }
   508  
   509  // retreat decrements to the previous key, if necessary by
   510  // recursively decrementing parent nodes.
   511  func (cur *cursor) retreat(ctx context.Context) error {
   512  	if cur.hasPrev() {
   513  		cur.idx--
   514  		return nil
   515  	}
   516  
   517  	if cur.parent == nil {
   518  		cur.invalidateAtStart()
   519  		return nil
   520  	}
   521  
   522  	// recursively decrement the parent
   523  	err := cur.parent.retreat(ctx)
   524  	if err != nil {
   525  		return err
   526  	}
   527  
   528  	if cur.parent.outOfBounds() {
   529  		// exhausted every parent cursor
   530  		cur.invalidateAtStart()
   531  		return nil
   532  	}
   533  
   534  	// new parent cursor points to new cur node
   535  	err = cur.fetchNode(ctx)
   536  	if err != nil {
   537  		return err
   538  	}
   539  
   540  	cur.skipToNodeEnd()
   541  	return nil
   542  }
   543  
   544  // fetchNode loads the Node that the cursor index points to.
   545  // It's called whenever the cursor advances/retreats to a different chunk.
   546  func (cur *cursor) fetchNode(ctx context.Context) (err error) {
   547  	assertTrue(cur.parent != nil, "cannot fetch node for cursor with nil parent")
   548  	cur.nd, err = fetchChild(ctx, cur.nrw, cur.parent.currentRef())
   549  	cur.idx = -1 // caller must set
   550  	return err
   551  }
   552  
   553  // Compare returns the highest relative index difference
   554  // between two cursor trees. A parent has a higher precedence
   555  // than its child.
   556  //
   557  // Ex:
   558  //
   559  // cur:   L3 -> 4, L2 -> 2, L1 -> 5, L0 -> 2
   560  // other: L3 -> 4, L2 -> 2, L1 -> 5, L0 -> 4
   561  //
   562  //	res => -2 (from level 0)
   563  //
   564  // cur:   L3 -> 4, L2 -> 2, L1 -> 5, L0 -> 2
   565  // other: L3 -> 4, L2 -> 3, L1 -> 5, L0 -> 4
   566  //
   567  //	res => +1 (from level 2)
   568  func (cur *cursor) compare(other *cursor) int {
   569  	return compareCursors(cur, other)
   570  }
   571  
   572  func (cur *cursor) clone() *cursor {
   573  	cln := cursor{
   574  		nd:  cur.nd,
   575  		idx: cur.idx,
   576  		nrw: cur.nrw,
   577  	}
   578  
   579  	if cur.parent != nil {
   580  		cln.parent = cur.parent.clone()
   581  	}
   582  
   583  	return &cln
   584  }
   585  
   586  func (cur *cursor) copy(other *cursor) {
   587  	cur.nd = other.nd
   588  	cur.idx = other.idx
   589  	cur.nrw = other.nrw
   590  
   591  	if cur.parent != nil {
   592  		assertTrue(other.parent != nil, "cursors must be of equal height to call copy()")
   593  		cur.parent.copy(other.parent)
   594  	} else {
   595  		assertTrue(other.parent == nil, "cursors must be of equal height to call copy()")
   596  	}
   597  }
   598  
   599  func compareCursors(left, right *cursor) (diff int) {
   600  	diff = 0
   601  	for {
   602  		d := left.idx - right.idx
   603  		if d != 0 {
   604  			diff = d
   605  		}
   606  
   607  		if left.parent == nil || right.parent == nil {
   608  			break
   609  		}
   610  		left, right = left.parent, right.parent
   611  	}
   612  	return
   613  }
   614  
   615  func fetchChild(ctx context.Context, ns NodeStore, ref hash.Hash) (Node, error) {
   616  	// todo(andy) handle nil Node, dangling ref
   617  	return ns.Read(ctx, ref)
   618  }
   619  
   620  func assertTrue(b bool, msg string) {
   621  	if !b {
   622  		panic("assertion failed: " + msg)
   623  	}
   624  }