github.com/whtcorpsinc/MilevaDB-Prod@v0.0.0-20211104133533-f57f4be3b597/causetstore/petri/acyclic/causet/embedded/memex_rewriter.go (about)

     1  // Copyright 2020 WHTCORPS INC, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package embedded
    15  
    16  import (
    17  	"context"
    18  	"strconv"
    19  	"strings"
    20  
    21  	"github.com/whtcorpsinc/BerolinaSQL/allegrosql"
    22  	"github.com/whtcorpsinc/BerolinaSQL/ast"
    23  	"github.com/whtcorpsinc/BerolinaSQL/charset"
    24  	"github.com/whtcorpsinc/BerolinaSQL/opcode"
    25  	"github.com/whtcorpsinc/BerolinaSQL/perceptron"
    26  	"github.com/whtcorpsinc/errors"
    27  	"github.com/whtcorpsinc/milevadb/causet"
    28  	"github.com/whtcorpsinc/milevadb/memex"
    29  	"github.com/whtcorpsinc/milevadb/memex/aggregation"
    30  	"github.com/whtcorpsinc/milevadb/schemareplicant"
    31  	"github.com/whtcorpsinc/milevadb/soliton/chunk"
    32  	"github.com/whtcorpsinc/milevadb/soliton/collate"
    33  	"github.com/whtcorpsinc/milevadb/soliton/hint"
    34  	"github.com/whtcorpsinc/milevadb/soliton/stringutil"
    35  	"github.com/whtcorpsinc/milevadb/stochastikctx"
    36  	"github.com/whtcorpsinc/milevadb/stochastikctx/variable"
    37  	"github.com/whtcorpsinc/milevadb/types"
    38  	driver "github.com/whtcorpsinc/milevadb/types/BerolinaSQL_driver"
    39  )
    40  
    41  // EvalSubqueryFirstRow evaluates incorrelated subqueries once, and get first event.
    42  var EvalSubqueryFirstRow func(ctx context.Context, p PhysicalCauset, is schemareplicant.SchemaReplicant, sctx stochastikctx.Context) (event []types.Causet, err error)
    43  
    44  // evalAstExpr evaluates ast memex directly.
    45  func evalAstExpr(sctx stochastikctx.Context, expr ast.ExprNode) (types.Causet, error) {
    46  	if val, ok := expr.(*driver.ValueExpr); ok {
    47  		return val.Causet, nil
    48  	}
    49  	newExpr, err := rewriteAstExpr(sctx, expr, nil, nil)
    50  	if err != nil {
    51  		return types.Causet{}, err
    52  	}
    53  	return newExpr.Eval(chunk.Row{})
    54  }
    55  
    56  // rewriteAstExpr rewrites ast memex directly.
    57  func rewriteAstExpr(sctx stochastikctx.Context, expr ast.ExprNode, schemaReplicant *memex.Schema, names types.NameSlice) (memex.Expression, error) {
    58  	var is schemareplicant.SchemaReplicant
    59  	if sctx.GetStochastikVars().TxnCtx.SchemaReplicant != nil {
    60  		is = sctx.GetStochastikVars().TxnCtx.SchemaReplicant.(schemareplicant.SchemaReplicant)
    61  	}
    62  	b := NewCausetBuilder(sctx, is, &hint.BlockHintProcessor{})
    63  	fakeCauset := LogicalBlockDual{}.Init(sctx, 0)
    64  	if schemaReplicant != nil {
    65  		fakeCauset.schemaReplicant = schemaReplicant
    66  		fakeCauset.names = names
    67  	}
    68  	newExpr, _, err := b.rewrite(context.TODO(), expr, fakeCauset, nil, true)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  	return newExpr, nil
    73  }
    74  
    75  func (b *CausetBuilder) rewriteInsertOnDuplicateUFIDelate(ctx context.Context, exprNode ast.ExprNode, mockCauset LogicalCauset, insertCauset *Insert) (memex.Expression, error) {
    76  	b.rewriterCounter++
    77  	defer func() { b.rewriterCounter-- }()
    78  
    79  	rewriter := b.getExpressionRewriter(ctx, mockCauset)
    80  	// The rewriter maybe is obtained from "b.rewriterPool", "rewriter.err" is
    81  	// not nil means certain previous procedure has not handled this error.
    82  	// Here we give us one more chance to make a correct behavior by handling
    83  	// this missed error.
    84  	if rewriter.err != nil {
    85  		return nil, rewriter.err
    86  	}
    87  
    88  	rewriter.insertCauset = insertCauset
    89  	rewriter.asScalar = true
    90  
    91  	expr, _, err := b.rewriteExprNode(rewriter, exprNode, true)
    92  	return expr, err
    93  }
    94  
    95  // rewrite function rewrites ast expr to memex.Expression.
    96  // aggMapper maps ast.AggregateFuncExpr to the columns offset in p's output schemaReplicant.
    97  // asScalar means whether this memex must be treated as a scalar memex.
    98  // And this function returns a result memex, a new plan that may have apply or semi-join.
    99  func (b *CausetBuilder) rewrite(ctx context.Context, exprNode ast.ExprNode, p LogicalCauset, aggMapper map[*ast.AggregateFuncExpr]int, asScalar bool) (memex.Expression, LogicalCauset, error) {
   100  	expr, resultCauset, err := b.rewriteWithPreprocess(ctx, exprNode, p, aggMapper, nil, asScalar, nil)
   101  	return expr, resultCauset, err
   102  }
   103  
   104  // rewriteWithPreprocess is for handling the situation that we need to adjust the input ast tree
   105  // before really using its node in `memexRewriter.Leave`. In that case, we first call
   106  // er.preprocess(expr), which returns a new expr. Then we use the new expr in `Leave`.
   107  func (b *CausetBuilder) rewriteWithPreprocess(
   108  	ctx context.Context,
   109  	exprNode ast.ExprNode,
   110  	p LogicalCauset, aggMapper map[*ast.AggregateFuncExpr]int,
   111  	windowMapper map[*ast.WindowFuncExpr]int,
   112  	asScalar bool,
   113  	preprocess func(ast.Node) ast.Node,
   114  ) (memex.Expression, LogicalCauset, error) {
   115  	b.rewriterCounter++
   116  	defer func() { b.rewriterCounter-- }()
   117  
   118  	rewriter := b.getExpressionRewriter(ctx, p)
   119  	// The rewriter maybe is obtained from "b.rewriterPool", "rewriter.err" is
   120  	// not nil means certain previous procedure has not handled this error.
   121  	// Here we give us one more chance to make a correct behavior by handling
   122  	// this missed error.
   123  	if rewriter.err != nil {
   124  		return nil, nil, rewriter.err
   125  	}
   126  
   127  	rewriter.aggrMap = aggMapper
   128  	rewriter.windowMap = windowMapper
   129  	rewriter.asScalar = asScalar
   130  	rewriter.preprocess = preprocess
   131  
   132  	expr, resultCauset, err := b.rewriteExprNode(rewriter, exprNode, asScalar)
   133  	return expr, resultCauset, err
   134  }
   135  
   136  func (b *CausetBuilder) getExpressionRewriter(ctx context.Context, p LogicalCauset) (rewriter *memexRewriter) {
   137  	defer func() {
   138  		if p != nil {
   139  			rewriter.schemaReplicant = p.Schema()
   140  			rewriter.names = p.OutputNames()
   141  		}
   142  	}()
   143  
   144  	if len(b.rewriterPool) < b.rewriterCounter {
   145  		rewriter = &memexRewriter{p: p, b: b, sctx: b.ctx, ctx: ctx}
   146  		b.rewriterPool = append(b.rewriterPool, rewriter)
   147  		return
   148  	}
   149  
   150  	rewriter = b.rewriterPool[b.rewriterCounter-1]
   151  	rewriter.p = p
   152  	rewriter.asScalar = false
   153  	rewriter.aggrMap = nil
   154  	rewriter.preprocess = nil
   155  	rewriter.insertCauset = nil
   156  	rewriter.disableFoldCounter = 0
   157  	rewriter.tryFoldCounter = 0
   158  	rewriter.ctxStack = rewriter.ctxStack[:0]
   159  	rewriter.ctxNameStk = rewriter.ctxNameStk[:0]
   160  	rewriter.ctx = ctx
   161  	return
   162  }
   163  
   164  func (b *CausetBuilder) rewriteExprNode(rewriter *memexRewriter, exprNode ast.ExprNode, asScalar bool) (memex.Expression, LogicalCauset, error) {
   165  	if rewriter.p != nil {
   166  		curDefCausLen := rewriter.p.Schema().Len()
   167  		defer func() {
   168  			names := rewriter.p.OutputNames().Shallow()[:curDefCausLen]
   169  			for i := curDefCausLen; i < rewriter.p.Schema().Len(); i++ {
   170  				names = append(names, types.EmptyName)
   171  			}
   172  			// After rewriting finished, only old columns are visible.
   173  			// e.g. select * from t where t.a in (select t1.a from t1);
   174  			// The output columns before we enter the subquery are the columns from t.
   175  			// But when we leave the subquery `t.a in (select t1.a from t1)`, we got a Apply operator
   176  			// and the output columns become [t.*, t1.*]. But t1.* is used only inside the subquery. If there's another filter
   177  			// which is also a subquery where t1 is involved. The name resolving will fail if we still expose the column from
   178  			// the previous subquery.
   179  			// So here we just reset the names to empty to avoid this situation.
   180  			// TODO: implement ScalarSubQuery and resolve it during optimizing. In building phase, we will not change the plan's structure.
   181  			rewriter.p.SetOutputNames(names)
   182  		}()
   183  	}
   184  	exprNode.Accept(rewriter)
   185  	if rewriter.err != nil {
   186  		return nil, nil, errors.Trace(rewriter.err)
   187  	}
   188  	if !asScalar && len(rewriter.ctxStack) == 0 {
   189  		return nil, rewriter.p, nil
   190  	}
   191  	if len(rewriter.ctxStack) != 1 {
   192  		return nil, nil, errors.Errorf("context len %v is invalid", len(rewriter.ctxStack))
   193  	}
   194  	rewriter.err = memex.CheckArgsNotMultiDeferredCausetRow(rewriter.ctxStack[0])
   195  	if rewriter.err != nil {
   196  		return nil, nil, errors.Trace(rewriter.err)
   197  	}
   198  	return rewriter.ctxStack[0], rewriter.p, nil
   199  }
   200  
   201  type memexRewriter struct {
   202  	ctxStack        []memex.Expression
   203  	ctxNameStk      []*types.FieldName
   204  	p               LogicalCauset
   205  	schemaReplicant *memex.Schema
   206  	names           []*types.FieldName
   207  	err             error
   208  	aggrMap         map[*ast.AggregateFuncExpr]int
   209  	windowMap       map[*ast.WindowFuncExpr]int
   210  	b               *CausetBuilder
   211  	sctx            stochastikctx.Context
   212  	ctx             context.Context
   213  
   214  	// asScalar indicates the return value must be a scalar value.
   215  	// NOTE: This value can be changed during memex rewritten.
   216  	asScalar bool
   217  
   218  	// preprocess is called for every ast.Node in Leave.
   219  	preprocess func(ast.Node) ast.Node
   220  
   221  	// insertCauset is only used to rewrite the memexs inside the assignment
   222  	// of the "INSERT" memex.
   223  	insertCauset *Insert
   224  
   225  	// disableFoldCounter controls fold-disabled scope. If > 0, rewriter will NOT do constant folding.
   226  	// Typically, during visiting AST, while entering the scope(disable), the counter will +1; while
   227  	// leaving the scope(enable again), the counter will -1.
   228  	// NOTE: This value can be changed during memex rewritten.
   229  	disableFoldCounter int
   230  	tryFoldCounter     int
   231  }
   232  
   233  func (er *memexRewriter) ctxStackLen() int {
   234  	return len(er.ctxStack)
   235  }
   236  
   237  func (er *memexRewriter) ctxStackPop(num int) {
   238  	l := er.ctxStackLen()
   239  	er.ctxStack = er.ctxStack[:l-num]
   240  	er.ctxNameStk = er.ctxNameStk[:l-num]
   241  }
   242  
   243  func (er *memexRewriter) ctxStackAppend(col memex.Expression, name *types.FieldName) {
   244  	er.ctxStack = append(er.ctxStack, col)
   245  	er.ctxNameStk = append(er.ctxNameStk, name)
   246  }
   247  
   248  // constructBinaryOpFunction converts binary operator functions
   249  // 1. If op are EQ or NE or NullEQ, constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to (a0 op b0) and (a1 op b1) and (a2 op b2)
   250  // 2. Else constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to
   251  // `IF( a0 NE b0, a0 op b0,
   252  // 		IF ( isNull(a0 NE b0), Null,
   253  // 			IF ( a1 NE b1, a1 op b1,
   254  // 				IF ( isNull(a1 NE b1), Null, a2 op b2))))`
   255  func (er *memexRewriter) constructBinaryOpFunction(l memex.Expression, r memex.Expression, op string) (memex.Expression, error) {
   256  	lLen, rLen := memex.GetRowLen(l), memex.GetRowLen(r)
   257  	if lLen == 1 && rLen == 1 {
   258  		return er.newFunction(op, types.NewFieldType(allegrosql.TypeTiny), l, r)
   259  	} else if rLen != lLen {
   260  		return nil, memex.ErrOperandDeferredCausets.GenWithStackByArgs(lLen)
   261  	}
   262  	switch op {
   263  	case ast.EQ, ast.NE, ast.NullEQ:
   264  		funcs := make([]memex.Expression, lLen)
   265  		for i := 0; i < lLen; i++ {
   266  			var err error
   267  			funcs[i], err = er.constructBinaryOpFunction(memex.GetFuncArg(l, i), memex.GetFuncArg(r, i), op)
   268  			if err != nil {
   269  				return nil, err
   270  			}
   271  		}
   272  		if op == ast.NE {
   273  			return memex.ComposeDNFCondition(er.sctx, funcs...), nil
   274  		}
   275  		return memex.ComposeCNFCondition(er.sctx, funcs...), nil
   276  	default:
   277  		larg0, rarg0 := memex.GetFuncArg(l, 0), memex.GetFuncArg(r, 0)
   278  		var expr1, expr2, expr3, expr4, expr5 memex.Expression
   279  		expr1 = memex.NewFunctionInternal(er.sctx, ast.NE, types.NewFieldType(allegrosql.TypeTiny), larg0, rarg0)
   280  		expr2 = memex.NewFunctionInternal(er.sctx, op, types.NewFieldType(allegrosql.TypeTiny), larg0, rarg0)
   281  		expr3 = memex.NewFunctionInternal(er.sctx, ast.IsNull, types.NewFieldType(allegrosql.TypeTiny), expr1)
   282  		var err error
   283  		l, err = memex.PopRowFirstArg(er.sctx, l)
   284  		if err != nil {
   285  			return nil, err
   286  		}
   287  		r, err = memex.PopRowFirstArg(er.sctx, r)
   288  		if err != nil {
   289  			return nil, err
   290  		}
   291  		expr4, err = er.constructBinaryOpFunction(l, r, op)
   292  		if err != nil {
   293  			return nil, err
   294  		}
   295  		expr5, err = er.newFunction(ast.If, types.NewFieldType(allegrosql.TypeTiny), expr3, memex.NewNull(), expr4)
   296  		if err != nil {
   297  			return nil, err
   298  		}
   299  		return er.newFunction(ast.If, types.NewFieldType(allegrosql.TypeTiny), expr1, expr2, expr5)
   300  	}
   301  }
   302  
   303  func (er *memexRewriter) buildSubquery(ctx context.Context, subq *ast.SubqueryExpr) (LogicalCauset, error) {
   304  	if er.schemaReplicant != nil {
   305  		outerSchema := er.schemaReplicant.Clone()
   306  		er.b.outerSchemas = append(er.b.outerSchemas, outerSchema)
   307  		er.b.outerNames = append(er.b.outerNames, er.names)
   308  		defer func() {
   309  			er.b.outerSchemas = er.b.outerSchemas[0 : len(er.b.outerSchemas)-1]
   310  			er.b.outerNames = er.b.outerNames[0 : len(er.b.outerNames)-1]
   311  		}()
   312  	}
   313  
   314  	np, err := er.b.buildResultSetNode(ctx, subq.Query)
   315  	if err != nil {
   316  		return nil, err
   317  	}
   318  	// Pop the handle map generated by the subquery.
   319  	er.b.handleHelper.popMap()
   320  	return np, nil
   321  }
   322  
   323  // Enter implements Visitor interface.
   324  func (er *memexRewriter) Enter(inNode ast.Node) (ast.Node, bool) {
   325  	switch v := inNode.(type) {
   326  	case *ast.AggregateFuncExpr:
   327  		index, ok := -1, false
   328  		if er.aggrMap != nil {
   329  			index, ok = er.aggrMap[v]
   330  		}
   331  		if !ok {
   332  			er.err = ErrInvalidGroupFuncUse
   333  			return inNode, true
   334  		}
   335  		er.ctxStackAppend(er.schemaReplicant.DeferredCausets[index], er.names[index])
   336  		return inNode, true
   337  	case *ast.DeferredCausetNameExpr:
   338  		if index, ok := er.b.colMapper[v]; ok {
   339  			er.ctxStackAppend(er.schemaReplicant.DeferredCausets[index], er.names[index])
   340  			return inNode, true
   341  		}
   342  	case *ast.CompareSubqueryExpr:
   343  		return er.handleCompareSubquery(er.ctx, v)
   344  	case *ast.ExistsSubqueryExpr:
   345  		return er.handleExistSubquery(er.ctx, v)
   346  	case *ast.PatternInExpr:
   347  		if v.Sel != nil {
   348  			return er.handleInSubquery(er.ctx, v)
   349  		}
   350  		if len(v.List) != 1 {
   351  			break
   352  		}
   353  		// For 10 in ((select * from t)), the BerolinaSQL won't set v.Sel.
   354  		// So we must process this case here.
   355  		x := v.List[0]
   356  		for {
   357  			switch y := x.(type) {
   358  			case *ast.SubqueryExpr:
   359  				v.Sel = y
   360  				return er.handleInSubquery(er.ctx, v)
   361  			case *ast.ParenthesesExpr:
   362  				x = y.Expr
   363  			default:
   364  				return inNode, false
   365  			}
   366  		}
   367  	case *ast.SubqueryExpr:
   368  		return er.handleScalarSubquery(er.ctx, v)
   369  	case *ast.ParenthesesExpr:
   370  	case *ast.ValuesExpr:
   371  		schemaReplicant, names := er.schemaReplicant, er.names
   372  		// NOTE: "er.insertCauset != nil" means that we are rewriting the
   373  		// memexs inside the assignment of "INSERT" memex. we have to
   374  		// use the "blockSchema" of that "insertCauset".
   375  		if er.insertCauset != nil {
   376  			schemaReplicant = er.insertCauset.blockSchema
   377  			names = er.insertCauset.blockDefCausNames
   378  		}
   379  		idx, err := memex.FindFieldName(names, v.DeferredCauset.Name)
   380  		if err != nil {
   381  			er.err = err
   382  			return inNode, false
   383  		}
   384  		if idx < 0 {
   385  			er.err = ErrUnknownDeferredCauset.GenWithStackByArgs(v.DeferredCauset.Name.OrigDefCausName(), "field list")
   386  			return inNode, false
   387  		}
   388  		col := schemaReplicant.DeferredCausets[idx]
   389  		er.ctxStackAppend(memex.NewValuesFunc(er.sctx, col.Index, col.RetType), types.EmptyName)
   390  		return inNode, true
   391  	case *ast.WindowFuncExpr:
   392  		index, ok := -1, false
   393  		if er.windowMap != nil {
   394  			index, ok = er.windowMap[v]
   395  		}
   396  		if !ok {
   397  			er.err = ErrWindowInvalidWindowFuncUse.GenWithStackByArgs(strings.ToLower(v.F))
   398  			return inNode, true
   399  		}
   400  		er.ctxStackAppend(er.schemaReplicant.DeferredCausets[index], er.names[index])
   401  		return inNode, true
   402  	case *ast.FuncCallExpr:
   403  		if _, ok := memex.DisableFoldFunctions[v.FnName.L]; ok {
   404  			er.disableFoldCounter++
   405  		}
   406  		if _, ok := memex.TryFoldFunctions[v.FnName.L]; ok {
   407  			er.tryFoldCounter++
   408  		}
   409  	case *ast.CaseExpr:
   410  		if _, ok := memex.DisableFoldFunctions["case"]; ok {
   411  			er.disableFoldCounter++
   412  		}
   413  		if _, ok := memex.TryFoldFunctions["case"]; ok {
   414  			er.tryFoldCounter++
   415  		}
   416  	case *ast.SetDefCauslationExpr:
   417  		// Do nothing
   418  	default:
   419  		er.asScalar = true
   420  	}
   421  	return inNode, false
   422  }
   423  
   424  func (er *memexRewriter) buildSemiApplyFromEqualSubq(np LogicalCauset, l, r memex.Expression, not bool) {
   425  	var condition memex.Expression
   426  	if rDefCaus, ok := r.(*memex.DeferredCauset); ok && (er.asScalar || not) {
   427  		// If both input columns of `!= all / = any` memex are not null, we can treat the memex
   428  		// as normal column equal condition.
   429  		if lDefCaus, ok := l.(*memex.DeferredCauset); !ok || !allegrosql.HasNotNullFlag(lDefCaus.GetType().Flag) || !allegrosql.HasNotNullFlag(rDefCaus.GetType().Flag) {
   430  			rDefCausCopy := *rDefCaus
   431  			rDefCausCopy.InOperand = true
   432  			r = &rDefCausCopy
   433  		}
   434  	}
   435  	condition, er.err = er.constructBinaryOpFunction(l, r, ast.EQ)
   436  	if er.err != nil {
   437  		return
   438  	}
   439  	er.p, er.err = er.b.buildSemiApply(er.p, np, []memex.Expression{condition}, er.asScalar, not)
   440  }
   441  
   442  func (er *memexRewriter) handleCompareSubquery(ctx context.Context, v *ast.CompareSubqueryExpr) (ast.Node, bool) {
   443  	v.L.Accept(er)
   444  	if er.err != nil {
   445  		return v, true
   446  	}
   447  	lexpr := er.ctxStack[len(er.ctxStack)-1]
   448  	subq, ok := v.R.(*ast.SubqueryExpr)
   449  	if !ok {
   450  		er.err = errors.Errorf("Unknown compare type %T.", v.R)
   451  		return v, true
   452  	}
   453  	np, err := er.buildSubquery(ctx, subq)
   454  	if err != nil {
   455  		er.err = err
   456  		return v, true
   457  	}
   458  	// Only (a,b,c) = any (...) and (a,b,c) != all (...) can use event memex.
   459  	canMultiDefCaus := (!v.All && v.Op == opcode.EQ) || (v.All && v.Op == opcode.NE)
   460  	if !canMultiDefCaus && (memex.GetRowLen(lexpr) != 1 || np.Schema().Len() != 1) {
   461  		er.err = memex.ErrOperandDeferredCausets.GenWithStackByArgs(1)
   462  		return v, true
   463  	}
   464  	lLen := memex.GetRowLen(lexpr)
   465  	if lLen != np.Schema().Len() {
   466  		er.err = memex.ErrOperandDeferredCausets.GenWithStackByArgs(lLen)
   467  		return v, true
   468  	}
   469  	var rexpr memex.Expression
   470  	if np.Schema().Len() == 1 {
   471  		rexpr = np.Schema().DeferredCausets[0]
   472  	} else {
   473  		args := make([]memex.Expression, 0, np.Schema().Len())
   474  		for _, col := range np.Schema().DeferredCausets {
   475  			args = append(args, col)
   476  		}
   477  		rexpr, er.err = er.newFunction(ast.RowFunc, args[0].GetType(), args...)
   478  		if er.err != nil {
   479  			return v, true
   480  		}
   481  	}
   482  	switch v.Op {
   483  	// Only EQ, NE and NullEQ can be composed with and.
   484  	case opcode.EQ, opcode.NE, opcode.NullEQ:
   485  		if v.Op == opcode.EQ {
   486  			if v.All {
   487  				er.handleEQAll(lexpr, rexpr, np)
   488  			} else {
   489  				// `a = any(subq)` will be rewriten as `a in (subq)`.
   490  				er.buildSemiApplyFromEqualSubq(np, lexpr, rexpr, false)
   491  				if er.err != nil {
   492  					return v, true
   493  				}
   494  			}
   495  		} else if v.Op == opcode.NE {
   496  			if v.All {
   497  				// `a != all(subq)` will be rewriten as `a not in (subq)`.
   498  				er.buildSemiApplyFromEqualSubq(np, lexpr, rexpr, true)
   499  				if er.err != nil {
   500  					return v, true
   501  				}
   502  			} else {
   503  				er.handleNEAny(lexpr, rexpr, np)
   504  			}
   505  		} else {
   506  			// TODO: Support this in future.
   507  			er.err = errors.New("We don't support <=> all or <=> any now")
   508  			return v, true
   509  		}
   510  	default:
   511  		// When < all or > any , the agg function should use min.
   512  		useMin := ((v.Op == opcode.LT || v.Op == opcode.LE) && v.All) || ((v.Op == opcode.GT || v.Op == opcode.GE) && !v.All)
   513  		er.handleOtherComparableSubq(lexpr, rexpr, np, useMin, v.Op.String(), v.All)
   514  	}
   515  	if er.asScalar {
   516  		// The parent memex only use the last column in schemaReplicant, which represents whether the condition is matched.
   517  		er.ctxStack[len(er.ctxStack)-1] = er.p.Schema().DeferredCausets[er.p.Schema().Len()-1]
   518  		er.ctxNameStk[len(er.ctxNameStk)-1] = er.p.OutputNames()[er.p.Schema().Len()-1]
   519  	}
   520  	return v, true
   521  }
   522  
   523  // handleOtherComparableSubq handles the queries like < any, < max, etc. For example, if the query is t.id < any (select s.id from s),
   524  // it will be rewrote to t.id < (select max(s.id) from s).
   525  func (er *memexRewriter) handleOtherComparableSubq(lexpr, rexpr memex.Expression, np LogicalCauset, useMin bool, cmpFunc string, all bool) {
   526  	plan4Agg := LogicalAggregation{}.Init(er.sctx, er.b.getSelectOffset())
   527  	if hint := er.b.BlockHints(); hint != nil {
   528  		plan4Agg.aggHints = hint.aggHints
   529  	}
   530  	plan4Agg.SetChildren(np)
   531  
   532  	// Create a "max" or "min" aggregation.
   533  	funcName := ast.AggFuncMax
   534  	if useMin {
   535  		funcName = ast.AggFuncMin
   536  	}
   537  	funcMaxOrMin, err := aggregation.NewAggFuncDesc(er.sctx, funcName, []memex.Expression{rexpr}, false)
   538  	if err != nil {
   539  		er.err = err
   540  		return
   541  	}
   542  
   543  	// Create a column and append it to the schemaReplicant of that aggregation.
   544  	colMaxOrMin := &memex.DeferredCauset{
   545  		UniqueID: er.sctx.GetStochastikVars().AllocCausetDeferredCausetID(),
   546  		RetType:  funcMaxOrMin.RetTp,
   547  	}
   548  	schemaReplicant := memex.NewSchema(colMaxOrMin)
   549  
   550  	plan4Agg.names = append(plan4Agg.names, types.EmptyName)
   551  	plan4Agg.SetSchema(schemaReplicant)
   552  	plan4Agg.AggFuncs = []*aggregation.AggFuncDesc{funcMaxOrMin}
   553  
   554  	cond := memex.NewFunctionInternal(er.sctx, cmpFunc, types.NewFieldType(allegrosql.TypeTiny), lexpr, colMaxOrMin)
   555  	er.buildQuantifierCauset(plan4Agg, cond, lexpr, rexpr, all)
   556  }
   557  
   558  // buildQuantifierCauset adds extra condition for any / all subquery.
   559  func (er *memexRewriter) buildQuantifierCauset(plan4Agg *LogicalAggregation, cond, lexpr, rexpr memex.Expression, all bool) {
   560  	innerIsNull := memex.NewFunctionInternal(er.sctx, ast.IsNull, types.NewFieldType(allegrosql.TypeTiny), rexpr)
   561  	outerIsNull := memex.NewFunctionInternal(er.sctx, ast.IsNull, types.NewFieldType(allegrosql.TypeTiny), lexpr)
   562  
   563  	funcSum, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncSum, []memex.Expression{innerIsNull}, false)
   564  	if err != nil {
   565  		er.err = err
   566  		return
   567  	}
   568  	colSum := &memex.DeferredCauset{
   569  		UniqueID: er.sctx.GetStochastikVars().AllocCausetDeferredCausetID(),
   570  		RetType:  funcSum.RetTp,
   571  	}
   572  	plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcSum)
   573  	plan4Agg.schemaReplicant.Append(colSum)
   574  	innerHasNull := memex.NewFunctionInternal(er.sctx, ast.NE, types.NewFieldType(allegrosql.TypeTiny), colSum, memex.NewZero())
   575  
   576  	// Build `count(1)` aggregation to check if subquery is empty.
   577  	funcCount, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncCount, []memex.Expression{memex.NewOne()}, false)
   578  	if err != nil {
   579  		er.err = err
   580  		return
   581  	}
   582  	colCount := &memex.DeferredCauset{
   583  		UniqueID: er.sctx.GetStochastikVars().AllocCausetDeferredCausetID(),
   584  		RetType:  funcCount.RetTp,
   585  	}
   586  	plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcCount)
   587  	plan4Agg.schemaReplicant.Append(colCount)
   588  
   589  	if all {
   590  		// All of the inner record set should not contain null value. So for t.id < all(select s.id from s), it
   591  		// should be rewrote to t.id < min(s.id) and if(sum(s.id is null) != 0, null, true).
   592  		innerNullChecker := memex.NewFunctionInternal(er.sctx, ast.If, types.NewFieldType(allegrosql.TypeTiny), innerHasNull, memex.NewNull(), memex.NewOne())
   593  		cond = memex.ComposeCNFCondition(er.sctx, cond, innerNullChecker)
   594  		// If the subquery is empty, it should always return true.
   595  		emptyChecker := memex.NewFunctionInternal(er.sctx, ast.EQ, types.NewFieldType(allegrosql.TypeTiny), colCount, memex.NewZero())
   596  		// If outer key is null, and subquery is not empty, it should always return null, even when it is `null = all (1, 2)`.
   597  		outerNullChecker := memex.NewFunctionInternal(er.sctx, ast.If, types.NewFieldType(allegrosql.TypeTiny), outerIsNull, memex.NewNull(), memex.NewZero())
   598  		cond = memex.ComposeDNFCondition(er.sctx, cond, emptyChecker, outerNullChecker)
   599  	} else {
   600  		// For "any" memex, if the subquery has null and the cond returns false, the result should be NULL.
   601  		// Specifically, `t.id < any (select s.id from s)` would be rewrote to `t.id < max(s.id) or if(sum(s.id is null) != 0, null, false)`
   602  		innerNullChecker := memex.NewFunctionInternal(er.sctx, ast.If, types.NewFieldType(allegrosql.TypeTiny), innerHasNull, memex.NewNull(), memex.NewZero())
   603  		cond = memex.ComposeDNFCondition(er.sctx, cond, innerNullChecker)
   604  		// If the subquery is empty, it should always return false.
   605  		emptyChecker := memex.NewFunctionInternal(er.sctx, ast.NE, types.NewFieldType(allegrosql.TypeTiny), colCount, memex.NewZero())
   606  		// If outer key is null, and subquery is not empty, it should return null.
   607  		outerNullChecker := memex.NewFunctionInternal(er.sctx, ast.If, types.NewFieldType(allegrosql.TypeTiny), outerIsNull, memex.NewNull(), memex.NewOne())
   608  		cond = memex.ComposeCNFCondition(er.sctx, cond, emptyChecker, outerNullChecker)
   609  	}
   610  
   611  	// TODO: Add a Projection if any argument of aggregate funcs or group by items are scalar functions.
   612  	// plan4Agg.buildProjectionIfNecessary()
   613  	if !er.asScalar {
   614  		// For Semi LogicalApply without aux column, the result is no matter false or null. So we can add it to join predicate.
   615  		er.p, er.err = er.b.buildSemiApply(er.p, plan4Agg, []memex.Expression{cond}, false, false)
   616  		return
   617  	}
   618  	// If we treat the result as a scalar value, we will add a projection with a extra column to output true, false or null.
   619  	outerSchemaLen := er.p.Schema().Len()
   620  	er.p = er.b.buildApplyWithJoinType(er.p, plan4Agg, InnerJoin)
   621  	joinSchema := er.p.Schema()
   622  	proj := LogicalProjection{
   623  		Exprs: memex.DeferredCauset2Exprs(joinSchema.Clone().DeferredCausets[:outerSchemaLen]),
   624  	}.Init(er.sctx, er.b.getSelectOffset())
   625  	proj.names = make([]*types.FieldName, outerSchemaLen, outerSchemaLen+1)
   626  	copy(proj.names, er.p.OutputNames())
   627  	proj.SetSchema(memex.NewSchema(joinSchema.Clone().DeferredCausets[:outerSchemaLen]...))
   628  	proj.Exprs = append(proj.Exprs, cond)
   629  	proj.schemaReplicant.Append(&memex.DeferredCauset{
   630  		UniqueID: er.sctx.GetStochastikVars().AllocCausetDeferredCausetID(),
   631  		RetType:  cond.GetType(),
   632  	})
   633  	proj.names = append(proj.names, types.EmptyName)
   634  	proj.SetChildren(er.p)
   635  	er.p = proj
   636  }
   637  
   638  // handleNEAny handles the case of != any. For example, if the query is t.id != any (select s.id from s), it will be rewrote to
   639  // t.id != s.id or count(distinct s.id) > 1 or [any checker]. If there are two different values in s.id ,
   640  // there must exist a s.id that doesn't equal to t.id.
   641  func (er *memexRewriter) handleNEAny(lexpr, rexpr memex.Expression, np LogicalCauset) {
   642  	// If there is NULL in s.id column, s.id should be the value that isn't null in condition t.id != s.id.
   643  	// So use function max to filter NULL.
   644  	maxFunc, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncMax, []memex.Expression{rexpr}, false)
   645  	if err != nil {
   646  		er.err = err
   647  		return
   648  	}
   649  	countFunc, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncCount, []memex.Expression{rexpr}, true)
   650  	if err != nil {
   651  		er.err = err
   652  		return
   653  	}
   654  	plan4Agg := LogicalAggregation{
   655  		AggFuncs: []*aggregation.AggFuncDesc{maxFunc, countFunc},
   656  	}.Init(er.sctx, er.b.getSelectOffset())
   657  	if hint := er.b.BlockHints(); hint != nil {
   658  		plan4Agg.aggHints = hint.aggHints
   659  	}
   660  	plan4Agg.SetChildren(np)
   661  	maxResultDefCaus := &memex.DeferredCauset{
   662  		UniqueID: er.sctx.GetStochastikVars().AllocCausetDeferredCausetID(),
   663  		RetType:  maxFunc.RetTp,
   664  	}
   665  	count := &memex.DeferredCauset{
   666  		UniqueID: er.sctx.GetStochastikVars().AllocCausetDeferredCausetID(),
   667  		RetType:  countFunc.RetTp,
   668  	}
   669  	plan4Agg.names = append(plan4Agg.names, types.EmptyName, types.EmptyName)
   670  	plan4Agg.SetSchema(memex.NewSchema(maxResultDefCaus, count))
   671  	gtFunc := memex.NewFunctionInternal(er.sctx, ast.GT, types.NewFieldType(allegrosql.TypeTiny), count, memex.NewOne())
   672  	neCond := memex.NewFunctionInternal(er.sctx, ast.NE, types.NewFieldType(allegrosql.TypeTiny), lexpr, maxResultDefCaus)
   673  	cond := memex.ComposeDNFCondition(er.sctx, gtFunc, neCond)
   674  	er.buildQuantifierCauset(plan4Agg, cond, lexpr, rexpr, false)
   675  }
   676  
   677  // handleEQAll handles the case of = all. For example, if the query is t.id = all (select s.id from s), it will be rewrote to
   678  // t.id = (select s.id from s having count(distinct s.id) <= 1 and [all checker]).
   679  func (er *memexRewriter) handleEQAll(lexpr, rexpr memex.Expression, np LogicalCauset) {
   680  	firstRowFunc, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncFirstRow, []memex.Expression{rexpr}, false)
   681  	if err != nil {
   682  		er.err = err
   683  		return
   684  	}
   685  	countFunc, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncCount, []memex.Expression{rexpr}, true)
   686  	if err != nil {
   687  		er.err = err
   688  		return
   689  	}
   690  	plan4Agg := LogicalAggregation{
   691  		AggFuncs: []*aggregation.AggFuncDesc{firstRowFunc, countFunc},
   692  	}.Init(er.sctx, er.b.getSelectOffset())
   693  	if hint := er.b.BlockHints(); hint != nil {
   694  		plan4Agg.aggHints = hint.aggHints
   695  	}
   696  	plan4Agg.SetChildren(np)
   697  	plan4Agg.names = append(plan4Agg.names, types.EmptyName)
   698  	firstRowResultDefCaus := &memex.DeferredCauset{
   699  		UniqueID: er.sctx.GetStochastikVars().AllocCausetDeferredCausetID(),
   700  		RetType:  firstRowFunc.RetTp,
   701  	}
   702  	plan4Agg.names = append(plan4Agg.names, types.EmptyName)
   703  	count := &memex.DeferredCauset{
   704  		UniqueID: er.sctx.GetStochastikVars().AllocCausetDeferredCausetID(),
   705  		RetType:  countFunc.RetTp,
   706  	}
   707  	plan4Agg.SetSchema(memex.NewSchema(firstRowResultDefCaus, count))
   708  	leFunc := memex.NewFunctionInternal(er.sctx, ast.LE, types.NewFieldType(allegrosql.TypeTiny), count, memex.NewOne())
   709  	eqCond := memex.NewFunctionInternal(er.sctx, ast.EQ, types.NewFieldType(allegrosql.TypeTiny), lexpr, firstRowResultDefCaus)
   710  	cond := memex.ComposeCNFCondition(er.sctx, leFunc, eqCond)
   711  	er.buildQuantifierCauset(plan4Agg, cond, lexpr, rexpr, true)
   712  }
   713  
   714  func (er *memexRewriter) handleExistSubquery(ctx context.Context, v *ast.ExistsSubqueryExpr) (ast.Node, bool) {
   715  	subq, ok := v.Sel.(*ast.SubqueryExpr)
   716  	if !ok {
   717  		er.err = errors.Errorf("Unknown exists type %T.", v.Sel)
   718  		return v, true
   719  	}
   720  	np, err := er.buildSubquery(ctx, subq)
   721  	if err != nil {
   722  		er.err = err
   723  		return v, true
   724  	}
   725  	np = er.popExistsSubCauset(np)
   726  	if len(ExtractCorrelatedDefCauss4LogicalCauset(np)) > 0 {
   727  		er.p, er.err = er.b.buildSemiApply(er.p, np, nil, er.asScalar, v.Not)
   728  		if er.err != nil || !er.asScalar {
   729  			return v, true
   730  		}
   731  		er.ctxStackAppend(er.p.Schema().DeferredCausets[er.p.Schema().Len()-1], er.p.OutputNames()[er.p.Schema().Len()-1])
   732  	} else {
   733  		// We don't want nth_plan hint to affect separately executed subqueries here, so disable nth_plan temporarily.
   734  		NthCausetBackup := er.sctx.GetStochastikVars().StmtCtx.StmtHints.ForceNthCauset
   735  		er.sctx.GetStochastikVars().StmtCtx.StmtHints.ForceNthCauset = -1
   736  		physicalCauset, _, err := DoOptimize(ctx, er.sctx, er.b.optFlag, np)
   737  		er.sctx.GetStochastikVars().StmtCtx.StmtHints.ForceNthCauset = NthCausetBackup
   738  		if err != nil {
   739  			er.err = err
   740  			return v, true
   741  		}
   742  		event, err := EvalSubqueryFirstRow(ctx, physicalCauset, er.b.is, er.b.ctx)
   743  		if err != nil {
   744  			er.err = err
   745  			return v, true
   746  		}
   747  		if (event != nil && !v.Not) || (event == nil && v.Not) {
   748  			er.ctxStackAppend(memex.NewOne(), types.EmptyName)
   749  		} else {
   750  			er.ctxStackAppend(memex.NewZero(), types.EmptyName)
   751  		}
   752  	}
   753  	return v, true
   754  }
   755  
   756  // popExistsSubCauset will remove the useless plan in exist's child.
   757  // See comments inside the method for more details.
   758  func (er *memexRewriter) popExistsSubCauset(p LogicalCauset) LogicalCauset {
   759  out:
   760  	for {
   761  		switch plan := p.(type) {
   762  		// This can be removed when in exists clause,
   763  		// e.g. exists(select count(*) from t order by a) is equal to exists t.
   764  		case *LogicalProjection, *LogicalSort:
   765  			p = p.Children()[0]
   766  		case *LogicalAggregation:
   767  			if len(plan.GroupByItems) == 0 {
   768  				p = LogicalBlockDual{RowCount: 1}.Init(er.sctx, er.b.getSelectOffset())
   769  				break out
   770  			}
   771  			p = p.Children()[0]
   772  		default:
   773  			break out
   774  		}
   775  	}
   776  	return p
   777  }
   778  
   779  func (er *memexRewriter) handleInSubquery(ctx context.Context, v *ast.PatternInExpr) (ast.Node, bool) {
   780  	asScalar := er.asScalar
   781  	er.asScalar = true
   782  	v.Expr.Accept(er)
   783  	if er.err != nil {
   784  		return v, true
   785  	}
   786  	lexpr := er.ctxStack[len(er.ctxStack)-1]
   787  	subq, ok := v.Sel.(*ast.SubqueryExpr)
   788  	if !ok {
   789  		er.err = errors.Errorf("Unknown compare type %T.", v.Sel)
   790  		return v, true
   791  	}
   792  	np, err := er.buildSubquery(ctx, subq)
   793  	if err != nil {
   794  		er.err = err
   795  		return v, true
   796  	}
   797  	lLen := memex.GetRowLen(lexpr)
   798  	if lLen != np.Schema().Len() {
   799  		er.err = memex.ErrOperandDeferredCausets.GenWithStackByArgs(lLen)
   800  		return v, true
   801  	}
   802  	var rexpr memex.Expression
   803  	if np.Schema().Len() == 1 {
   804  		rexpr = np.Schema().DeferredCausets[0]
   805  		rDefCaus := rexpr.(*memex.DeferredCauset)
   806  		// For AntiSemiJoin/LeftOuterSemiJoin/AntiLeftOuterSemiJoin, we cannot treat `in` memex as
   807  		// normal column equal condition, so we specially mark the inner operand here.
   808  		if v.Not || asScalar {
   809  			// If both input columns of `in` memex are not null, we can treat the memex
   810  			// as normal column equal condition instead.
   811  			if !allegrosql.HasNotNullFlag(lexpr.GetType().Flag) || !allegrosql.HasNotNullFlag(rDefCaus.GetType().Flag) {
   812  				rDefCausCopy := *rDefCaus
   813  				rDefCausCopy.InOperand = true
   814  				rexpr = &rDefCausCopy
   815  			}
   816  		}
   817  	} else {
   818  		args := make([]memex.Expression, 0, np.Schema().Len())
   819  		for _, col := range np.Schema().DeferredCausets {
   820  			args = append(args, col)
   821  		}
   822  		rexpr, er.err = er.newFunction(ast.RowFunc, args[0].GetType(), args...)
   823  		if er.err != nil {
   824  			return v, true
   825  		}
   826  	}
   827  	checkCondition, err := er.constructBinaryOpFunction(lexpr, rexpr, ast.EQ)
   828  	if err != nil {
   829  		er.err = err
   830  		return v, true
   831  	}
   832  
   833  	// If the leftKey and the rightKey have different collations, don't convert the sub-query to an inner-join
   834  	// since when converting we will add a distinct-agg upon the right child and this distinct-agg doesn't have the right collation.
   835  	// To keep it simple, we forbid this converting if they have different collations.
   836  	lt, rt := lexpr.GetType(), rexpr.GetType()
   837  	collFlag := collate.CompatibleDefCauslate(lt.DefCauslate, rt.DefCauslate)
   838  
   839  	// If it's not the form of `not in (SUBQUERY)`,
   840  	// and has no correlated column from the current level plan(if the correlated column is from upper level,
   841  	// we can treat it as constant, because the upper LogicalApply cannot be eliminated since current node is a join node),
   842  	// and don't need to append a scalar value, we can rewrite it to inner join.
   843  	if er.sctx.GetStochastikVars().GetAllowInSubqToJoinAnPosetDagg() && !v.Not && !asScalar && len(extractCorDeferredCausetsBySchema4LogicalCauset(np, er.p.Schema())) == 0 && collFlag {
   844  		// We need to try to eliminate the agg and the projection produced by this operation.
   845  		er.b.optFlag |= flagEliminateAgg
   846  		er.b.optFlag |= flagEliminateProjection
   847  		er.b.optFlag |= flagJoinReOrder
   848  		// Build distinct for the inner query.
   849  		agg, err := er.b.buildDistinct(np, np.Schema().Len())
   850  		if err != nil {
   851  			er.err = err
   852  			return v, true
   853  		}
   854  		// Build inner join above the aggregation.
   855  		join := LogicalJoin{JoinType: InnerJoin}.Init(er.sctx, er.b.getSelectOffset())
   856  		join.SetChildren(er.p, agg)
   857  		join.SetSchema(memex.MergeSchema(er.p.Schema(), agg.schemaReplicant))
   858  		join.names = make([]*types.FieldName, er.p.Schema().Len()+agg.Schema().Len())
   859  		copy(join.names, er.p.OutputNames())
   860  		copy(join.names[er.p.Schema().Len():], agg.OutputNames())
   861  		join.AttachOnConds(memex.SplitCNFItems(checkCondition))
   862  		// Set join hint for this join.
   863  		if er.b.BlockHints() != nil {
   864  			join.setPreferredJoinType(er.b.BlockHints())
   865  		}
   866  		er.p = join
   867  	} else {
   868  		er.p, er.err = er.b.buildSemiApply(er.p, np, memex.SplitCNFItems(checkCondition), asScalar, v.Not)
   869  		if er.err != nil {
   870  			return v, true
   871  		}
   872  	}
   873  
   874  	er.ctxStackPop(1)
   875  	if asScalar {
   876  		col := er.p.Schema().DeferredCausets[er.p.Schema().Len()-1]
   877  		er.ctxStackAppend(col, er.p.OutputNames()[er.p.Schema().Len()-1])
   878  	}
   879  	return v, true
   880  }
   881  
   882  func (er *memexRewriter) handleScalarSubquery(ctx context.Context, v *ast.SubqueryExpr) (ast.Node, bool) {
   883  	np, err := er.buildSubquery(ctx, v)
   884  	if err != nil {
   885  		er.err = err
   886  		return v, true
   887  	}
   888  	np = er.b.buildMaxOneRow(np)
   889  	if len(ExtractCorrelatedDefCauss4LogicalCauset(np)) > 0 {
   890  		er.p = er.b.buildApplyWithJoinType(er.p, np, LeftOuterJoin)
   891  		if np.Schema().Len() > 1 {
   892  			newDefCauss := make([]memex.Expression, 0, np.Schema().Len())
   893  			for _, col := range np.Schema().DeferredCausets {
   894  				newDefCauss = append(newDefCauss, col)
   895  			}
   896  			expr, err1 := er.newFunction(ast.RowFunc, newDefCauss[0].GetType(), newDefCauss...)
   897  			if err1 != nil {
   898  				er.err = err1
   899  				return v, true
   900  			}
   901  			er.ctxStackAppend(expr, types.EmptyName)
   902  		} else {
   903  			er.ctxStackAppend(er.p.Schema().DeferredCausets[er.p.Schema().Len()-1], er.p.OutputNames()[er.p.Schema().Len()-1])
   904  		}
   905  		return v, true
   906  	}
   907  	// We don't want nth_plan hint to affect separately executed subqueries here, so disable nth_plan temporarily.
   908  	NthCausetBackup := er.sctx.GetStochastikVars().StmtCtx.StmtHints.ForceNthCauset
   909  	er.sctx.GetStochastikVars().StmtCtx.StmtHints.ForceNthCauset = -1
   910  	physicalCauset, _, err := DoOptimize(ctx, er.sctx, er.b.optFlag, np)
   911  	er.sctx.GetStochastikVars().StmtCtx.StmtHints.ForceNthCauset = NthCausetBackup
   912  	if err != nil {
   913  		er.err = err
   914  		return v, true
   915  	}
   916  	event, err := EvalSubqueryFirstRow(ctx, physicalCauset, er.b.is, er.b.ctx)
   917  	if err != nil {
   918  		er.err = err
   919  		return v, true
   920  	}
   921  	if np.Schema().Len() > 1 {
   922  		newDefCauss := make([]memex.Expression, 0, np.Schema().Len())
   923  		for i, data := range event {
   924  			newDefCauss = append(newDefCauss, &memex.Constant{
   925  				Value:   data,
   926  				RetType: np.Schema().DeferredCausets[i].GetType()})
   927  		}
   928  		expr, err1 := er.newFunction(ast.RowFunc, newDefCauss[0].GetType(), newDefCauss...)
   929  		if err1 != nil {
   930  			er.err = err1
   931  			return v, true
   932  		}
   933  		er.ctxStackAppend(expr, types.EmptyName)
   934  	} else {
   935  		er.ctxStackAppend(&memex.Constant{
   936  			Value:   event[0],
   937  			RetType: np.Schema().DeferredCausets[0].GetType(),
   938  		}, types.EmptyName)
   939  	}
   940  	return v, true
   941  }
   942  
   943  // Leave implements Visitor interface.
   944  func (er *memexRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok bool) {
   945  	if er.err != nil {
   946  		return retNode, false
   947  	}
   948  	var inNode = originInNode
   949  	if er.preprocess != nil {
   950  		inNode = er.preprocess(inNode)
   951  	}
   952  	switch v := inNode.(type) {
   953  	case *ast.AggregateFuncExpr, *ast.DeferredCausetNameExpr, *ast.ParenthesesExpr, *ast.WhenClause,
   954  		*ast.SubqueryExpr, *ast.ExistsSubqueryExpr, *ast.CompareSubqueryExpr, *ast.ValuesExpr, *ast.WindowFuncExpr, *ast.BlockNameExpr:
   955  	case *driver.ValueExpr:
   956  		v.Causet.SetValue(v.Causet.GetValue(), &v.Type)
   957  		value := &memex.Constant{Value: v.Causet, RetType: &v.Type}
   958  		er.ctxStackAppend(value, types.EmptyName)
   959  	case *driver.ParamMarkerExpr:
   960  		var value memex.Expression
   961  		value, er.err = memex.ParamMarkerExpression(er.sctx, v)
   962  		if er.err != nil {
   963  			return retNode, false
   964  		}
   965  		er.ctxStackAppend(value, types.EmptyName)
   966  	case *ast.VariableExpr:
   967  		er.rewriteVariable(v)
   968  	case *ast.FuncCallExpr:
   969  		if _, ok := memex.TryFoldFunctions[v.FnName.L]; ok {
   970  			er.tryFoldCounter--
   971  		}
   972  		er.funcCallToExpression(v)
   973  		if _, ok := memex.DisableFoldFunctions[v.FnName.L]; ok {
   974  			er.disableFoldCounter--
   975  		}
   976  	case *ast.BlockName:
   977  		er.toBlock(v)
   978  	case *ast.DeferredCausetName:
   979  		er.toDeferredCauset(v)
   980  	case *ast.UnaryOperationExpr:
   981  		er.unaryOpToExpression(v)
   982  	case *ast.BinaryOperationExpr:
   983  		er.binaryOpToExpression(v)
   984  	case *ast.BetweenExpr:
   985  		er.betweenToExpression(v)
   986  	case *ast.CaseExpr:
   987  		if _, ok := memex.TryFoldFunctions["case"]; ok {
   988  			er.tryFoldCounter--
   989  		}
   990  		er.caseToExpression(v)
   991  		if _, ok := memex.DisableFoldFunctions["case"]; ok {
   992  			er.disableFoldCounter--
   993  		}
   994  	case *ast.FuncCastExpr:
   995  		arg := er.ctxStack[len(er.ctxStack)-1]
   996  		er.err = memex.CheckArgsNotMultiDeferredCausetRow(arg)
   997  		if er.err != nil {
   998  			return retNode, false
   999  		}
  1000  
  1001  		// check the decimal precision of "CAST(AS TIME)".
  1002  		er.err = er.checkTimePrecision(v.Tp)
  1003  		if er.err != nil {
  1004  			return retNode, false
  1005  		}
  1006  
  1007  		er.ctxStack[len(er.ctxStack)-1] = memex.BuildCastFunction(er.sctx, arg, v.Tp)
  1008  		er.ctxNameStk[len(er.ctxNameStk)-1] = types.EmptyName
  1009  	case *ast.PatternLikeExpr:
  1010  		er.patternLikeToExpression(v)
  1011  	case *ast.PatternRegexpExpr:
  1012  		er.regexpToScalarFunc(v)
  1013  	case *ast.RowExpr:
  1014  		er.rowToScalarFunc(v)
  1015  	case *ast.PatternInExpr:
  1016  		if v.Sel == nil {
  1017  			er.inToExpression(len(v.List), v.Not, &v.Type)
  1018  		}
  1019  	case *ast.PositionExpr:
  1020  		er.positionToScalarFunc(v)
  1021  	case *ast.IsNullExpr:
  1022  		er.isNullToExpression(v)
  1023  	case *ast.IsTruthExpr:
  1024  		er.isTrueToScalarFunc(v)
  1025  	case *ast.DefaultExpr:
  1026  		er.evalDefaultExpr(v)
  1027  	// TODO: Perhaps we don't need to transcode these back to generic integers/strings
  1028  	case *ast.TrimDirectionExpr:
  1029  		er.ctxStackAppend(&memex.Constant{
  1030  			Value:   types.NewIntCauset(int64(v.Direction)),
  1031  			RetType: types.NewFieldType(allegrosql.TypeTiny),
  1032  		}, types.EmptyName)
  1033  	case *ast.TimeUnitExpr:
  1034  		er.ctxStackAppend(&memex.Constant{
  1035  			Value:   types.NewStringCauset(v.Unit.String()),
  1036  			RetType: types.NewFieldType(allegrosql.TypeVarchar),
  1037  		}, types.EmptyName)
  1038  	case *ast.GetFormatSelectorExpr:
  1039  		er.ctxStackAppend(&memex.Constant{
  1040  			Value:   types.NewStringCauset(v.Selector.String()),
  1041  			RetType: types.NewFieldType(allegrosql.TypeVarchar),
  1042  		}, types.EmptyName)
  1043  	case *ast.SetDefCauslationExpr:
  1044  		arg := er.ctxStack[len(er.ctxStack)-1]
  1045  		if collate.NewDefCauslationEnabled() {
  1046  			var collInfo *charset.DefCauslation
  1047  			// TODO(bb7133): use charset.ValidCharsetAndDefCauslation when its bug is fixed.
  1048  			if collInfo, er.err = collate.GetDefCauslationByName(v.DefCauslate); er.err != nil {
  1049  				break
  1050  			}
  1051  			chs := arg.GetType().Charset
  1052  			if chs != "" && collInfo.CharsetName != chs {
  1053  				er.err = charset.ErrDefCauslationCharsetMismatch.GenWithStackByArgs(collInfo.Name, chs)
  1054  				break
  1055  			}
  1056  		}
  1057  		// SetDefCauslationExpr sets the collation explicitly, even when the evaluation type of the memex is non-string.
  1058  		if _, ok := arg.(*memex.DeferredCauset); ok {
  1059  			// Wrap a cast here to avoid changing the original FieldType of the column memex.
  1060  			exprType := arg.GetType().Clone()
  1061  			exprType.DefCauslate = v.DefCauslate
  1062  			casted := memex.BuildCastFunction(er.sctx, arg, exprType)
  1063  			er.ctxStackPop(1)
  1064  			er.ctxStackAppend(casted, types.EmptyName)
  1065  		} else {
  1066  			// For constant and scalar function, we can set its collate directly.
  1067  			arg.GetType().DefCauslate = v.DefCauslate
  1068  		}
  1069  		er.ctxStack[len(er.ctxStack)-1].SetCoercibility(memex.CoercibilityExplicit)
  1070  		er.ctxStack[len(er.ctxStack)-1].SetCharsetAndDefCauslation(arg.GetType().Charset, arg.GetType().DefCauslate)
  1071  	default:
  1072  		er.err = errors.Errorf("UnknownType: %T", v)
  1073  		return retNode, false
  1074  	}
  1075  
  1076  	if er.err != nil {
  1077  		return retNode, false
  1078  	}
  1079  	return originInNode, true
  1080  }
  1081  
  1082  // newFunction chooses which memex.NewFunctionImpl() will be used.
  1083  func (er *memexRewriter) newFunction(funcName string, retType *types.FieldType, args ...memex.Expression) (memex.Expression, error) {
  1084  	if er.disableFoldCounter > 0 {
  1085  		return memex.NewFunctionBase(er.sctx, funcName, retType, args...)
  1086  	}
  1087  	if er.tryFoldCounter > 0 {
  1088  		return memex.NewFunctionTryFold(er.sctx, funcName, retType, args...)
  1089  	}
  1090  	return memex.NewFunction(er.sctx, funcName, retType, args...)
  1091  }
  1092  
  1093  func (er *memexRewriter) checkTimePrecision(ft *types.FieldType) error {
  1094  	if ft.EvalType() == types.ETDuration && ft.Decimal > int(types.MaxFsp) {
  1095  		return errTooBigPrecision.GenWithStackByArgs(ft.Decimal, "CAST", types.MaxFsp)
  1096  	}
  1097  	return nil
  1098  }
  1099  
  1100  func (er *memexRewriter) useCache() bool {
  1101  	return er.sctx.GetStochastikVars().StmtCtx.UseCache
  1102  }
  1103  
  1104  func (er *memexRewriter) rewriteVariable(v *ast.VariableExpr) {
  1105  	stkLen := len(er.ctxStack)
  1106  	name := strings.ToLower(v.Name)
  1107  	stochastikVars := er.b.ctx.GetStochastikVars()
  1108  	if !v.IsSystem {
  1109  		if v.Value != nil {
  1110  			er.ctxStack[stkLen-1], er.err = er.newFunction(ast.SetVar,
  1111  				er.ctxStack[stkLen-1].GetType(),
  1112  				memex.CausetToConstant(types.NewCauset(name), allegrosql.TypeString),
  1113  				er.ctxStack[stkLen-1])
  1114  			er.ctxNameStk[stkLen-1] = types.EmptyName
  1115  			return
  1116  		}
  1117  		f, err := er.newFunction(ast.GetVar,
  1118  			// TODO: Here is wrong, the stochastikVars should causetstore a name -> Causet map. Will fix it later.
  1119  			types.NewFieldType(allegrosql.TypeString),
  1120  			memex.CausetToConstant(types.NewStringCauset(name), allegrosql.TypeString))
  1121  		if err != nil {
  1122  			er.err = err
  1123  			return
  1124  		}
  1125  		f.SetCoercibility(memex.CoercibilityImplicit)
  1126  		er.ctxStackAppend(f, types.EmptyName)
  1127  		return
  1128  	}
  1129  	var val string
  1130  	var err error
  1131  	if v.ExplicitScope {
  1132  		err = variable.ValidateGetSystemVar(name, v.IsGlobal)
  1133  		if err != nil {
  1134  			er.err = err
  1135  			return
  1136  		}
  1137  	}
  1138  	sysVar := variable.SysVars[name]
  1139  	if sysVar == nil {
  1140  		er.err = variable.ErrUnknownSystemVar.GenWithStackByArgs(name)
  1141  		return
  1142  	}
  1143  	// Variable is @@gobal.variable_name or variable is only global scope variable.
  1144  	if v.IsGlobal || sysVar.Scope == variable.ScopeGlobal {
  1145  		val, err = variable.GetGlobalSystemVar(stochastikVars, name)
  1146  	} else {
  1147  		val, err = variable.GetStochastikSystemVar(stochastikVars, name)
  1148  	}
  1149  	if err != nil {
  1150  		er.err = err
  1151  		return
  1152  	}
  1153  	e := memex.CausetToConstant(types.NewStringCauset(val), allegrosql.TypeVarString)
  1154  	e.GetType().Charset, _ = er.sctx.GetStochastikVars().GetSystemVar(variable.CharacterSetConnection)
  1155  	e.GetType().DefCauslate, _ = er.sctx.GetStochastikVars().GetSystemVar(variable.DefCauslationConnection)
  1156  	er.ctxStackAppend(e, types.EmptyName)
  1157  }
  1158  
  1159  func (er *memexRewriter) unaryOpToExpression(v *ast.UnaryOperationExpr) {
  1160  	stkLen := len(er.ctxStack)
  1161  	var op string
  1162  	switch v.Op {
  1163  	case opcode.Plus:
  1164  		// memex (+ a) is equal to a
  1165  		return
  1166  	case opcode.Minus:
  1167  		op = ast.UnaryMinus
  1168  	case opcode.BitNeg:
  1169  		op = ast.BitNeg
  1170  	case opcode.Not:
  1171  		op = ast.UnaryNot
  1172  	default:
  1173  		er.err = errors.Errorf("Unknown Unary Op %T", v.Op)
  1174  		return
  1175  	}
  1176  	if memex.GetRowLen(er.ctxStack[stkLen-1]) != 1 {
  1177  		er.err = memex.ErrOperandDeferredCausets.GenWithStackByArgs(1)
  1178  		return
  1179  	}
  1180  	er.ctxStack[stkLen-1], er.err = er.newFunction(op, &v.Type, er.ctxStack[stkLen-1])
  1181  	er.ctxNameStk[stkLen-1] = types.EmptyName
  1182  }
  1183  
  1184  func (er *memexRewriter) binaryOpToExpression(v *ast.BinaryOperationExpr) {
  1185  	stkLen := len(er.ctxStack)
  1186  	var function memex.Expression
  1187  	switch v.Op {
  1188  	case opcode.EQ, opcode.NE, opcode.NullEQ, opcode.GT, opcode.GE, opcode.LT, opcode.LE:
  1189  		function, er.err = er.constructBinaryOpFunction(er.ctxStack[stkLen-2], er.ctxStack[stkLen-1],
  1190  			v.Op.String())
  1191  	default:
  1192  		lLen := memex.GetRowLen(er.ctxStack[stkLen-2])
  1193  		rLen := memex.GetRowLen(er.ctxStack[stkLen-1])
  1194  		if lLen != 1 || rLen != 1 {
  1195  			er.err = memex.ErrOperandDeferredCausets.GenWithStackByArgs(1)
  1196  			return
  1197  		}
  1198  		function, er.err = er.newFunction(v.Op.String(), types.NewFieldType(allegrosql.TypeUnspecified), er.ctxStack[stkLen-2:]...)
  1199  	}
  1200  	if er.err != nil {
  1201  		return
  1202  	}
  1203  	er.ctxStackPop(2)
  1204  	er.ctxStackAppend(function, types.EmptyName)
  1205  }
  1206  
  1207  func (er *memexRewriter) notToExpression(hasNot bool, op string, tp *types.FieldType,
  1208  	args ...memex.Expression) memex.Expression {
  1209  	opFunc, err := er.newFunction(op, tp, args...)
  1210  	if err != nil {
  1211  		er.err = err
  1212  		return nil
  1213  	}
  1214  	if !hasNot {
  1215  		return opFunc
  1216  	}
  1217  
  1218  	opFunc, err = er.newFunction(ast.UnaryNot, tp, opFunc)
  1219  	if err != nil {
  1220  		er.err = err
  1221  		return nil
  1222  	}
  1223  	return opFunc
  1224  }
  1225  
  1226  func (er *memexRewriter) isNullToExpression(v *ast.IsNullExpr) {
  1227  	stkLen := len(er.ctxStack)
  1228  	if memex.GetRowLen(er.ctxStack[stkLen-1]) != 1 {
  1229  		er.err = memex.ErrOperandDeferredCausets.GenWithStackByArgs(1)
  1230  		return
  1231  	}
  1232  	function := er.notToExpression(v.Not, ast.IsNull, &v.Type, er.ctxStack[stkLen-1])
  1233  	er.ctxStackPop(1)
  1234  	er.ctxStackAppend(function, types.EmptyName)
  1235  }
  1236  
  1237  func (er *memexRewriter) positionToScalarFunc(v *ast.PositionExpr) {
  1238  	pos := v.N
  1239  	str := strconv.Itoa(pos)
  1240  	if v.P != nil {
  1241  		stkLen := len(er.ctxStack)
  1242  		val := er.ctxStack[stkLen-1]
  1243  		intNum, isNull, err := memex.GetIntFromConstant(er.sctx, val)
  1244  		str = "?"
  1245  		if err == nil {
  1246  			if isNull {
  1247  				return
  1248  			}
  1249  			pos = intNum
  1250  			er.ctxStackPop(1)
  1251  		}
  1252  		er.err = err
  1253  	}
  1254  	if er.err == nil && pos > 0 && pos <= er.schemaReplicant.Len() {
  1255  		er.ctxStackAppend(er.schemaReplicant.DeferredCausets[pos-1], er.names[pos-1])
  1256  	} else {
  1257  		er.err = ErrUnknownDeferredCauset.GenWithStackByArgs(str, clauseMsg[er.b.curClause])
  1258  	}
  1259  }
  1260  
  1261  func (er *memexRewriter) isTrueToScalarFunc(v *ast.IsTruthExpr) {
  1262  	stkLen := len(er.ctxStack)
  1263  	op := ast.IsTruthWithoutNull
  1264  	if v.True == 0 {
  1265  		op = ast.IsFalsity
  1266  	}
  1267  	if memex.GetRowLen(er.ctxStack[stkLen-1]) != 1 {
  1268  		er.err = memex.ErrOperandDeferredCausets.GenWithStackByArgs(1)
  1269  		return
  1270  	}
  1271  	function := er.notToExpression(v.Not, op, &v.Type, er.ctxStack[stkLen-1])
  1272  	er.ctxStackPop(1)
  1273  	er.ctxStackAppend(function, types.EmptyName)
  1274  }
  1275  
  1276  // inToExpression converts in memex to a scalar function. The argument lLen means the length of in list.
  1277  // The argument not means if the memex is not in. The tp stands for the memex type, which is always bool.
  1278  // a in (b, c, d) will be rewritten as `(a = b) or (a = c) or (a = d)`.
  1279  func (er *memexRewriter) inToExpression(lLen int, not bool, tp *types.FieldType) {
  1280  	stkLen := len(er.ctxStack)
  1281  	l := memex.GetRowLen(er.ctxStack[stkLen-lLen-1])
  1282  	for i := 0; i < lLen; i++ {
  1283  		if l != memex.GetRowLen(er.ctxStack[stkLen-lLen+i]) {
  1284  			er.err = memex.ErrOperandDeferredCausets.GenWithStackByArgs(l)
  1285  			return
  1286  		}
  1287  	}
  1288  	args := er.ctxStack[stkLen-lLen-1:]
  1289  	leftFt := args[0].GetType()
  1290  	leftEt, leftIsNull := leftFt.EvalType(), leftFt.Tp == allegrosql.TypeNull
  1291  	if leftIsNull {
  1292  		er.ctxStackPop(lLen + 1)
  1293  		er.ctxStackAppend(memex.NewNull(), types.EmptyName)
  1294  		return
  1295  	}
  1296  	if leftEt == types.ETInt {
  1297  		for i := 1; i < len(args); i++ {
  1298  			if c, ok := args[i].(*memex.Constant); ok {
  1299  				var isExceptional bool
  1300  				args[i], isExceptional = memex.RefineComparedConstant(er.sctx, *leftFt, c, opcode.EQ)
  1301  				if isExceptional {
  1302  					args[i] = c
  1303  				}
  1304  			}
  1305  		}
  1306  	}
  1307  	allSameType := true
  1308  	for _, arg := range args[1:] {
  1309  		if arg.GetType().Tp != allegrosql.TypeNull && memex.GetAccurateCmpType(args[0], arg) != leftEt {
  1310  			allSameType = false
  1311  			break
  1312  		}
  1313  	}
  1314  	var function memex.Expression
  1315  	if allSameType && l == 1 && lLen > 1 {
  1316  		function = er.notToExpression(not, ast.In, tp, er.ctxStack[stkLen-lLen-1:]...)
  1317  	} else {
  1318  		eqFunctions := make([]memex.Expression, 0, lLen)
  1319  		for i := stkLen - lLen; i < stkLen; i++ {
  1320  			expr, err := er.constructBinaryOpFunction(args[0], er.ctxStack[i], ast.EQ)
  1321  			if err != nil {
  1322  				er.err = err
  1323  				return
  1324  			}
  1325  			eqFunctions = append(eqFunctions, expr)
  1326  		}
  1327  		function = memex.ComposeDNFCondition(er.sctx, eqFunctions...)
  1328  		if not {
  1329  			var err error
  1330  			function, err = er.newFunction(ast.UnaryNot, tp, function)
  1331  			if err != nil {
  1332  				er.err = err
  1333  				return
  1334  			}
  1335  		}
  1336  	}
  1337  	er.ctxStackPop(lLen + 1)
  1338  	er.ctxStackAppend(function, types.EmptyName)
  1339  }
  1340  
  1341  func (er *memexRewriter) caseToExpression(v *ast.CaseExpr) {
  1342  	stkLen := len(er.ctxStack)
  1343  	argsLen := 2 * len(v.WhenClauses)
  1344  	if v.ElseClause != nil {
  1345  		argsLen++
  1346  	}
  1347  	er.err = memex.CheckArgsNotMultiDeferredCausetRow(er.ctxStack[stkLen-argsLen:]...)
  1348  	if er.err != nil {
  1349  		return
  1350  	}
  1351  
  1352  	// value                          -> ctxStack[stkLen-argsLen-1]
  1353  	// when clause(condition, result) -> ctxStack[stkLen-argsLen:stkLen-1];
  1354  	// else clause                    -> ctxStack[stkLen-1]
  1355  	var args []memex.Expression
  1356  	if v.Value != nil {
  1357  		// args:  eq scalar func(args: value, condition1), result1,
  1358  		//        eq scalar func(args: value, condition2), result2,
  1359  		//        ...
  1360  		//        else clause
  1361  		value := er.ctxStack[stkLen-argsLen-1]
  1362  		args = make([]memex.Expression, 0, argsLen)
  1363  		for i := stkLen - argsLen; i < stkLen-1; i += 2 {
  1364  			arg, err := er.newFunction(ast.EQ, types.NewFieldType(allegrosql.TypeTiny), value, er.ctxStack[i])
  1365  			if err != nil {
  1366  				er.err = err
  1367  				return
  1368  			}
  1369  			args = append(args, arg)
  1370  			args = append(args, er.ctxStack[i+1])
  1371  		}
  1372  		if v.ElseClause != nil {
  1373  			args = append(args, er.ctxStack[stkLen-1])
  1374  		}
  1375  		argsLen++ // for trimming the value element later
  1376  	} else {
  1377  		// args:  condition1, result1,
  1378  		//        condition2, result2,
  1379  		//        ...
  1380  		//        else clause
  1381  		args = er.ctxStack[stkLen-argsLen:]
  1382  	}
  1383  	function, err := er.newFunction(ast.Case, &v.Type, args...)
  1384  	if err != nil {
  1385  		er.err = err
  1386  		return
  1387  	}
  1388  	er.ctxStackPop(argsLen)
  1389  	er.ctxStackAppend(function, types.EmptyName)
  1390  }
  1391  
  1392  func (er *memexRewriter) patternLikeToExpression(v *ast.PatternLikeExpr) {
  1393  	l := len(er.ctxStack)
  1394  	er.err = memex.CheckArgsNotMultiDeferredCausetRow(er.ctxStack[l-2:]...)
  1395  	if er.err != nil {
  1396  		return
  1397  	}
  1398  
  1399  	char, col := er.sctx.GetStochastikVars().GetCharsetInfo()
  1400  	var function memex.Expression
  1401  	fieldType := &types.FieldType{}
  1402  	isPatternExactMatch := false
  1403  	// Treat predicate 'like' the same way as predicate '=' when it is an exact match.
  1404  	if patExpression, ok := er.ctxStack[l-1].(*memex.Constant); ok {
  1405  		patString, isNull, err := patExpression.EvalString(nil, chunk.Row{})
  1406  		if err != nil {
  1407  			er.err = err
  1408  			return
  1409  		}
  1410  		if !isNull {
  1411  			patValue, patTypes := stringutil.CompilePattern(patString, v.Escape)
  1412  			if stringutil.IsExactMatch(patTypes) && er.ctxStack[l-2].GetType().EvalType() == types.ETString {
  1413  				op := ast.EQ
  1414  				if v.Not {
  1415  					op = ast.NE
  1416  				}
  1417  				types.DefaultTypeForValue(string(patValue), fieldType, char, col)
  1418  				function, er.err = er.constructBinaryOpFunction(er.ctxStack[l-2],
  1419  					&memex.Constant{Value: types.NewStringCauset(string(patValue)), RetType: fieldType},
  1420  					op)
  1421  				isPatternExactMatch = true
  1422  			}
  1423  		}
  1424  	}
  1425  	if !isPatternExactMatch {
  1426  		types.DefaultTypeForValue(int(v.Escape), fieldType, char, col)
  1427  		function = er.notToExpression(v.Not, ast.Like, &v.Type,
  1428  			er.ctxStack[l-2], er.ctxStack[l-1], &memex.Constant{Value: types.NewIntCauset(int64(v.Escape)), RetType: fieldType})
  1429  	}
  1430  
  1431  	er.ctxStackPop(2)
  1432  	er.ctxStackAppend(function, types.EmptyName)
  1433  }
  1434  
  1435  func (er *memexRewriter) regexpToScalarFunc(v *ast.PatternRegexpExpr) {
  1436  	l := len(er.ctxStack)
  1437  	er.err = memex.CheckArgsNotMultiDeferredCausetRow(er.ctxStack[l-2:]...)
  1438  	if er.err != nil {
  1439  		return
  1440  	}
  1441  	function := er.notToExpression(v.Not, ast.Regexp, &v.Type, er.ctxStack[l-2], er.ctxStack[l-1])
  1442  	er.ctxStackPop(2)
  1443  	er.ctxStackAppend(function, types.EmptyName)
  1444  }
  1445  
  1446  func (er *memexRewriter) rowToScalarFunc(v *ast.RowExpr) {
  1447  	stkLen := len(er.ctxStack)
  1448  	length := len(v.Values)
  1449  	rows := make([]memex.Expression, 0, length)
  1450  	for i := stkLen - length; i < stkLen; i++ {
  1451  		rows = append(rows, er.ctxStack[i])
  1452  	}
  1453  	er.ctxStackPop(length)
  1454  	function, err := er.newFunction(ast.RowFunc, rows[0].GetType(), rows...)
  1455  	if err != nil {
  1456  		er.err = err
  1457  		return
  1458  	}
  1459  	er.ctxStackAppend(function, types.EmptyName)
  1460  }
  1461  
  1462  func (er *memexRewriter) betweenToExpression(v *ast.BetweenExpr) {
  1463  	stkLen := len(er.ctxStack)
  1464  	er.err = memex.CheckArgsNotMultiDeferredCausetRow(er.ctxStack[stkLen-3:]...)
  1465  	if er.err != nil {
  1466  		return
  1467  	}
  1468  
  1469  	expr, lexp, rexp := er.ctxStack[stkLen-3], er.ctxStack[stkLen-2], er.ctxStack[stkLen-1]
  1470  
  1471  	if memex.GetCmpTp4MinMax([]memex.Expression{expr, lexp, rexp}) == types.ETDatetime {
  1472  		expr = memex.WrapWithCastAsTime(er.sctx, expr, types.NewFieldType(allegrosql.TypeDatetime))
  1473  		lexp = memex.WrapWithCastAsTime(er.sctx, lexp, types.NewFieldType(allegrosql.TypeDatetime))
  1474  		rexp = memex.WrapWithCastAsTime(er.sctx, rexp, types.NewFieldType(allegrosql.TypeDatetime))
  1475  	}
  1476  
  1477  	var op string
  1478  	var l, r memex.Expression
  1479  	l, er.err = er.newFunction(ast.GE, &v.Type, expr, lexp)
  1480  	if er.err == nil {
  1481  		r, er.err = er.newFunction(ast.LE, &v.Type, expr, rexp)
  1482  	}
  1483  	op = ast.LogicAnd
  1484  	if er.err != nil {
  1485  		return
  1486  	}
  1487  	function, err := er.newFunction(op, &v.Type, l, r)
  1488  	if err != nil {
  1489  		er.err = err
  1490  		return
  1491  	}
  1492  	if v.Not {
  1493  		function, err = er.newFunction(ast.UnaryNot, &v.Type, function)
  1494  		if err != nil {
  1495  			er.err = err
  1496  			return
  1497  		}
  1498  	}
  1499  	er.ctxStackPop(3)
  1500  	er.ctxStackAppend(function, types.EmptyName)
  1501  }
  1502  
  1503  // rewriteFuncCall handles a FuncCallExpr and generates a customized function.
  1504  // It should return true if for the given FuncCallExpr a rewrite is performed so that original behavior is skipped.
  1505  // Otherwise it should return false to indicate (the caller) that original behavior needs to be performed.
  1506  func (er *memexRewriter) rewriteFuncCall(v *ast.FuncCallExpr) bool {
  1507  	switch v.FnName.L {
  1508  	// when column is not null, ifnull on such column is not necessary.
  1509  	case ast.Ifnull:
  1510  		if len(v.Args) != 2 {
  1511  			er.err = memex.ErrIncorrectParameterCount.GenWithStackByArgs(v.FnName.O)
  1512  			return true
  1513  		}
  1514  		stackLen := len(er.ctxStack)
  1515  		arg1 := er.ctxStack[stackLen-2]
  1516  		col, isDeferredCauset := arg1.(*memex.DeferredCauset)
  1517  		// if expr1 is a column and column has not null flag, then we can eliminate ifnull on
  1518  		// this column.
  1519  		if isDeferredCauset && allegrosql.HasNotNullFlag(col.RetType.Flag) {
  1520  			name := er.ctxNameStk[stackLen-2]
  1521  			newDefCaus := col.Clone().(*memex.DeferredCauset)
  1522  			er.ctxStackPop(len(v.Args))
  1523  			er.ctxStackAppend(newDefCaus, name)
  1524  			return true
  1525  		}
  1526  
  1527  		return false
  1528  	case ast.Nullif:
  1529  		if len(v.Args) != 2 {
  1530  			er.err = memex.ErrIncorrectParameterCount.GenWithStackByArgs(v.FnName.O)
  1531  			return true
  1532  		}
  1533  		stackLen := len(er.ctxStack)
  1534  		param1 := er.ctxStack[stackLen-2]
  1535  		param2 := er.ctxStack[stackLen-1]
  1536  		// param1 = param2
  1537  		funcCompare, err := er.constructBinaryOpFunction(param1, param2, ast.EQ)
  1538  		if err != nil {
  1539  			er.err = err
  1540  			return true
  1541  		}
  1542  		// NULL
  1543  		nullTp := types.NewFieldType(allegrosql.TypeNull)
  1544  		nullTp.Flen, nullTp.Decimal = allegrosql.GetDefaultFieldLengthAndDecimal(allegrosql.TypeNull)
  1545  		paramNull := &memex.Constant{
  1546  			Value:   types.NewCauset(nil),
  1547  			RetType: nullTp,
  1548  		}
  1549  		// if(param1 = param2, NULL, param1)
  1550  		funcIf, err := er.newFunction(ast.If, &v.Type, funcCompare, paramNull, param1)
  1551  		if err != nil {
  1552  			er.err = err
  1553  			return true
  1554  		}
  1555  		er.ctxStackPop(len(v.Args))
  1556  		er.ctxStackAppend(funcIf, types.EmptyName)
  1557  		return true
  1558  	default:
  1559  		return false
  1560  	}
  1561  }
  1562  
  1563  func (er *memexRewriter) funcCallToExpression(v *ast.FuncCallExpr) {
  1564  	stackLen := len(er.ctxStack)
  1565  	args := er.ctxStack[stackLen-len(v.Args):]
  1566  	er.err = memex.CheckArgsNotMultiDeferredCausetRow(args...)
  1567  	if er.err != nil {
  1568  		return
  1569  	}
  1570  
  1571  	if er.rewriteFuncCall(v) {
  1572  		return
  1573  	}
  1574  
  1575  	var function memex.Expression
  1576  	er.ctxStackPop(len(v.Args))
  1577  	if _, ok := memex.DeferredFunctions[v.FnName.L]; er.useCache() && ok {
  1578  		// When the memex is unix_timestamp and the number of argument is not zero,
  1579  		// we deal with it as normal memex.
  1580  		if v.FnName.L == ast.UnixTimestamp && len(v.Args) != 0 {
  1581  			function, er.err = er.newFunction(v.FnName.L, &v.Type, args...)
  1582  			er.ctxStackAppend(function, types.EmptyName)
  1583  		} else {
  1584  			function, er.err = memex.NewFunctionBase(er.sctx, v.FnName.L, &v.Type, args...)
  1585  			c := &memex.Constant{Value: types.NewCauset(nil), RetType: function.GetType().Clone(), DeferredExpr: function}
  1586  			er.ctxStackAppend(c, types.EmptyName)
  1587  		}
  1588  	} else {
  1589  		function, er.err = er.newFunction(v.FnName.L, &v.Type, args...)
  1590  		er.ctxStackAppend(function, types.EmptyName)
  1591  	}
  1592  }
  1593  
  1594  // Now BlockName in memex only used by sequence function like nextval(seq).
  1595  // The function arg should be evaluated as a causet name rather than normal column name like allegrosql does.
  1596  func (er *memexRewriter) toBlock(v *ast.BlockName) {
  1597  	fullName := v.Name.L
  1598  	if len(v.Schema.L) != 0 {
  1599  		fullName = v.Schema.L + "." + fullName
  1600  	}
  1601  	val := &memex.Constant{
  1602  		Value:   types.NewCauset(fullName),
  1603  		RetType: types.NewFieldType(allegrosql.TypeString),
  1604  	}
  1605  	er.ctxStackAppend(val, types.EmptyName)
  1606  }
  1607  
  1608  func (er *memexRewriter) toDeferredCauset(v *ast.DeferredCausetName) {
  1609  	idx, err := memex.FindFieldName(er.names, v)
  1610  	if err != nil {
  1611  		er.err = ErrAmbiguous.GenWithStackByArgs(v.Name, clauseMsg[fieldList])
  1612  		return
  1613  	}
  1614  	if idx >= 0 {
  1615  		column := er.schemaReplicant.DeferredCausets[idx]
  1616  		if column.IsHidden {
  1617  			er.err = ErrUnknownDeferredCauset.GenWithStackByArgs(v.Name, clauseMsg[er.b.curClause])
  1618  			return
  1619  		}
  1620  		er.ctxStackAppend(column, er.names[idx])
  1621  		return
  1622  	}
  1623  	for i := len(er.b.outerSchemas) - 1; i >= 0; i-- {
  1624  		outerSchema, outerName := er.b.outerSchemas[i], er.b.outerNames[i]
  1625  		idx, err = memex.FindFieldName(outerName, v)
  1626  		if idx >= 0 {
  1627  			column := outerSchema.DeferredCausets[idx]
  1628  			er.ctxStackAppend(&memex.CorrelatedDeferredCauset{DeferredCauset: *column, Data: new(types.Causet)}, outerName[idx])
  1629  			return
  1630  		}
  1631  		if err != nil {
  1632  			er.err = ErrAmbiguous.GenWithStackByArgs(v.Name, clauseMsg[fieldList])
  1633  			return
  1634  		}
  1635  	}
  1636  	if join, ok := er.p.(*LogicalJoin); ok && join.redundantSchema != nil {
  1637  		idx, err := memex.FindFieldName(join.redundantNames, v)
  1638  		if err != nil {
  1639  			er.err = err
  1640  			return
  1641  		}
  1642  		if idx >= 0 {
  1643  			er.ctxStackAppend(join.redundantSchema.DeferredCausets[idx], join.redundantNames[idx])
  1644  			return
  1645  		}
  1646  	}
  1647  	if _, ok := er.p.(*LogicalUnionAll); ok && v.Block.O != "" {
  1648  		er.err = ErrBlocknameNotAllowedHere.GenWithStackByArgs(v.Block.O, "SELECT", clauseMsg[er.b.curClause])
  1649  		return
  1650  	}
  1651  	if er.b.curClause == globalOrderByClause {
  1652  		er.b.curClause = orderByClause
  1653  	}
  1654  	er.err = ErrUnknownDeferredCauset.GenWithStackByArgs(v.String(), clauseMsg[er.b.curClause])
  1655  }
  1656  
  1657  func (er *memexRewriter) evalDefaultExpr(v *ast.DefaultExpr) {
  1658  	var name *types.FieldName
  1659  	// Here we will find the corresponding column for default function. At the same time, we need to consider the issue
  1660  	// of subquery and name space.
  1661  	// For example, we have two blocks t1(a int default 1, b int) and t2(a int default -1, c int). Consider the following ALLEGROALLEGROSQL:
  1662  	// 		select a from t1 where a > (select default(a) from t2)
  1663  	// Refer to the behavior of MyALLEGROSQL, we need to find column a in causet t2. If causet t2 does not have column a, then find it
  1664  	// in causet t1. If there are none, return an error message.
  1665  	// Based on the above description, we need to look in er.b.allNames from back to front.
  1666  	for i := len(er.b.allNames) - 1; i >= 0; i-- {
  1667  		idx, err := memex.FindFieldName(er.b.allNames[i], v.Name)
  1668  		if err != nil {
  1669  			er.err = err
  1670  			return
  1671  		}
  1672  		if idx >= 0 {
  1673  			name = er.b.allNames[i][idx]
  1674  			break
  1675  		}
  1676  	}
  1677  	if name == nil {
  1678  		idx, err := memex.FindFieldName(er.names, v.Name)
  1679  		if err != nil {
  1680  			er.err = err
  1681  			return
  1682  		}
  1683  		if idx < 0 {
  1684  			er.err = ErrUnknownDeferredCauset.GenWithStackByArgs(v.Name.OrigDefCausName(), "field list")
  1685  			return
  1686  		}
  1687  		name = er.names[idx]
  1688  	}
  1689  
  1690  	dbName := name.DBName
  1691  	if dbName.O == "" {
  1692  		// if database name is not specified, use current database name
  1693  		dbName = perceptron.NewCIStr(er.sctx.GetStochastikVars().CurrentDB)
  1694  	}
  1695  	if name.OrigTblName.O == "" {
  1696  		// column is evaluated by some memexs, for example:
  1697  		// `select default(c) from (select (a+1) as c from t) as t0`
  1698  		// in such case, a 'no default' error is returned
  1699  		er.err = causet.ErrNoDefaultValue.GenWithStackByArgs(name.DefCausName)
  1700  		return
  1701  	}
  1702  	var tbl causet.Block
  1703  	tbl, er.err = er.b.is.BlockByName(dbName, name.OrigTblName)
  1704  	if er.err != nil {
  1705  		return
  1706  	}
  1707  	colName := name.OrigDefCausName.O
  1708  	if colName == "" {
  1709  		// in some cases, OrigDefCausName is empty, use DefCausName instead
  1710  		colName = name.DefCausName.O
  1711  	}
  1712  	col := causet.FindDefCaus(tbl.DefCauss(), colName)
  1713  	if col == nil {
  1714  		er.err = ErrUnknownDeferredCauset.GenWithStackByArgs(v.Name, "field_list")
  1715  		return
  1716  	}
  1717  	isCurrentTimestamp := hasCurrentDatetimeDefault(col)
  1718  	var val *memex.Constant
  1719  	switch {
  1720  	case isCurrentTimestamp && col.Tp == allegrosql.TypeDatetime:
  1721  		// for DATETIME column with current_timestamp, use NULL to be compatible with MyALLEGROSQL 5.7
  1722  		val = memex.NewNull()
  1723  	case isCurrentTimestamp && col.Tp == allegrosql.TypeTimestamp:
  1724  		// for TIMESTAMP column with current_timestamp, use 0 to be compatible with MyALLEGROSQL 5.7
  1725  		zero := types.NewTime(types.ZeroCoreTime, allegrosql.TypeTimestamp, int8(col.Decimal))
  1726  		val = &memex.Constant{
  1727  			Value:   types.NewCauset(zero),
  1728  			RetType: types.NewFieldType(allegrosql.TypeTimestamp),
  1729  		}
  1730  	default:
  1731  		// for other columns, just use what it is
  1732  		val, er.err = er.b.getDefaultValue(col)
  1733  	}
  1734  	if er.err != nil {
  1735  		return
  1736  	}
  1737  	er.ctxStackAppend(val, types.EmptyName)
  1738  }
  1739  
  1740  // hasCurrentDatetimeDefault checks if column has current_timestamp default value
  1741  func hasCurrentDatetimeDefault(col *causet.DeferredCauset) bool {
  1742  	x, ok := col.DefaultValue.(string)
  1743  	if !ok {
  1744  		return false
  1745  	}
  1746  	return strings.ToLower(x) == ast.CurrentTimestamp
  1747  }