github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/store/prolly/tree/map.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package tree
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"io"
    21  
    22  	"github.com/dolthub/dolt/go/store/hash"
    23  	"github.com/dolthub/dolt/go/store/prolly/message"
    24  	"github.com/dolthub/dolt/go/store/skip"
    25  )
    26  
    27  type KeyValueFn[K, V ~[]byte] func(key K, value V) error
    28  
    29  type KvIter[K, V ~[]byte] interface {
    30  	Next(ctx context.Context) (K, V, error)
    31  }
    32  
    33  // StaticMap is a static prolly Tree with ordered elements.
    34  type StaticMap[K, V ~[]byte, O Ordering[K]] struct {
    35  	Root      Node
    36  	NodeStore NodeStore
    37  	Order     O
    38  }
    39  
    40  // DiffOrderedTrees invokes `cb` for each difference between `from` and `to. If `considerAllRowsModified`
    41  // is true, then a key that exists in both trees will be considered a modification even if the bytes are the same.
    42  // This is used when `from` and `to` have different schemas.
    43  func DiffOrderedTrees[K, V ~[]byte, O Ordering[K]](
    44  	ctx context.Context,
    45  	from, to StaticMap[K, V, O],
    46  	considerAllRowsModified bool,
    47  	cb DiffFn,
    48  ) error {
    49  	differ, err := DifferFromRoots[K](ctx, from.NodeStore, to.NodeStore, from.Root, to.Root, from.Order, considerAllRowsModified)
    50  	if err != nil {
    51  		return err
    52  	}
    53  
    54  	for {
    55  		var diff Diff
    56  		if diff, err = differ.Next(ctx); err != nil {
    57  			break
    58  		}
    59  
    60  		if err = cb(ctx, diff); err != nil {
    61  			break
    62  		}
    63  	}
    64  	return err
    65  }
    66  
    67  func DiffKeyRangeOrderedTrees[K, V ~[]byte, O Ordering[K]](
    68  	ctx context.Context,
    69  	from, to StaticMap[K, V, O],
    70  	start, stop K,
    71  	cb DiffFn,
    72  ) error {
    73  	var fromStart, fromStop, toStart, toStop *cursor
    74  	var err error
    75  
    76  	if len(start) == 0 {
    77  		fromStart, err = newCursorAtStart(ctx, from.NodeStore, from.Root)
    78  		if err != nil {
    79  			return err
    80  		}
    81  
    82  		toStart, err = newCursorAtStart(ctx, to.NodeStore, to.Root)
    83  		if err != nil {
    84  			return err
    85  		}
    86  	} else {
    87  		fromStart, err = newCursorAtKey(ctx, from.NodeStore, from.Root, start, from.Order)
    88  		if err != nil {
    89  			return err
    90  		}
    91  
    92  		toStart, err = newCursorAtKey(ctx, to.NodeStore, to.Root, start, to.Order)
    93  		if err != nil {
    94  			return err
    95  		}
    96  	}
    97  
    98  	if len(stop) == 0 {
    99  		fromStop, err = newCursorPastEnd(ctx, from.NodeStore, from.Root)
   100  		if err != nil {
   101  			return err
   102  		}
   103  
   104  		toStop, err = newCursorPastEnd(ctx, to.NodeStore, to.Root)
   105  		if err != nil {
   106  			return err
   107  		}
   108  	} else {
   109  		fromStop, err = newCursorAtKey(ctx, from.NodeStore, from.Root, stop, from.Order)
   110  		if err != nil {
   111  			return err
   112  		}
   113  
   114  		toStop, err = newCursorAtKey(ctx, to.NodeStore, to.Root, stop, to.Order)
   115  		if err != nil {
   116  			return err
   117  		}
   118  	}
   119  
   120  	differ := Differ[K, O]{
   121  		from:     fromStart,
   122  		to:       toStart,
   123  		fromStop: fromStop,
   124  		toStop:   toStop,
   125  		order:    from.Order,
   126  	}
   127  
   128  	for {
   129  		var diff Diff
   130  		if diff, err = differ.Next(ctx); err != nil {
   131  			break
   132  		}
   133  
   134  		if err = cb(ctx, diff); err != nil {
   135  			break
   136  		}
   137  	}
   138  	return err
   139  }
   140  
   141  func MergeOrderedTrees[K, V ~[]byte, O Ordering[K], S message.Serializer](
   142  	ctx context.Context,
   143  	l, r, base StaticMap[K, V, O],
   144  	cb CollisionFn,
   145  	leftSchemaChanged, rightSchemaChanged bool,
   146  	serializer S,
   147  ) (StaticMap[K, V, O], MergeStats, error) {
   148  	root, stats, err := ThreeWayMerge[K](ctx, base.NodeStore, l.Root, r.Root, base.Root, cb, leftSchemaChanged, rightSchemaChanged, base.Order, serializer)
   149  	if err != nil {
   150  		return StaticMap[K, V, O]{}, MergeStats{}, err
   151  	}
   152  
   153  	return StaticMap[K, V, O]{
   154  		Root:      root,
   155  		NodeStore: base.NodeStore,
   156  		Order:     base.Order,
   157  	}, stats, nil
   158  }
   159  
   160  // VisitMapLevelOrder visits each internal node of the tree in level order and calls the provided callback `cb` on each hash
   161  // encountered. This function is used primarily for building appendix table files for databases to help optimize reads.
   162  func VisitMapLevelOrder[K, V ~[]byte, O Ordering[K]](
   163  	ctx context.Context,
   164  	m StaticMap[K, V, O],
   165  	cb func(h hash.Hash) (int64, error),
   166  ) error {
   167  	// get cursor to leaves
   168  	cur, err := newCursorAtStart(ctx, m.NodeStore, m.Root)
   169  	if err != nil {
   170  		return err
   171  	}
   172  	first := cur.CurrentKey()
   173  
   174  	// start by iterating level 1 nodes,
   175  	// then recurse upwards until we're at the root
   176  	for cur.parent != nil {
   177  		cur = cur.parent
   178  		for cur.Valid() {
   179  			_, err = cb(cur.currentRef())
   180  			if err != nil {
   181  				return err
   182  			}
   183  			if err = cur.advance(ctx); err != nil {
   184  				return err
   185  			}
   186  		}
   187  
   188  		// return cursor to the start of the map
   189  		if err = Seek(ctx, cur, K(first), m.Order); err != nil {
   190  			return err
   191  		}
   192  	}
   193  	return err
   194  }
   195  
   196  func (t StaticMap[K, V, O]) Count() (int, error) {
   197  	return t.Root.TreeCount()
   198  }
   199  
   200  func (t StaticMap[K, V, O]) Height() int {
   201  	return t.Root.Level() + 1
   202  }
   203  
   204  func (t StaticMap[K, V, O]) HashOf() hash.Hash {
   205  	return t.Root.HashOf()
   206  }
   207  
   208  func (t StaticMap[K, V, O]) Mutate() MutableMap[K, V, O] {
   209  	return MutableMap[K, V, O]{
   210  		Edits: skip.NewSkipList(func(left, right []byte) int {
   211  			return t.Order.Compare(left, right)
   212  		}),
   213  		Static: t,
   214  	}
   215  }
   216  
   217  func (t StaticMap[K, V, O]) WalkAddresses(ctx context.Context, cb AddressCb) error {
   218  	return WalkAddresses(ctx, t.Root, t.NodeStore, cb)
   219  }
   220  
   221  func (t StaticMap[K, V, O]) WalkNodes(ctx context.Context, cb NodeCb) error {
   222  	return WalkNodes(ctx, t.Root, t.NodeStore, cb)
   223  }
   224  
   225  func (t StaticMap[K, V, O]) Get(ctx context.Context, query K, cb KeyValueFn[K, V]) (err error) {
   226  	cur, err := newLeafCursorAtKey(ctx, t.NodeStore, t.Root, query, t.Order)
   227  	if err != nil {
   228  		return err
   229  	}
   230  
   231  	var key K
   232  	var value V
   233  
   234  	if cur.Valid() {
   235  		key = K(cur.CurrentKey())
   236  		if t.Order.Compare(query, key) == 0 {
   237  			value = V(cur.currentValue())
   238  		} else {
   239  			key = nil
   240  		}
   241  	}
   242  	return cb(key, value)
   243  }
   244  
   245  func (t StaticMap[K, V, O]) GetPrefix(ctx context.Context, query K, prefixOrder O, cb KeyValueFn[K, V]) (err error) {
   246  	cur, err := newLeafCursorAtKey(ctx, t.NodeStore, t.Root, query, prefixOrder)
   247  	if err != nil {
   248  		return err
   249  	}
   250  
   251  	var key K
   252  	var value V
   253  
   254  	if cur.Valid() {
   255  		key = K(cur.CurrentKey())
   256  		if prefixOrder.Compare(query, key) == 0 {
   257  			value = V(cur.currentValue())
   258  		} else {
   259  			key = nil
   260  		}
   261  	}
   262  	return cb(key, value)
   263  }
   264  
   265  func (t StaticMap[K, V, O]) Has(ctx context.Context, query K) (ok bool, err error) {
   266  	cur, err := newLeafCursorAtKey(ctx, t.NodeStore, t.Root, query, t.Order)
   267  	if err != nil {
   268  		return false, err
   269  	} else if cur.Valid() {
   270  		ok = t.Order.Compare(query, K(cur.CurrentKey())) == 0
   271  	}
   272  	return
   273  }
   274  
   275  func (t StaticMap[K, V, O]) HasPrefix(ctx context.Context, query K, prefixOrder O) (ok bool, err error) {
   276  	cur, err := newLeafCursorAtKey(ctx, t.NodeStore, t.Root, query, prefixOrder)
   277  	if err != nil {
   278  		return false, err
   279  	} else if cur.Valid() {
   280  		// true if |query| is a prefix of |cur.currentKey()|
   281  		ok = prefixOrder.Compare(query, K(cur.CurrentKey())) == 0
   282  	}
   283  	return
   284  }
   285  
   286  func (t StaticMap[K, V, O]) LastKey(ctx context.Context) (key K) {
   287  	if t.Root.count > 0 {
   288  		// if |t.Root| is a leaf node, it represents the entire map
   289  		// if |t.Root| is an internal node, its last key is the
   290  		// delimiter for last subtree and is the last key in the map
   291  		key = K(getLastKey(t.Root))
   292  	}
   293  	return
   294  }
   295  
   296  func (t StaticMap[K, V, O]) IterAll(ctx context.Context) (*OrderedTreeIter[K, V], error) {
   297  	c, err := newCursorAtStart(ctx, t.NodeStore, t.Root)
   298  	if err != nil {
   299  		return nil, err
   300  	}
   301  
   302  	s, err := newCursorPastEnd(ctx, t.NodeStore, t.Root)
   303  	if err != nil {
   304  		return nil, err
   305  	}
   306  
   307  	stop := func(curr *cursor) bool {
   308  		return curr.compare(s) >= 0
   309  	}
   310  
   311  	if stop(c) {
   312  		// empty range
   313  		return &OrderedTreeIter[K, V]{curr: nil}, nil
   314  	}
   315  
   316  	return &OrderedTreeIter[K, V]{curr: c, stop: stop, step: c.advance}, nil
   317  }
   318  
   319  func (t StaticMap[K, V, O]) IterAllReverse(ctx context.Context) (*OrderedTreeIter[K, V], error) {
   320  	beginning, err := newCursorAtStart(ctx, t.NodeStore, t.Root)
   321  	if err != nil {
   322  		return nil, err
   323  	}
   324  	err = beginning.retreat(ctx)
   325  	if err != nil {
   326  		return nil, err
   327  	}
   328  
   329  	end, err := newCursorAtEnd(ctx, t.NodeStore, t.Root)
   330  	if err != nil {
   331  		return nil, err
   332  	}
   333  
   334  	stop := func(curr *cursor) bool {
   335  		return curr.compare(beginning) <= 0
   336  	}
   337  
   338  	if stop(end) {
   339  		// empty range
   340  		return &OrderedTreeIter[K, V]{curr: nil}, nil
   341  	}
   342  
   343  	return &OrderedTreeIter[K, V]{curr: end, stop: stop, step: end.retreat}, nil
   344  }
   345  
   346  func (t StaticMap[K, V, O]) IterOrdinalRange(ctx context.Context, start, stop uint64) (*OrderedTreeIter[K, V], error) {
   347  	if stop == start {
   348  		return &OrderedTreeIter[K, V]{curr: nil}, nil
   349  	}
   350  	if stop < start {
   351  		return nil, fmt.Errorf("invalid ordinal bounds (%d, %d)", start, stop)
   352  	} else {
   353  		c, err := t.Count()
   354  		if err != nil {
   355  			return nil, err
   356  		}
   357  		if stop > uint64(c) {
   358  			return nil, fmt.Errorf("stop index (%d) out of bounds", stop)
   359  		}
   360  	}
   361  
   362  	lo, err := newCursorAtOrdinal(ctx, t.NodeStore, t.Root, start)
   363  	if err != nil {
   364  		return nil, err
   365  	}
   366  
   367  	hi, err := newCursorAtOrdinal(ctx, t.NodeStore, t.Root, stop)
   368  	if err != nil {
   369  		return nil, err
   370  	}
   371  
   372  	stopF := func(curr *cursor) bool {
   373  		return curr.compare(hi) >= 0
   374  	}
   375  
   376  	return &OrderedTreeIter[K, V]{curr: lo, stop: stopF, step: lo.advance}, nil
   377  }
   378  
   379  func (t StaticMap[K, V, O]) FetchOrdinalRange(ctx context.Context, start, stop uint64) (*orderedLeafSpanIter[K, V], error) {
   380  	if stop == start {
   381  		return &orderedLeafSpanIter[K, V]{}, nil
   382  	}
   383  	if stop < start {
   384  		return nil, fmt.Errorf("invalid ordinal bounds (%d, %d)", start, stop)
   385  	} else {
   386  		c, err := t.Count()
   387  		if err != nil {
   388  			return nil, err
   389  		} else if stop > uint64(c) {
   390  			return nil, fmt.Errorf("stop index (%d) out of bounds", stop)
   391  		}
   392  	}
   393  
   394  	span, err := fetchLeafNodeSpan(ctx, t.NodeStore, t.Root, start, stop)
   395  	if err != nil {
   396  		return nil, err
   397  	}
   398  
   399  	nd, leaves := span.Leaves[0], span.Leaves[1:]
   400  	c, s := span.LocalStart, nd.Count()
   401  	if len(leaves) == 0 {
   402  		s = span.LocalStop // one leaf span
   403  	}
   404  
   405  	return &orderedLeafSpanIter[K, V]{
   406  		nd:     nd,
   407  		curr:   c,
   408  		stop:   s,
   409  		leaves: leaves,
   410  		final:  span.LocalStop,
   411  	}, nil
   412  }
   413  
   414  func (t StaticMap[K, V, O]) IterKeyRange(ctx context.Context, start, stop K) (*OrderedTreeIter[K, V], error) {
   415  	lo, hi, err := t.getKeyRangeCursors(ctx, start, stop)
   416  	if err != nil {
   417  		return nil, err
   418  	}
   419  
   420  	stopF := func(curr *cursor) bool {
   421  		return curr.compare(hi) >= 0
   422  	}
   423  
   424  	if stopF(lo) {
   425  		return &OrderedTreeIter[K, V]{curr: nil}, nil
   426  	}
   427  
   428  	return &OrderedTreeIter[K, V]{curr: lo, stop: stopF, step: lo.advance}, nil
   429  }
   430  
   431  func (t StaticMap[K, V, O]) GetKeyRangeCardinality(ctx context.Context, start, stop K) (uint64, error) {
   432  	lo, hi, err := t.getKeyRangeCursors(ctx, start, stop)
   433  	if err != nil {
   434  		return 0, err
   435  	}
   436  
   437  	startOrd, err := getOrdinalOfCursor(lo)
   438  	if err != nil {
   439  		return 0, err
   440  	}
   441  
   442  	endOrd, err := getOrdinalOfCursor(hi)
   443  	if err != nil {
   444  		return 0, err
   445  	}
   446  
   447  	if startOrd > endOrd {
   448  		return 0, nil
   449  	}
   450  	return endOrd - startOrd, nil
   451  }
   452  
   453  func (t StaticMap[K, V, O]) getKeyRangeCursors(ctx context.Context, startInclusive, stopExclusive K) (lo, hi *cursor, err error) {
   454  	if len(startInclusive) == 0 {
   455  		lo, err = newCursorAtStart(ctx, t.NodeStore, t.Root)
   456  		if err != nil {
   457  			return nil, nil, err
   458  		}
   459  	} else {
   460  		lo, err = newCursorAtKey(ctx, t.NodeStore, t.Root, startInclusive, t.Order)
   461  		if err != nil {
   462  			return nil, nil, err
   463  		}
   464  	}
   465  
   466  	if len(stopExclusive) == 0 {
   467  		hi, err = newCursorPastEnd(ctx, t.NodeStore, t.Root)
   468  		if err != nil {
   469  			return nil, nil, err
   470  		}
   471  	} else {
   472  		hi, err = newCursorAtKey(ctx, t.NodeStore, t.Root, stopExclusive, t.Order)
   473  		if err != nil {
   474  			return nil, nil, err
   475  		}
   476  	}
   477  	return
   478  }
   479  
   480  // GetOrdinalForKey returns the smallest ordinal position at which the key >= |query|.
   481  func (t StaticMap[K, V, O]) GetOrdinalForKey(ctx context.Context, query K) (uint64, error) {
   482  	cur, err := newCursorAtKey(ctx, t.NodeStore, t.Root, query, t.Order)
   483  	if err != nil {
   484  		return 0, err
   485  	}
   486  	return getOrdinalOfCursor(cur)
   487  }
   488  
   489  type OrderedTreeIter[K, V ~[]byte] struct {
   490  	// current tuple location
   491  	curr *cursor
   492  
   493  	// the function called to moved |curr| forward in the direction of iteration.
   494  	step func(context.Context) error
   495  	// should return |true| if the passed in cursor is past the iteration's stopping point.
   496  	stop func(*cursor) bool
   497  }
   498  
   499  func ReverseOrderedTreeIterFromCursors[K, V ~[]byte](
   500  	ctx context.Context,
   501  	root Node, ns NodeStore,
   502  	findStart, findEnd SearchFn,
   503  ) (*OrderedTreeIter[K, V], error) {
   504  	start, err := newCursorFromSearchFn(ctx, ns, root, findStart)
   505  	if err != nil {
   506  		return nil, err
   507  	}
   508  	end, err := newCursorFromSearchFn(ctx, ns, root, findEnd)
   509  	if err != nil {
   510  		return nil, err
   511  	}
   512  	err = end.retreat(ctx)
   513  	if err != nil {
   514  		return nil, err
   515  	}
   516  
   517  	stopFn := func(curr *cursor) bool {
   518  		return curr.compare(start) < 0
   519  	}
   520  
   521  	if stopFn(end) {
   522  		end = nil // empty range
   523  	}
   524  
   525  	return &OrderedTreeIter[K, V]{curr: end, stop: stopFn, step: end.retreat}, nil
   526  }
   527  
   528  func OrderedTreeIterFromCursors[K, V ~[]byte](
   529  	ctx context.Context,
   530  	root Node, ns NodeStore,
   531  	findStart, findStop SearchFn,
   532  ) (*OrderedTreeIter[K, V], error) {
   533  	start, err := newCursorFromSearchFn(ctx, ns, root, findStart)
   534  	if err != nil {
   535  		return nil, err
   536  	}
   537  	stop, err := newCursorFromSearchFn(ctx, ns, root, findStop)
   538  	if err != nil {
   539  		return nil, err
   540  	}
   541  
   542  	stopFn := func(curr *cursor) bool {
   543  		return curr.compare(stop) >= 0
   544  	}
   545  
   546  	if stopFn(start) {
   547  		start = nil // empty range
   548  	}
   549  
   550  	return &OrderedTreeIter[K, V]{curr: start, stop: stopFn, step: start.advance}, nil
   551  }
   552  
   553  func (it *OrderedTreeIter[K, V]) Next(ctx context.Context) (key K, value V, err error) {
   554  	if it.curr == nil {
   555  		return nil, nil, io.EOF
   556  	}
   557  
   558  	k, v := currentCursorItems(it.curr)
   559  	key, value = K(k), V(v)
   560  
   561  	err = it.step(ctx)
   562  	if err != nil {
   563  		return nil, nil, err
   564  	}
   565  	if it.stop(it.curr) {
   566  		// past the end of the range
   567  		it.curr = nil
   568  	}
   569  
   570  	return
   571  }
   572  
   573  func (it *OrderedTreeIter[K, V]) Current() (key K, value V) {
   574  	// |it.curr| is set to nil when its range is exhausted
   575  	if it.curr != nil && it.curr.Valid() {
   576  		k, v := currentCursorItems(it.curr)
   577  		key, value = K(k), V(v)
   578  	}
   579  	return
   580  }
   581  
   582  func (it *OrderedTreeIter[K, V]) Iterate(ctx context.Context) (err error) {
   583  	err = it.step(ctx)
   584  	if err != nil {
   585  		return err
   586  	}
   587  
   588  	if it.stop(it.curr) {
   589  		// past the end of the range
   590  		it.curr = nil
   591  	}
   592  
   593  	return
   594  }
   595  
   596  type orderedLeafSpanIter[K, V ~[]byte] struct {
   597  	// in-progress node
   598  	nd Node
   599  	// current index,
   600  	curr int
   601  	// last index for |nd|
   602  	stop int
   603  	// remaining leaves
   604  	leaves []Node
   605  	// stop index in last leaf node
   606  	final int
   607  }
   608  
   609  func (s *orderedLeafSpanIter[K, V]) Next(ctx context.Context) (key K, value V, err error) {
   610  	if s.curr >= s.stop {
   611  		// |s.nd| exhausted
   612  		if len(s.leaves) == 0 {
   613  			// span exhausted
   614  			return nil, nil, io.EOF
   615  		}
   616  
   617  		s.nd = s.leaves[0]
   618  		s.curr = 0
   619  		s.stop = s.nd.Count()
   620  
   621  		s.leaves = s.leaves[1:]
   622  		if len(s.leaves) == 0 {
   623  			// |s.nd| is the last leaf
   624  			s.stop = s.final
   625  		}
   626  	}
   627  
   628  	key = K(s.nd.GetKey(s.curr))
   629  	value = V(s.nd.GetValue(s.curr))
   630  	s.curr++
   631  	return
   632  }