vitess.io/vitess@v0.16.2/go/vt/vtgate/planbuilder/expr.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package planbuilder
    18  
    19  import (
    20  	"bytes"
    21  	"fmt"
    22  
    23  	"vitess.io/vitess/go/vt/vterrors"
    24  
    25  	"vitess.io/vitess/go/vt/sqlparser"
    26  	"vitess.io/vitess/go/vt/vtgate/engine"
    27  	"vitess.io/vitess/go/vt/vtgate/vindexes"
    28  )
    29  
    30  type subqueryInfo struct {
    31  	ast    *sqlparser.Subquery
    32  	plan   logicalPlan
    33  	origin logicalPlan
    34  }
    35  
    36  // findOrigin identifies the right-most origin referenced by expr. In situations where
    37  // the expression references columns from multiple origins, the expression will be
    38  // pushed to the right-most origin, and the executor will use the results of
    39  // the previous origins to feed the necessary values to the primitives on the right.
    40  //
    41  // If the expression contains a subquery, the right-most origin identification
    42  // also follows the same rules of a normal expression. This is achieved by
    43  // looking at the Externs field of its symbol table that contains the list of
    44  // external references.
    45  //
    46  // Once the target origin is identified, we have to verify that the subquery's
    47  // route can be merged with it. If it cannot, we fail the query. This is because
    48  // we don't have the ability to wire up subqueries through expression evaluation
    49  // primitives. Consequently, if the plan for a subquery comes out as a Join,
    50  // we can immediately error out.
    51  //
    52  // Since findOrigin can itself be called from within a subquery, it has to assume
    53  // that some of the external references may actually be pointing to an outer
    54  // query. The isLocal response from the symtab is used to make sure that we
    55  // only analyze symbols that point to the current symtab.
    56  //
    57  // If an expression has no references to the current query, then the left-most
    58  // origin is chosen as the default.
    59  func (pb *primitiveBuilder) findOrigin(expr sqlparser.Expr, reservedVars *sqlparser.ReservedVars) (pullouts []*pulloutSubquery, origin logicalPlan, pushExpr sqlparser.Expr, err error) {
    60  	// highestOrigin tracks the highest origin referenced by the expression.
    61  	// Default is the first.
    62  	highestOrigin := first(pb.plan)
    63  
    64  	// subqueries tracks the list of subqueries encountered.
    65  	var subqueries []subqueryInfo
    66  
    67  	// constructsMap tracks the sub-construct in which a subquery
    68  	// occurred. The construct type decides on how the query gets
    69  	// pulled out.
    70  	constructsMap := make(map[*sqlparser.Subquery]sqlparser.Expr)
    71  
    72  	err = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
    73  		switch node := node.(type) {
    74  		case *sqlparser.ColName:
    75  			newOrigin, isLocal, err := pb.st.Find(node)
    76  			if err != nil {
    77  				return false, err
    78  			}
    79  			if isLocal && newOrigin.Order() > highestOrigin.Order() {
    80  				highestOrigin = newOrigin
    81  			}
    82  		case *sqlparser.ComparisonExpr:
    83  			if node.Operator == sqlparser.InOp || node.Operator == sqlparser.NotInOp {
    84  				if sq, ok := node.Right.(*sqlparser.Subquery); ok {
    85  					constructsMap[sq] = node
    86  				}
    87  			}
    88  		case *sqlparser.ExistsExpr:
    89  			constructsMap[node.Subquery] = node
    90  		case *sqlparser.Subquery:
    91  			spb := newPrimitiveBuilder(pb.vschema, pb.jt)
    92  			switch stmt := node.Select.(type) {
    93  			case *sqlparser.Select:
    94  				if err := spb.processSelect(stmt, reservedVars, pb.st, ""); err != nil {
    95  					return false, err
    96  				}
    97  			case *sqlparser.Union:
    98  				if err := spb.processUnion(stmt, reservedVars, pb.st); err != nil {
    99  					return false, err
   100  				}
   101  			default:
   102  				return false, vterrors.VT13001(fmt.Sprintf("unexpected SELECT type: %T", node))
   103  			}
   104  			sqi := subqueryInfo{
   105  				ast:  node,
   106  				plan: spb.plan,
   107  			}
   108  			for _, extern := range spb.st.Externs {
   109  				// No error expected. These are resolved externs.
   110  				newOrigin, isLocal, _ := pb.st.Find(extern)
   111  				if !isLocal {
   112  					continue
   113  				}
   114  				if highestOrigin.Order() < newOrigin.Order() {
   115  					highestOrigin = newOrigin
   116  				}
   117  				if sqi.origin == nil {
   118  					sqi.origin = newOrigin
   119  				} else if sqi.origin.Order() < newOrigin.Order() {
   120  					sqi.origin = newOrigin
   121  				}
   122  			}
   123  			subqueries = append(subqueries, sqi)
   124  			return false, nil
   125  		}
   126  		return true, nil
   127  	}, expr)
   128  	if err != nil {
   129  		return nil, nil, nil, err
   130  	}
   131  
   132  	highestRoute, _ := highestOrigin.(*route)
   133  	for _, sqi := range subqueries {
   134  		subroute, _ := sqi.plan.(*route)
   135  		if highestRoute != nil && subroute != nil && highestRoute.MergeSubquery(pb, subroute) {
   136  			continue
   137  		}
   138  		if sqi.origin != nil {
   139  			return nil, nil, nil, vterrors.VT12001("cross-shard correlated subquery")
   140  		}
   141  
   142  		sqName, hasValues := pb.jt.GenerateSubqueryVars()
   143  		construct, ok := constructsMap[sqi.ast]
   144  		if !ok {
   145  			// (subquery) -> :_sq
   146  			expr = sqlparser.ReplaceExpr(expr, sqi.ast, sqlparser.NewArgument(sqName))
   147  			pullouts = append(pullouts, newPulloutSubquery(engine.PulloutValue, sqName, hasValues, sqi.plan))
   148  			continue
   149  		}
   150  		switch construct := construct.(type) {
   151  		case *sqlparser.ComparisonExpr:
   152  			if construct.Operator == sqlparser.InOp {
   153  				// a in (subquery) -> (:__sq_has_values = 1 and (a in ::__sq))
   154  				right := &sqlparser.ComparisonExpr{
   155  					Operator: construct.Operator,
   156  					Left:     construct.Left,
   157  					Right:    sqlparser.ListArg(sqName),
   158  				}
   159  				left := &sqlparser.ComparisonExpr{
   160  					Left:     sqlparser.NewArgument(hasValues),
   161  					Operator: sqlparser.EqualOp,
   162  					Right:    sqlparser.NewIntLiteral("1"),
   163  				}
   164  				newExpr := &sqlparser.AndExpr{
   165  					Left:  left,
   166  					Right: right,
   167  				}
   168  				expr = sqlparser.ReplaceExpr(expr, construct, newExpr)
   169  				pullouts = append(pullouts, newPulloutSubquery(engine.PulloutIn, sqName, hasValues, sqi.plan))
   170  			} else {
   171  				// a not in (subquery) -> (:__sq_has_values = 0 or (a not in ::__sq))
   172  				left := &sqlparser.ComparisonExpr{
   173  					Left:     sqlparser.NewArgument(hasValues),
   174  					Operator: sqlparser.EqualOp,
   175  					Right:    sqlparser.NewIntLiteral("0"),
   176  				}
   177  				right := &sqlparser.ComparisonExpr{
   178  					Operator: construct.Operator,
   179  					Left:     construct.Left,
   180  					Right:    sqlparser.ListArg(sqName),
   181  				}
   182  				newExpr := &sqlparser.OrExpr{
   183  					Left:  left,
   184  					Right: right,
   185  				}
   186  				expr = sqlparser.ReplaceExpr(expr, construct, newExpr)
   187  				pullouts = append(pullouts, newPulloutSubquery(engine.PulloutNotIn, sqName, hasValues, sqi.plan))
   188  			}
   189  		case *sqlparser.ExistsExpr:
   190  			// exists (subquery) -> :__sq_has_values
   191  			expr = sqlparser.ReplaceExpr(expr, construct, sqlparser.NewArgument(hasValues))
   192  			pullouts = append(pullouts, newPulloutSubquery(engine.PulloutExists, sqName, hasValues, sqi.plan))
   193  		}
   194  	}
   195  	return pullouts, highestOrigin, expr, nil
   196  }
   197  
   198  var dummyErr = vterrors.VT13001("dummy")
   199  
   200  func hasSubquery(node sqlparser.SQLNode) bool {
   201  	has := false
   202  	_ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
   203  		switch node.(type) {
   204  		case *sqlparser.DerivedTable, *sqlparser.Subquery:
   205  			has = true
   206  			return false, dummyErr
   207  		}
   208  		return true, nil
   209  	}, node)
   210  	return has
   211  }
   212  
   213  func (pb *primitiveBuilder) finalizeUnshardedDMLSubqueries(reservedVars *sqlparser.ReservedVars, nodes ...sqlparser.SQLNode) (bool, []*vindexes.Table) {
   214  	var keyspace string
   215  	var tables []*vindexes.Table
   216  	if rb, ok := pb.plan.(*route); ok {
   217  		keyspace = rb.eroute.Keyspace.Name
   218  	} else {
   219  		// This code is unreachable because the caller checks.
   220  		return false, nil
   221  	}
   222  
   223  	for _, node := range nodes {
   224  		samePlan := true
   225  		inSubQuery := false
   226  		_ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
   227  			switch nodeType := node.(type) {
   228  			case *sqlparser.Subquery, *sqlparser.Insert:
   229  				inSubQuery = true
   230  				return true, nil
   231  			case *sqlparser.Select:
   232  				if !inSubQuery {
   233  					return true, nil
   234  				}
   235  				spb := newPrimitiveBuilder(pb.vschema, pb.jt)
   236  				if err := spb.processSelect(nodeType, reservedVars, pb.st, ""); err != nil {
   237  					samePlan = false
   238  					return false, err
   239  				}
   240  				innerRoute, ok := spb.plan.(*route)
   241  				if !ok {
   242  					samePlan = false
   243  					return false, dummyErr
   244  				}
   245  				if innerRoute.eroute.Keyspace.Name != keyspace {
   246  					samePlan = false
   247  					return false, dummyErr
   248  				}
   249  				for _, sub := range innerRoute.substitutions {
   250  					*sub.oldExpr = *sub.newExpr
   251  				}
   252  				spbTables, err := spb.st.AllVschemaTableNames()
   253  				if err != nil {
   254  					return false, err
   255  				}
   256  				tables = append(tables, spbTables...)
   257  			case *sqlparser.Union:
   258  				if !inSubQuery {
   259  					return true, nil
   260  				}
   261  				spb := newPrimitiveBuilder(pb.vschema, pb.jt)
   262  				if err := spb.processUnion(nodeType, reservedVars, pb.st); err != nil {
   263  					samePlan = false
   264  					return false, err
   265  				}
   266  				innerRoute, ok := spb.plan.(*route)
   267  				if !ok {
   268  					samePlan = false
   269  					return false, dummyErr
   270  				}
   271  				if innerRoute.eroute.Keyspace.Name != keyspace {
   272  					samePlan = false
   273  					return false, dummyErr
   274  				}
   275  			}
   276  
   277  			return true, nil
   278  		}, node)
   279  		if !samePlan {
   280  			return false, nil
   281  		}
   282  	}
   283  	return true, tables
   284  }
   285  
   286  func valEqual(a, b sqlparser.Expr) bool {
   287  	switch a := a.(type) {
   288  	case *sqlparser.ColName:
   289  		if b, ok := b.(*sqlparser.ColName); ok {
   290  			return a.Metadata == b.Metadata
   291  		}
   292  	case sqlparser.Argument:
   293  		b, ok := b.(sqlparser.Argument)
   294  		if !ok {
   295  			return false
   296  		}
   297  		return a == b
   298  	case *sqlparser.Literal:
   299  		b, ok := b.(*sqlparser.Literal)
   300  		if !ok {
   301  			return false
   302  		}
   303  		switch a.Type {
   304  		case sqlparser.StrVal:
   305  			switch b.Type {
   306  			case sqlparser.StrVal:
   307  				return a.Val == b.Val
   308  			case sqlparser.HexVal:
   309  				return hexEqual(b, a)
   310  			}
   311  		case sqlparser.HexVal:
   312  			return hexEqual(a, b)
   313  		case sqlparser.IntVal:
   314  			if b.Type == (sqlparser.IntVal) {
   315  				return a.Val == b.Val
   316  			}
   317  		}
   318  	}
   319  	return false
   320  }
   321  
   322  func hexEqual(a, b *sqlparser.Literal) bool {
   323  	v, err := a.HexDecode()
   324  	if err != nil {
   325  		return false
   326  	}
   327  	switch b.Type {
   328  	case sqlparser.StrVal:
   329  		return bytes.Equal(v, b.Bytes())
   330  	case sqlparser.HexVal:
   331  		v2, err := b.HexDecode()
   332  		if err != nil {
   333  			return false
   334  		}
   335  		return bytes.Equal(v, v2)
   336  	}
   337  	return false
   338  }