github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/causetstore/petri/acyclic/causet/embedded/memex_rewriter.go (about) 1 // Copyright 2020 WHTCORPS INC, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package embedded 15 16 import ( 17 "context" 18 "strconv" 19 "strings" 20 21 "github.com/whtcorpsinc/BerolinaSQL/allegrosql" 22 "github.com/whtcorpsinc/BerolinaSQL/ast" 23 "github.com/whtcorpsinc/BerolinaSQL/charset" 24 "github.com/whtcorpsinc/BerolinaSQL/opcode" 25 "github.com/whtcorpsinc/BerolinaSQL/perceptron" 26 "github.com/whtcorpsinc/errors" 27 "github.com/whtcorpsinc/milevadb/causet" 28 "github.com/whtcorpsinc/milevadb/memex" 29 "github.com/whtcorpsinc/milevadb/memex/aggregation" 30 "github.com/whtcorpsinc/milevadb/schemareplicant" 31 "github.com/whtcorpsinc/milevadb/soliton/chunk" 32 "github.com/whtcorpsinc/milevadb/soliton/collate" 33 "github.com/whtcorpsinc/milevadb/soliton/hint" 34 "github.com/whtcorpsinc/milevadb/soliton/stringutil" 35 "github.com/whtcorpsinc/milevadb/stochastikctx" 36 "github.com/whtcorpsinc/milevadb/stochastikctx/variable" 37 "github.com/whtcorpsinc/milevadb/types" 38 driver "github.com/whtcorpsinc/milevadb/types/BerolinaSQL_driver" 39 ) 40 41 // EvalSubqueryFirstRow evaluates incorrelated subqueries once, and get first event. 42 var EvalSubqueryFirstRow func(ctx context.Context, p PhysicalCauset, is schemareplicant.SchemaReplicant, sctx stochastikctx.Context) (event []types.Causet, err error) 43 44 // evalAstExpr evaluates ast memex directly. 45 func evalAstExpr(sctx stochastikctx.Context, expr ast.ExprNode) (types.Causet, error) { 46 if val, ok := expr.(*driver.ValueExpr); ok { 47 return val.Causet, nil 48 } 49 newExpr, err := rewriteAstExpr(sctx, expr, nil, nil) 50 if err != nil { 51 return types.Causet{}, err 52 } 53 return newExpr.Eval(chunk.Row{}) 54 } 55 56 // rewriteAstExpr rewrites ast memex directly. 57 func rewriteAstExpr(sctx stochastikctx.Context, expr ast.ExprNode, schemaReplicant *memex.Schema, names types.NameSlice) (memex.Expression, error) { 58 var is schemareplicant.SchemaReplicant 59 if sctx.GetStochastikVars().TxnCtx.SchemaReplicant != nil { 60 is = sctx.GetStochastikVars().TxnCtx.SchemaReplicant.(schemareplicant.SchemaReplicant) 61 } 62 b := NewCausetBuilder(sctx, is, &hint.BlockHintProcessor{}) 63 fakeCauset := LogicalBlockDual{}.Init(sctx, 0) 64 if schemaReplicant != nil { 65 fakeCauset.schemaReplicant = schemaReplicant 66 fakeCauset.names = names 67 } 68 newExpr, _, err := b.rewrite(context.TODO(), expr, fakeCauset, nil, true) 69 if err != nil { 70 return nil, err 71 } 72 return newExpr, nil 73 } 74 75 func (b *CausetBuilder) rewriteInsertOnDuplicateUFIDelate(ctx context.Context, exprNode ast.ExprNode, mockCauset LogicalCauset, insertCauset *Insert) (memex.Expression, error) { 76 b.rewriterCounter++ 77 defer func() { b.rewriterCounter-- }() 78 79 rewriter := b.getExpressionRewriter(ctx, mockCauset) 80 // The rewriter maybe is obtained from "b.rewriterPool", "rewriter.err" is 81 // not nil means certain previous procedure has not handled this error. 82 // Here we give us one more chance to make a correct behavior by handling 83 // this missed error. 84 if rewriter.err != nil { 85 return nil, rewriter.err 86 } 87 88 rewriter.insertCauset = insertCauset 89 rewriter.asScalar = true 90 91 expr, _, err := b.rewriteExprNode(rewriter, exprNode, true) 92 return expr, err 93 } 94 95 // rewrite function rewrites ast expr to memex.Expression. 96 // aggMapper maps ast.AggregateFuncExpr to the columns offset in p's output schemaReplicant. 97 // asScalar means whether this memex must be treated as a scalar memex. 98 // And this function returns a result memex, a new plan that may have apply or semi-join. 99 func (b *CausetBuilder) rewrite(ctx context.Context, exprNode ast.ExprNode, p LogicalCauset, aggMapper map[*ast.AggregateFuncExpr]int, asScalar bool) (memex.Expression, LogicalCauset, error) { 100 expr, resultCauset, err := b.rewriteWithPreprocess(ctx, exprNode, p, aggMapper, nil, asScalar, nil) 101 return expr, resultCauset, err 102 } 103 104 // rewriteWithPreprocess is for handling the situation that we need to adjust the input ast tree 105 // before really using its node in `memexRewriter.Leave`. In that case, we first call 106 // er.preprocess(expr), which returns a new expr. Then we use the new expr in `Leave`. 107 func (b *CausetBuilder) rewriteWithPreprocess( 108 ctx context.Context, 109 exprNode ast.ExprNode, 110 p LogicalCauset, aggMapper map[*ast.AggregateFuncExpr]int, 111 windowMapper map[*ast.WindowFuncExpr]int, 112 asScalar bool, 113 preprocess func(ast.Node) ast.Node, 114 ) (memex.Expression, LogicalCauset, error) { 115 b.rewriterCounter++ 116 defer func() { b.rewriterCounter-- }() 117 118 rewriter := b.getExpressionRewriter(ctx, p) 119 // The rewriter maybe is obtained from "b.rewriterPool", "rewriter.err" is 120 // not nil means certain previous procedure has not handled this error. 121 // Here we give us one more chance to make a correct behavior by handling 122 // this missed error. 123 if rewriter.err != nil { 124 return nil, nil, rewriter.err 125 } 126 127 rewriter.aggrMap = aggMapper 128 rewriter.windowMap = windowMapper 129 rewriter.asScalar = asScalar 130 rewriter.preprocess = preprocess 131 132 expr, resultCauset, err := b.rewriteExprNode(rewriter, exprNode, asScalar) 133 return expr, resultCauset, err 134 } 135 136 func (b *CausetBuilder) getExpressionRewriter(ctx context.Context, p LogicalCauset) (rewriter *memexRewriter) { 137 defer func() { 138 if p != nil { 139 rewriter.schemaReplicant = p.Schema() 140 rewriter.names = p.OutputNames() 141 } 142 }() 143 144 if len(b.rewriterPool) < b.rewriterCounter { 145 rewriter = &memexRewriter{p: p, b: b, sctx: b.ctx, ctx: ctx} 146 b.rewriterPool = append(b.rewriterPool, rewriter) 147 return 148 } 149 150 rewriter = b.rewriterPool[b.rewriterCounter-1] 151 rewriter.p = p 152 rewriter.asScalar = false 153 rewriter.aggrMap = nil 154 rewriter.preprocess = nil 155 rewriter.insertCauset = nil 156 rewriter.disableFoldCounter = 0 157 rewriter.tryFoldCounter = 0 158 rewriter.ctxStack = rewriter.ctxStack[:0] 159 rewriter.ctxNameStk = rewriter.ctxNameStk[:0] 160 rewriter.ctx = ctx 161 return 162 } 163 164 func (b *CausetBuilder) rewriteExprNode(rewriter *memexRewriter, exprNode ast.ExprNode, asScalar bool) (memex.Expression, LogicalCauset, error) { 165 if rewriter.p != nil { 166 curDefCausLen := rewriter.p.Schema().Len() 167 defer func() { 168 names := rewriter.p.OutputNames().Shallow()[:curDefCausLen] 169 for i := curDefCausLen; i < rewriter.p.Schema().Len(); i++ { 170 names = append(names, types.EmptyName) 171 } 172 // After rewriting finished, only old columns are visible. 173 // e.g. select * from t where t.a in (select t1.a from t1); 174 // The output columns before we enter the subquery are the columns from t. 175 // But when we leave the subquery `t.a in (select t1.a from t1)`, we got a Apply operator 176 // and the output columns become [t.*, t1.*]. But t1.* is used only inside the subquery. If there's another filter 177 // which is also a subquery where t1 is involved. The name resolving will fail if we still expose the column from 178 // the previous subquery. 179 // So here we just reset the names to empty to avoid this situation. 180 // TODO: implement ScalarSubQuery and resolve it during optimizing. In building phase, we will not change the plan's structure. 181 rewriter.p.SetOutputNames(names) 182 }() 183 } 184 exprNode.Accept(rewriter) 185 if rewriter.err != nil { 186 return nil, nil, errors.Trace(rewriter.err) 187 } 188 if !asScalar && len(rewriter.ctxStack) == 0 { 189 return nil, rewriter.p, nil 190 } 191 if len(rewriter.ctxStack) != 1 { 192 return nil, nil, errors.Errorf("context len %v is invalid", len(rewriter.ctxStack)) 193 } 194 rewriter.err = memex.CheckArgsNotMultiDeferredCausetRow(rewriter.ctxStack[0]) 195 if rewriter.err != nil { 196 return nil, nil, errors.Trace(rewriter.err) 197 } 198 return rewriter.ctxStack[0], rewriter.p, nil 199 } 200 201 type memexRewriter struct { 202 ctxStack []memex.Expression 203 ctxNameStk []*types.FieldName 204 p LogicalCauset 205 schemaReplicant *memex.Schema 206 names []*types.FieldName 207 err error 208 aggrMap map[*ast.AggregateFuncExpr]int 209 windowMap map[*ast.WindowFuncExpr]int 210 b *CausetBuilder 211 sctx stochastikctx.Context 212 ctx context.Context 213 214 // asScalar indicates the return value must be a scalar value. 215 // NOTE: This value can be changed during memex rewritten. 216 asScalar bool 217 218 // preprocess is called for every ast.Node in Leave. 219 preprocess func(ast.Node) ast.Node 220 221 // insertCauset is only used to rewrite the memexs inside the assignment 222 // of the "INSERT" memex. 223 insertCauset *Insert 224 225 // disableFoldCounter controls fold-disabled scope. If > 0, rewriter will NOT do constant folding. 226 // Typically, during visiting AST, while entering the scope(disable), the counter will +1; while 227 // leaving the scope(enable again), the counter will -1. 228 // NOTE: This value can be changed during memex rewritten. 229 disableFoldCounter int 230 tryFoldCounter int 231 } 232 233 func (er *memexRewriter) ctxStackLen() int { 234 return len(er.ctxStack) 235 } 236 237 func (er *memexRewriter) ctxStackPop(num int) { 238 l := er.ctxStackLen() 239 er.ctxStack = er.ctxStack[:l-num] 240 er.ctxNameStk = er.ctxNameStk[:l-num] 241 } 242 243 func (er *memexRewriter) ctxStackAppend(col memex.Expression, name *types.FieldName) { 244 er.ctxStack = append(er.ctxStack, col) 245 er.ctxNameStk = append(er.ctxNameStk, name) 246 } 247 248 // constructBinaryOpFunction converts binary operator functions 249 // 1. If op are EQ or NE or NullEQ, constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to (a0 op b0) and (a1 op b1) and (a2 op b2) 250 // 2. Else constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to 251 // `IF( a0 NE b0, a0 op b0, 252 // IF ( isNull(a0 NE b0), Null, 253 // IF ( a1 NE b1, a1 op b1, 254 // IF ( isNull(a1 NE b1), Null, a2 op b2))))` 255 func (er *memexRewriter) constructBinaryOpFunction(l memex.Expression, r memex.Expression, op string) (memex.Expression, error) { 256 lLen, rLen := memex.GetRowLen(l), memex.GetRowLen(r) 257 if lLen == 1 && rLen == 1 { 258 return er.newFunction(op, types.NewFieldType(allegrosql.TypeTiny), l, r) 259 } else if rLen != lLen { 260 return nil, memex.ErrOperandDeferredCausets.GenWithStackByArgs(lLen) 261 } 262 switch op { 263 case ast.EQ, ast.NE, ast.NullEQ: 264 funcs := make([]memex.Expression, lLen) 265 for i := 0; i < lLen; i++ { 266 var err error 267 funcs[i], err = er.constructBinaryOpFunction(memex.GetFuncArg(l, i), memex.GetFuncArg(r, i), op) 268 if err != nil { 269 return nil, err 270 } 271 } 272 if op == ast.NE { 273 return memex.ComposeDNFCondition(er.sctx, funcs...), nil 274 } 275 return memex.ComposeCNFCondition(er.sctx, funcs...), nil 276 default: 277 larg0, rarg0 := memex.GetFuncArg(l, 0), memex.GetFuncArg(r, 0) 278 var expr1, expr2, expr3, expr4, expr5 memex.Expression 279 expr1 = memex.NewFunctionInternal(er.sctx, ast.NE, types.NewFieldType(allegrosql.TypeTiny), larg0, rarg0) 280 expr2 = memex.NewFunctionInternal(er.sctx, op, types.NewFieldType(allegrosql.TypeTiny), larg0, rarg0) 281 expr3 = memex.NewFunctionInternal(er.sctx, ast.IsNull, types.NewFieldType(allegrosql.TypeTiny), expr1) 282 var err error 283 l, err = memex.PopRowFirstArg(er.sctx, l) 284 if err != nil { 285 return nil, err 286 } 287 r, err = memex.PopRowFirstArg(er.sctx, r) 288 if err != nil { 289 return nil, err 290 } 291 expr4, err = er.constructBinaryOpFunction(l, r, op) 292 if err != nil { 293 return nil, err 294 } 295 expr5, err = er.newFunction(ast.If, types.NewFieldType(allegrosql.TypeTiny), expr3, memex.NewNull(), expr4) 296 if err != nil { 297 return nil, err 298 } 299 return er.newFunction(ast.If, types.NewFieldType(allegrosql.TypeTiny), expr1, expr2, expr5) 300 } 301 } 302 303 func (er *memexRewriter) buildSubquery(ctx context.Context, subq *ast.SubqueryExpr) (LogicalCauset, error) { 304 if er.schemaReplicant != nil { 305 outerSchema := er.schemaReplicant.Clone() 306 er.b.outerSchemas = append(er.b.outerSchemas, outerSchema) 307 er.b.outerNames = append(er.b.outerNames, er.names) 308 defer func() { 309 er.b.outerSchemas = er.b.outerSchemas[0 : len(er.b.outerSchemas)-1] 310 er.b.outerNames = er.b.outerNames[0 : len(er.b.outerNames)-1] 311 }() 312 } 313 314 np, err := er.b.buildResultSetNode(ctx, subq.Query) 315 if err != nil { 316 return nil, err 317 } 318 // Pop the handle map generated by the subquery. 319 er.b.handleHelper.popMap() 320 return np, nil 321 } 322 323 // Enter implements Visitor interface. 324 func (er *memexRewriter) Enter(inNode ast.Node) (ast.Node, bool) { 325 switch v := inNode.(type) { 326 case *ast.AggregateFuncExpr: 327 index, ok := -1, false 328 if er.aggrMap != nil { 329 index, ok = er.aggrMap[v] 330 } 331 if !ok { 332 er.err = ErrInvalidGroupFuncUse 333 return inNode, true 334 } 335 er.ctxStackAppend(er.schemaReplicant.DeferredCausets[index], er.names[index]) 336 return inNode, true 337 case *ast.DeferredCausetNameExpr: 338 if index, ok := er.b.colMapper[v]; ok { 339 er.ctxStackAppend(er.schemaReplicant.DeferredCausets[index], er.names[index]) 340 return inNode, true 341 } 342 case *ast.CompareSubqueryExpr: 343 return er.handleCompareSubquery(er.ctx, v) 344 case *ast.ExistsSubqueryExpr: 345 return er.handleExistSubquery(er.ctx, v) 346 case *ast.PatternInExpr: 347 if v.Sel != nil { 348 return er.handleInSubquery(er.ctx, v) 349 } 350 if len(v.List) != 1 { 351 break 352 } 353 // For 10 in ((select * from t)), the BerolinaSQL won't set v.Sel. 354 // So we must process this case here. 355 x := v.List[0] 356 for { 357 switch y := x.(type) { 358 case *ast.SubqueryExpr: 359 v.Sel = y 360 return er.handleInSubquery(er.ctx, v) 361 case *ast.ParenthesesExpr: 362 x = y.Expr 363 default: 364 return inNode, false 365 } 366 } 367 case *ast.SubqueryExpr: 368 return er.handleScalarSubquery(er.ctx, v) 369 case *ast.ParenthesesExpr: 370 case *ast.ValuesExpr: 371 schemaReplicant, names := er.schemaReplicant, er.names 372 // NOTE: "er.insertCauset != nil" means that we are rewriting the 373 // memexs inside the assignment of "INSERT" memex. we have to 374 // use the "blockSchema" of that "insertCauset". 375 if er.insertCauset != nil { 376 schemaReplicant = er.insertCauset.blockSchema 377 names = er.insertCauset.blockDefCausNames 378 } 379 idx, err := memex.FindFieldName(names, v.DeferredCauset.Name) 380 if err != nil { 381 er.err = err 382 return inNode, false 383 } 384 if idx < 0 { 385 er.err = ErrUnknownDeferredCauset.GenWithStackByArgs(v.DeferredCauset.Name.OrigDefCausName(), "field list") 386 return inNode, false 387 } 388 col := schemaReplicant.DeferredCausets[idx] 389 er.ctxStackAppend(memex.NewValuesFunc(er.sctx, col.Index, col.RetType), types.EmptyName) 390 return inNode, true 391 case *ast.WindowFuncExpr: 392 index, ok := -1, false 393 if er.windowMap != nil { 394 index, ok = er.windowMap[v] 395 } 396 if !ok { 397 er.err = ErrWindowInvalidWindowFuncUse.GenWithStackByArgs(strings.ToLower(v.F)) 398 return inNode, true 399 } 400 er.ctxStackAppend(er.schemaReplicant.DeferredCausets[index], er.names[index]) 401 return inNode, true 402 case *ast.FuncCallExpr: 403 if _, ok := memex.DisableFoldFunctions[v.FnName.L]; ok { 404 er.disableFoldCounter++ 405 } 406 if _, ok := memex.TryFoldFunctions[v.FnName.L]; ok { 407 er.tryFoldCounter++ 408 } 409 case *ast.CaseExpr: 410 if _, ok := memex.DisableFoldFunctions["case"]; ok { 411 er.disableFoldCounter++ 412 } 413 if _, ok := memex.TryFoldFunctions["case"]; ok { 414 er.tryFoldCounter++ 415 } 416 case *ast.SetDefCauslationExpr: 417 // Do nothing 418 default: 419 er.asScalar = true 420 } 421 return inNode, false 422 } 423 424 func (er *memexRewriter) buildSemiApplyFromEqualSubq(np LogicalCauset, l, r memex.Expression, not bool) { 425 var condition memex.Expression 426 if rDefCaus, ok := r.(*memex.DeferredCauset); ok && (er.asScalar || not) { 427 // If both input columns of `!= all / = any` memex are not null, we can treat the memex 428 // as normal column equal condition. 429 if lDefCaus, ok := l.(*memex.DeferredCauset); !ok || !allegrosql.HasNotNullFlag(lDefCaus.GetType().Flag) || !allegrosql.HasNotNullFlag(rDefCaus.GetType().Flag) { 430 rDefCausCopy := *rDefCaus 431 rDefCausCopy.InOperand = true 432 r = &rDefCausCopy 433 } 434 } 435 condition, er.err = er.constructBinaryOpFunction(l, r, ast.EQ) 436 if er.err != nil { 437 return 438 } 439 er.p, er.err = er.b.buildSemiApply(er.p, np, []memex.Expression{condition}, er.asScalar, not) 440 } 441 442 func (er *memexRewriter) handleCompareSubquery(ctx context.Context, v *ast.CompareSubqueryExpr) (ast.Node, bool) { 443 v.L.Accept(er) 444 if er.err != nil { 445 return v, true 446 } 447 lexpr := er.ctxStack[len(er.ctxStack)-1] 448 subq, ok := v.R.(*ast.SubqueryExpr) 449 if !ok { 450 er.err = errors.Errorf("Unknown compare type %T.", v.R) 451 return v, true 452 } 453 np, err := er.buildSubquery(ctx, subq) 454 if err != nil { 455 er.err = err 456 return v, true 457 } 458 // Only (a,b,c) = any (...) and (a,b,c) != all (...) can use event memex. 459 canMultiDefCaus := (!v.All && v.Op == opcode.EQ) || (v.All && v.Op == opcode.NE) 460 if !canMultiDefCaus && (memex.GetRowLen(lexpr) != 1 || np.Schema().Len() != 1) { 461 er.err = memex.ErrOperandDeferredCausets.GenWithStackByArgs(1) 462 return v, true 463 } 464 lLen := memex.GetRowLen(lexpr) 465 if lLen != np.Schema().Len() { 466 er.err = memex.ErrOperandDeferredCausets.GenWithStackByArgs(lLen) 467 return v, true 468 } 469 var rexpr memex.Expression 470 if np.Schema().Len() == 1 { 471 rexpr = np.Schema().DeferredCausets[0] 472 } else { 473 args := make([]memex.Expression, 0, np.Schema().Len()) 474 for _, col := range np.Schema().DeferredCausets { 475 args = append(args, col) 476 } 477 rexpr, er.err = er.newFunction(ast.RowFunc, args[0].GetType(), args...) 478 if er.err != nil { 479 return v, true 480 } 481 } 482 switch v.Op { 483 // Only EQ, NE and NullEQ can be composed with and. 484 case opcode.EQ, opcode.NE, opcode.NullEQ: 485 if v.Op == opcode.EQ { 486 if v.All { 487 er.handleEQAll(lexpr, rexpr, np) 488 } else { 489 // `a = any(subq)` will be rewriten as `a in (subq)`. 490 er.buildSemiApplyFromEqualSubq(np, lexpr, rexpr, false) 491 if er.err != nil { 492 return v, true 493 } 494 } 495 } else if v.Op == opcode.NE { 496 if v.All { 497 // `a != all(subq)` will be rewriten as `a not in (subq)`. 498 er.buildSemiApplyFromEqualSubq(np, lexpr, rexpr, true) 499 if er.err != nil { 500 return v, true 501 } 502 } else { 503 er.handleNEAny(lexpr, rexpr, np) 504 } 505 } else { 506 // TODO: Support this in future. 507 er.err = errors.New("We don't support <=> all or <=> any now") 508 return v, true 509 } 510 default: 511 // When < all or > any , the agg function should use min. 512 useMin := ((v.Op == opcode.LT || v.Op == opcode.LE) && v.All) || ((v.Op == opcode.GT || v.Op == opcode.GE) && !v.All) 513 er.handleOtherComparableSubq(lexpr, rexpr, np, useMin, v.Op.String(), v.All) 514 } 515 if er.asScalar { 516 // The parent memex only use the last column in schemaReplicant, which represents whether the condition is matched. 517 er.ctxStack[len(er.ctxStack)-1] = er.p.Schema().DeferredCausets[er.p.Schema().Len()-1] 518 er.ctxNameStk[len(er.ctxNameStk)-1] = er.p.OutputNames()[er.p.Schema().Len()-1] 519 } 520 return v, true 521 } 522 523 // handleOtherComparableSubq handles the queries like < any, < max, etc. For example, if the query is t.id < any (select s.id from s), 524 // it will be rewrote to t.id < (select max(s.id) from s). 525 func (er *memexRewriter) handleOtherComparableSubq(lexpr, rexpr memex.Expression, np LogicalCauset, useMin bool, cmpFunc string, all bool) { 526 plan4Agg := LogicalAggregation{}.Init(er.sctx, er.b.getSelectOffset()) 527 if hint := er.b.BlockHints(); hint != nil { 528 plan4Agg.aggHints = hint.aggHints 529 } 530 plan4Agg.SetChildren(np) 531 532 // Create a "max" or "min" aggregation. 533 funcName := ast.AggFuncMax 534 if useMin { 535 funcName = ast.AggFuncMin 536 } 537 funcMaxOrMin, err := aggregation.NewAggFuncDesc(er.sctx, funcName, []memex.Expression{rexpr}, false) 538 if err != nil { 539 er.err = err 540 return 541 } 542 543 // Create a column and append it to the schemaReplicant of that aggregation. 544 colMaxOrMin := &memex.DeferredCauset{ 545 UniqueID: er.sctx.GetStochastikVars().AllocCausetDeferredCausetID(), 546 RetType: funcMaxOrMin.RetTp, 547 } 548 schemaReplicant := memex.NewSchema(colMaxOrMin) 549 550 plan4Agg.names = append(plan4Agg.names, types.EmptyName) 551 plan4Agg.SetSchema(schemaReplicant) 552 plan4Agg.AggFuncs = []*aggregation.AggFuncDesc{funcMaxOrMin} 553 554 cond := memex.NewFunctionInternal(er.sctx, cmpFunc, types.NewFieldType(allegrosql.TypeTiny), lexpr, colMaxOrMin) 555 er.buildQuantifierCauset(plan4Agg, cond, lexpr, rexpr, all) 556 } 557 558 // buildQuantifierCauset adds extra condition for any / all subquery. 559 func (er *memexRewriter) buildQuantifierCauset(plan4Agg *LogicalAggregation, cond, lexpr, rexpr memex.Expression, all bool) { 560 innerIsNull := memex.NewFunctionInternal(er.sctx, ast.IsNull, types.NewFieldType(allegrosql.TypeTiny), rexpr) 561 outerIsNull := memex.NewFunctionInternal(er.sctx, ast.IsNull, types.NewFieldType(allegrosql.TypeTiny), lexpr) 562 563 funcSum, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncSum, []memex.Expression{innerIsNull}, false) 564 if err != nil { 565 er.err = err 566 return 567 } 568 colSum := &memex.DeferredCauset{ 569 UniqueID: er.sctx.GetStochastikVars().AllocCausetDeferredCausetID(), 570 RetType: funcSum.RetTp, 571 } 572 plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcSum) 573 plan4Agg.schemaReplicant.Append(colSum) 574 innerHasNull := memex.NewFunctionInternal(er.sctx, ast.NE, types.NewFieldType(allegrosql.TypeTiny), colSum, memex.NewZero()) 575 576 // Build `count(1)` aggregation to check if subquery is empty. 577 funcCount, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncCount, []memex.Expression{memex.NewOne()}, false) 578 if err != nil { 579 er.err = err 580 return 581 } 582 colCount := &memex.DeferredCauset{ 583 UniqueID: er.sctx.GetStochastikVars().AllocCausetDeferredCausetID(), 584 RetType: funcCount.RetTp, 585 } 586 plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcCount) 587 plan4Agg.schemaReplicant.Append(colCount) 588 589 if all { 590 // All of the inner record set should not contain null value. So for t.id < all(select s.id from s), it 591 // should be rewrote to t.id < min(s.id) and if(sum(s.id is null) != 0, null, true). 592 innerNullChecker := memex.NewFunctionInternal(er.sctx, ast.If, types.NewFieldType(allegrosql.TypeTiny), innerHasNull, memex.NewNull(), memex.NewOne()) 593 cond = memex.ComposeCNFCondition(er.sctx, cond, innerNullChecker) 594 // If the subquery is empty, it should always return true. 595 emptyChecker := memex.NewFunctionInternal(er.sctx, ast.EQ, types.NewFieldType(allegrosql.TypeTiny), colCount, memex.NewZero()) 596 // If outer key is null, and subquery is not empty, it should always return null, even when it is `null = all (1, 2)`. 597 outerNullChecker := memex.NewFunctionInternal(er.sctx, ast.If, types.NewFieldType(allegrosql.TypeTiny), outerIsNull, memex.NewNull(), memex.NewZero()) 598 cond = memex.ComposeDNFCondition(er.sctx, cond, emptyChecker, outerNullChecker) 599 } else { 600 // For "any" memex, if the subquery has null and the cond returns false, the result should be NULL. 601 // Specifically, `t.id < any (select s.id from s)` would be rewrote to `t.id < max(s.id) or if(sum(s.id is null) != 0, null, false)` 602 innerNullChecker := memex.NewFunctionInternal(er.sctx, ast.If, types.NewFieldType(allegrosql.TypeTiny), innerHasNull, memex.NewNull(), memex.NewZero()) 603 cond = memex.ComposeDNFCondition(er.sctx, cond, innerNullChecker) 604 // If the subquery is empty, it should always return false. 605 emptyChecker := memex.NewFunctionInternal(er.sctx, ast.NE, types.NewFieldType(allegrosql.TypeTiny), colCount, memex.NewZero()) 606 // If outer key is null, and subquery is not empty, it should return null. 607 outerNullChecker := memex.NewFunctionInternal(er.sctx, ast.If, types.NewFieldType(allegrosql.TypeTiny), outerIsNull, memex.NewNull(), memex.NewOne()) 608 cond = memex.ComposeCNFCondition(er.sctx, cond, emptyChecker, outerNullChecker) 609 } 610 611 // TODO: Add a Projection if any argument of aggregate funcs or group by items are scalar functions. 612 // plan4Agg.buildProjectionIfNecessary() 613 if !er.asScalar { 614 // For Semi LogicalApply without aux column, the result is no matter false or null. So we can add it to join predicate. 615 er.p, er.err = er.b.buildSemiApply(er.p, plan4Agg, []memex.Expression{cond}, false, false) 616 return 617 } 618 // If we treat the result as a scalar value, we will add a projection with a extra column to output true, false or null. 619 outerSchemaLen := er.p.Schema().Len() 620 er.p = er.b.buildApplyWithJoinType(er.p, plan4Agg, InnerJoin) 621 joinSchema := er.p.Schema() 622 proj := LogicalProjection{ 623 Exprs: memex.DeferredCauset2Exprs(joinSchema.Clone().DeferredCausets[:outerSchemaLen]), 624 }.Init(er.sctx, er.b.getSelectOffset()) 625 proj.names = make([]*types.FieldName, outerSchemaLen, outerSchemaLen+1) 626 copy(proj.names, er.p.OutputNames()) 627 proj.SetSchema(memex.NewSchema(joinSchema.Clone().DeferredCausets[:outerSchemaLen]...)) 628 proj.Exprs = append(proj.Exprs, cond) 629 proj.schemaReplicant.Append(&memex.DeferredCauset{ 630 UniqueID: er.sctx.GetStochastikVars().AllocCausetDeferredCausetID(), 631 RetType: cond.GetType(), 632 }) 633 proj.names = append(proj.names, types.EmptyName) 634 proj.SetChildren(er.p) 635 er.p = proj 636 } 637 638 // handleNEAny handles the case of != any. For example, if the query is t.id != any (select s.id from s), it will be rewrote to 639 // t.id != s.id or count(distinct s.id) > 1 or [any checker]. If there are two different values in s.id , 640 // there must exist a s.id that doesn't equal to t.id. 641 func (er *memexRewriter) handleNEAny(lexpr, rexpr memex.Expression, np LogicalCauset) { 642 // If there is NULL in s.id column, s.id should be the value that isn't null in condition t.id != s.id. 643 // So use function max to filter NULL. 644 maxFunc, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncMax, []memex.Expression{rexpr}, false) 645 if err != nil { 646 er.err = err 647 return 648 } 649 countFunc, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncCount, []memex.Expression{rexpr}, true) 650 if err != nil { 651 er.err = err 652 return 653 } 654 plan4Agg := LogicalAggregation{ 655 AggFuncs: []*aggregation.AggFuncDesc{maxFunc, countFunc}, 656 }.Init(er.sctx, er.b.getSelectOffset()) 657 if hint := er.b.BlockHints(); hint != nil { 658 plan4Agg.aggHints = hint.aggHints 659 } 660 plan4Agg.SetChildren(np) 661 maxResultDefCaus := &memex.DeferredCauset{ 662 UniqueID: er.sctx.GetStochastikVars().AllocCausetDeferredCausetID(), 663 RetType: maxFunc.RetTp, 664 } 665 count := &memex.DeferredCauset{ 666 UniqueID: er.sctx.GetStochastikVars().AllocCausetDeferredCausetID(), 667 RetType: countFunc.RetTp, 668 } 669 plan4Agg.names = append(plan4Agg.names, types.EmptyName, types.EmptyName) 670 plan4Agg.SetSchema(memex.NewSchema(maxResultDefCaus, count)) 671 gtFunc := memex.NewFunctionInternal(er.sctx, ast.GT, types.NewFieldType(allegrosql.TypeTiny), count, memex.NewOne()) 672 neCond := memex.NewFunctionInternal(er.sctx, ast.NE, types.NewFieldType(allegrosql.TypeTiny), lexpr, maxResultDefCaus) 673 cond := memex.ComposeDNFCondition(er.sctx, gtFunc, neCond) 674 er.buildQuantifierCauset(plan4Agg, cond, lexpr, rexpr, false) 675 } 676 677 // handleEQAll handles the case of = all. For example, if the query is t.id = all (select s.id from s), it will be rewrote to 678 // t.id = (select s.id from s having count(distinct s.id) <= 1 and [all checker]). 679 func (er *memexRewriter) handleEQAll(lexpr, rexpr memex.Expression, np LogicalCauset) { 680 firstRowFunc, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncFirstRow, []memex.Expression{rexpr}, false) 681 if err != nil { 682 er.err = err 683 return 684 } 685 countFunc, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncCount, []memex.Expression{rexpr}, true) 686 if err != nil { 687 er.err = err 688 return 689 } 690 plan4Agg := LogicalAggregation{ 691 AggFuncs: []*aggregation.AggFuncDesc{firstRowFunc, countFunc}, 692 }.Init(er.sctx, er.b.getSelectOffset()) 693 if hint := er.b.BlockHints(); hint != nil { 694 plan4Agg.aggHints = hint.aggHints 695 } 696 plan4Agg.SetChildren(np) 697 plan4Agg.names = append(plan4Agg.names, types.EmptyName) 698 firstRowResultDefCaus := &memex.DeferredCauset{ 699 UniqueID: er.sctx.GetStochastikVars().AllocCausetDeferredCausetID(), 700 RetType: firstRowFunc.RetTp, 701 } 702 plan4Agg.names = append(plan4Agg.names, types.EmptyName) 703 count := &memex.DeferredCauset{ 704 UniqueID: er.sctx.GetStochastikVars().AllocCausetDeferredCausetID(), 705 RetType: countFunc.RetTp, 706 } 707 plan4Agg.SetSchema(memex.NewSchema(firstRowResultDefCaus, count)) 708 leFunc := memex.NewFunctionInternal(er.sctx, ast.LE, types.NewFieldType(allegrosql.TypeTiny), count, memex.NewOne()) 709 eqCond := memex.NewFunctionInternal(er.sctx, ast.EQ, types.NewFieldType(allegrosql.TypeTiny), lexpr, firstRowResultDefCaus) 710 cond := memex.ComposeCNFCondition(er.sctx, leFunc, eqCond) 711 er.buildQuantifierCauset(plan4Agg, cond, lexpr, rexpr, true) 712 } 713 714 func (er *memexRewriter) handleExistSubquery(ctx context.Context, v *ast.ExistsSubqueryExpr) (ast.Node, bool) { 715 subq, ok := v.Sel.(*ast.SubqueryExpr) 716 if !ok { 717 er.err = errors.Errorf("Unknown exists type %T.", v.Sel) 718 return v, true 719 } 720 np, err := er.buildSubquery(ctx, subq) 721 if err != nil { 722 er.err = err 723 return v, true 724 } 725 np = er.popExistsSubCauset(np) 726 if len(ExtractCorrelatedDefCauss4LogicalCauset(np)) > 0 { 727 er.p, er.err = er.b.buildSemiApply(er.p, np, nil, er.asScalar, v.Not) 728 if er.err != nil || !er.asScalar { 729 return v, true 730 } 731 er.ctxStackAppend(er.p.Schema().DeferredCausets[er.p.Schema().Len()-1], er.p.OutputNames()[er.p.Schema().Len()-1]) 732 } else { 733 // We don't want nth_plan hint to affect separately executed subqueries here, so disable nth_plan temporarily. 734 NthCausetBackup := er.sctx.GetStochastikVars().StmtCtx.StmtHints.ForceNthCauset 735 er.sctx.GetStochastikVars().StmtCtx.StmtHints.ForceNthCauset = -1 736 physicalCauset, _, err := DoOptimize(ctx, er.sctx, er.b.optFlag, np) 737 er.sctx.GetStochastikVars().StmtCtx.StmtHints.ForceNthCauset = NthCausetBackup 738 if err != nil { 739 er.err = err 740 return v, true 741 } 742 event, err := EvalSubqueryFirstRow(ctx, physicalCauset, er.b.is, er.b.ctx) 743 if err != nil { 744 er.err = err 745 return v, true 746 } 747 if (event != nil && !v.Not) || (event == nil && v.Not) { 748 er.ctxStackAppend(memex.NewOne(), types.EmptyName) 749 } else { 750 er.ctxStackAppend(memex.NewZero(), types.EmptyName) 751 } 752 } 753 return v, true 754 } 755 756 // popExistsSubCauset will remove the useless plan in exist's child. 757 // See comments inside the method for more details. 758 func (er *memexRewriter) popExistsSubCauset(p LogicalCauset) LogicalCauset { 759 out: 760 for { 761 switch plan := p.(type) { 762 // This can be removed when in exists clause, 763 // e.g. exists(select count(*) from t order by a) is equal to exists t. 764 case *LogicalProjection, *LogicalSort: 765 p = p.Children()[0] 766 case *LogicalAggregation: 767 if len(plan.GroupByItems) == 0 { 768 p = LogicalBlockDual{RowCount: 1}.Init(er.sctx, er.b.getSelectOffset()) 769 break out 770 } 771 p = p.Children()[0] 772 default: 773 break out 774 } 775 } 776 return p 777 } 778 779 func (er *memexRewriter) handleInSubquery(ctx context.Context, v *ast.PatternInExpr) (ast.Node, bool) { 780 asScalar := er.asScalar 781 er.asScalar = true 782 v.Expr.Accept(er) 783 if er.err != nil { 784 return v, true 785 } 786 lexpr := er.ctxStack[len(er.ctxStack)-1] 787 subq, ok := v.Sel.(*ast.SubqueryExpr) 788 if !ok { 789 er.err = errors.Errorf("Unknown compare type %T.", v.Sel) 790 return v, true 791 } 792 np, err := er.buildSubquery(ctx, subq) 793 if err != nil { 794 er.err = err 795 return v, true 796 } 797 lLen := memex.GetRowLen(lexpr) 798 if lLen != np.Schema().Len() { 799 er.err = memex.ErrOperandDeferredCausets.GenWithStackByArgs(lLen) 800 return v, true 801 } 802 var rexpr memex.Expression 803 if np.Schema().Len() == 1 { 804 rexpr = np.Schema().DeferredCausets[0] 805 rDefCaus := rexpr.(*memex.DeferredCauset) 806 // For AntiSemiJoin/LeftOuterSemiJoin/AntiLeftOuterSemiJoin, we cannot treat `in` memex as 807 // normal column equal condition, so we specially mark the inner operand here. 808 if v.Not || asScalar { 809 // If both input columns of `in` memex are not null, we can treat the memex 810 // as normal column equal condition instead. 811 if !allegrosql.HasNotNullFlag(lexpr.GetType().Flag) || !allegrosql.HasNotNullFlag(rDefCaus.GetType().Flag) { 812 rDefCausCopy := *rDefCaus 813 rDefCausCopy.InOperand = true 814 rexpr = &rDefCausCopy 815 } 816 } 817 } else { 818 args := make([]memex.Expression, 0, np.Schema().Len()) 819 for _, col := range np.Schema().DeferredCausets { 820 args = append(args, col) 821 } 822 rexpr, er.err = er.newFunction(ast.RowFunc, args[0].GetType(), args...) 823 if er.err != nil { 824 return v, true 825 } 826 } 827 checkCondition, err := er.constructBinaryOpFunction(lexpr, rexpr, ast.EQ) 828 if err != nil { 829 er.err = err 830 return v, true 831 } 832 833 // If the leftKey and the rightKey have different collations, don't convert the sub-query to an inner-join 834 // since when converting we will add a distinct-agg upon the right child and this distinct-agg doesn't have the right collation. 835 // To keep it simple, we forbid this converting if they have different collations. 836 lt, rt := lexpr.GetType(), rexpr.GetType() 837 collFlag := collate.CompatibleDefCauslate(lt.DefCauslate, rt.DefCauslate) 838 839 // If it's not the form of `not in (SUBQUERY)`, 840 // and has no correlated column from the current level plan(if the correlated column is from upper level, 841 // we can treat it as constant, because the upper LogicalApply cannot be eliminated since current node is a join node), 842 // and don't need to append a scalar value, we can rewrite it to inner join. 843 if er.sctx.GetStochastikVars().GetAllowInSubqToJoinAnPosetDagg() && !v.Not && !asScalar && len(extractCorDeferredCausetsBySchema4LogicalCauset(np, er.p.Schema())) == 0 && collFlag { 844 // We need to try to eliminate the agg and the projection produced by this operation. 845 er.b.optFlag |= flagEliminateAgg 846 er.b.optFlag |= flagEliminateProjection 847 er.b.optFlag |= flagJoinReOrder 848 // Build distinct for the inner query. 849 agg, err := er.b.buildDistinct(np, np.Schema().Len()) 850 if err != nil { 851 er.err = err 852 return v, true 853 } 854 // Build inner join above the aggregation. 855 join := LogicalJoin{JoinType: InnerJoin}.Init(er.sctx, er.b.getSelectOffset()) 856 join.SetChildren(er.p, agg) 857 join.SetSchema(memex.MergeSchema(er.p.Schema(), agg.schemaReplicant)) 858 join.names = make([]*types.FieldName, er.p.Schema().Len()+agg.Schema().Len()) 859 copy(join.names, er.p.OutputNames()) 860 copy(join.names[er.p.Schema().Len():], agg.OutputNames()) 861 join.AttachOnConds(memex.SplitCNFItems(checkCondition)) 862 // Set join hint for this join. 863 if er.b.BlockHints() != nil { 864 join.setPreferredJoinType(er.b.BlockHints()) 865 } 866 er.p = join 867 } else { 868 er.p, er.err = er.b.buildSemiApply(er.p, np, memex.SplitCNFItems(checkCondition), asScalar, v.Not) 869 if er.err != nil { 870 return v, true 871 } 872 } 873 874 er.ctxStackPop(1) 875 if asScalar { 876 col := er.p.Schema().DeferredCausets[er.p.Schema().Len()-1] 877 er.ctxStackAppend(col, er.p.OutputNames()[er.p.Schema().Len()-1]) 878 } 879 return v, true 880 } 881 882 func (er *memexRewriter) handleScalarSubquery(ctx context.Context, v *ast.SubqueryExpr) (ast.Node, bool) { 883 np, err := er.buildSubquery(ctx, v) 884 if err != nil { 885 er.err = err 886 return v, true 887 } 888 np = er.b.buildMaxOneRow(np) 889 if len(ExtractCorrelatedDefCauss4LogicalCauset(np)) > 0 { 890 er.p = er.b.buildApplyWithJoinType(er.p, np, LeftOuterJoin) 891 if np.Schema().Len() > 1 { 892 newDefCauss := make([]memex.Expression, 0, np.Schema().Len()) 893 for _, col := range np.Schema().DeferredCausets { 894 newDefCauss = append(newDefCauss, col) 895 } 896 expr, err1 := er.newFunction(ast.RowFunc, newDefCauss[0].GetType(), newDefCauss...) 897 if err1 != nil { 898 er.err = err1 899 return v, true 900 } 901 er.ctxStackAppend(expr, types.EmptyName) 902 } else { 903 er.ctxStackAppend(er.p.Schema().DeferredCausets[er.p.Schema().Len()-1], er.p.OutputNames()[er.p.Schema().Len()-1]) 904 } 905 return v, true 906 } 907 // We don't want nth_plan hint to affect separately executed subqueries here, so disable nth_plan temporarily. 908 NthCausetBackup := er.sctx.GetStochastikVars().StmtCtx.StmtHints.ForceNthCauset 909 er.sctx.GetStochastikVars().StmtCtx.StmtHints.ForceNthCauset = -1 910 physicalCauset, _, err := DoOptimize(ctx, er.sctx, er.b.optFlag, np) 911 er.sctx.GetStochastikVars().StmtCtx.StmtHints.ForceNthCauset = NthCausetBackup 912 if err != nil { 913 er.err = err 914 return v, true 915 } 916 event, err := EvalSubqueryFirstRow(ctx, physicalCauset, er.b.is, er.b.ctx) 917 if err != nil { 918 er.err = err 919 return v, true 920 } 921 if np.Schema().Len() > 1 { 922 newDefCauss := make([]memex.Expression, 0, np.Schema().Len()) 923 for i, data := range event { 924 newDefCauss = append(newDefCauss, &memex.Constant{ 925 Value: data, 926 RetType: np.Schema().DeferredCausets[i].GetType()}) 927 } 928 expr, err1 := er.newFunction(ast.RowFunc, newDefCauss[0].GetType(), newDefCauss...) 929 if err1 != nil { 930 er.err = err1 931 return v, true 932 } 933 er.ctxStackAppend(expr, types.EmptyName) 934 } else { 935 er.ctxStackAppend(&memex.Constant{ 936 Value: event[0], 937 RetType: np.Schema().DeferredCausets[0].GetType(), 938 }, types.EmptyName) 939 } 940 return v, true 941 } 942 943 // Leave implements Visitor interface. 944 func (er *memexRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok bool) { 945 if er.err != nil { 946 return retNode, false 947 } 948 var inNode = originInNode 949 if er.preprocess != nil { 950 inNode = er.preprocess(inNode) 951 } 952 switch v := inNode.(type) { 953 case *ast.AggregateFuncExpr, *ast.DeferredCausetNameExpr, *ast.ParenthesesExpr, *ast.WhenClause, 954 *ast.SubqueryExpr, *ast.ExistsSubqueryExpr, *ast.CompareSubqueryExpr, *ast.ValuesExpr, *ast.WindowFuncExpr, *ast.BlockNameExpr: 955 case *driver.ValueExpr: 956 v.Causet.SetValue(v.Causet.GetValue(), &v.Type) 957 value := &memex.Constant{Value: v.Causet, RetType: &v.Type} 958 er.ctxStackAppend(value, types.EmptyName) 959 case *driver.ParamMarkerExpr: 960 var value memex.Expression 961 value, er.err = memex.ParamMarkerExpression(er.sctx, v) 962 if er.err != nil { 963 return retNode, false 964 } 965 er.ctxStackAppend(value, types.EmptyName) 966 case *ast.VariableExpr: 967 er.rewriteVariable(v) 968 case *ast.FuncCallExpr: 969 if _, ok := memex.TryFoldFunctions[v.FnName.L]; ok { 970 er.tryFoldCounter-- 971 } 972 er.funcCallToExpression(v) 973 if _, ok := memex.DisableFoldFunctions[v.FnName.L]; ok { 974 er.disableFoldCounter-- 975 } 976 case *ast.BlockName: 977 er.toBlock(v) 978 case *ast.DeferredCausetName: 979 er.toDeferredCauset(v) 980 case *ast.UnaryOperationExpr: 981 er.unaryOpToExpression(v) 982 case *ast.BinaryOperationExpr: 983 er.binaryOpToExpression(v) 984 case *ast.BetweenExpr: 985 er.betweenToExpression(v) 986 case *ast.CaseExpr: 987 if _, ok := memex.TryFoldFunctions["case"]; ok { 988 er.tryFoldCounter-- 989 } 990 er.caseToExpression(v) 991 if _, ok := memex.DisableFoldFunctions["case"]; ok { 992 er.disableFoldCounter-- 993 } 994 case *ast.FuncCastExpr: 995 arg := er.ctxStack[len(er.ctxStack)-1] 996 er.err = memex.CheckArgsNotMultiDeferredCausetRow(arg) 997 if er.err != nil { 998 return retNode, false 999 } 1000 1001 // check the decimal precision of "CAST(AS TIME)". 1002 er.err = er.checkTimePrecision(v.Tp) 1003 if er.err != nil { 1004 return retNode, false 1005 } 1006 1007 er.ctxStack[len(er.ctxStack)-1] = memex.BuildCastFunction(er.sctx, arg, v.Tp) 1008 er.ctxNameStk[len(er.ctxNameStk)-1] = types.EmptyName 1009 case *ast.PatternLikeExpr: 1010 er.patternLikeToExpression(v) 1011 case *ast.PatternRegexpExpr: 1012 er.regexpToScalarFunc(v) 1013 case *ast.RowExpr: 1014 er.rowToScalarFunc(v) 1015 case *ast.PatternInExpr: 1016 if v.Sel == nil { 1017 er.inToExpression(len(v.List), v.Not, &v.Type) 1018 } 1019 case *ast.PositionExpr: 1020 er.positionToScalarFunc(v) 1021 case *ast.IsNullExpr: 1022 er.isNullToExpression(v) 1023 case *ast.IsTruthExpr: 1024 er.isTrueToScalarFunc(v) 1025 case *ast.DefaultExpr: 1026 er.evalDefaultExpr(v) 1027 // TODO: Perhaps we don't need to transcode these back to generic integers/strings 1028 case *ast.TrimDirectionExpr: 1029 er.ctxStackAppend(&memex.Constant{ 1030 Value: types.NewIntCauset(int64(v.Direction)), 1031 RetType: types.NewFieldType(allegrosql.TypeTiny), 1032 }, types.EmptyName) 1033 case *ast.TimeUnitExpr: 1034 er.ctxStackAppend(&memex.Constant{ 1035 Value: types.NewStringCauset(v.Unit.String()), 1036 RetType: types.NewFieldType(allegrosql.TypeVarchar), 1037 }, types.EmptyName) 1038 case *ast.GetFormatSelectorExpr: 1039 er.ctxStackAppend(&memex.Constant{ 1040 Value: types.NewStringCauset(v.Selector.String()), 1041 RetType: types.NewFieldType(allegrosql.TypeVarchar), 1042 }, types.EmptyName) 1043 case *ast.SetDefCauslationExpr: 1044 arg := er.ctxStack[len(er.ctxStack)-1] 1045 if collate.NewDefCauslationEnabled() { 1046 var collInfo *charset.DefCauslation 1047 // TODO(bb7133): use charset.ValidCharsetAndDefCauslation when its bug is fixed. 1048 if collInfo, er.err = collate.GetDefCauslationByName(v.DefCauslate); er.err != nil { 1049 break 1050 } 1051 chs := arg.GetType().Charset 1052 if chs != "" && collInfo.CharsetName != chs { 1053 er.err = charset.ErrDefCauslationCharsetMismatch.GenWithStackByArgs(collInfo.Name, chs) 1054 break 1055 } 1056 } 1057 // SetDefCauslationExpr sets the collation explicitly, even when the evaluation type of the memex is non-string. 1058 if _, ok := arg.(*memex.DeferredCauset); ok { 1059 // Wrap a cast here to avoid changing the original FieldType of the column memex. 1060 exprType := arg.GetType().Clone() 1061 exprType.DefCauslate = v.DefCauslate 1062 casted := memex.BuildCastFunction(er.sctx, arg, exprType) 1063 er.ctxStackPop(1) 1064 er.ctxStackAppend(casted, types.EmptyName) 1065 } else { 1066 // For constant and scalar function, we can set its collate directly. 1067 arg.GetType().DefCauslate = v.DefCauslate 1068 } 1069 er.ctxStack[len(er.ctxStack)-1].SetCoercibility(memex.CoercibilityExplicit) 1070 er.ctxStack[len(er.ctxStack)-1].SetCharsetAndDefCauslation(arg.GetType().Charset, arg.GetType().DefCauslate) 1071 default: 1072 er.err = errors.Errorf("UnknownType: %T", v) 1073 return retNode, false 1074 } 1075 1076 if er.err != nil { 1077 return retNode, false 1078 } 1079 return originInNode, true 1080 } 1081 1082 // newFunction chooses which memex.NewFunctionImpl() will be used. 1083 func (er *memexRewriter) newFunction(funcName string, retType *types.FieldType, args ...memex.Expression) (memex.Expression, error) { 1084 if er.disableFoldCounter > 0 { 1085 return memex.NewFunctionBase(er.sctx, funcName, retType, args...) 1086 } 1087 if er.tryFoldCounter > 0 { 1088 return memex.NewFunctionTryFold(er.sctx, funcName, retType, args...) 1089 } 1090 return memex.NewFunction(er.sctx, funcName, retType, args...) 1091 } 1092 1093 func (er *memexRewriter) checkTimePrecision(ft *types.FieldType) error { 1094 if ft.EvalType() == types.ETDuration && ft.Decimal > int(types.MaxFsp) { 1095 return errTooBigPrecision.GenWithStackByArgs(ft.Decimal, "CAST", types.MaxFsp) 1096 } 1097 return nil 1098 } 1099 1100 func (er *memexRewriter) useCache() bool { 1101 return er.sctx.GetStochastikVars().StmtCtx.UseCache 1102 } 1103 1104 func (er *memexRewriter) rewriteVariable(v *ast.VariableExpr) { 1105 stkLen := len(er.ctxStack) 1106 name := strings.ToLower(v.Name) 1107 stochastikVars := er.b.ctx.GetStochastikVars() 1108 if !v.IsSystem { 1109 if v.Value != nil { 1110 er.ctxStack[stkLen-1], er.err = er.newFunction(ast.SetVar, 1111 er.ctxStack[stkLen-1].GetType(), 1112 memex.CausetToConstant(types.NewCauset(name), allegrosql.TypeString), 1113 er.ctxStack[stkLen-1]) 1114 er.ctxNameStk[stkLen-1] = types.EmptyName 1115 return 1116 } 1117 f, err := er.newFunction(ast.GetVar, 1118 // TODO: Here is wrong, the stochastikVars should causetstore a name -> Causet map. Will fix it later. 1119 types.NewFieldType(allegrosql.TypeString), 1120 memex.CausetToConstant(types.NewStringCauset(name), allegrosql.TypeString)) 1121 if err != nil { 1122 er.err = err 1123 return 1124 } 1125 f.SetCoercibility(memex.CoercibilityImplicit) 1126 er.ctxStackAppend(f, types.EmptyName) 1127 return 1128 } 1129 var val string 1130 var err error 1131 if v.ExplicitScope { 1132 err = variable.ValidateGetSystemVar(name, v.IsGlobal) 1133 if err != nil { 1134 er.err = err 1135 return 1136 } 1137 } 1138 sysVar := variable.SysVars[name] 1139 if sysVar == nil { 1140 er.err = variable.ErrUnknownSystemVar.GenWithStackByArgs(name) 1141 return 1142 } 1143 // Variable is @@gobal.variable_name or variable is only global scope variable. 1144 if v.IsGlobal || sysVar.Scope == variable.ScopeGlobal { 1145 val, err = variable.GetGlobalSystemVar(stochastikVars, name) 1146 } else { 1147 val, err = variable.GetStochastikSystemVar(stochastikVars, name) 1148 } 1149 if err != nil { 1150 er.err = err 1151 return 1152 } 1153 e := memex.CausetToConstant(types.NewStringCauset(val), allegrosql.TypeVarString) 1154 e.GetType().Charset, _ = er.sctx.GetStochastikVars().GetSystemVar(variable.CharacterSetConnection) 1155 e.GetType().DefCauslate, _ = er.sctx.GetStochastikVars().GetSystemVar(variable.DefCauslationConnection) 1156 er.ctxStackAppend(e, types.EmptyName) 1157 } 1158 1159 func (er *memexRewriter) unaryOpToExpression(v *ast.UnaryOperationExpr) { 1160 stkLen := len(er.ctxStack) 1161 var op string 1162 switch v.Op { 1163 case opcode.Plus: 1164 // memex (+ a) is equal to a 1165 return 1166 case opcode.Minus: 1167 op = ast.UnaryMinus 1168 case opcode.BitNeg: 1169 op = ast.BitNeg 1170 case opcode.Not: 1171 op = ast.UnaryNot 1172 default: 1173 er.err = errors.Errorf("Unknown Unary Op %T", v.Op) 1174 return 1175 } 1176 if memex.GetRowLen(er.ctxStack[stkLen-1]) != 1 { 1177 er.err = memex.ErrOperandDeferredCausets.GenWithStackByArgs(1) 1178 return 1179 } 1180 er.ctxStack[stkLen-1], er.err = er.newFunction(op, &v.Type, er.ctxStack[stkLen-1]) 1181 er.ctxNameStk[stkLen-1] = types.EmptyName 1182 } 1183 1184 func (er *memexRewriter) binaryOpToExpression(v *ast.BinaryOperationExpr) { 1185 stkLen := len(er.ctxStack) 1186 var function memex.Expression 1187 switch v.Op { 1188 case opcode.EQ, opcode.NE, opcode.NullEQ, opcode.GT, opcode.GE, opcode.LT, opcode.LE: 1189 function, er.err = er.constructBinaryOpFunction(er.ctxStack[stkLen-2], er.ctxStack[stkLen-1], 1190 v.Op.String()) 1191 default: 1192 lLen := memex.GetRowLen(er.ctxStack[stkLen-2]) 1193 rLen := memex.GetRowLen(er.ctxStack[stkLen-1]) 1194 if lLen != 1 || rLen != 1 { 1195 er.err = memex.ErrOperandDeferredCausets.GenWithStackByArgs(1) 1196 return 1197 } 1198 function, er.err = er.newFunction(v.Op.String(), types.NewFieldType(allegrosql.TypeUnspecified), er.ctxStack[stkLen-2:]...) 1199 } 1200 if er.err != nil { 1201 return 1202 } 1203 er.ctxStackPop(2) 1204 er.ctxStackAppend(function, types.EmptyName) 1205 } 1206 1207 func (er *memexRewriter) notToExpression(hasNot bool, op string, tp *types.FieldType, 1208 args ...memex.Expression) memex.Expression { 1209 opFunc, err := er.newFunction(op, tp, args...) 1210 if err != nil { 1211 er.err = err 1212 return nil 1213 } 1214 if !hasNot { 1215 return opFunc 1216 } 1217 1218 opFunc, err = er.newFunction(ast.UnaryNot, tp, opFunc) 1219 if err != nil { 1220 er.err = err 1221 return nil 1222 } 1223 return opFunc 1224 } 1225 1226 func (er *memexRewriter) isNullToExpression(v *ast.IsNullExpr) { 1227 stkLen := len(er.ctxStack) 1228 if memex.GetRowLen(er.ctxStack[stkLen-1]) != 1 { 1229 er.err = memex.ErrOperandDeferredCausets.GenWithStackByArgs(1) 1230 return 1231 } 1232 function := er.notToExpression(v.Not, ast.IsNull, &v.Type, er.ctxStack[stkLen-1]) 1233 er.ctxStackPop(1) 1234 er.ctxStackAppend(function, types.EmptyName) 1235 } 1236 1237 func (er *memexRewriter) positionToScalarFunc(v *ast.PositionExpr) { 1238 pos := v.N 1239 str := strconv.Itoa(pos) 1240 if v.P != nil { 1241 stkLen := len(er.ctxStack) 1242 val := er.ctxStack[stkLen-1] 1243 intNum, isNull, err := memex.GetIntFromConstant(er.sctx, val) 1244 str = "?" 1245 if err == nil { 1246 if isNull { 1247 return 1248 } 1249 pos = intNum 1250 er.ctxStackPop(1) 1251 } 1252 er.err = err 1253 } 1254 if er.err == nil && pos > 0 && pos <= er.schemaReplicant.Len() { 1255 er.ctxStackAppend(er.schemaReplicant.DeferredCausets[pos-1], er.names[pos-1]) 1256 } else { 1257 er.err = ErrUnknownDeferredCauset.GenWithStackByArgs(str, clauseMsg[er.b.curClause]) 1258 } 1259 } 1260 1261 func (er *memexRewriter) isTrueToScalarFunc(v *ast.IsTruthExpr) { 1262 stkLen := len(er.ctxStack) 1263 op := ast.IsTruthWithoutNull 1264 if v.True == 0 { 1265 op = ast.IsFalsity 1266 } 1267 if memex.GetRowLen(er.ctxStack[stkLen-1]) != 1 { 1268 er.err = memex.ErrOperandDeferredCausets.GenWithStackByArgs(1) 1269 return 1270 } 1271 function := er.notToExpression(v.Not, op, &v.Type, er.ctxStack[stkLen-1]) 1272 er.ctxStackPop(1) 1273 er.ctxStackAppend(function, types.EmptyName) 1274 } 1275 1276 // inToExpression converts in memex to a scalar function. The argument lLen means the length of in list. 1277 // The argument not means if the memex is not in. The tp stands for the memex type, which is always bool. 1278 // a in (b, c, d) will be rewritten as `(a = b) or (a = c) or (a = d)`. 1279 func (er *memexRewriter) inToExpression(lLen int, not bool, tp *types.FieldType) { 1280 stkLen := len(er.ctxStack) 1281 l := memex.GetRowLen(er.ctxStack[stkLen-lLen-1]) 1282 for i := 0; i < lLen; i++ { 1283 if l != memex.GetRowLen(er.ctxStack[stkLen-lLen+i]) { 1284 er.err = memex.ErrOperandDeferredCausets.GenWithStackByArgs(l) 1285 return 1286 } 1287 } 1288 args := er.ctxStack[stkLen-lLen-1:] 1289 leftFt := args[0].GetType() 1290 leftEt, leftIsNull := leftFt.EvalType(), leftFt.Tp == allegrosql.TypeNull 1291 if leftIsNull { 1292 er.ctxStackPop(lLen + 1) 1293 er.ctxStackAppend(memex.NewNull(), types.EmptyName) 1294 return 1295 } 1296 if leftEt == types.ETInt { 1297 for i := 1; i < len(args); i++ { 1298 if c, ok := args[i].(*memex.Constant); ok { 1299 var isExceptional bool 1300 args[i], isExceptional = memex.RefineComparedConstant(er.sctx, *leftFt, c, opcode.EQ) 1301 if isExceptional { 1302 args[i] = c 1303 } 1304 } 1305 } 1306 } 1307 allSameType := true 1308 for _, arg := range args[1:] { 1309 if arg.GetType().Tp != allegrosql.TypeNull && memex.GetAccurateCmpType(args[0], arg) != leftEt { 1310 allSameType = false 1311 break 1312 } 1313 } 1314 var function memex.Expression 1315 if allSameType && l == 1 && lLen > 1 { 1316 function = er.notToExpression(not, ast.In, tp, er.ctxStack[stkLen-lLen-1:]...) 1317 } else { 1318 eqFunctions := make([]memex.Expression, 0, lLen) 1319 for i := stkLen - lLen; i < stkLen; i++ { 1320 expr, err := er.constructBinaryOpFunction(args[0], er.ctxStack[i], ast.EQ) 1321 if err != nil { 1322 er.err = err 1323 return 1324 } 1325 eqFunctions = append(eqFunctions, expr) 1326 } 1327 function = memex.ComposeDNFCondition(er.sctx, eqFunctions...) 1328 if not { 1329 var err error 1330 function, err = er.newFunction(ast.UnaryNot, tp, function) 1331 if err != nil { 1332 er.err = err 1333 return 1334 } 1335 } 1336 } 1337 er.ctxStackPop(lLen + 1) 1338 er.ctxStackAppend(function, types.EmptyName) 1339 } 1340 1341 func (er *memexRewriter) caseToExpression(v *ast.CaseExpr) { 1342 stkLen := len(er.ctxStack) 1343 argsLen := 2 * len(v.WhenClauses) 1344 if v.ElseClause != nil { 1345 argsLen++ 1346 } 1347 er.err = memex.CheckArgsNotMultiDeferredCausetRow(er.ctxStack[stkLen-argsLen:]...) 1348 if er.err != nil { 1349 return 1350 } 1351 1352 // value -> ctxStack[stkLen-argsLen-1] 1353 // when clause(condition, result) -> ctxStack[stkLen-argsLen:stkLen-1]; 1354 // else clause -> ctxStack[stkLen-1] 1355 var args []memex.Expression 1356 if v.Value != nil { 1357 // args: eq scalar func(args: value, condition1), result1, 1358 // eq scalar func(args: value, condition2), result2, 1359 // ... 1360 // else clause 1361 value := er.ctxStack[stkLen-argsLen-1] 1362 args = make([]memex.Expression, 0, argsLen) 1363 for i := stkLen - argsLen; i < stkLen-1; i += 2 { 1364 arg, err := er.newFunction(ast.EQ, types.NewFieldType(allegrosql.TypeTiny), value, er.ctxStack[i]) 1365 if err != nil { 1366 er.err = err 1367 return 1368 } 1369 args = append(args, arg) 1370 args = append(args, er.ctxStack[i+1]) 1371 } 1372 if v.ElseClause != nil { 1373 args = append(args, er.ctxStack[stkLen-1]) 1374 } 1375 argsLen++ // for trimming the value element later 1376 } else { 1377 // args: condition1, result1, 1378 // condition2, result2, 1379 // ... 1380 // else clause 1381 args = er.ctxStack[stkLen-argsLen:] 1382 } 1383 function, err := er.newFunction(ast.Case, &v.Type, args...) 1384 if err != nil { 1385 er.err = err 1386 return 1387 } 1388 er.ctxStackPop(argsLen) 1389 er.ctxStackAppend(function, types.EmptyName) 1390 } 1391 1392 func (er *memexRewriter) patternLikeToExpression(v *ast.PatternLikeExpr) { 1393 l := len(er.ctxStack) 1394 er.err = memex.CheckArgsNotMultiDeferredCausetRow(er.ctxStack[l-2:]...) 1395 if er.err != nil { 1396 return 1397 } 1398 1399 char, col := er.sctx.GetStochastikVars().GetCharsetInfo() 1400 var function memex.Expression 1401 fieldType := &types.FieldType{} 1402 isPatternExactMatch := false 1403 // Treat predicate 'like' the same way as predicate '=' when it is an exact match. 1404 if patExpression, ok := er.ctxStack[l-1].(*memex.Constant); ok { 1405 patString, isNull, err := patExpression.EvalString(nil, chunk.Row{}) 1406 if err != nil { 1407 er.err = err 1408 return 1409 } 1410 if !isNull { 1411 patValue, patTypes := stringutil.CompilePattern(patString, v.Escape) 1412 if stringutil.IsExactMatch(patTypes) && er.ctxStack[l-2].GetType().EvalType() == types.ETString { 1413 op := ast.EQ 1414 if v.Not { 1415 op = ast.NE 1416 } 1417 types.DefaultTypeForValue(string(patValue), fieldType, char, col) 1418 function, er.err = er.constructBinaryOpFunction(er.ctxStack[l-2], 1419 &memex.Constant{Value: types.NewStringCauset(string(patValue)), RetType: fieldType}, 1420 op) 1421 isPatternExactMatch = true 1422 } 1423 } 1424 } 1425 if !isPatternExactMatch { 1426 types.DefaultTypeForValue(int(v.Escape), fieldType, char, col) 1427 function = er.notToExpression(v.Not, ast.Like, &v.Type, 1428 er.ctxStack[l-2], er.ctxStack[l-1], &memex.Constant{Value: types.NewIntCauset(int64(v.Escape)), RetType: fieldType}) 1429 } 1430 1431 er.ctxStackPop(2) 1432 er.ctxStackAppend(function, types.EmptyName) 1433 } 1434 1435 func (er *memexRewriter) regexpToScalarFunc(v *ast.PatternRegexpExpr) { 1436 l := len(er.ctxStack) 1437 er.err = memex.CheckArgsNotMultiDeferredCausetRow(er.ctxStack[l-2:]...) 1438 if er.err != nil { 1439 return 1440 } 1441 function := er.notToExpression(v.Not, ast.Regexp, &v.Type, er.ctxStack[l-2], er.ctxStack[l-1]) 1442 er.ctxStackPop(2) 1443 er.ctxStackAppend(function, types.EmptyName) 1444 } 1445 1446 func (er *memexRewriter) rowToScalarFunc(v *ast.RowExpr) { 1447 stkLen := len(er.ctxStack) 1448 length := len(v.Values) 1449 rows := make([]memex.Expression, 0, length) 1450 for i := stkLen - length; i < stkLen; i++ { 1451 rows = append(rows, er.ctxStack[i]) 1452 } 1453 er.ctxStackPop(length) 1454 function, err := er.newFunction(ast.RowFunc, rows[0].GetType(), rows...) 1455 if err != nil { 1456 er.err = err 1457 return 1458 } 1459 er.ctxStackAppend(function, types.EmptyName) 1460 } 1461 1462 func (er *memexRewriter) betweenToExpression(v *ast.BetweenExpr) { 1463 stkLen := len(er.ctxStack) 1464 er.err = memex.CheckArgsNotMultiDeferredCausetRow(er.ctxStack[stkLen-3:]...) 1465 if er.err != nil { 1466 return 1467 } 1468 1469 expr, lexp, rexp := er.ctxStack[stkLen-3], er.ctxStack[stkLen-2], er.ctxStack[stkLen-1] 1470 1471 if memex.GetCmpTp4MinMax([]memex.Expression{expr, lexp, rexp}) == types.ETDatetime { 1472 expr = memex.WrapWithCastAsTime(er.sctx, expr, types.NewFieldType(allegrosql.TypeDatetime)) 1473 lexp = memex.WrapWithCastAsTime(er.sctx, lexp, types.NewFieldType(allegrosql.TypeDatetime)) 1474 rexp = memex.WrapWithCastAsTime(er.sctx, rexp, types.NewFieldType(allegrosql.TypeDatetime)) 1475 } 1476 1477 var op string 1478 var l, r memex.Expression 1479 l, er.err = er.newFunction(ast.GE, &v.Type, expr, lexp) 1480 if er.err == nil { 1481 r, er.err = er.newFunction(ast.LE, &v.Type, expr, rexp) 1482 } 1483 op = ast.LogicAnd 1484 if er.err != nil { 1485 return 1486 } 1487 function, err := er.newFunction(op, &v.Type, l, r) 1488 if err != nil { 1489 er.err = err 1490 return 1491 } 1492 if v.Not { 1493 function, err = er.newFunction(ast.UnaryNot, &v.Type, function) 1494 if err != nil { 1495 er.err = err 1496 return 1497 } 1498 } 1499 er.ctxStackPop(3) 1500 er.ctxStackAppend(function, types.EmptyName) 1501 } 1502 1503 // rewriteFuncCall handles a FuncCallExpr and generates a customized function. 1504 // It should return true if for the given FuncCallExpr a rewrite is performed so that original behavior is skipped. 1505 // Otherwise it should return false to indicate (the caller) that original behavior needs to be performed. 1506 func (er *memexRewriter) rewriteFuncCall(v *ast.FuncCallExpr) bool { 1507 switch v.FnName.L { 1508 // when column is not null, ifnull on such column is not necessary. 1509 case ast.Ifnull: 1510 if len(v.Args) != 2 { 1511 er.err = memex.ErrIncorrectParameterCount.GenWithStackByArgs(v.FnName.O) 1512 return true 1513 } 1514 stackLen := len(er.ctxStack) 1515 arg1 := er.ctxStack[stackLen-2] 1516 col, isDeferredCauset := arg1.(*memex.DeferredCauset) 1517 // if expr1 is a column and column has not null flag, then we can eliminate ifnull on 1518 // this column. 1519 if isDeferredCauset && allegrosql.HasNotNullFlag(col.RetType.Flag) { 1520 name := er.ctxNameStk[stackLen-2] 1521 newDefCaus := col.Clone().(*memex.DeferredCauset) 1522 er.ctxStackPop(len(v.Args)) 1523 er.ctxStackAppend(newDefCaus, name) 1524 return true 1525 } 1526 1527 return false 1528 case ast.Nullif: 1529 if len(v.Args) != 2 { 1530 er.err = memex.ErrIncorrectParameterCount.GenWithStackByArgs(v.FnName.O) 1531 return true 1532 } 1533 stackLen := len(er.ctxStack) 1534 param1 := er.ctxStack[stackLen-2] 1535 param2 := er.ctxStack[stackLen-1] 1536 // param1 = param2 1537 funcCompare, err := er.constructBinaryOpFunction(param1, param2, ast.EQ) 1538 if err != nil { 1539 er.err = err 1540 return true 1541 } 1542 // NULL 1543 nullTp := types.NewFieldType(allegrosql.TypeNull) 1544 nullTp.Flen, nullTp.Decimal = allegrosql.GetDefaultFieldLengthAndDecimal(allegrosql.TypeNull) 1545 paramNull := &memex.Constant{ 1546 Value: types.NewCauset(nil), 1547 RetType: nullTp, 1548 } 1549 // if(param1 = param2, NULL, param1) 1550 funcIf, err := er.newFunction(ast.If, &v.Type, funcCompare, paramNull, param1) 1551 if err != nil { 1552 er.err = err 1553 return true 1554 } 1555 er.ctxStackPop(len(v.Args)) 1556 er.ctxStackAppend(funcIf, types.EmptyName) 1557 return true 1558 default: 1559 return false 1560 } 1561 } 1562 1563 func (er *memexRewriter) funcCallToExpression(v *ast.FuncCallExpr) { 1564 stackLen := len(er.ctxStack) 1565 args := er.ctxStack[stackLen-len(v.Args):] 1566 er.err = memex.CheckArgsNotMultiDeferredCausetRow(args...) 1567 if er.err != nil { 1568 return 1569 } 1570 1571 if er.rewriteFuncCall(v) { 1572 return 1573 } 1574 1575 var function memex.Expression 1576 er.ctxStackPop(len(v.Args)) 1577 if _, ok := memex.DeferredFunctions[v.FnName.L]; er.useCache() && ok { 1578 // When the memex is unix_timestamp and the number of argument is not zero, 1579 // we deal with it as normal memex. 1580 if v.FnName.L == ast.UnixTimestamp && len(v.Args) != 0 { 1581 function, er.err = er.newFunction(v.FnName.L, &v.Type, args...) 1582 er.ctxStackAppend(function, types.EmptyName) 1583 } else { 1584 function, er.err = memex.NewFunctionBase(er.sctx, v.FnName.L, &v.Type, args...) 1585 c := &memex.Constant{Value: types.NewCauset(nil), RetType: function.GetType().Clone(), DeferredExpr: function} 1586 er.ctxStackAppend(c, types.EmptyName) 1587 } 1588 } else { 1589 function, er.err = er.newFunction(v.FnName.L, &v.Type, args...) 1590 er.ctxStackAppend(function, types.EmptyName) 1591 } 1592 } 1593 1594 // Now BlockName in memex only used by sequence function like nextval(seq). 1595 // The function arg should be evaluated as a causet name rather than normal column name like allegrosql does. 1596 func (er *memexRewriter) toBlock(v *ast.BlockName) { 1597 fullName := v.Name.L 1598 if len(v.Schema.L) != 0 { 1599 fullName = v.Schema.L + "." + fullName 1600 } 1601 val := &memex.Constant{ 1602 Value: types.NewCauset(fullName), 1603 RetType: types.NewFieldType(allegrosql.TypeString), 1604 } 1605 er.ctxStackAppend(val, types.EmptyName) 1606 } 1607 1608 func (er *memexRewriter) toDeferredCauset(v *ast.DeferredCausetName) { 1609 idx, err := memex.FindFieldName(er.names, v) 1610 if err != nil { 1611 er.err = ErrAmbiguous.GenWithStackByArgs(v.Name, clauseMsg[fieldList]) 1612 return 1613 } 1614 if idx >= 0 { 1615 column := er.schemaReplicant.DeferredCausets[idx] 1616 if column.IsHidden { 1617 er.err = ErrUnknownDeferredCauset.GenWithStackByArgs(v.Name, clauseMsg[er.b.curClause]) 1618 return 1619 } 1620 er.ctxStackAppend(column, er.names[idx]) 1621 return 1622 } 1623 for i := len(er.b.outerSchemas) - 1; i >= 0; i-- { 1624 outerSchema, outerName := er.b.outerSchemas[i], er.b.outerNames[i] 1625 idx, err = memex.FindFieldName(outerName, v) 1626 if idx >= 0 { 1627 column := outerSchema.DeferredCausets[idx] 1628 er.ctxStackAppend(&memex.CorrelatedDeferredCauset{DeferredCauset: *column, Data: new(types.Causet)}, outerName[idx]) 1629 return 1630 } 1631 if err != nil { 1632 er.err = ErrAmbiguous.GenWithStackByArgs(v.Name, clauseMsg[fieldList]) 1633 return 1634 } 1635 } 1636 if join, ok := er.p.(*LogicalJoin); ok && join.redundantSchema != nil { 1637 idx, err := memex.FindFieldName(join.redundantNames, v) 1638 if err != nil { 1639 er.err = err 1640 return 1641 } 1642 if idx >= 0 { 1643 er.ctxStackAppend(join.redundantSchema.DeferredCausets[idx], join.redundantNames[idx]) 1644 return 1645 } 1646 } 1647 if _, ok := er.p.(*LogicalUnionAll); ok && v.Block.O != "" { 1648 er.err = ErrBlocknameNotAllowedHere.GenWithStackByArgs(v.Block.O, "SELECT", clauseMsg[er.b.curClause]) 1649 return 1650 } 1651 if er.b.curClause == globalOrderByClause { 1652 er.b.curClause = orderByClause 1653 } 1654 er.err = ErrUnknownDeferredCauset.GenWithStackByArgs(v.String(), clauseMsg[er.b.curClause]) 1655 } 1656 1657 func (er *memexRewriter) evalDefaultExpr(v *ast.DefaultExpr) { 1658 var name *types.FieldName 1659 // Here we will find the corresponding column for default function. At the same time, we need to consider the issue 1660 // of subquery and name space. 1661 // For example, we have two blocks t1(a int default 1, b int) and t2(a int default -1, c int). Consider the following ALLEGROALLEGROSQL: 1662 // select a from t1 where a > (select default(a) from t2) 1663 // Refer to the behavior of MyALLEGROSQL, we need to find column a in causet t2. If causet t2 does not have column a, then find it 1664 // in causet t1. If there are none, return an error message. 1665 // Based on the above description, we need to look in er.b.allNames from back to front. 1666 for i := len(er.b.allNames) - 1; i >= 0; i-- { 1667 idx, err := memex.FindFieldName(er.b.allNames[i], v.Name) 1668 if err != nil { 1669 er.err = err 1670 return 1671 } 1672 if idx >= 0 { 1673 name = er.b.allNames[i][idx] 1674 break 1675 } 1676 } 1677 if name == nil { 1678 idx, err := memex.FindFieldName(er.names, v.Name) 1679 if err != nil { 1680 er.err = err 1681 return 1682 } 1683 if idx < 0 { 1684 er.err = ErrUnknownDeferredCauset.GenWithStackByArgs(v.Name.OrigDefCausName(), "field list") 1685 return 1686 } 1687 name = er.names[idx] 1688 } 1689 1690 dbName := name.DBName 1691 if dbName.O == "" { 1692 // if database name is not specified, use current database name 1693 dbName = perceptron.NewCIStr(er.sctx.GetStochastikVars().CurrentDB) 1694 } 1695 if name.OrigTblName.O == "" { 1696 // column is evaluated by some memexs, for example: 1697 // `select default(c) from (select (a+1) as c from t) as t0` 1698 // in such case, a 'no default' error is returned 1699 er.err = causet.ErrNoDefaultValue.GenWithStackByArgs(name.DefCausName) 1700 return 1701 } 1702 var tbl causet.Block 1703 tbl, er.err = er.b.is.BlockByName(dbName, name.OrigTblName) 1704 if er.err != nil { 1705 return 1706 } 1707 colName := name.OrigDefCausName.O 1708 if colName == "" { 1709 // in some cases, OrigDefCausName is empty, use DefCausName instead 1710 colName = name.DefCausName.O 1711 } 1712 col := causet.FindDefCaus(tbl.DefCauss(), colName) 1713 if col == nil { 1714 er.err = ErrUnknownDeferredCauset.GenWithStackByArgs(v.Name, "field_list") 1715 return 1716 } 1717 isCurrentTimestamp := hasCurrentDatetimeDefault(col) 1718 var val *memex.Constant 1719 switch { 1720 case isCurrentTimestamp && col.Tp == allegrosql.TypeDatetime: 1721 // for DATETIME column with current_timestamp, use NULL to be compatible with MyALLEGROSQL 5.7 1722 val = memex.NewNull() 1723 case isCurrentTimestamp && col.Tp == allegrosql.TypeTimestamp: 1724 // for TIMESTAMP column with current_timestamp, use 0 to be compatible with MyALLEGROSQL 5.7 1725 zero := types.NewTime(types.ZeroCoreTime, allegrosql.TypeTimestamp, int8(col.Decimal)) 1726 val = &memex.Constant{ 1727 Value: types.NewCauset(zero), 1728 RetType: types.NewFieldType(allegrosql.TypeTimestamp), 1729 } 1730 default: 1731 // for other columns, just use what it is 1732 val, er.err = er.b.getDefaultValue(col) 1733 } 1734 if er.err != nil { 1735 return 1736 } 1737 er.ctxStackAppend(val, types.EmptyName) 1738 } 1739 1740 // hasCurrentDatetimeDefault checks if column has current_timestamp default value 1741 func hasCurrentDatetimeDefault(col *causet.DeferredCauset) bool { 1742 x, ok := col.DefaultValue.(string) 1743 if !ok { 1744 return false 1745 } 1746 return strings.ToLower(x) == ast.CurrentTimestamp 1747 }