github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/branch_control/expr_parser_node.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package branch_control
    16  
    17  import (
    18  	"math"
    19  	"sync"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql"
    22  )
    23  
    24  var (
    25  	aiciSorter = sql.Collation_utf8mb4_0900_ai_ci.Sorter()
    26  	sortFuncs  = []func(r rune) int32{aiciSorter, aiciSorter, sql.Collation_utf8mb4_0900_bin.Sorter(), aiciSorter}
    27  )
    28  
    29  // MatchNode contains a collection of sort orders that allow for an optimized level of traversal compared to
    30  // MatchExpression due to the sharing of like sort orders, reducing the overall number of comparisons needed.
    31  type MatchNode struct {
    32  	SortOrders []int32              // These are the sort orders that will be compared against when matching a given rune.
    33  	Children   map[int32]*MatchNode // These are the children of this node that each represent a different path in the sort orders.
    34  	Data       *MatchNodeData       // This is the collection of data that the node holds. Will be nil if it's not a destination node.
    35  }
    36  
    37  // MatchNodeData is the data contained in a destination MatchNode.
    38  type MatchNodeData struct {
    39  	Permissions Permissions
    40  	RowIndex    uint32
    41  }
    42  
    43  // MatchResult contains the data and expression length of a successful match.
    44  type MatchResult struct {
    45  	MatchNodeData
    46  	Length uint32
    47  }
    48  
    49  // matchNodeCounted is an intermediary node used while processing matches that records the length of the match so far.
    50  // This may be used to distinguish between which matches are the longest.
    51  type matchNodeCounted struct {
    52  	MatchNode
    53  	Length uint32
    54  }
    55  
    56  // matchNodeCountedPool is a pool for MatchNodeCounted.
    57  var matchNodeCountedPool = &sync.Pool{
    58  	New: func() any {
    59  		return make([]matchNodeCounted, 0, 16)
    60  	},
    61  }
    62  
    63  // concatenatedSortOrderPool is a pool for concatenated sort orders.
    64  var concatenatedSortOrderPool = &sync.Pool{
    65  	New: func() any {
    66  		return make([]int32, 0, 128)
    67  	},
    68  }
    69  
    70  // Match returns a collection of results based on the given strings or expressions. When the parameters represent
    71  // standard strings, then this simply matches those strings against the parsed expressions. However, if the parameters
    72  // represent expressions, then this matches against all parsed expressions that are either duplicates or supersets of
    73  // the given expressions. This allows the user to "match" against new expressions to see if they are already covered.
    74  func (mn *MatchNode) Match(database, branch, user, host string) []MatchResult {
    75  	allSortOrders := mn.parseExpression(database, branch, user, host)
    76  	defer func() {
    77  		concatenatedSortOrderPool.Put(allSortOrders)
    78  	}()
    79  
    80  	// This is the slice that we'll put matches into. This will also flip to become the match subset. This way we reuse
    81  	// the underlying arrays. We grab this from the pool. These are not pointers, as we modify the data inside to
    82  	// simplify the loop's logic.
    83  	matches := matchNodeCountedPool.Get().([]matchNodeCounted)[:0]
    84  	// This is the slice we'll iterate over. We also grab this from the pool.
    85  	matchSubset := matchNodeCountedPool.Get().([]matchNodeCounted)[:0]
    86  	matchSubset = append(matchSubset, matchNodeCounted{
    87  		MatchNode: *mn,
    88  		Length:    0,
    89  	})
    90  
    91  	// Loop over the entire set of sort orders
    92  	for _, sortOrder := range allSortOrders {
    93  		for _, node := range matchSubset {
    94  			if len(node.SortOrders) == 0 {
    95  				// At most we'll look at three children that may match, we can ignore all other children
    96  				if child, ok := node.Children[singleMatch]; ok {
    97  					matches = processMatch(matches, matchNodeCounted{
    98  						MatchNode: *child,
    99  						Length:    node.Length,
   100  					}, sortOrder)
   101  				}
   102  				if child, ok := node.Children[anyMatch]; ok {
   103  					matches = processMatch(matches, matchNodeCounted{
   104  						MatchNode: *child,
   105  						Length:    node.Length,
   106  					}, sortOrder)
   107  				}
   108  				if child, ok := node.Children[sortOrder]; ok {
   109  					matches = processMatch(matches, matchNodeCounted{
   110  						MatchNode: *child,
   111  						Length:    node.Length,
   112  					}, sortOrder)
   113  				}
   114  				continue
   115  			}
   116  			matches = processMatch(matches, node, sortOrder)
   117  		}
   118  		// Swap the two, and put the slice of matches to be at the beginning of the previous subset array to reuse it
   119  		matches, matchSubset = matchSubset[:0], matches
   120  	}
   121  	// We're done with the matches slice, so put it back in the pool
   122  	matchNodeCountedPool.Put(matches)
   123  
   124  	// The subset may contain partial matches (which do not count), so we filter for only complete matches
   125  	results := make([]MatchResult, 0, len(matchSubset))
   126  	for _, node := range matchSubset {
   127  		if node.Data != nil {
   128  			if len(node.SortOrders) == 0 {
   129  				results = append(results, MatchResult{
   130  					MatchNodeData: *node.Data,
   131  					Length:        node.Length,
   132  				})
   133  			} else if len(node.SortOrders) == 1 && node.SortOrders[0] == anyMatch {
   134  				results = append(results, MatchResult{
   135  					MatchNodeData: *node.Data,
   136  					Length:        node.Length + 1,
   137  				})
   138  			}
   139  		}
   140  	}
   141  	// Now we're done with the subset slice, so put it back in the pool
   142  	matchNodeCountedPool.Put(matchSubset)
   143  	return results
   144  }
   145  
   146  // processMatch handles the behavior of how to process a sort order against a node. Returns a new slice with any newly
   147  // appended nodes (which should overwrite the first parameter in the calling function).
   148  func processMatch(matches []matchNodeCounted, node matchNodeCounted, sortOrder int32) []matchNodeCounted {
   149  	switch node.SortOrders[0] {
   150  	case singleMatch:
   151  		if sortOrder < singleMatch {
   152  			return matches
   153  		}
   154  		node.SortOrders = node.SortOrders[1:]
   155  		node.Length += 1
   156  		matches = append(matches, node)
   157  	case anyMatch:
   158  		// Since any match can be a zero-length match, we need to check if we also match the next sort order
   159  		if len(node.SortOrders) > 1 && node.SortOrders[1] == sortOrder {
   160  			matches = append(matches, matchNodeCounted{
   161  				MatchNode: MatchNode{
   162  					SortOrders: node.SortOrders[2:],
   163  					Children:   node.Children,
   164  					Data:       node.Data,
   165  				},
   166  				Length: node.Length + 2,
   167  			})
   168  		}
   169  		// Any match cannot match a columnMarker as they represent column boundaries
   170  		if sortOrder != columnMarker {
   171  			matches = append(matches, node)
   172  		}
   173  	default:
   174  		// NOTE: it's worth mentioning that separators only match with themselves, so no need for special logic
   175  		if sortOrder == node.SortOrders[0] {
   176  			node.SortOrders = node.SortOrders[1:]
   177  			node.Length += 1
   178  			matches = append(matches, node)
   179  		}
   180  	}
   181  	return matches
   182  }
   183  
   184  // Add will add the given expressions to the node hierarchy. If the expressions already exists, then this overwrites
   185  // the pre-existing entry. Assumes that the given expressions have already been folded.
   186  func (mn *MatchNode) Add(databaseExpr, branchExpr, userExpr, hostExpr string, data MatchNodeData) {
   187  	root := mn
   188  	allSortOrders := mn.parseExpression(databaseExpr, branchExpr, userExpr, hostExpr)
   189  	defer func() {
   190  		concatenatedSortOrderPool.Put(allSortOrders)
   191  	}()
   192  
   193  	remainingRootSortOrders := root.SortOrders
   194  	allSortOrdersMaxIndex := len(allSortOrders) - 1
   195  ParentLoop:
   196  	for i, sortOrder := range allSortOrders {
   197  		if remainingRootSortOrders[0] == sortOrder {
   198  			if len(remainingRootSortOrders) > 1 && i < allSortOrdersMaxIndex {
   199  				// There are more sort orders on both sides, so we simply continue
   200  				remainingRootSortOrders = remainingRootSortOrders[1:]
   201  				continue
   202  			} else if len(remainingRootSortOrders) > 1 && i == allSortOrdersMaxIndex {
   203  				// We have more sort orders on the root, but no more in our expressions, so we put the remaining root
   204  				// sort orders as a child and set this as a destination node
   205  				root.Children = map[int32]*MatchNode{remainingRootSortOrders[1]: {
   206  					SortOrders: remainingRootSortOrders[1:],
   207  					Children:   root.Children,
   208  					Data:       root.Data,
   209  				}}
   210  				root.SortOrders = root.SortOrders[:len(root.SortOrders)-len(remainingRootSortOrders)+1]
   211  				root.Data = &data
   212  				break
   213  			} else if len(remainingRootSortOrders) == 1 && i < allSortOrdersMaxIndex {
   214  				// We've run out of sort orders on the root, but still have more from children, so check if there's a
   215  				// matching child
   216  				nextSortOrder := allSortOrders[i+1]
   217  				if child, ok := root.Children[nextSortOrder]; ok {
   218  					remainingRootSortOrders = child.SortOrders
   219  					root = root.Children[nextSortOrder]
   220  					continue ParentLoop
   221  				}
   222  				// None of the children matched, so we create a new one and add it. As we're using a pool, we need to
   223  				// create a new slice.
   224  				originalSortOrders := allSortOrders[i+1:]
   225  				newSortOrders := make([]int32, len(originalSortOrders))
   226  				copy(newSortOrders, originalSortOrders)
   227  				root.Children[newSortOrders[0]] = &MatchNode{
   228  					SortOrders: newSortOrders,
   229  					Children:   make(map[int32]*MatchNode),
   230  					Data:       &data,
   231  				}
   232  				break
   233  			} else {
   234  				// We have no more sort orders on either side so this is an exact match, therefore we update the data
   235  				root.Data = &data
   236  				break
   237  			}
   238  		} else {
   239  			// Since the sort orders do not match, we create a child here with the remaining expressions' sort orders,
   240  			// and move the root's remaining sort orders to its own child.
   241  			splitRoot := &MatchNode{
   242  				SortOrders: remainingRootSortOrders,
   243  				Children:   root.Children,
   244  				Data:       root.Data,
   245  			}
   246  			// As we're using a pool, we need to create a new slice
   247  			originalSortOrders := allSortOrders[i:]
   248  			newSortOrders := make([]int32, len(originalSortOrders))
   249  			copy(newSortOrders, originalSortOrders)
   250  			newChild := &MatchNode{
   251  				SortOrders: newSortOrders,
   252  				Children:   make(map[int32]*MatchNode),
   253  				Data:       &data,
   254  			}
   255  			root.SortOrders = root.SortOrders[:len(root.SortOrders)-len(remainingRootSortOrders)]
   256  			root.Children = map[int32]*MatchNode{splitRoot.SortOrders[0]: splitRoot, newChild.SortOrders[0]: newChild}
   257  			// As the root's data is now in the split, we set the data here to nil as it's no longer a destination node.
   258  			// If it wasn't a destination node, then nothing changes (we just set the split's data to nil as well).
   259  			root.Data = nil
   260  			break
   261  		}
   262  	}
   263  }
   264  
   265  // Remove will remove the given expressions to the node hierarchy. If the expressions do not exist, then nothing
   266  // happens. Assumes that the given expressions have already been folded.
   267  func (mn *MatchNode) Remove(databaseExpr, branchExpr, userExpr, hostExpr string) uint32 {
   268  	root := mn
   269  	allSortOrders := mn.parseExpression(databaseExpr, branchExpr, userExpr, hostExpr)
   270  	defer func() {
   271  		concatenatedSortOrderPool.Put(allSortOrders)
   272  	}()
   273  
   274  	// We track the parent of the root node so that we can delete its child if applicable
   275  	var rootParent *MatchNode = nil
   276  	childIndex := int32(0)
   277  
   278  	remainingRootSortOrders := root.SortOrders
   279  	allSortOrdersMaxIndex := len(allSortOrders) - 1
   280  	removedIndex := uint32(math.MaxUint32)
   281  ParentLoop:
   282  	for i, sortOrder := range allSortOrders {
   283  		if remainingRootSortOrders[0] == sortOrder {
   284  			if len(remainingRootSortOrders) > 1 && i < allSortOrdersMaxIndex {
   285  				// There are more sort orders on both sides, so we simply continue
   286  				remainingRootSortOrders = remainingRootSortOrders[1:]
   287  				continue
   288  			} else if len(remainingRootSortOrders) > 1 && i == allSortOrdersMaxIndex {
   289  				// We have more sort orders on the root, but no more in our expressions, so this set of expressions
   290  				// don't have a match
   291  				break
   292  			} else if len(remainingRootSortOrders) == 1 && i < allSortOrdersMaxIndex {
   293  				// We've run out of sort orders on the root, but still have more from the expressions, so check if a
   294  				// child will match the next sort order from the expressions
   295  				nextSortOrder := allSortOrders[i+1]
   296  				if child, ok := root.Children[nextSortOrder]; ok {
   297  					remainingRootSortOrders = child.SortOrders
   298  					rootParent = root
   299  					childIndex = nextSortOrder
   300  					root = child
   301  					continue ParentLoop
   302  				}
   303  				// None of the children matched, so this set of expressions don't have a match
   304  				break
   305  			} else {
   306  				// We have no more sort orders on either side so this is an exact match.
   307  				// If it's a destination node, then we mark it as no longer being one.
   308  				if root.Data != nil {
   309  					removedIndex = root.Data.RowIndex
   310  				}
   311  				root.Data = nil
   312  				if len(root.Children) == 1 {
   313  					// Since there is only a single child, we merge it with this node
   314  					for _, child := range root.Children {
   315  						// The fact that you gotta do a range + break to get a single map element is silly
   316  						root.SortOrders = append(root.SortOrders, child.SortOrders...)
   317  						root.Data = child.Data
   318  						root.Children = nil
   319  						break
   320  					}
   321  				} else if len(root.Children) == 0 {
   322  					if rootParent != nil {
   323  						// With no children, we can remove this node from the parent
   324  						delete(rootParent.Children, childIndex)
   325  						// If the parent only has a single child, and it's not a destination node, we can merge that child
   326  						// with the parent
   327  						if len(rootParent.Children) == 1 && rootParent.Data == nil {
   328  							// Since there is only a single child, we merge it with this node
   329  							for _, child := range rootParent.Children {
   330  								// It was silly a few lines ago, and it's still silly here
   331  								rootParent.SortOrders = append(rootParent.SortOrders, child.SortOrders...)
   332  								rootParent.Data = child.Data
   333  								rootParent.Children = child.Children
   334  							}
   335  						}
   336  					} else {
   337  						// This is the base root of the table, and it has no children (they may have been merged with
   338  						// the base root in a previous deletion), so we completely reset its sort orders to the base state
   339  						root.SortOrders = []int32{columnMarker}
   340  					}
   341  				}
   342  				// If this node has multiple children then we have nothing more to do
   343  				break
   344  			}
   345  		} else {
   346  			// Since the sort orders do not match, that means that this set of expressions don't have a match
   347  			break
   348  		}
   349  	}
   350  	return removedIndex
   351  }
   352  
   353  // parseExpression parses expressions into a concatenated collection of sort orders. The returned slice belongs to the
   354  // pool, which, if possible, should be returned once it is no longer needed. As this function doesn't distinguish
   355  // between strings and expressions, it assumes any given expressions have already been folded.
   356  func (mn *MatchNode) parseExpression(database, branch, user, host string) []int32 {
   357  	if len(database) > math.MaxUint16 {
   358  		database = database[:math.MaxUint16]
   359  	}
   360  	if len(branch) > math.MaxUint16 {
   361  		branch = branch[:math.MaxUint16]
   362  	}
   363  	if len(user) > math.MaxUint16 {
   364  		user = user[:math.MaxUint16]
   365  	}
   366  	if len(host) > math.MaxUint16 {
   367  		host = host[:math.MaxUint16]
   368  	}
   369  
   370  	allSortOrders := concatenatedSortOrderPool.Get().([]int32)[:0]
   371  	for i, str := range []string{database, branch, user, host} {
   372  		escaped := false
   373  		sortFunc := sortFuncs[i]
   374  		allSortOrders = append(allSortOrders, columnMarker)
   375  		for _, r := range str {
   376  			if escaped {
   377  				escaped = false
   378  				allSortOrders = append(allSortOrders, sortFunc(r))
   379  			} else {
   380  				switch r {
   381  				case '\\':
   382  					escaped = true
   383  				case '%':
   384  					allSortOrders = append(allSortOrders, anyMatch)
   385  				case '_':
   386  					allSortOrders = append(allSortOrders, singleMatch)
   387  				default:
   388  					allSortOrders = append(allSortOrders, sortFunc(r))
   389  				}
   390  			}
   391  		}
   392  	}
   393  	return allSortOrders
   394  }