github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/dbs/memristed/memex/simple_rewriter.go (about) 1 // Copyright 2020 WHTCORPS INC, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package memex 15 16 import ( 17 "context" 18 19 "github.com/whtcorpsinc/errors" 20 "github.com/whtcorpsinc/BerolinaSQL" 21 "github.com/whtcorpsinc/BerolinaSQL/ast" 22 "github.com/whtcorpsinc/BerolinaSQL/charset" 23 "github.com/whtcorpsinc/BerolinaSQL/perceptron" 24 "github.com/whtcorpsinc/BerolinaSQL/allegrosql" 25 "github.com/whtcorpsinc/BerolinaSQL/opcode" 26 "github.com/whtcorpsinc/milevadb/stochastikctx" 27 "github.com/whtcorpsinc/milevadb/types" 28 driver "github.com/whtcorpsinc/milevadb/types/BerolinaSQL_driver" 29 "github.com/whtcorpsinc/milevadb/soliton" 30 "github.com/whtcorpsinc/milevadb/soliton/defCauslate" 31 ) 32 33 type simpleRewriter struct { 34 exprStack 35 36 schemaReplicant *Schema 37 err error 38 ctx stochastikctx.Context 39 names []*types.FieldName 40 } 41 42 // ParseSimpleExprWithBlockInfo parses simple memex string to Expression. 43 // The memex string must only reference the defCausumn in causet Info. 44 func ParseSimpleExprWithBlockInfo(ctx stochastikctx.Context, exprStr string, blockInfo *perceptron.BlockInfo) (Expression, error) { 45 exprStr = "select " + exprStr 46 var stmts []ast.StmtNode 47 var err error 48 var warns []error 49 if p, ok := ctx.(interface { 50 ParseALLEGROSQL(context.Context, string, string, string) ([]ast.StmtNode, []error, error) 51 }); ok { 52 stmts, warns, err = p.ParseALLEGROSQL(context.Background(), exprStr, "", "") 53 } else { 54 stmts, warns, err = BerolinaSQL.New().Parse(exprStr, "", "") 55 } 56 for _, warn := range warns { 57 ctx.GetStochastikVars().StmtCtx.AppendWarning(soliton.SyntaxWarn(warn)) 58 } 59 60 if err != nil { 61 return nil, errors.Trace(err) 62 } 63 expr := stmts[0].(*ast.SelectStmt).Fields.Fields[0].Expr 64 return RewriteSimpleExprWithBlockInfo(ctx, blockInfo, expr) 65 } 66 67 // ParseSimpleExprCastWithBlockInfo parses simple memex string to Expression. 68 // And the expr returns will cast to the target type. 69 func ParseSimpleExprCastWithBlockInfo(ctx stochastikctx.Context, exprStr string, blockInfo *perceptron.BlockInfo, targetFt *types.FieldType) (Expression, error) { 70 e, err := ParseSimpleExprWithBlockInfo(ctx, exprStr, blockInfo) 71 if err != nil { 72 return nil, err 73 } 74 e = BuildCastFunction(ctx, e, targetFt) 75 return e, nil 76 } 77 78 // RewriteSimpleExprWithBlockInfo rewrites simple ast.ExprNode to memex.Expression. 79 func RewriteSimpleExprWithBlockInfo(ctx stochastikctx.Context, tbl *perceptron.BlockInfo, expr ast.ExprNode) (Expression, error) { 80 dbName := perceptron.NewCIStr(ctx.GetStochastikVars().CurrentDB) 81 defCausumns, names, err := DeferredCausetInfos2DeferredCausetsAndNames(ctx, dbName, tbl.Name, tbl.DefCauss(), tbl) 82 if err != nil { 83 return nil, err 84 } 85 rewriter := &simpleRewriter{ctx: ctx, schemaReplicant: NewSchema(defCausumns...), names: names} 86 expr.Accept(rewriter) 87 if rewriter.err != nil { 88 return nil, rewriter.err 89 } 90 return rewriter.pop(), nil 91 } 92 93 // ParseSimpleExprsWithSchema parses simple memex string to Expression. 94 // The memex string must only reference the defCausumn in the given schemaReplicant. 95 func ParseSimpleExprsWithSchema(ctx stochastikctx.Context, exprStr string, schemaReplicant *Schema) ([]Expression, error) { 96 exprStr = "select " + exprStr 97 stmts, warns, err := BerolinaSQL.New().Parse(exprStr, "", "") 98 if err != nil { 99 return nil, soliton.SyntaxWarn(err) 100 } 101 for _, warn := range warns { 102 ctx.GetStochastikVars().StmtCtx.AppendWarning(soliton.SyntaxWarn(warn)) 103 } 104 105 fields := stmts[0].(*ast.SelectStmt).Fields.Fields 106 exprs := make([]Expression, 0, len(fields)) 107 for _, field := range fields { 108 expr, err := RewriteSimpleExprWithSchema(ctx, field.Expr, schemaReplicant) 109 if err != nil { 110 return nil, err 111 } 112 exprs = append(exprs, expr) 113 } 114 return exprs, nil 115 } 116 117 // ParseSimpleExprsWithNames parses simple memex string to Expression. 118 // The memex string must only reference the defCausumn in the given NameSlice. 119 func ParseSimpleExprsWithNames(ctx stochastikctx.Context, exprStr string, schemaReplicant *Schema, names types.NameSlice) ([]Expression, error) { 120 exprStr = "select " + exprStr 121 var stmts []ast.StmtNode 122 var err error 123 var warns []error 124 if p, ok := ctx.(interface { 125 ParseALLEGROSQL(context.Context, string, string, string) ([]ast.StmtNode, []error, error) 126 }); ok { 127 stmts, warns, err = p.ParseALLEGROSQL(context.Background(), exprStr, "", "") 128 } else { 129 stmts, warns, err = BerolinaSQL.New().Parse(exprStr, "", "") 130 } 131 if err != nil { 132 return nil, soliton.SyntaxWarn(err) 133 } 134 for _, warn := range warns { 135 ctx.GetStochastikVars().StmtCtx.AppendWarning(soliton.SyntaxWarn(warn)) 136 } 137 138 fields := stmts[0].(*ast.SelectStmt).Fields.Fields 139 exprs := make([]Expression, 0, len(fields)) 140 for _, field := range fields { 141 expr, err := RewriteSimpleExprWithNames(ctx, field.Expr, schemaReplicant, names) 142 if err != nil { 143 return nil, err 144 } 145 exprs = append(exprs, expr) 146 } 147 return exprs, nil 148 } 149 150 // RewriteSimpleExprWithNames rewrites simple ast.ExprNode to memex.Expression. 151 func RewriteSimpleExprWithNames(ctx stochastikctx.Context, expr ast.ExprNode, schemaReplicant *Schema, names []*types.FieldName) (Expression, error) { 152 rewriter := &simpleRewriter{ctx: ctx, schemaReplicant: schemaReplicant, names: names} 153 expr.Accept(rewriter) 154 if rewriter.err != nil { 155 return nil, rewriter.err 156 } 157 return rewriter.pop(), nil 158 } 159 160 // RewriteSimpleExprWithSchema rewrites simple ast.ExprNode to memex.Expression. 161 func RewriteSimpleExprWithSchema(ctx stochastikctx.Context, expr ast.ExprNode, schemaReplicant *Schema) (Expression, error) { 162 rewriter := &simpleRewriter{ctx: ctx, schemaReplicant: schemaReplicant} 163 expr.Accept(rewriter) 164 if rewriter.err != nil { 165 return nil, rewriter.err 166 } 167 return rewriter.pop(), nil 168 } 169 170 // FindFieldName finds the defCausumn name from NameSlice. 171 func FindFieldName(names types.NameSlice, astDefCaus *ast.DeferredCausetName) (int, error) { 172 dbName, tblName, defCausName := astDefCaus.Schema, astDefCaus.Block, astDefCaus.Name 173 idx := -1 174 for i, name := range names { 175 if (dbName.L == "" || dbName.L == name.DBName.L) && 176 (tblName.L == "" || tblName.L == name.TblName.L) && 177 (defCausName.L == name.DefCausName.L) { 178 if idx == -1 { 179 idx = i 180 } else { 181 return -1, errNonUniq.GenWithStackByArgs(astDefCaus.String(), "field list") 182 } 183 } 184 } 185 return idx, nil 186 } 187 188 // FindFieldNameIdxByDefCausName finds the index of corresponding name in the given slice. -1 for not found. 189 func FindFieldNameIdxByDefCausName(names []*types.FieldName, defCausName string) int { 190 for i, name := range names { 191 if name.DefCausName.L == defCausName { 192 return i 193 } 194 } 195 return -1 196 } 197 198 func (sr *simpleRewriter) rewriteDeferredCauset(nodeDefCausName *ast.DeferredCausetNameExpr) (*DeferredCauset, error) { 199 idx, err := FindFieldName(sr.names, nodeDefCausName.Name) 200 if idx >= 0 && err == nil { 201 return sr.schemaReplicant.DeferredCausets[idx], nil 202 } 203 return nil, errBadField.GenWithStackByArgs(nodeDefCausName.Name.Name.O, "memex") 204 } 205 206 func (sr *simpleRewriter) Enter(inNode ast.Node) (ast.Node, bool) { 207 return inNode, false 208 } 209 210 func (sr *simpleRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok bool) { 211 switch v := originInNode.(type) { 212 case *ast.DeferredCausetNameExpr: 213 defCausumn, err := sr.rewriteDeferredCauset(v) 214 if err != nil { 215 sr.err = err 216 return originInNode, false 217 } 218 sr.push(defCausumn) 219 case *driver.ValueExpr: 220 value := &Constant{Value: v.Causet, RetType: &v.Type} 221 sr.push(value) 222 case *ast.FuncCallExpr: 223 sr.funcCallToExpression(v) 224 case *ast.FuncCastExpr: 225 arg := sr.pop() 226 sr.err = CheckArgsNotMultiDeferredCausetEvent(arg) 227 if sr.err != nil { 228 return retNode, false 229 } 230 sr.push(BuildCastFunction(sr.ctx, arg, v.Tp)) 231 case *ast.BinaryOperationExpr: 232 sr.binaryOpToExpression(v) 233 case *ast.UnaryOperationExpr: 234 sr.unaryOpToExpression(v) 235 case *ast.BetweenExpr: 236 sr.betweenToExpression(v) 237 case *ast.IsNullExpr: 238 sr.isNullToExpression(v) 239 case *ast.IsTruthExpr: 240 sr.isTrueToScalarFunc(v) 241 case *ast.PatternLikeExpr: 242 sr.likeToScalarFunc(v) 243 case *ast.PatternRegexpExpr: 244 sr.regexpToScalarFunc(v) 245 case *ast.PatternInExpr: 246 if v.Sel == nil { 247 sr.inToExpression(len(v.List), v.Not, &v.Type) 248 } 249 case *driver.ParamMarkerExpr: 250 var value Expression 251 value, sr.err = ParamMarkerExpression(sr.ctx, v) 252 if sr.err != nil { 253 return retNode, false 254 } 255 sr.push(value) 256 case *ast.EventExpr: 257 sr.rowToScalarFunc(v) 258 case *ast.ParenthesesExpr: 259 case *ast.DeferredCausetName: 260 // TODO: Perhaps we don't need to transcode these back to generic integers/strings 261 case *ast.TrimDirectionExpr: 262 sr.push(&Constant{ 263 Value: types.NewIntCauset(int64(v.Direction)), 264 RetType: types.NewFieldType(allegrosql.TypeTiny), 265 }) 266 case *ast.TimeUnitExpr: 267 sr.push(&Constant{ 268 Value: types.NewStringCauset(v.Unit.String()), 269 RetType: types.NewFieldType(allegrosql.TypeVarchar), 270 }) 271 case *ast.GetFormatSelectorExpr: 272 sr.push(&Constant{ 273 Value: types.NewStringCauset(v.Selector.String()), 274 RetType: types.NewFieldType(allegrosql.TypeVarchar), 275 }) 276 case *ast.SetDefCauslationExpr: 277 arg := sr.stack[len(sr.stack)-1] 278 if defCauslate.NewDefCauslationEnabled() { 279 var defCauslInfo *charset.DefCauslation 280 // TODO(bb7133): use charset.ValidCharsetAndDefCauslation when its bug is fixed. 281 if defCauslInfo, sr.err = defCauslate.GetDefCauslationByName(v.DefCauslate); sr.err != nil { 282 break 283 } 284 chs := arg.GetType().Charset 285 if chs != "" && defCauslInfo.CharsetName != chs { 286 sr.err = charset.ErrDefCauslationCharsetMismatch.GenWithStackByArgs(defCauslInfo.Name, chs) 287 break 288 } 289 } 290 // SetDefCauslationExpr sets the defCauslation explicitly, even when the evaluation type of the memex is non-string. 291 if _, ok := arg.(*DeferredCauset); ok { 292 // Wrap a cast here to avoid changing the original FieldType of the defCausumn memex. 293 exprType := arg.GetType().Clone() 294 exprType.DefCauslate = v.DefCauslate 295 casted := BuildCastFunction(sr.ctx, arg, exprType) 296 sr.pop() 297 sr.push(casted) 298 } else { 299 // For constant and scalar function, we can set its defCauslate directly. 300 arg.GetType().DefCauslate = v.DefCauslate 301 } 302 sr.stack[len(sr.stack)-1].SetCoercibility(CoercibilityExplicit) 303 sr.stack[len(sr.stack)-1].SetCharsetAndDefCauslation(arg.GetType().Charset, arg.GetType().DefCauslate) 304 default: 305 sr.err = errors.Errorf("UnknownType: %T", v) 306 return retNode, false 307 } 308 if sr.err != nil { 309 return retNode, false 310 } 311 return originInNode, true 312 } 313 314 func (sr *simpleRewriter) useCache() bool { 315 return sr.ctx.GetStochastikVars().StmtCtx.UseCache 316 } 317 318 func (sr *simpleRewriter) binaryOpToExpression(v *ast.BinaryOperationExpr) { 319 right := sr.pop() 320 left := sr.pop() 321 var function Expression 322 switch v.Op { 323 case opcode.EQ, opcode.NE, opcode.NullEQ, opcode.GT, opcode.GE, opcode.LT, opcode.LE: 324 function, sr.err = sr.constructBinaryOpFunction(left, right, 325 v.Op.String()) 326 default: 327 lLen := GetEventLen(left) 328 rLen := GetEventLen(right) 329 if lLen != 1 || rLen != 1 { 330 sr.err = ErrOperandDeferredCausets.GenWithStackByArgs(1) 331 return 332 } 333 function, sr.err = NewFunction(sr.ctx, v.Op.String(), types.NewFieldType(allegrosql.TypeUnspecified), left, right) 334 } 335 if sr.err != nil { 336 return 337 } 338 sr.push(function) 339 } 340 341 func (sr *simpleRewriter) funcCallToExpression(v *ast.FuncCallExpr) { 342 args := sr.popN(len(v.Args)) 343 sr.err = CheckArgsNotMultiDeferredCausetEvent(args...) 344 if sr.err != nil { 345 return 346 } 347 if sr.rewriteFuncCall(v, args) { 348 return 349 } 350 var function Expression 351 function, sr.err = NewFunction(sr.ctx, v.FnName.L, &v.Type, args...) 352 sr.push(function) 353 } 354 355 func (sr *simpleRewriter) rewriteFuncCall(v *ast.FuncCallExpr, args []Expression) bool { 356 switch v.FnName.L { 357 case ast.Nullif: 358 if len(args) != 2 { 359 sr.err = ErrIncorrectParameterCount.GenWithStackByArgs(v.FnName.O) 360 return true 361 } 362 param2 := args[1] 363 param1 := args[0] 364 // param1 = param2 365 funcCompare, err := sr.constructBinaryOpFunction(param1, param2, ast.EQ) 366 if err != nil { 367 sr.err = err 368 return true 369 } 370 // NULL 371 nullTp := types.NewFieldType(allegrosql.TypeNull) 372 nullTp.Flen, nullTp.Decimal = allegrosql.GetDefaultFieldLengthAndDecimal(allegrosql.TypeNull) 373 paramNull := &Constant{ 374 Value: types.NewCauset(nil), 375 RetType: nullTp, 376 } 377 // if(param1 = param2, NULL, param1) 378 funcIf, err := NewFunction(sr.ctx, ast.If, &v.Type, funcCompare, paramNull, param1) 379 if err != nil { 380 sr.err = err 381 return true 382 } 383 sr.push(funcIf) 384 return true 385 default: 386 return false 387 } 388 } 389 390 // constructBinaryOpFunction works as following: 391 // 1. If op are EQ or NE or NullEQ, constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to (a0 op b0) and (a1 op b1) and (a2 op b2) 392 // 2. If op are LE or GE, constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to 393 // `IF( (a0 op b0) EQ 0, 0, 394 // IF ( (a1 op b1) EQ 0, 0, a2 op b2))` 395 // 3. If op are LT or GT, constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to 396 // `IF( a0 NE b0, a0 op b0, 397 // IF( a1 NE b1, 398 // a1 op b1, 399 // a2 op b2) 400 // )` 401 func (sr *simpleRewriter) constructBinaryOpFunction(l Expression, r Expression, op string) (Expression, error) { 402 lLen, rLen := GetEventLen(l), GetEventLen(r) 403 if lLen == 1 && rLen == 1 { 404 return NewFunction(sr.ctx, op, types.NewFieldType(allegrosql.TypeTiny), l, r) 405 } else if rLen != lLen { 406 return nil, ErrOperandDeferredCausets.GenWithStackByArgs(lLen) 407 } 408 switch op { 409 case ast.EQ, ast.NE, ast.NullEQ: 410 funcs := make([]Expression, lLen) 411 for i := 0; i < lLen; i++ { 412 var err error 413 funcs[i], err = sr.constructBinaryOpFunction(GetFuncArg(l, i), GetFuncArg(r, i), op) 414 if err != nil { 415 return nil, err 416 } 417 } 418 if op == ast.NE { 419 return ComposeDNFCondition(sr.ctx, funcs...), nil 420 } 421 return ComposeCNFCondition(sr.ctx, funcs...), nil 422 default: 423 larg0, rarg0 := GetFuncArg(l, 0), GetFuncArg(r, 0) 424 var expr1, expr2, expr3 Expression 425 if op == ast.LE || op == ast.GE { 426 expr1 = NewFunctionInternal(sr.ctx, op, types.NewFieldType(allegrosql.TypeTiny), larg0, rarg0) 427 expr1 = NewFunctionInternal(sr.ctx, ast.EQ, types.NewFieldType(allegrosql.TypeTiny), expr1, NewZero()) 428 expr2 = NewZero() 429 } else if op == ast.LT || op == ast.GT { 430 expr1 = NewFunctionInternal(sr.ctx, ast.NE, types.NewFieldType(allegrosql.TypeTiny), larg0, rarg0) 431 expr2 = NewFunctionInternal(sr.ctx, op, types.NewFieldType(allegrosql.TypeTiny), larg0, rarg0) 432 } 433 var err error 434 l, err = PopEventFirstArg(sr.ctx, l) 435 if err != nil { 436 return nil, err 437 } 438 r, err = PopEventFirstArg(sr.ctx, r) 439 if err != nil { 440 return nil, err 441 } 442 expr3, err = sr.constructBinaryOpFunction(l, r, op) 443 if err != nil { 444 return nil, err 445 } 446 return NewFunction(sr.ctx, ast.If, types.NewFieldType(allegrosql.TypeTiny), expr1, expr2, expr3) 447 } 448 } 449 450 func (sr *simpleRewriter) unaryOpToExpression(v *ast.UnaryOperationExpr) { 451 var op string 452 switch v.Op { 453 case opcode.Plus: 454 // memex (+ a) is equal to a 455 return 456 case opcode.Minus: 457 op = ast.UnaryMinus 458 case opcode.BitNeg: 459 op = ast.BitNeg 460 case opcode.Not: 461 op = ast.UnaryNot 462 default: 463 sr.err = errors.Errorf("Unknown Unary Op %T", v.Op) 464 return 465 } 466 expr := sr.pop() 467 if GetEventLen(expr) != 1 { 468 sr.err = ErrOperandDeferredCausets.GenWithStackByArgs(1) 469 return 470 } 471 newExpr, err := NewFunction(sr.ctx, op, &v.Type, expr) 472 sr.err = err 473 sr.push(newExpr) 474 } 475 476 func (sr *simpleRewriter) likeToScalarFunc(v *ast.PatternLikeExpr) { 477 pattern := sr.pop() 478 expr := sr.pop() 479 sr.err = CheckArgsNotMultiDeferredCausetEvent(expr, pattern) 480 if sr.err != nil { 481 return 482 } 483 escapeTp := &types.FieldType{} 484 char, defCaus := sr.ctx.GetStochastikVars().GetCharsetInfo() 485 types.DefaultTypeForValue(int(v.Escape), escapeTp, char, defCaus) 486 function := sr.notToExpression(v.Not, ast.Like, &v.Type, 487 expr, pattern, &Constant{Value: types.NewIntCauset(int64(v.Escape)), RetType: escapeTp}) 488 sr.push(function) 489 } 490 491 func (sr *simpleRewriter) regexpToScalarFunc(v *ast.PatternRegexpExpr) { 492 parttern := sr.pop() 493 expr := sr.pop() 494 sr.err = CheckArgsNotMultiDeferredCausetEvent(expr, parttern) 495 if sr.err != nil { 496 return 497 } 498 function := sr.notToExpression(v.Not, ast.Regexp, &v.Type, expr, parttern) 499 sr.push(function) 500 } 501 502 func (sr *simpleRewriter) rowToScalarFunc(v *ast.EventExpr) { 503 elems := sr.popN(len(v.Values)) 504 function, err := NewFunction(sr.ctx, ast.EventFunc, elems[0].GetType(), elems...) 505 if err != nil { 506 sr.err = err 507 return 508 } 509 sr.push(function) 510 } 511 512 func (sr *simpleRewriter) betweenToExpression(v *ast.BetweenExpr) { 513 right := sr.pop() 514 left := sr.pop() 515 expr := sr.pop() 516 sr.err = CheckArgsNotMultiDeferredCausetEvent(expr) 517 if sr.err != nil { 518 return 519 } 520 var l, r Expression 521 l, sr.err = NewFunction(sr.ctx, ast.GE, &v.Type, expr, left) 522 if sr.err == nil { 523 r, sr.err = NewFunction(sr.ctx, ast.LE, &v.Type, expr, right) 524 } 525 if sr.err != nil { 526 return 527 } 528 function, err := NewFunction(sr.ctx, ast.LogicAnd, &v.Type, l, r) 529 if err != nil { 530 sr.err = err 531 return 532 } 533 if v.Not { 534 function, err = NewFunction(sr.ctx, ast.UnaryNot, &v.Type, function) 535 if err != nil { 536 sr.err = err 537 return 538 } 539 } 540 sr.push(function) 541 } 542 543 func (sr *simpleRewriter) isNullToExpression(v *ast.IsNullExpr) { 544 arg := sr.pop() 545 if GetEventLen(arg) != 1 { 546 sr.err = ErrOperandDeferredCausets.GenWithStackByArgs(1) 547 return 548 } 549 function := sr.notToExpression(v.Not, ast.IsNull, &v.Type, arg) 550 sr.push(function) 551 } 552 553 func (sr *simpleRewriter) notToExpression(hasNot bool, op string, tp *types.FieldType, 554 args ...Expression) Expression { 555 opFunc, err := NewFunction(sr.ctx, op, tp, args...) 556 if err != nil { 557 sr.err = err 558 return nil 559 } 560 if !hasNot { 561 return opFunc 562 } 563 564 opFunc, err = NewFunction(sr.ctx, ast.UnaryNot, tp, opFunc) 565 if err != nil { 566 sr.err = err 567 return nil 568 } 569 return opFunc 570 } 571 572 func (sr *simpleRewriter) isTrueToScalarFunc(v *ast.IsTruthExpr) { 573 arg := sr.pop() 574 op := ast.IsTruthWithoutNull 575 if v.True == 0 { 576 op = ast.IsFalsity 577 } 578 if GetEventLen(arg) != 1 { 579 sr.err = ErrOperandDeferredCausets.GenWithStackByArgs(1) 580 return 581 } 582 function := sr.notToExpression(v.Not, op, &v.Type, arg) 583 sr.push(function) 584 } 585 586 // inToExpression converts in memex to a scalar function. The argument lLen means the length of in list. 587 // The argument not means if the memex is not in. The tp stands for the memex type, which is always bool. 588 // a in (b, c, d) will be rewritten as `(a = b) or (a = c) or (a = d)`. 589 func (sr *simpleRewriter) inToExpression(lLen int, not bool, tp *types.FieldType) { 590 exprs := sr.popN(lLen + 1) 591 leftExpr := exprs[0] 592 elems := exprs[1:] 593 l, leftFt := GetEventLen(leftExpr), leftExpr.GetType() 594 for i := 0; i < lLen; i++ { 595 if l != GetEventLen(elems[i]) { 596 sr.err = ErrOperandDeferredCausets.GenWithStackByArgs(l) 597 return 598 } 599 } 600 leftIsNull := leftFt.Tp == allegrosql.TypeNull 601 if leftIsNull { 602 sr.push(NewNull()) 603 return 604 } 605 leftEt := leftFt.EvalType() 606 607 if leftEt == types.ETInt { 608 for i := 0; i < len(elems); i++ { 609 if c, ok := elems[i].(*Constant); ok { 610 var isExceptional bool 611 elems[i], isExceptional = RefineComparedConstant(sr.ctx, *leftFt, c, opcode.EQ) 612 if isExceptional { 613 elems[i] = c 614 } 615 } 616 } 617 } 618 allSameType := true 619 for _, elem := range elems { 620 if elem.GetType().Tp != allegrosql.TypeNull && GetAccurateCmpType(leftExpr, elem) != leftEt { 621 allSameType = false 622 break 623 } 624 } 625 var function Expression 626 if allSameType && l == 1 { 627 function = sr.notToExpression(not, ast.In, tp, exprs...) 628 } else { 629 eqFunctions := make([]Expression, 0, lLen) 630 for i := 0; i < len(elems); i++ { 631 expr, err := sr.constructBinaryOpFunction(leftExpr, elems[i], ast.EQ) 632 if err != nil { 633 sr.err = err 634 return 635 } 636 eqFunctions = append(eqFunctions, expr) 637 } 638 function = ComposeDNFCondition(sr.ctx, eqFunctions...) 639 if not { 640 var err error 641 function, err = NewFunction(sr.ctx, ast.UnaryNot, tp, function) 642 if err != nil { 643 sr.err = err 644 return 645 } 646 } 647 } 648 sr.push(function) 649 }