github.com/vedadiyan/sqlparser@v1.0.0/pkg/sqlparser/ast_rewriting.go (about)

     1  /*
     2  Copyright 2020 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package sqlparser
    18  
    19  import (
    20  	"strconv"
    21  	"strings"
    22  
    23  	querypb "github.com/vedadiyan/sqlparser/pkg/query"
    24  	"github.com/vedadiyan/sqlparser/pkg/sysvars"
    25  	"github.com/vedadiyan/sqlparser/pkg/vterrors"
    26  	vtrpcpb "github.com/vedadiyan/sqlparser/pkg/vtrpc"
    27  )
    28  
    29  var (
    30  	subQueryBaseArgName = []byte("__sq")
    31  
    32  	// HasValueSubQueryBaseName is the prefix of each parameter representing an EXISTS subquery
    33  	HasValueSubQueryBaseName = []byte("__sq_has_values")
    34  )
    35  
    36  // SQLSelectLimitUnset default value for sql_select_limit not set.
    37  const SQLSelectLimitUnset = -1
    38  
    39  // RewriteASTResult contains the rewritten ast and meta information about it
    40  type RewriteASTResult struct {
    41  	*BindVarNeeds
    42  	AST Statement // The rewritten AST
    43  }
    44  
    45  // ReservedVars keeps track of the bind variable names that have already been used
    46  // in a parsed query.
    47  type ReservedVars struct {
    48  	prefix       string
    49  	reserved     BindVars
    50  	next         []byte
    51  	counter      int
    52  	fast, static bool
    53  	sqNext       int64
    54  }
    55  
    56  type VSchemaViews interface {
    57  	FindView(name TableName) SelectStatement
    58  }
    59  
    60  // ReserveAll tries to reserve all the given variable names. If they're all available,
    61  // they are reserved and the function returns true. Otherwise, the function returns false.
    62  func (r *ReservedVars) ReserveAll(names ...string) bool {
    63  	for _, name := range names {
    64  		if _, ok := r.reserved[name]; ok {
    65  			return false
    66  		}
    67  	}
    68  	for _, name := range names {
    69  		r.reserved[name] = struct{}{}
    70  	}
    71  	return true
    72  }
    73  
    74  // ReserveColName reserves a variable name for the given column; if a variable
    75  // with the same name already exists, it'll be suffixed with a numberic identifier
    76  // to make it unique.
    77  func (r *ReservedVars) ReserveColName(col *ColName) string {
    78  	compliantName := col.CompliantName()
    79  	if r.fast && strings.HasPrefix(compliantName, r.prefix) {
    80  		compliantName = "_" + compliantName
    81  	}
    82  
    83  	joinVar := []byte(compliantName)
    84  	baseLen := len(joinVar)
    85  	i := int64(1)
    86  
    87  	for {
    88  		if _, ok := r.reserved[string(joinVar)]; !ok {
    89  			r.reserved[string(joinVar)] = struct{}{}
    90  			return string(joinVar)
    91  		}
    92  		joinVar = strconv.AppendInt(joinVar[:baseLen], i, 10)
    93  		i++
    94  	}
    95  }
    96  
    97  // ReserveSubQuery returns the next argument name to replace subquery with pullout value.
    98  func (r *ReservedVars) ReserveSubQuery() string {
    99  	for {
   100  		r.sqNext++
   101  		joinVar := strconv.AppendInt(subQueryBaseArgName, r.sqNext, 10)
   102  		if _, ok := r.reserved[string(joinVar)]; !ok {
   103  			r.reserved[string(joinVar)] = struct{}{}
   104  			return string(joinVar)
   105  		}
   106  	}
   107  }
   108  
   109  // ReserveSubQueryWithHasValues returns the next argument name to replace subquery with pullout value.
   110  func (r *ReservedVars) ReserveSubQueryWithHasValues() (string, string) {
   111  	for {
   112  		r.sqNext++
   113  		joinVar := strconv.AppendInt(subQueryBaseArgName, r.sqNext, 10)
   114  		hasValuesJoinVar := strconv.AppendInt(HasValueSubQueryBaseName, r.sqNext, 10)
   115  		_, joinVarOK := r.reserved[string(joinVar)]
   116  		_, hasValuesJoinVarOK := r.reserved[string(hasValuesJoinVar)]
   117  		if !joinVarOK && !hasValuesJoinVarOK {
   118  			r.reserved[string(joinVar)] = struct{}{}
   119  			r.reserved[string(hasValuesJoinVar)] = struct{}{}
   120  			return string(joinVar), string(hasValuesJoinVar)
   121  		}
   122  	}
   123  }
   124  
   125  // ReserveHasValuesSubQuery returns the next argument name to replace subquery with has value.
   126  func (r *ReservedVars) ReserveHasValuesSubQuery() string {
   127  	for {
   128  		r.sqNext++
   129  		joinVar := strconv.AppendInt(HasValueSubQueryBaseName, r.sqNext, 10)
   130  		if _, ok := r.reserved[string(joinVar)]; !ok {
   131  			r.reserved[string(joinVar)] = struct{}{}
   132  			return string(joinVar)
   133  		}
   134  	}
   135  }
   136  
   137  const staticBvar10 = "vtg0vtg1vtg2vtg3vtg4vtg5vtg6vtg7vtg8vtg9"
   138  const staticBvar100 = "vtg10vtg11vtg12vtg13vtg14vtg15vtg16vtg17vtg18vtg19vtg20vtg21vtg22vtg23vtg24vtg25vtg26vtg27vtg28vtg29vtg30vtg31vtg32vtg33vtg34vtg35vtg36vtg37vtg38vtg39vtg40vtg41vtg42vtg43vtg44vtg45vtg46vtg47vtg48vtg49vtg50vtg51vtg52vtg53vtg54vtg55vtg56vtg57vtg58vtg59vtg60vtg61vtg62vtg63vtg64vtg65vtg66vtg67vtg68vtg69vtg70vtg71vtg72vtg73vtg74vtg75vtg76vtg77vtg78vtg79vtg80vtg81vtg82vtg83vtg84vtg85vtg86vtg87vtg88vtg89vtg90vtg91vtg92vtg93vtg94vtg95vtg96vtg97vtg98vtg99"
   139  
   140  func (r *ReservedVars) nextUnusedVar() string {
   141  	if r.fast {
   142  		r.counter++
   143  
   144  		if r.static {
   145  			switch {
   146  			case r.counter < 10:
   147  				ofs := r.counter * 4
   148  				return staticBvar10[ofs : ofs+4]
   149  			case r.counter < 100:
   150  				ofs := (r.counter - 10) * 5
   151  				return staticBvar100[ofs : ofs+5]
   152  			}
   153  		}
   154  
   155  		r.next = strconv.AppendInt(r.next[:len(r.prefix)], int64(r.counter), 10)
   156  		return string(r.next)
   157  	}
   158  
   159  	for {
   160  		r.counter++
   161  		r.next = strconv.AppendInt(r.next[:len(r.prefix)], int64(r.counter), 10)
   162  
   163  		if _, ok := r.reserved[string(r.next)]; !ok {
   164  			bvar := string(r.next)
   165  			r.reserved[bvar] = struct{}{}
   166  			return bvar
   167  		}
   168  	}
   169  }
   170  
   171  // NewReservedVars allocates a ReservedVar instance that will generate unique
   172  // variable names starting with the given `prefix` and making sure that they
   173  // don't conflict with the given set of `known` variables.
   174  func NewReservedVars(prefix string, known BindVars) *ReservedVars {
   175  	rv := &ReservedVars{
   176  		prefix:   prefix,
   177  		counter:  0,
   178  		reserved: known,
   179  		fast:     true,
   180  		next:     []byte(prefix),
   181  	}
   182  
   183  	if prefix != "" && prefix[0] == '_' {
   184  		panic("cannot reserve variables with a '_' prefix")
   185  	}
   186  
   187  	for bvar := range known {
   188  		if strings.HasPrefix(bvar, prefix) {
   189  			rv.fast = false
   190  			break
   191  		}
   192  	}
   193  
   194  	if prefix == "vtg" {
   195  		rv.static = true
   196  	}
   197  	return rv
   198  }
   199  
   200  // PrepareAST will normalize the query
   201  func PrepareAST(
   202  	in Statement,
   203  	reservedVars *ReservedVars,
   204  	bindVars map[string]*querypb.BindVariable,
   205  	parameterize bool,
   206  	keyspace string,
   207  	selectLimit int,
   208  	setVarComment string,
   209  	sysVars map[string]string,
   210  	views VSchemaViews,
   211  ) (*RewriteASTResult, error) {
   212  	if parameterize {
   213  		err := Normalize(in, reservedVars, bindVars)
   214  		if err != nil {
   215  			return nil, err
   216  		}
   217  	}
   218  	return RewriteAST(in, keyspace, selectLimit, setVarComment, sysVars, views)
   219  }
   220  
   221  // RewriteAST rewrites the whole AST, replacing function calls and adding column aliases to queries.
   222  // SET_VAR comments are also added to the AST if required.
   223  func RewriteAST(
   224  	in Statement,
   225  	keyspace string,
   226  	selectLimit int,
   227  	setVarComment string,
   228  	sysVars map[string]string,
   229  	views VSchemaViews,
   230  ) (*RewriteASTResult, error) {
   231  	er := newASTRewriter(keyspace, selectLimit, setVarComment, sysVars, views)
   232  	er.shouldRewriteDatabaseFunc = shouldRewriteDatabaseFunc(in)
   233  	result := SafeRewrite(in, er.rewriteDown, er.rewriteUp)
   234  	if er.err != nil {
   235  		return nil, er.err
   236  	}
   237  
   238  	out, ok := result.(Statement)
   239  	if !ok {
   240  		return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "statement rewriting returned a non statement: %s", String(out))
   241  	}
   242  
   243  	r := &RewriteASTResult{
   244  		AST:          out,
   245  		BindVarNeeds: er.bindVars,
   246  	}
   247  	return r, nil
   248  }
   249  
   250  func shouldRewriteDatabaseFunc(in Statement) bool {
   251  	selct, ok := in.(*Select)
   252  	if !ok {
   253  		return false
   254  	}
   255  	if len(selct.From) != 1 {
   256  		return false
   257  	}
   258  	aliasedTable, ok := selct.From[0].(*AliasedTableExpr)
   259  	if !ok {
   260  		return false
   261  	}
   262  	tableName, ok := aliasedTable.Expr.(TableName)
   263  	if !ok {
   264  		return false
   265  	}
   266  	return tableName.Name.String() == "dual"
   267  }
   268  
   269  type astRewriter struct {
   270  	bindVars                  *BindVarNeeds
   271  	shouldRewriteDatabaseFunc bool
   272  	err                       error
   273  
   274  	// we need to know this to make a decision if we can safely rewrite JOIN USING => JOIN ON
   275  	hasStarInSelect bool
   276  
   277  	keyspace      string
   278  	selectLimit   int
   279  	setVarComment string
   280  	sysVars       map[string]string
   281  	views         VSchemaViews
   282  }
   283  
   284  func newASTRewriter(keyspace string, selectLimit int, setVarComment string, sysVars map[string]string, views VSchemaViews) *astRewriter {
   285  	return &astRewriter{
   286  		bindVars:      &BindVarNeeds{},
   287  		keyspace:      keyspace,
   288  		selectLimit:   selectLimit,
   289  		setVarComment: setVarComment,
   290  		sysVars:       sysVars,
   291  		views:         views,
   292  	}
   293  }
   294  
   295  const (
   296  	// LastInsertIDName is a reserved bind var name for last_insert_id()
   297  	LastInsertIDName = "__lastInsertId"
   298  
   299  	// DBVarName is a reserved bind var name for database()
   300  	DBVarName = "__vtdbname"
   301  
   302  	// FoundRowsName is a reserved bind var name for found_rows()
   303  	FoundRowsName = "__vtfrows"
   304  
   305  	// RowCountName is a reserved bind var name for row_count()
   306  	RowCountName = "__vtrcount"
   307  
   308  	// UserDefinedVariableName is what we prepend bind var names for user defined variables
   309  	UserDefinedVariableName = "__vtudv"
   310  )
   311  
   312  func (er *astRewriter) rewriteAliasedExpr(node *AliasedExpr) (*BindVarNeeds, error) {
   313  	inner := newASTRewriter(er.keyspace, er.selectLimit, er.setVarComment, er.sysVars, er.views)
   314  	inner.shouldRewriteDatabaseFunc = er.shouldRewriteDatabaseFunc
   315  	tmp := SafeRewrite(node.Expr, inner.rewriteDown, inner.rewriteUp)
   316  	newExpr, ok := tmp.(Expr)
   317  	if !ok {
   318  		return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to rewrite AST. function expected to return Expr returned a %s", String(tmp))
   319  	}
   320  	node.Expr = newExpr
   321  	return inner.bindVars, nil
   322  }
   323  
   324  func (er *astRewriter) rewriteDown(node SQLNode, _ SQLNode) bool {
   325  	switch node := node.(type) {
   326  	case *Select:
   327  		er.visitSelect(node)
   328  	}
   329  	return true
   330  }
   331  
   332  func (er *astRewriter) rewriteUp(cursor *Cursor) bool {
   333  	// Add SET_VAR comment to this node if it supports it and is needed
   334  	if supportOptimizerHint, supportsOptimizerHint := cursor.Node().(SupportOptimizerHint); supportsOptimizerHint && er.setVarComment != "" {
   335  		newComments, err := supportOptimizerHint.GetParsedComments().AddQueryHint(er.setVarComment)
   336  		if err != nil {
   337  			er.err = err
   338  			return false
   339  		}
   340  		supportOptimizerHint.SetComments(newComments)
   341  	}
   342  
   343  	switch node := cursor.Node().(type) {
   344  	case *Union:
   345  		er.rewriteUnion(node)
   346  	case *FuncExpr:
   347  		er.funcRewrite(cursor, node)
   348  	case *Variable:
   349  		er.rewriteVariable(cursor, node)
   350  	case *Subquery:
   351  		er.unnestSubQueries(cursor, node)
   352  	case *NotExpr:
   353  		er.rewriteNotExpr(cursor, node)
   354  	case *AliasedTableExpr:
   355  		er.rewriteAliasedTable(cursor, node)
   356  	case *ShowBasic:
   357  		er.rewriteShowBasic(node)
   358  	case *ExistsExpr:
   359  		er.existsRewrite(cursor, node)
   360  	}
   361  	return true
   362  }
   363  
   364  func (er *astRewriter) rewriteUnion(node *Union) {
   365  	// set select limit if explicitly not set when sql_select_limit is set on the connection.
   366  	if er.selectLimit > 0 && node.Limit == nil {
   367  		node.Limit = &Limit{Rowcount: NewIntLiteral(strconv.Itoa(er.selectLimit))}
   368  	}
   369  }
   370  
   371  func (er *astRewriter) rewriteAliasedTable(cursor *Cursor, node *AliasedTableExpr) {
   372  	aliasTableName, ok := node.Expr.(TableName)
   373  	if !ok {
   374  		return
   375  	}
   376  
   377  	// Qualifier should not be added to dual table
   378  	tblName := aliasTableName.Name.String()
   379  	if tblName == "dual" {
   380  		return
   381  	}
   382  
   383  	if SystemSchema(er.keyspace) {
   384  		if aliasTableName.Qualifier.IsEmpty() {
   385  			aliasTableName.Qualifier = NewIdentifierCS(er.keyspace)
   386  			node.Expr = aliasTableName
   387  			cursor.Replace(node)
   388  		}
   389  		return
   390  	}
   391  
   392  	// Could we be dealing with a view?
   393  	if er.views == nil {
   394  		return
   395  	}
   396  	view := er.views.FindView(aliasTableName)
   397  	if view == nil {
   398  		return
   399  	}
   400  
   401  	// Aha! It's a view. Let's replace it with a derived table
   402  	node.Expr = &DerivedTable{Select: CloneSelectStatement(view)}
   403  	if node.As.IsEmpty() {
   404  		node.As = NewIdentifierCS(tblName)
   405  	}
   406  }
   407  
   408  func (er *astRewriter) rewriteShowBasic(node *ShowBasic) {
   409  	if node.Command == VariableGlobal || node.Command == VariableSession {
   410  		varsToAdd := sysvars.GetInterestingVariables()
   411  		for _, sysVar := range varsToAdd {
   412  			er.bindVars.AddSysVar(sysVar)
   413  		}
   414  	}
   415  }
   416  
   417  func (er *astRewriter) rewriteNotExpr(cursor *Cursor, node *NotExpr) {
   418  	switch inner := node.Expr.(type) {
   419  	case *ComparisonExpr:
   420  		// not col = 42 => col != 42
   421  		// not col > 42 => col <= 42
   422  		// etc
   423  		canChange, inverse := inverseOp(inner.Operator)
   424  		if canChange {
   425  			inner.Operator = inverse
   426  			cursor.Replace(inner)
   427  		}
   428  	case *NotExpr:
   429  		// not not true => true
   430  		cursor.Replace(inner.Expr)
   431  	case BoolVal:
   432  		// not true => false
   433  		inner = !inner
   434  		cursor.Replace(inner)
   435  	}
   436  }
   437  
   438  func (er *astRewriter) rewriteVariable(cursor *Cursor, node *Variable) {
   439  	// Iff we are in SET, we want to change the scope of variables if a modifier has been set
   440  	// and only on the lhs of the assignment:
   441  	// set session sql_mode = @someElse
   442  	// here we need to change the scope of `sql_mode` and not of `@someElse`
   443  	if v, isSet := cursor.Parent().(*SetExpr); isSet && v.Var == node {
   444  		return
   445  	}
   446  	switch node.Scope {
   447  	case VariableScope:
   448  		er.udvRewrite(cursor, node)
   449  	case GlobalScope, SessionScope, NextTxScope:
   450  		er.sysVarRewrite(cursor, node)
   451  	}
   452  }
   453  
   454  func (er *astRewriter) visitSelect(node *Select) {
   455  	for _, col := range node.SelectExprs {
   456  		if _, hasStar := col.(*StarExpr); hasStar {
   457  			er.hasStarInSelect = true
   458  			continue
   459  		}
   460  
   461  		aliasedExpr, ok := col.(*AliasedExpr)
   462  		if !ok || !aliasedExpr.As.IsEmpty() {
   463  			continue
   464  		}
   465  		buf := NewTrackedBuffer(nil)
   466  		aliasedExpr.Expr.Format(buf)
   467  		// select last_insert_id() -> select :__lastInsertId as `last_insert_id()`
   468  		innerBindVarNeeds, err := er.rewriteAliasedExpr(aliasedExpr)
   469  		if err != nil {
   470  			er.err = err
   471  			return
   472  		}
   473  		if innerBindVarNeeds.HasRewrites() {
   474  			aliasedExpr.As = NewIdentifierCI(buf.String())
   475  		}
   476  		er.bindVars.MergeWith(innerBindVarNeeds)
   477  
   478  	}
   479  	// set select limit if explicitly not set when sql_select_limit is set on the connection.
   480  	if er.selectLimit > 0 && node.Limit == nil {
   481  		node.Limit = &Limit{Rowcount: NewIntLiteral(strconv.Itoa(er.selectLimit))}
   482  	}
   483  }
   484  
   485  func inverseOp(i ComparisonExprOperator) (bool, ComparisonExprOperator) {
   486  	switch i {
   487  	case EqualOp:
   488  		return true, NotEqualOp
   489  	case LessThanOp:
   490  		return true, GreaterEqualOp
   491  	case GreaterThanOp:
   492  		return true, LessEqualOp
   493  	case LessEqualOp:
   494  		return true, GreaterThanOp
   495  	case GreaterEqualOp:
   496  		return true, LessThanOp
   497  	case NotEqualOp:
   498  		return true, EqualOp
   499  	case InOp:
   500  		return true, NotInOp
   501  	case NotInOp:
   502  		return true, InOp
   503  	case LikeOp:
   504  		return true, NotLikeOp
   505  	case NotLikeOp:
   506  		return true, LikeOp
   507  	case RegexpOp:
   508  		return true, NotRegexpOp
   509  	case NotRegexpOp:
   510  		return true, RegexpOp
   511  	}
   512  
   513  	return false, i
   514  }
   515  
   516  func (er *astRewriter) sysVarRewrite(cursor *Cursor, node *Variable) {
   517  	lowered := node.Name.Lowered()
   518  
   519  	var found bool
   520  	if er.sysVars != nil {
   521  		_, found = er.sysVars[lowered]
   522  	}
   523  
   524  	switch lowered {
   525  	case sysvars.Autocommit.Name,
   526  		sysvars.Charset.Name,
   527  		sysvars.ClientFoundRows.Name,
   528  		sysvars.DDLStrategy.Name,
   529  		sysvars.Names.Name,
   530  		sysvars.TransactionMode.Name,
   531  		sysvars.ReadAfterWriteGTID.Name,
   532  		sysvars.ReadAfterWriteTimeOut.Name,
   533  		sysvars.SessionEnableSystemSettings.Name,
   534  		sysvars.SessionTrackGTIDs.Name,
   535  		sysvars.SessionUUID.Name,
   536  		sysvars.SkipQueryPlanCache.Name,
   537  		sysvars.Socket.Name,
   538  		sysvars.SQLSelectLimit.Name,
   539  		sysvars.Version.Name,
   540  		sysvars.VersionComment.Name,
   541  		sysvars.QueryTimeout.Name,
   542  		sysvars.Workload.Name:
   543  		found = true
   544  	}
   545  
   546  	if found {
   547  		cursor.Replace(bindVarExpression("__vt" + lowered))
   548  		er.bindVars.AddSysVar(lowered)
   549  	}
   550  }
   551  
   552  func (er *astRewriter) udvRewrite(cursor *Cursor, node *Variable) {
   553  	udv := strings.ToLower(node.Name.CompliantName())
   554  	cursor.Replace(bindVarExpression(UserDefinedVariableName + udv))
   555  	er.bindVars.AddUserDefVar(udv)
   556  }
   557  
   558  var funcRewrites = map[string]string{
   559  	"last_insert_id": LastInsertIDName,
   560  	"database":       DBVarName,
   561  	"schema":         DBVarName,
   562  	"found_rows":     FoundRowsName,
   563  	"row_count":      RowCountName,
   564  }
   565  
   566  func (er *astRewriter) funcRewrite(cursor *Cursor, node *FuncExpr) {
   567  	bindVar, found := funcRewrites[node.Name.Lowered()]
   568  	if !found || (bindVar == DBVarName && !er.shouldRewriteDatabaseFunc) {
   569  		return
   570  	}
   571  	if len(node.Exprs) > 0 {
   572  		er.err = vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Argument to %s() not supported", node.Name.Lowered())
   573  		return
   574  	}
   575  	cursor.Replace(bindVarExpression(bindVar))
   576  	er.bindVars.AddFuncResult(bindVar)
   577  }
   578  
   579  func (er *astRewriter) unnestSubQueries(cursor *Cursor, subquery *Subquery) {
   580  	if _, isExists := cursor.Parent().(*ExistsExpr); isExists {
   581  		return
   582  	}
   583  	sel, isSimpleSelect := subquery.Select.(*Select)
   584  	if !isSimpleSelect {
   585  		return
   586  	}
   587  
   588  	if len(sel.SelectExprs) != 1 ||
   589  		len(sel.OrderBy) != 0 ||
   590  		len(sel.GroupBy) != 0 ||
   591  		len(sel.From) != 1 ||
   592  		sel.Where != nil ||
   593  		sel.Having != nil ||
   594  		sel.Limit != nil || sel.Lock != NoLock {
   595  		return
   596  	}
   597  
   598  	aliasedTable, ok := sel.From[0].(*AliasedTableExpr)
   599  	if !ok {
   600  		return
   601  	}
   602  	table, ok := aliasedTable.Expr.(TableName)
   603  	if !ok || table.Name.String() != "dual" {
   604  		return
   605  	}
   606  	expr, ok := sel.SelectExprs[0].(*AliasedExpr)
   607  	if !ok {
   608  		return
   609  	}
   610  	_, isColName := expr.Expr.(*ColName)
   611  	if isColName {
   612  		// If we find a single col-name in a `dual` subquery, we can be pretty sure the user is returning a column
   613  		// already projected.
   614  		// `select 1 as x, (select x)`
   615  		// is perfectly valid - any aliased columns to the left are available inside subquery scopes
   616  		return
   617  	}
   618  	er.bindVars.NoteRewrite()
   619  	// we need to make sure that the inner expression also gets rewritten,
   620  	// so we fire off another rewriter traversal here
   621  	rewritten := SafeRewrite(expr.Expr, er.rewriteDown, er.rewriteUp)
   622  
   623  	// Here we need to handle the subquery rewrite in case in occurs in an IN clause
   624  	// For example, SELECT id FROM user WHERE id IN (SELECT 1 FROM DUAL)
   625  	// Here we cannot rewrite the query to SELECT id FROM user WHERE id IN 1, since that is syntactically wrong
   626  	// We must rewrite it to SELECT id FROM user WHERE id IN (1)
   627  	// Find more cases in the test file
   628  	rewrittenExpr, isExpr := rewritten.(Expr)
   629  	_, isColTuple := rewritten.(ColTuple)
   630  	comparisonExpr, isCompExpr := cursor.Parent().(*ComparisonExpr)
   631  	// Check that the parent is a comparison operator with IN or NOT IN operation.
   632  	// Also, if rewritten is already a ColTuple (like a subquery), then we do not need this
   633  	// We also need to check that rewritten is an Expr, if it is then we can rewrite it as a ValTuple
   634  	if isCompExpr && (comparisonExpr.Operator == InOp || comparisonExpr.Operator == NotInOp) && !isColTuple && isExpr {
   635  		cursor.Replace(ValTuple{rewrittenExpr})
   636  		return
   637  	}
   638  
   639  	cursor.Replace(rewritten)
   640  }
   641  
   642  func (er *astRewriter) existsRewrite(cursor *Cursor, node *ExistsExpr) {
   643  	sel, ok := node.Subquery.Select.(*Select)
   644  	if !ok {
   645  		return
   646  	}
   647  
   648  	if sel.Limit == nil {
   649  		sel.Limit = &Limit{}
   650  	}
   651  	sel.Limit.Rowcount = NewIntLiteral("1")
   652  
   653  	if sel.Having != nil {
   654  		// If the query has HAVING, we can't take any shortcuts
   655  		return
   656  	}
   657  
   658  	if len(sel.GroupBy) == 0 && sel.SelectExprs.AllAggregation() {
   659  		// in these situations, we are guaranteed to always get a non-empty result,
   660  		// so we can replace the EXISTS with a literal true
   661  		cursor.Replace(BoolVal(true))
   662  	}
   663  
   664  	// If we are not doing HAVING, we can safely replace all select expressions with a
   665  	// single `1` and remove any grouping
   666  	sel.SelectExprs = SelectExprs{
   667  		&AliasedExpr{Expr: NewIntLiteral("1")},
   668  	}
   669  	sel.GroupBy = nil
   670  }
   671  
   672  func bindVarExpression(name string) Expr {
   673  	return NewArgument(name)
   674  }
   675  
   676  // SystemSchema returns true if the schema passed is system schema
   677  func SystemSchema(schema string) bool {
   678  	return strings.EqualFold(schema, "information_schema") ||
   679  		strings.EqualFold(schema, "performance_schema") ||
   680  		strings.EqualFold(schema, "sys") ||
   681  		strings.EqualFold(schema, "mysql")
   682  }