github.com/dolthub/go-mysql-server@v0.18.0/sql/transform/node.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package transform
    16  
    17  import (
    18  	"github.com/dolthub/go-mysql-server/sql"
    19  )
    20  
    21  // NodeFunc is a function that given a node will return that node
    22  // as is or transformed, a TreeIdentity to indicate whether the
    23  // node was modified, and an error or nil.
    24  type NodeFunc func(n sql.Node) (sql.Node, TreeIdentity, error)
    25  
    26  // ExprFunc is a function that given an expression will return that
    27  // expression as is or transformed, a TreeIdentity to indicate
    28  // whether the expression was modified, and an error or nil.
    29  type ExprFunc func(e sql.Expression) (sql.Expression, TreeIdentity, error)
    30  
    31  // Context provides additional metadata to a SelectorFunc about the
    32  // active node in a traversal, including the parent node, and a
    33  // partial prefix schema of sibling nodes in a level order traversal.
    34  type Context struct {
    35  	// Node is the currently visited node which will be transformed.
    36  	Node sql.Node
    37  	// Parent is the current parent of the transforming node.
    38  	Parent sql.Node
    39  	// ChildNum is the index of Node in Parent.Children().
    40  	ChildNum int
    41  	// SchemaPrefix is the concatenation of the Parent's SchemaPrefix with
    42  	// child.Schema() for all child with an index < ChildNum in
    43  	// Parent.Children(). For many Node, this represents the schema of the
    44  	// |row| parameter that is going to be passed to this node by its
    45  	// parent in a RowIter() call. This field is only non-nil if the entire
    46  	// in-order traversal of the tree up to this point is Resolved().
    47  	SchemaPrefix sql.Schema
    48  }
    49  
    50  // CtxFunc is a function which will return new sql.Node values for a given
    51  // Context.
    52  type CtxFunc func(Context) (sql.Node, TreeIdentity, error)
    53  
    54  // SelectorFunc is a function which will allow NodeWithCtx to not
    55  // traverse past a certain Context. If this function returns |false|
    56  // for a given Context, the subtree is not transformed and the child
    57  // is kept in its existing place in the parent as-is.
    58  type SelectorFunc func(Context) bool
    59  
    60  // ExprWithNodeFunc is a function that given an expression and the node
    61  // that contains it, will return that expression as is or transformed
    62  // along with an error, if any.
    63  type ExprWithNodeFunc func(sql.Node, sql.Expression) (sql.Expression, TreeIdentity, error)
    64  
    65  // TreeIdentity tracks modifications to node and expression trees.
    66  // Only return SameTree when it is acceptable to return the original
    67  // input and discard the returned result as a performance improvement.
    68  type TreeIdentity bool
    69  
    70  const (
    71  	SameTree TreeIdentity = true
    72  	NewTree  TreeIdentity = false
    73  )
    74  
    75  // NodeExprsWithNode applies a transformation function to all expressions
    76  // on the given tree from the bottom up.
    77  func NodeExprsWithNode(node sql.Node, f ExprWithNodeFunc) (sql.Node, TreeIdentity, error) {
    78  	return Node(node, func(n sql.Node) (sql.Node, TreeIdentity, error) {
    79  		return OneNodeExprsWithNode(n, f)
    80  	})
    81  }
    82  
    83  // NodeExprs applies a transformation function to all expressions
    84  // on the given plan tree from the bottom up.
    85  func NodeExprs(node sql.Node, f ExprFunc) (sql.Node, TreeIdentity, error) {
    86  	return NodeExprsWithNode(node, func(n sql.Node, e sql.Expression) (sql.Expression, TreeIdentity, error) {
    87  		return f(e)
    88  	})
    89  }
    90  
    91  // NodeExprsWithNodeWithOpaque applies a transformation function to all expressions
    92  // on the given tree from the bottom up, including through opaque nodes.
    93  func NodeExprsWithNodeWithOpaque(node sql.Node, f ExprWithNodeFunc) (sql.Node, TreeIdentity, error) {
    94  	return NodeWithOpaque(node, func(n sql.Node) (sql.Node, TreeIdentity, error) {
    95  		return OneNodeExprsWithNode(n, f)
    96  	})
    97  }
    98  
    99  // NodeExprsWithOpaque applies a transformation function to all expressions
   100  // on the given plan tree from the bottom up, including through opaque nodes.
   101  func NodeExprsWithOpaque(node sql.Node, f ExprFunc) (sql.Node, TreeIdentity, error) {
   102  	return NodeExprsWithNodeWithOpaque(node, func(n sql.Node, e sql.Expression) (sql.Expression, TreeIdentity, error) {
   103  		return f(e)
   104  	})
   105  }
   106  
   107  // OneNodeExprsWithNode applies a transformation function to all expressions
   108  // on the specified node. It does not traverse the children of the specified node.
   109  func OneNodeExprsWithNode(n sql.Node, f ExprWithNodeFunc) (sql.Node, TreeIdentity, error) {
   110  	ne, ok := n.(sql.Expressioner)
   111  	if !ok {
   112  		return n, SameTree, nil
   113  	}
   114  
   115  	exprs := ne.Expressions()
   116  	if len(exprs) == 0 {
   117  		return n, SameTree, nil
   118  	}
   119  
   120  	var (
   121  		newExprs []sql.Expression
   122  		err      error
   123  	)
   124  
   125  	for i := range exprs {
   126  		e := exprs[i]
   127  		e, same, err := ExprWithNode(n, e, f)
   128  		if err != nil {
   129  			return nil, SameTree, err
   130  		}
   131  		if !same {
   132  			if newExprs == nil {
   133  				newExprs = make([]sql.Expression, len(exprs))
   134  				copy(newExprs, exprs)
   135  			}
   136  			newExprs[i] = e
   137  		}
   138  	}
   139  
   140  	if len(newExprs) > 0 {
   141  		n, err = ne.WithExpressions(newExprs...)
   142  		if err != nil {
   143  			return nil, SameTree, err
   144  		}
   145  		return n, NewTree, nil
   146  	}
   147  	return n, SameTree, nil
   148  }
   149  
   150  // OneNodeExpressions applies a transformation function to all expressions
   151  // on the specified node. It does not traverse the children of the specified node.
   152  func OneNodeExpressions(n sql.Node, f ExprFunc) (sql.Node, TreeIdentity, error) {
   153  	e, ok := n.(sql.Expressioner)
   154  	if !ok {
   155  		return n, SameTree, nil
   156  	}
   157  
   158  	exprs := e.Expressions()
   159  	if len(exprs) == 0 {
   160  		return n, SameTree, nil
   161  	}
   162  
   163  	var newExprs []sql.Expression
   164  	for i := range exprs {
   165  		expr := exprs[i]
   166  		expr, same, err := Expr(expr, f)
   167  		if err != nil {
   168  			return nil, SameTree, err
   169  		}
   170  		if !same {
   171  			if newExprs == nil {
   172  				newExprs = make([]sql.Expression, len(exprs))
   173  				copy(newExprs, exprs)
   174  			}
   175  			newExprs[i] = expr
   176  		}
   177  	}
   178  	if len(newExprs) > 0 {
   179  		n, err := e.WithExpressions(newExprs...)
   180  		if err != nil {
   181  			return nil, SameTree, err
   182  		}
   183  		return n, NewTree, nil
   184  	}
   185  	return n, SameTree, nil
   186  }
   187  
   188  // NodeWithCtx transforms |n| from the bottom up, left to right, by passing
   189  // each node to |f|. If |s| is non-nil, does not descend into children where
   190  // |s| returns false.
   191  func NodeWithCtx(n sql.Node, s SelectorFunc, f CtxFunc) (sql.Node, TreeIdentity, error) {
   192  	return nodeWithCtxHelper(Context{n, nil, -1, sql.Schema{}}, s, f)
   193  }
   194  
   195  func nodeWithCtxHelper(c Context, s SelectorFunc, f CtxFunc) (sql.Node, TreeIdentity, error) {
   196  	node := c.Node
   197  	_, ok := node.(sql.OpaqueNode)
   198  	if ok {
   199  		return f(c)
   200  	}
   201  
   202  	children := node.Children()
   203  	if len(children) == 0 {
   204  		return f(c)
   205  	}
   206  
   207  	var (
   208  		newChildren []sql.Node
   209  		err         error
   210  	)
   211  	for i := range children {
   212  		child := children[i]
   213  		cc := Context{child, node, i, nil}
   214  		if s == nil || s(cc) {
   215  			child, same, err := nodeWithCtxHelper(cc, s, f)
   216  			if err != nil {
   217  				return nil, SameTree, err
   218  			}
   219  			if !same {
   220  				if newChildren == nil {
   221  					newChildren = make([]sql.Node, len(children))
   222  					copy(newChildren, children)
   223  				}
   224  				newChildren[i] = child
   225  			}
   226  		}
   227  	}
   228  
   229  	sameC := SameTree
   230  	if len(newChildren) > 0 {
   231  		sameC = NewTree
   232  		node, err = node.WithChildren(newChildren...)
   233  		if err != nil {
   234  			return nil, SameTree, err
   235  		}
   236  	}
   237  
   238  	node, sameN, err := f(Context{node, c.Parent, c.ChildNum, c.SchemaPrefix})
   239  	if err != nil {
   240  		return nil, SameTree, err
   241  	}
   242  	return node, sameC && sameN, nil
   243  }
   244  
   245  // NodeWithPrefixSchema transforms |n| from the bottom up, left to right, by passing
   246  // each node to |f|. If |s| is non-nil, does not descend into children where
   247  // |s| returns false.
   248  func NodeWithPrefixSchema(n sql.Node, s SelectorFunc, f CtxFunc) (sql.Node, TreeIdentity, error) {
   249  	return transformUpWithPrefixSchemaHelper(Context{n, nil, -1, sql.Schema{}}, s, f)
   250  }
   251  
   252  func transformUpWithPrefixSchemaHelper(c Context, s SelectorFunc, f CtxFunc) (sql.Node, TreeIdentity, error) {
   253  	node := c.Node
   254  	_, ok := node.(sql.OpaqueNode)
   255  	if ok {
   256  		return f(c)
   257  	}
   258  
   259  	children := node.Children()
   260  	if len(children) == 0 {
   261  		return f(c)
   262  	}
   263  
   264  	var (
   265  		newChildren []sql.Node
   266  		err         error
   267  	)
   268  
   269  	childPrefix := append(sql.Schema{}, c.SchemaPrefix...)
   270  	for i := range children {
   271  		child := children[i]
   272  		cc := Context{child, node, i, childPrefix}
   273  		if s == nil || s(cc) {
   274  			child, same, err := transformUpWithPrefixSchemaHelper(cc, s, f)
   275  			if err != nil {
   276  				return nil, SameTree, err
   277  			}
   278  			if !same {
   279  				if newChildren == nil {
   280  					newChildren = make([]sql.Node, len(children))
   281  					copy(newChildren, children)
   282  				}
   283  				newChildren[i] = child
   284  			}
   285  			if child.Resolved() && childPrefix != nil {
   286  				cs := child.Schema()
   287  				childPrefix = append(childPrefix, cs...)
   288  			} else {
   289  				childPrefix = nil
   290  			}
   291  		}
   292  	}
   293  
   294  	sameC := SameTree
   295  	if len(newChildren) > 0 {
   296  		sameC = NewTree
   297  		node, err = node.WithChildren(newChildren...)
   298  		if err != nil {
   299  			return nil, SameTree, err
   300  		}
   301  	}
   302  
   303  	node, sameN, err := f(Context{node, c.Parent, c.ChildNum, c.SchemaPrefix})
   304  	if err != nil {
   305  		return nil, SameTree, err
   306  	}
   307  	return node, sameC && sameN, nil
   308  }
   309  
   310  // Node applies a transformation function to the given tree from the
   311  // bottom up.
   312  func Node(node sql.Node, f NodeFunc) (sql.Node, TreeIdentity, error) {
   313  	_, ok := node.(sql.OpaqueNode)
   314  	if ok {
   315  		return f(node)
   316  	}
   317  
   318  	children := node.Children()
   319  	if len(children) == 0 {
   320  		return f(node)
   321  	}
   322  
   323  	var (
   324  		newChildren []sql.Node
   325  		child       sql.Node
   326  	)
   327  
   328  	for i := range children {
   329  		child = children[i]
   330  		child, same, err := Node(child, f)
   331  		if err != nil {
   332  			return nil, SameTree, err
   333  		}
   334  		if !same {
   335  			if newChildren == nil {
   336  				newChildren = make([]sql.Node, len(children))
   337  				copy(newChildren, children)
   338  			}
   339  			newChildren[i] = child
   340  		}
   341  	}
   342  
   343  	var err error
   344  	sameC := SameTree
   345  	if len(newChildren) > 0 {
   346  		sameC = NewTree
   347  		node, err = node.WithChildren(newChildren...)
   348  		if err != nil {
   349  			return nil, SameTree, err
   350  		}
   351  	}
   352  
   353  	node, sameN, err := f(node)
   354  	if err != nil {
   355  		return nil, SameTree, err
   356  	}
   357  	return node, sameC && sameN, nil
   358  }
   359  
   360  // NodeWithOpaque applies a transformation function to the given tree from the bottom up, including through
   361  // opaque nodes. This method is generally not safe to use for a transformation. Opaque nodes need to be considered in
   362  // isolation except for very specific exceptions.
   363  func NodeWithOpaque(node sql.Node, f NodeFunc) (sql.Node, TreeIdentity, error) {
   364  	children := node.Children()
   365  	if len(children) == 0 {
   366  		return f(node)
   367  	}
   368  
   369  	var (
   370  		newChildren []sql.Node
   371  		err         error
   372  	)
   373  
   374  	for i := range children {
   375  		c := children[i]
   376  		c, same, err := NodeWithOpaque(c, f)
   377  		if err != nil {
   378  			return nil, SameTree, err
   379  		}
   380  		if !same {
   381  			if newChildren == nil {
   382  				newChildren = make([]sql.Node, len(children))
   383  				copy(newChildren, children)
   384  			}
   385  			newChildren[i] = c
   386  		}
   387  	}
   388  
   389  	sameC := SameTree
   390  	if len(newChildren) > 0 {
   391  		sameC = NewTree
   392  		node, err = node.WithChildren(newChildren...)
   393  		if err != nil {
   394  			return nil, SameTree, err
   395  		}
   396  	}
   397  	node, sameN, err := f(node)
   398  	if err != nil {
   399  		return nil, SameTree, err
   400  	}
   401  	return node, sameC && sameN, nil
   402  }
   403  
   404  // NodeChildren applies a transformation function to the given node's children.
   405  func NodeChildren(node sql.Node, f NodeFunc) (sql.Node, TreeIdentity, error) {
   406  	children := node.Children()
   407  	if len(children) == 0 {
   408  		return node, SameTree, nil
   409  	}
   410  
   411  	var (
   412  		newChildren []sql.Node
   413  		child       sql.Node
   414  	)
   415  
   416  	for i := range children {
   417  		child = children[i]
   418  		child, same, err := f(child)
   419  		if err != nil {
   420  			return nil, SameTree, err
   421  		}
   422  		if !same {
   423  			if newChildren == nil {
   424  				newChildren = make([]sql.Node, len(children))
   425  				copy(newChildren, children)
   426  			}
   427  			newChildren[i] = child
   428  		}
   429  	}
   430  
   431  	var err error
   432  	if len(newChildren) > 0 {
   433  		node, err = node.WithChildren(newChildren...)
   434  		return node, NewTree, err
   435  	}
   436  	return node, SameTree, nil
   437  }