github.com/team-ide/go-dialect@v1.9.20/vitess/sqlparser/ast_rewriting.go (about) 1 /* 2 Copyright 2020 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package sqlparser 18 19 import ( 20 "strconv" 21 22 querypb "github.com/team-ide/go-dialect/vitess/query" 23 "github.com/team-ide/go-dialect/vitess/vterrors" 24 vtrpcpb "github.com/team-ide/go-dialect/vitess/vtrpc" 25 26 "strings" 27 28 "github.com/team-ide/go-dialect/vitess/sysvars" 29 ) 30 31 var ( 32 subQueryBaseArgName = []byte("__sq") 33 34 // HasValueSubQueryBaseName is the prefix of each parameter representing an EXISTS subquery 35 HasValueSubQueryBaseName = []byte("__sq_has_values") 36 ) 37 38 // SQLSelectLimitUnset default value for sql_select_limit not set. 39 const SQLSelectLimitUnset = -1 40 41 // RewriteASTResult contains the rewritten ast and meta information about it 42 type RewriteASTResult struct { 43 *BindVarNeeds 44 AST Statement // The rewritten AST 45 } 46 47 // ReservedVars keeps track of the bind variable names that have already been used 48 // in a parsed query. 49 type ReservedVars struct { 50 prefix string 51 reserved BindVars 52 next []byte 53 counter int 54 fast, static bool 55 sqNext int64 56 } 57 58 // ReserveAll tries to reserve all the given variable names. If they're all available, 59 // they are reserved and the function returns true. Otherwise the function returns false. 60 func (r *ReservedVars) ReserveAll(names ...string) bool { 61 for _, name := range names { 62 if _, ok := r.reserved[name]; ok { 63 return false 64 } 65 } 66 for _, name := range names { 67 r.reserved[name] = struct{}{} 68 } 69 return true 70 } 71 72 // ReserveColName reserves a variable name for the given column; if a variable 73 // with the same name already exists, it'll be suffixed with a numberic identifier 74 // to make it unique. 75 func (r *ReservedVars) ReserveColName(col *ColName) string { 76 compliantName := col.CompliantName() 77 if r.fast && strings.HasPrefix(compliantName, r.prefix) { 78 compliantName = "_" + compliantName 79 } 80 81 joinVar := []byte(compliantName) 82 baseLen := len(joinVar) 83 i := int64(1) 84 85 for { 86 if _, ok := r.reserved[string(joinVar)]; !ok { 87 r.reserved[string(joinVar)] = struct{}{} 88 return string(joinVar) 89 } 90 joinVar = strconv.AppendInt(joinVar[:baseLen], i, 10) 91 i++ 92 } 93 } 94 95 // ReserveSubQuery returns the next argument name to replace subquery with pullout value. 96 func (r *ReservedVars) ReserveSubQuery() string { 97 for { 98 r.sqNext++ 99 joinVar := strconv.AppendInt(subQueryBaseArgName, r.sqNext, 10) 100 if _, ok := r.reserved[string(joinVar)]; !ok { 101 r.reserved[string(joinVar)] = struct{}{} 102 return string(joinVar) 103 } 104 } 105 } 106 107 // ReserveSubQueryWithHasValues returns the next argument name to replace subquery with pullout value. 108 func (r *ReservedVars) ReserveSubQueryWithHasValues() (string, string) { 109 for { 110 r.sqNext++ 111 joinVar := strconv.AppendInt(subQueryBaseArgName, r.sqNext, 10) 112 hasValuesJoinVar := strconv.AppendInt(HasValueSubQueryBaseName, r.sqNext, 10) 113 _, joinVarOK := r.reserved[string(joinVar)] 114 _, hasValuesJoinVarOK := r.reserved[string(hasValuesJoinVar)] 115 if !joinVarOK && !hasValuesJoinVarOK { 116 r.reserved[string(joinVar)] = struct{}{} 117 r.reserved[string(hasValuesJoinVar)] = struct{}{} 118 return string(joinVar), string(hasValuesJoinVar) 119 } 120 } 121 } 122 123 // ReserveHasValuesSubQuery returns the next argument name to replace subquery with has value. 124 func (r *ReservedVars) ReserveHasValuesSubQuery() string { 125 for { 126 r.sqNext++ 127 joinVar := strconv.AppendInt(HasValueSubQueryBaseName, r.sqNext, 10) 128 if _, ok := r.reserved[string(joinVar)]; !ok { 129 r.reserved[string(joinVar)] = struct{}{} 130 return string(joinVar) 131 } 132 } 133 } 134 135 const staticBvar10 = "vtg0vtg1vtg2vtg3vtg4vtg5vtg6vtg7vtg8vtg9" 136 const staticBvar100 = "vtg10vtg11vtg12vtg13vtg14vtg15vtg16vtg17vtg18vtg19vtg20vtg21vtg22vtg23vtg24vtg25vtg26vtg27vtg28vtg29vtg30vtg31vtg32vtg33vtg34vtg35vtg36vtg37vtg38vtg39vtg40vtg41vtg42vtg43vtg44vtg45vtg46vtg47vtg48vtg49vtg50vtg51vtg52vtg53vtg54vtg55vtg56vtg57vtg58vtg59vtg60vtg61vtg62vtg63vtg64vtg65vtg66vtg67vtg68vtg69vtg70vtg71vtg72vtg73vtg74vtg75vtg76vtg77vtg78vtg79vtg80vtg81vtg82vtg83vtg84vtg85vtg86vtg87vtg88vtg89vtg90vtg91vtg92vtg93vtg94vtg95vtg96vtg97vtg98vtg99" 137 138 func (r *ReservedVars) nextUnusedVar() string { 139 if r.fast { 140 r.counter++ 141 142 if r.static { 143 switch { 144 case r.counter < 10: 145 ofs := r.counter * 4 146 return staticBvar10[ofs : ofs+4] 147 case r.counter < 100: 148 ofs := (r.counter - 10) * 5 149 return staticBvar100[ofs : ofs+5] 150 } 151 } 152 153 r.next = strconv.AppendInt(r.next[:len(r.prefix)], int64(r.counter), 10) 154 return string(r.next) 155 } 156 157 for { 158 r.counter++ 159 r.next = strconv.AppendInt(r.next[:len(r.prefix)], int64(r.counter), 10) 160 161 if _, ok := r.reserved[string(r.next)]; !ok { 162 bvar := string(r.next) 163 r.reserved[bvar] = struct{}{} 164 return bvar 165 } 166 } 167 } 168 169 // NewReservedVars allocates a ReservedVar instance that will generate unique 170 // variable names starting with the given `prefix` and making sure that they 171 // don't conflict with the given set of `known` variables. 172 func NewReservedVars(prefix string, known BindVars) *ReservedVars { 173 rv := &ReservedVars{ 174 prefix: prefix, 175 counter: 0, 176 reserved: known, 177 fast: true, 178 next: []byte(prefix), 179 } 180 181 if prefix != "" && prefix[0] == '_' { 182 panic("cannot reserve variables with a '_' prefix") 183 } 184 185 for bvar := range known { 186 if strings.HasPrefix(bvar, prefix) { 187 rv.fast = false 188 break 189 } 190 } 191 192 if prefix == "vtg" { 193 rv.static = true 194 } 195 return rv 196 } 197 198 // PrepareAST will normalize the query 199 func PrepareAST(in Statement, reservedVars *ReservedVars, bindVars map[string]*querypb.BindVariable, parameterize bool, keyspace string, selectLimit int) (*RewriteASTResult, error) { 200 if parameterize { 201 err := Normalize(in, reservedVars, bindVars) 202 if err != nil { 203 return nil, err 204 } 205 } 206 return RewriteAST(in, keyspace, selectLimit) 207 } 208 209 // RewriteAST rewrites the whole AST, replacing function calls and adding column aliases to queries 210 func RewriteAST(in Statement, keyspace string, selectLimit int) (*RewriteASTResult, error) { 211 er := newExpressionRewriter(keyspace, selectLimit) 212 er.shouldRewriteDatabaseFunc = shouldRewriteDatabaseFunc(in) 213 setRewriter := &setNormalizer{} 214 result := Rewrite(in, er.rewrite, setRewriter.rewriteSetComingUp) 215 if setRewriter.err != nil { 216 return nil, setRewriter.err 217 } 218 219 out, ok := result.(Statement) 220 if !ok { 221 return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "statement rewriting returned a non statement: %s", String(out)) 222 } 223 224 r := &RewriteASTResult{ 225 AST: out, 226 BindVarNeeds: er.bindVars, 227 } 228 return r, nil 229 } 230 231 func shouldRewriteDatabaseFunc(in Statement) bool { 232 selct, ok := in.(*Select) 233 if !ok { 234 return false 235 } 236 if len(selct.From) != 1 { 237 return false 238 } 239 aliasedTable, ok := selct.From[0].(*AliasedTableExpr) 240 if !ok { 241 return false 242 } 243 tableName, ok := aliasedTable.Expr.(TableName) 244 if !ok { 245 return false 246 } 247 return tableName.Name.String() == "dual" 248 } 249 250 type expressionRewriter struct { 251 bindVars *BindVarNeeds 252 shouldRewriteDatabaseFunc bool 253 err error 254 255 // we need to know this to make a decision if we can safely rewrite JOIN USING => JOIN ON 256 hasStarInSelect bool 257 258 keyspace string 259 selectLimit int 260 } 261 262 func newExpressionRewriter(keyspace string, selectLimit int) *expressionRewriter { 263 return &expressionRewriter{bindVars: &BindVarNeeds{}, keyspace: keyspace, selectLimit: selectLimit} 264 } 265 266 const ( 267 // LastInsertIDName is a reserved bind var name for last_insert_id() 268 LastInsertIDName = "__lastInsertId" 269 270 // DBVarName is a reserved bind var name for database() 271 DBVarName = "__vtdbname" 272 273 // FoundRowsName is a reserved bind var name for found_rows() 274 FoundRowsName = "__vtfrows" 275 276 // RowCountName is a reserved bind var name for row_count() 277 RowCountName = "__vtrcount" 278 279 // UserDefinedVariableName is what we prepend bind var names for user defined variables 280 UserDefinedVariableName = "__vtudv" 281 ) 282 283 func (er *expressionRewriter) rewriteAliasedExpr(node *AliasedExpr) (*BindVarNeeds, error) { 284 inner := newExpressionRewriter(er.keyspace, er.selectLimit) 285 inner.shouldRewriteDatabaseFunc = er.shouldRewriteDatabaseFunc 286 tmp := Rewrite(node.Expr, inner.rewrite, nil) 287 newExpr, ok := tmp.(Expr) 288 if !ok { 289 return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to rewrite AST. function expected to return Expr returned a %s", String(tmp)) 290 } 291 node.Expr = newExpr 292 return inner.bindVars, nil 293 } 294 295 func (er *expressionRewriter) rewrite(cursor *Cursor) bool { 296 switch node := cursor.Node().(type) { 297 // select last_insert_id() -> select :__lastInsertId as `last_insert_id()` 298 case *Select: 299 for _, col := range node.SelectExprs { 300 _, hasStar := col.(*StarExpr) 301 if hasStar { 302 er.hasStarInSelect = true 303 } 304 305 aliasedExpr, ok := col.(*AliasedExpr) 306 if ok && aliasedExpr.As.IsEmpty() { 307 buf := NewTrackedBuffer(nil) 308 aliasedExpr.Expr.Format(buf) 309 innerBindVarNeeds, err := er.rewriteAliasedExpr(aliasedExpr) 310 if err != nil { 311 er.err = err 312 return false 313 } 314 if innerBindVarNeeds.HasRewrites() { 315 aliasedExpr.As = NewColIdent(buf.String()) 316 } 317 er.bindVars.MergeWith(innerBindVarNeeds) 318 } 319 } 320 // set select limit if explicitly not set when sql_select_limit is set on the connection. 321 if er.selectLimit > 0 && node.Limit == nil { 322 node.Limit = &Limit{Rowcount: NewIntLiteral(strconv.Itoa(er.selectLimit))} 323 } 324 case *Union: 325 // set select limit if explicitly not set when sql_select_limit is set on the connection. 326 if er.selectLimit > 0 && node.Limit == nil { 327 node.Limit = &Limit{Rowcount: NewIntLiteral(strconv.Itoa(er.selectLimit))} 328 } 329 case *FuncExpr: 330 er.funcRewrite(cursor, node) 331 case *ColName: 332 switch node.Name.at { 333 case SingleAt: 334 er.udvRewrite(cursor, node) 335 case DoubleAt: 336 er.sysVarRewrite(cursor, node) 337 } 338 case *Subquery: 339 er.unnestSubQueries(cursor, node) 340 case *JoinCondition: 341 er.rewriteJoinCondition(cursor, node) 342 case *NotExpr: 343 switch inner := node.Expr.(type) { 344 case *ComparisonExpr: 345 // not col = 42 => col != 42 346 // not col > 42 => col <= 42 347 // etc 348 canChange, inverse := inverseOp(inner.Operator) 349 if canChange { 350 inner.Operator = inverse 351 cursor.Replace(inner) 352 } 353 case *NotExpr: 354 // not not true => true 355 cursor.Replace(inner.Expr) 356 case BoolVal: 357 // not true => false 358 inner = !inner 359 cursor.Replace(inner) 360 } 361 case *AliasedTableExpr: 362 if !SystemSchema(er.keyspace) { 363 break 364 } 365 aliasTableName, ok := node.Expr.(TableName) 366 if !ok { 367 return true 368 } 369 // Qualifier should not be added to dual table 370 if aliasTableName.Name.String() == "dual" { 371 break 372 } 373 if er.keyspace != "" && aliasTableName.Qualifier.IsEmpty() { 374 aliasTableName.Qualifier = NewTableIdent(er.keyspace) 375 node.Expr = aliasTableName 376 cursor.Replace(node) 377 } 378 case *ShowBasic: 379 if node.Command == VariableGlobal || node.Command == VariableSession { 380 varsToAdd := sysvars.GetInterestingVariables() 381 for _, sysVar := range varsToAdd { 382 er.bindVars.AddSysVar(sysVar) 383 } 384 } 385 } 386 return true 387 } 388 389 func inverseOp(i ComparisonExprOperator) (bool, ComparisonExprOperator) { 390 switch i { 391 case EqualOp: 392 return true, NotEqualOp 393 case LessThanOp: 394 return true, GreaterEqualOp 395 case GreaterThanOp: 396 return true, LessEqualOp 397 case LessEqualOp: 398 return true, GreaterThanOp 399 case GreaterEqualOp: 400 return true, LessThanOp 401 case NotEqualOp: 402 return true, EqualOp 403 case InOp: 404 return true, NotInOp 405 case NotInOp: 406 return true, InOp 407 case LikeOp: 408 return true, NotLikeOp 409 case NotLikeOp: 410 return true, LikeOp 411 case RegexpOp: 412 return true, NotRegexpOp 413 case NotRegexpOp: 414 return true, RegexpOp 415 } 416 417 return false, i 418 } 419 420 func (er *expressionRewriter) rewriteJoinCondition(cursor *Cursor, node *JoinCondition) { 421 if node.Using != nil && !er.hasStarInSelect { 422 joinTableExpr, ok := cursor.Parent().(*JoinTableExpr) 423 if !ok { 424 // this is not possible with the current AST 425 return 426 } 427 leftTable, leftOk := joinTableExpr.LeftExpr.(*AliasedTableExpr) 428 rightTable, rightOk := joinTableExpr.RightExpr.(*AliasedTableExpr) 429 if !(leftOk && rightOk) { 430 // we only deal with simple FROM A JOIN B USING queries at the moment 431 return 432 } 433 lft, err := leftTable.TableName() 434 if err != nil { 435 er.err = err 436 return 437 } 438 rgt, err := rightTable.TableName() 439 if err != nil { 440 er.err = err 441 return 442 } 443 newCondition := &JoinCondition{} 444 for _, colIdent := range node.Using { 445 lftCol := NewColNameWithQualifier(colIdent.String(), lft) 446 rgtCol := NewColNameWithQualifier(colIdent.String(), rgt) 447 cmp := &ComparisonExpr{ 448 Operator: EqualOp, 449 Left: lftCol, 450 Right: rgtCol, 451 } 452 if newCondition.On == nil { 453 newCondition.On = cmp 454 } else { 455 newCondition.On = &AndExpr{Left: newCondition.On, Right: cmp} 456 } 457 } 458 cursor.Replace(newCondition) 459 } 460 } 461 462 func (er *expressionRewriter) sysVarRewrite(cursor *Cursor, node *ColName) { 463 lowered := node.Name.Lowered() 464 switch lowered { 465 case sysvars.Autocommit.Name, 466 sysvars.Charset.Name, 467 sysvars.ClientFoundRows.Name, 468 sysvars.DDLStrategy.Name, 469 sysvars.Names.Name, 470 sysvars.TransactionMode.Name, 471 sysvars.ReadAfterWriteGTID.Name, 472 sysvars.ReadAfterWriteTimeOut.Name, 473 sysvars.SessionEnableSystemSettings.Name, 474 sysvars.SessionTrackGTIDs.Name, 475 sysvars.SessionUUID.Name, 476 sysvars.SkipQueryPlanCache.Name, 477 sysvars.Socket.Name, 478 sysvars.SQLSelectLimit.Name, 479 sysvars.Version.Name, 480 sysvars.VersionComment.Name, 481 sysvars.Workload.Name: 482 cursor.Replace(bindVarExpression("__vt" + lowered)) 483 er.bindVars.AddSysVar(lowered) 484 } 485 } 486 487 func (er *expressionRewriter) udvRewrite(cursor *Cursor, node *ColName) { 488 udv := strings.ToLower(node.Name.CompliantName()) 489 cursor.Replace(bindVarExpression(UserDefinedVariableName + udv)) 490 er.bindVars.AddUserDefVar(udv) 491 } 492 493 var funcRewrites = map[string]string{ 494 "last_insert_id": LastInsertIDName, 495 "database": DBVarName, 496 "schema": DBVarName, 497 "found_rows": FoundRowsName, 498 "row_count": RowCountName, 499 } 500 501 func (er *expressionRewriter) funcRewrite(cursor *Cursor, node *FuncExpr) { 502 bindVar, found := funcRewrites[node.Name.Lowered()] 503 if found { 504 if bindVar == DBVarName && !er.shouldRewriteDatabaseFunc { 505 return 506 } 507 if len(node.Exprs) > 0 { 508 er.err = vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Argument to %s() not supported", node.Name.Lowered()) 509 return 510 } 511 cursor.Replace(bindVarExpression(bindVar)) 512 er.bindVars.AddFuncResult(bindVar) 513 } 514 } 515 516 func (er *expressionRewriter) unnestSubQueries(cursor *Cursor, subquery *Subquery) { 517 if _, isExists := cursor.Parent().(*ExistsExpr); isExists { 518 return 519 } 520 sel, isSimpleSelect := subquery.Select.(*Select) 521 if !isSimpleSelect { 522 return 523 } 524 525 if len(sel.SelectExprs) != 1 || 526 len(sel.OrderBy) != 0 || 527 len(sel.GroupBy) != 0 || 528 len(sel.From) != 1 || 529 sel.Where != nil || 530 sel.Having != nil || 531 sel.Limit != nil || sel.Lock != NoLock { 532 return 533 } 534 535 aliasedTable, ok := sel.From[0].(*AliasedTableExpr) 536 if !ok { 537 return 538 } 539 table, ok := aliasedTable.Expr.(TableName) 540 if !ok || table.Name.String() != "dual" { 541 return 542 } 543 expr, ok := sel.SelectExprs[0].(*AliasedExpr) 544 if !ok { 545 return 546 } 547 er.bindVars.NoteRewrite() 548 // we need to make sure that the inner expression also gets rewritten, 549 // so we fire off another rewriter traversal here 550 rewritten := Rewrite(expr.Expr, er.rewrite, nil) 551 552 // Here we need to handle the subquery rewrite in case in occurs in an IN clause 553 // For example, SELECT id FROM user WHERE id IN (SELECT 1 FROM DUAL) 554 // Here we cannot rewrite the query to SELECT id FROM user WHERE id IN 1, since that is syntactically wrong 555 // We must rewrite it to SELECT id FROM user WHERE id IN (1) 556 // Find more cases in the test file 557 rewrittenExpr, isExpr := rewritten.(Expr) 558 _, isColTuple := rewritten.(ColTuple) 559 comparisonExpr, isCompExpr := cursor.Parent().(*ComparisonExpr) 560 // Check that the parent is a comparison operator with IN or NOT IN operation. 561 // Also, if rewritten is already a ColTuple (like a subquery), then we do not need this 562 // We also need to check that rewritten is an Expr, if it is then we can rewrite it as a ValTuple 563 if isCompExpr && (comparisonExpr.Operator == InOp || comparisonExpr.Operator == NotInOp) && !isColTuple && isExpr { 564 cursor.Replace(ValTuple{rewrittenExpr}) 565 return 566 } 567 568 cursor.Replace(rewritten) 569 } 570 571 func bindVarExpression(name string) Expr { 572 return NewArgument(name) 573 } 574 575 // SystemSchema returns true if the schema passed is system schema 576 func SystemSchema(schema string) bool { 577 return strings.EqualFold(schema, "information_schema") || 578 strings.EqualFold(schema, "performance_schema") || 579 strings.EqualFold(schema, "sys") || 580 strings.EqualFold(schema, "mysql") 581 } 582 583 // RewriteToCNF walks the input AST and rewrites any boolean logic into CNF 584 // Note: In order to re-plan, we need to empty the accumulated metadata in the AST, 585 // so ColName.Metadata will be nil:ed out as part of this rewrite 586 func RewriteToCNF(ast SQLNode) SQLNode { 587 for { 588 finishedRewrite := true 589 ast = Rewrite(ast, func(cursor *Cursor) bool { 590 if e, isExpr := cursor.node.(Expr); isExpr { 591 rewritten, didRewrite := rewriteToCNFExpr(e) 592 if didRewrite { 593 finishedRewrite = false 594 cursor.Replace(rewritten) 595 } 596 } 597 if col, isCol := cursor.node.(*ColName); isCol { 598 col.Metadata = nil 599 } 600 return true 601 }, nil) 602 603 if finishedRewrite { 604 return ast 605 } 606 } 607 } 608 609 func distinctOr(in *OrExpr) (Expr, bool) { 610 todo := []*OrExpr{in} 611 var leaves []Expr 612 for len(todo) > 0 { 613 curr := todo[0] 614 todo = todo[1:] 615 addAnd := func(in Expr) { 616 and, ok := in.(*OrExpr) 617 if ok { 618 todo = append(todo, and) 619 } else { 620 leaves = append(leaves, in) 621 } 622 } 623 addAnd(curr.Left) 624 addAnd(curr.Right) 625 } 626 original := len(leaves) 627 var predicates []Expr 628 629 outer1: 630 for len(leaves) > 0 { 631 curr := leaves[0] 632 leaves = leaves[1:] 633 for _, alreadyIn := range predicates { 634 if EqualsExpr(alreadyIn, curr) { 635 continue outer1 636 } 637 } 638 predicates = append(predicates, curr) 639 } 640 if original == len(predicates) { 641 return in, false 642 } 643 var result Expr 644 for i, curr := range predicates { 645 if i == 0 { 646 result = curr 647 continue 648 } 649 result = &OrExpr{Left: result, Right: curr} 650 } 651 return result, true 652 } 653 func distinctAnd(in *AndExpr) (Expr, bool) { 654 todo := []*AndExpr{in} 655 var leaves []Expr 656 for len(todo) > 0 { 657 curr := todo[0] 658 todo = todo[1:] 659 addAnd := func(in Expr) { 660 and, ok := in.(*AndExpr) 661 if ok { 662 todo = append(todo, and) 663 } else { 664 leaves = append(leaves, in) 665 } 666 } 667 addAnd(curr.Left) 668 addAnd(curr.Right) 669 } 670 original := len(leaves) 671 var predicates []Expr 672 673 outer1: 674 for len(leaves) > 0 { 675 curr := leaves[0] 676 leaves = leaves[1:] 677 for _, alreadyIn := range predicates { 678 if EqualsExpr(alreadyIn, curr) { 679 continue outer1 680 } 681 } 682 predicates = append(predicates, curr) 683 } 684 if original == len(predicates) { 685 return in, false 686 } 687 var result Expr 688 for i, curr := range predicates { 689 if i == 0 { 690 result = curr 691 continue 692 } 693 result = &AndExpr{Left: result, Right: curr} 694 } 695 return result, true 696 } 697 698 func rewriteToCNFExpr(expr Expr) (Expr, bool) { 699 switch expr := expr.(type) { 700 case *NotExpr: 701 switch child := expr.Expr.(type) { 702 case *NotExpr: 703 // NOT NOT A => A 704 return child.Expr, true 705 case *OrExpr: 706 // DeMorgan Rewriter 707 // NOT (A OR B) => NOT A AND NOT B 708 return &AndExpr{Right: &NotExpr{Expr: child.Right}, Left: &NotExpr{Expr: child.Left}}, true 709 case *AndExpr: 710 // DeMorgan Rewriter 711 // NOT (A AND B) => NOT A OR NOT B 712 return &OrExpr{Right: &NotExpr{Expr: child.Right}, Left: &NotExpr{Expr: child.Left}}, true 713 } 714 case *OrExpr: 715 or := expr 716 if and, ok := or.Left.(*AndExpr); ok { 717 // Simplification 718 // (A AND B) OR A => A 719 if EqualsExpr(or.Right, and.Left) || EqualsExpr(or.Right, and.Right) { 720 return or.Right, true 721 } 722 // Distribution Law 723 // (A AND B) OR C => (A OR C) AND (B OR C) 724 return &AndExpr{Left: &OrExpr{Left: and.Left, Right: or.Right}, Right: &OrExpr{Left: and.Right, Right: or.Right}}, true 725 } 726 if and, ok := or.Right.(*AndExpr); ok { 727 // Simplification 728 // A OR (A AND B) => A 729 if EqualsExpr(or.Left, and.Left) || EqualsExpr(or.Left, and.Right) { 730 return or.Left, true 731 } 732 // Distribution Law 733 // C OR (A AND B) => (C OR A) AND (C OR B) 734 return &AndExpr{Left: &OrExpr{Left: or.Left, Right: and.Left}, Right: &OrExpr{Left: or.Left, Right: and.Right}}, true 735 } 736 // Try to make distinct 737 return distinctOr(expr) 738 739 case *XorExpr: 740 // DeMorgan Rewriter 741 // (A XOR B) => (A OR B) AND NOT (A AND B) 742 return &AndExpr{Left: &OrExpr{Left: expr.Left, Right: expr.Right}, Right: &NotExpr{Expr: &AndExpr{Left: expr.Left, Right: expr.Right}}}, true 743 case *AndExpr: 744 res, rewritten := distinctAnd(expr) 745 if rewritten { 746 return res, rewritten 747 } 748 and := expr 749 if or, ok := and.Left.(*OrExpr); ok { 750 // Simplification 751 // (A OR B) AND A => A 752 if EqualsExpr(or.Left, and.Right) || EqualsExpr(or.Right, and.Right) { 753 return and.Right, true 754 } 755 } 756 if or, ok := and.Right.(*OrExpr); ok { 757 // Simplification 758 // A OR (A AND B) => A 759 if EqualsExpr(or.Left, and.Left) || EqualsExpr(or.Right, and.Left) { 760 return or.Left, true 761 } 762 } 763 764 } 765 return expr, false 766 }