github.com/dolthub/go-mysql-server@v0.18.0/sql/plan/bindvar.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"github.com/dolthub/go-mysql-server/sql"
    19  	"github.com/dolthub/go-mysql-server/sql/expression"
    20  	"github.com/dolthub/go-mysql-server/sql/transform"
    21  )
    22  
    23  // ApplyBindings replaces all `BindVar` expressions in the given sql.Node with
    24  // their corresponding sql.Expression entries in the provided |bindings| map.
    25  // If a binding for a |BindVar| expression is not found in the map, no error is
    26  // returned and the |BindVar| expression is left in place. There is no check on
    27  // whether all entries in |bindings| are used at least once throughout the |n|
    28  // but a map of all the used |bindings| are returned.
    29  // sql.DeferredType instances will be resolved by the binding types.
    30  func ApplyBindings(n sql.Node, bindings map[string]sql.Expression) (sql.Node, map[string]bool, error) {
    31  	n, _, usedBindings, err := applyBindingsHelper(n, bindings)
    32  	if err != nil {
    33  		return nil, nil, err
    34  	}
    35  	return n, usedBindings, err
    36  }
    37  
    38  func fixBindings(expr sql.Expression, bindings map[string]sql.Expression) (sql.Expression, transform.TreeIdentity, map[string]bool, error) {
    39  	usedBindings := map[string]bool{}
    40  	switch e := expr.(type) {
    41  	case *expression.BindVar:
    42  		val, found := bindings[e.Name]
    43  		if found {
    44  			usedBindings[e.Name] = true
    45  			return val, transform.NewTree, usedBindings, nil
    46  		}
    47  	case *expression.GetField:
    48  		//TODO: aliases derived from arithmetic
    49  		// expressions on BindVars should have types
    50  		// re-evaluated
    51  		t, ok := e.Type().(sql.DeferredType)
    52  		if !ok {
    53  			return expr, transform.SameTree, nil, nil
    54  		}
    55  		val, found := bindings[t.Name()]
    56  		if !found {
    57  			return expr, transform.SameTree, nil, nil
    58  		}
    59  		usedBindings[t.Name()] = true
    60  		return expression.NewGetFieldWithTable(e.Index(), int(e.TableId()), val.Type().Promote(), e.Database(), e.Table(), e.Name(), val.IsNullable()), transform.NewTree, usedBindings, nil
    61  	case *Subquery:
    62  		// *Subquery is a sql.Expression with a sql.Node not reachable
    63  		// by the visitor. Manually apply bindings to [Query] field.
    64  		q, subUsedBindings, err := ApplyBindings(e.Query, bindings)
    65  		if err != nil {
    66  			return nil, transform.SameTree, nil, err
    67  		}
    68  		for binding := range subUsedBindings {
    69  			usedBindings[binding] = true
    70  		}
    71  		return e.WithQuery(q), transform.NewTree, usedBindings, nil
    72  	}
    73  	return expr, transform.SameTree, nil, nil
    74  }
    75  
    76  func applyBindingsHelper(n sql.Node, bindings map[string]sql.Expression) (sql.Node, transform.TreeIdentity, map[string]bool, error) {
    77  	usedBindings := map[string]bool{}
    78  	fixBindingsTransform := func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
    79  		newN, same, subUsedBindings, err := fixBindings(e, bindings)
    80  		for binding := range subUsedBindings {
    81  			usedBindings[binding] = true
    82  		}
    83  		return newN, same, err
    84  	}
    85  	newN, same, err := transform.NodeWithOpaque(n, func(node sql.Node) (sql.Node, transform.TreeIdentity, error) {
    86  		switch n := node.(type) {
    87  		case *JoinNode:
    88  			// *plan.IndexedJoin cannot implement sql.Expressioner
    89  			// because the column indexes get mis-ordered by FixFieldIndexesForExpressions.
    90  			if n.Op.IsLookup() {
    91  				cond, same, err := transform.Expr(n.Filter, fixBindingsTransform)
    92  				if err != nil {
    93  					return nil, transform.SameTree, err
    94  				}
    95  				return NewJoin(n.left, n.right, n.Op, cond).WithScopeLen(n.ScopeLen), same, nil
    96  			}
    97  		case *InsertInto:
    98  			// Manually apply bindings to [Source] because only [Destination]
    99  			// is a proper child.
   100  			newSource, same, subUsedBindings, err := applyBindingsHelper(n.Source, bindings)
   101  			if err != nil {
   102  				return nil, transform.SameTree, err
   103  			}
   104  			for binding := range subUsedBindings {
   105  				usedBindings[binding] = true
   106  			}
   107  			if same {
   108  				return transform.NodeExprs(n, fixBindingsTransform)
   109  			}
   110  			ne, _, err := transform.NodeExprs(n.WithSource(newSource), fixBindingsTransform)
   111  			return ne, transform.NewTree, err
   112  		case *DeferredFilteredTable:
   113  			ft := n.Table.(sql.FilteredTable)
   114  			var fixedFilters []sql.Expression
   115  			for _, filter := range ft.Filters() {
   116  				newFilter, _, err := transform.Expr(filter, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
   117  					if bindVar, ok := e.(*expression.BindVar); ok {
   118  						if val, found := bindings[bindVar.Name]; found {
   119  							usedBindings[bindVar.Name] = true
   120  							return val, transform.NewTree, nil
   121  						}
   122  					}
   123  					return e, transform.SameTree, nil
   124  				})
   125  				if err != nil {
   126  					return nil, transform.SameTree, err
   127  				}
   128  				fixedFilters = append(fixedFilters, newFilter)
   129  			}
   130  
   131  			newTbl := ft.WithFilters(nil, fixedFilters)
   132  			n.ResolvedTable.Table = newTbl
   133  			return n.ResolvedTable, transform.NewTree, nil
   134  		}
   135  		return transform.NodeExprs(node, fixBindingsTransform)
   136  	})
   137  	return newN, same, usedBindings, err
   138  }
   139  
   140  func HasEmptyTable(n sql.Node) bool {
   141  	found := transform.InspectUp(n, func(n sql.Node) bool {
   142  		_, ok := n.(*EmptyTable)
   143  		return ok
   144  	})
   145  	if found {
   146  		return true
   147  	}
   148  	ne, ok := n.(sql.Expressioner)
   149  	if !ok {
   150  		return false
   151  	}
   152  	for _, e := range ne.Expressions() {
   153  		found := transform.InspectExpr(e, func(e sql.Expression) bool {
   154  			sq, ok := e.(*Subquery)
   155  			if ok {
   156  				return HasEmptyTable(sq.Query)
   157  			}
   158  			return false
   159  		})
   160  		if found {
   161  			return true
   162  		}
   163  	}
   164  	return false
   165  }