github.com/team-ide/go-dialect@v1.9.20/vitess/sqlparser/ast_rewriting.go (about)

     1  /*
     2  Copyright 2020 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package sqlparser
    18  
    19  import (
    20  	"strconv"
    21  
    22  	querypb "github.com/team-ide/go-dialect/vitess/query"
    23  	"github.com/team-ide/go-dialect/vitess/vterrors"
    24  	vtrpcpb "github.com/team-ide/go-dialect/vitess/vtrpc"
    25  
    26  	"strings"
    27  
    28  	"github.com/team-ide/go-dialect/vitess/sysvars"
    29  )
    30  
    31  var (
    32  	subQueryBaseArgName = []byte("__sq")
    33  
    34  	// HasValueSubQueryBaseName is the prefix of each parameter representing an EXISTS subquery
    35  	HasValueSubQueryBaseName = []byte("__sq_has_values")
    36  )
    37  
    38  // SQLSelectLimitUnset default value for sql_select_limit not set.
    39  const SQLSelectLimitUnset = -1
    40  
    41  // RewriteASTResult contains the rewritten ast and meta information about it
    42  type RewriteASTResult struct {
    43  	*BindVarNeeds
    44  	AST Statement // The rewritten AST
    45  }
    46  
    47  // ReservedVars keeps track of the bind variable names that have already been used
    48  // in a parsed query.
    49  type ReservedVars struct {
    50  	prefix       string
    51  	reserved     BindVars
    52  	next         []byte
    53  	counter      int
    54  	fast, static bool
    55  	sqNext       int64
    56  }
    57  
    58  // ReserveAll tries to reserve all the given variable names. If they're all available,
    59  // they are reserved and the function returns true. Otherwise the function returns false.
    60  func (r *ReservedVars) ReserveAll(names ...string) bool {
    61  	for _, name := range names {
    62  		if _, ok := r.reserved[name]; ok {
    63  			return false
    64  		}
    65  	}
    66  	for _, name := range names {
    67  		r.reserved[name] = struct{}{}
    68  	}
    69  	return true
    70  }
    71  
    72  // ReserveColName reserves a variable name for the given column; if a variable
    73  // with the same name already exists, it'll be suffixed with a numberic identifier
    74  // to make it unique.
    75  func (r *ReservedVars) ReserveColName(col *ColName) string {
    76  	compliantName := col.CompliantName()
    77  	if r.fast && strings.HasPrefix(compliantName, r.prefix) {
    78  		compliantName = "_" + compliantName
    79  	}
    80  
    81  	joinVar := []byte(compliantName)
    82  	baseLen := len(joinVar)
    83  	i := int64(1)
    84  
    85  	for {
    86  		if _, ok := r.reserved[string(joinVar)]; !ok {
    87  			r.reserved[string(joinVar)] = struct{}{}
    88  			return string(joinVar)
    89  		}
    90  		joinVar = strconv.AppendInt(joinVar[:baseLen], i, 10)
    91  		i++
    92  	}
    93  }
    94  
    95  // ReserveSubQuery returns the next argument name to replace subquery with pullout value.
    96  func (r *ReservedVars) ReserveSubQuery() string {
    97  	for {
    98  		r.sqNext++
    99  		joinVar := strconv.AppendInt(subQueryBaseArgName, r.sqNext, 10)
   100  		if _, ok := r.reserved[string(joinVar)]; !ok {
   101  			r.reserved[string(joinVar)] = struct{}{}
   102  			return string(joinVar)
   103  		}
   104  	}
   105  }
   106  
   107  // ReserveSubQueryWithHasValues returns the next argument name to replace subquery with pullout value.
   108  func (r *ReservedVars) ReserveSubQueryWithHasValues() (string, string) {
   109  	for {
   110  		r.sqNext++
   111  		joinVar := strconv.AppendInt(subQueryBaseArgName, r.sqNext, 10)
   112  		hasValuesJoinVar := strconv.AppendInt(HasValueSubQueryBaseName, r.sqNext, 10)
   113  		_, joinVarOK := r.reserved[string(joinVar)]
   114  		_, hasValuesJoinVarOK := r.reserved[string(hasValuesJoinVar)]
   115  		if !joinVarOK && !hasValuesJoinVarOK {
   116  			r.reserved[string(joinVar)] = struct{}{}
   117  			r.reserved[string(hasValuesJoinVar)] = struct{}{}
   118  			return string(joinVar), string(hasValuesJoinVar)
   119  		}
   120  	}
   121  }
   122  
   123  // ReserveHasValuesSubQuery returns the next argument name to replace subquery with has value.
   124  func (r *ReservedVars) ReserveHasValuesSubQuery() string {
   125  	for {
   126  		r.sqNext++
   127  		joinVar := strconv.AppendInt(HasValueSubQueryBaseName, r.sqNext, 10)
   128  		if _, ok := r.reserved[string(joinVar)]; !ok {
   129  			r.reserved[string(joinVar)] = struct{}{}
   130  			return string(joinVar)
   131  		}
   132  	}
   133  }
   134  
   135  const staticBvar10 = "vtg0vtg1vtg2vtg3vtg4vtg5vtg6vtg7vtg8vtg9"
   136  const staticBvar100 = "vtg10vtg11vtg12vtg13vtg14vtg15vtg16vtg17vtg18vtg19vtg20vtg21vtg22vtg23vtg24vtg25vtg26vtg27vtg28vtg29vtg30vtg31vtg32vtg33vtg34vtg35vtg36vtg37vtg38vtg39vtg40vtg41vtg42vtg43vtg44vtg45vtg46vtg47vtg48vtg49vtg50vtg51vtg52vtg53vtg54vtg55vtg56vtg57vtg58vtg59vtg60vtg61vtg62vtg63vtg64vtg65vtg66vtg67vtg68vtg69vtg70vtg71vtg72vtg73vtg74vtg75vtg76vtg77vtg78vtg79vtg80vtg81vtg82vtg83vtg84vtg85vtg86vtg87vtg88vtg89vtg90vtg91vtg92vtg93vtg94vtg95vtg96vtg97vtg98vtg99"
   137  
   138  func (r *ReservedVars) nextUnusedVar() string {
   139  	if r.fast {
   140  		r.counter++
   141  
   142  		if r.static {
   143  			switch {
   144  			case r.counter < 10:
   145  				ofs := r.counter * 4
   146  				return staticBvar10[ofs : ofs+4]
   147  			case r.counter < 100:
   148  				ofs := (r.counter - 10) * 5
   149  				return staticBvar100[ofs : ofs+5]
   150  			}
   151  		}
   152  
   153  		r.next = strconv.AppendInt(r.next[:len(r.prefix)], int64(r.counter), 10)
   154  		return string(r.next)
   155  	}
   156  
   157  	for {
   158  		r.counter++
   159  		r.next = strconv.AppendInt(r.next[:len(r.prefix)], int64(r.counter), 10)
   160  
   161  		if _, ok := r.reserved[string(r.next)]; !ok {
   162  			bvar := string(r.next)
   163  			r.reserved[bvar] = struct{}{}
   164  			return bvar
   165  		}
   166  	}
   167  }
   168  
   169  // NewReservedVars allocates a ReservedVar instance that will generate unique
   170  // variable names starting with the given `prefix` and making sure that they
   171  // don't conflict with the given set of `known` variables.
   172  func NewReservedVars(prefix string, known BindVars) *ReservedVars {
   173  	rv := &ReservedVars{
   174  		prefix:   prefix,
   175  		counter:  0,
   176  		reserved: known,
   177  		fast:     true,
   178  		next:     []byte(prefix),
   179  	}
   180  
   181  	if prefix != "" && prefix[0] == '_' {
   182  		panic("cannot reserve variables with a '_' prefix")
   183  	}
   184  
   185  	for bvar := range known {
   186  		if strings.HasPrefix(bvar, prefix) {
   187  			rv.fast = false
   188  			break
   189  		}
   190  	}
   191  
   192  	if prefix == "vtg" {
   193  		rv.static = true
   194  	}
   195  	return rv
   196  }
   197  
   198  // PrepareAST will normalize the query
   199  func PrepareAST(in Statement, reservedVars *ReservedVars, bindVars map[string]*querypb.BindVariable, parameterize bool, keyspace string, selectLimit int) (*RewriteASTResult, error) {
   200  	if parameterize {
   201  		err := Normalize(in, reservedVars, bindVars)
   202  		if err != nil {
   203  			return nil, err
   204  		}
   205  	}
   206  	return RewriteAST(in, keyspace, selectLimit)
   207  }
   208  
   209  // RewriteAST rewrites the whole AST, replacing function calls and adding column aliases to queries
   210  func RewriteAST(in Statement, keyspace string, selectLimit int) (*RewriteASTResult, error) {
   211  	er := newExpressionRewriter(keyspace, selectLimit)
   212  	er.shouldRewriteDatabaseFunc = shouldRewriteDatabaseFunc(in)
   213  	setRewriter := &setNormalizer{}
   214  	result := Rewrite(in, er.rewrite, setRewriter.rewriteSetComingUp)
   215  	if setRewriter.err != nil {
   216  		return nil, setRewriter.err
   217  	}
   218  
   219  	out, ok := result.(Statement)
   220  	if !ok {
   221  		return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "statement rewriting returned a non statement: %s", String(out))
   222  	}
   223  
   224  	r := &RewriteASTResult{
   225  		AST:          out,
   226  		BindVarNeeds: er.bindVars,
   227  	}
   228  	return r, nil
   229  }
   230  
   231  func shouldRewriteDatabaseFunc(in Statement) bool {
   232  	selct, ok := in.(*Select)
   233  	if !ok {
   234  		return false
   235  	}
   236  	if len(selct.From) != 1 {
   237  		return false
   238  	}
   239  	aliasedTable, ok := selct.From[0].(*AliasedTableExpr)
   240  	if !ok {
   241  		return false
   242  	}
   243  	tableName, ok := aliasedTable.Expr.(TableName)
   244  	if !ok {
   245  		return false
   246  	}
   247  	return tableName.Name.String() == "dual"
   248  }
   249  
   250  type expressionRewriter struct {
   251  	bindVars                  *BindVarNeeds
   252  	shouldRewriteDatabaseFunc bool
   253  	err                       error
   254  
   255  	// we need to know this to make a decision if we can safely rewrite JOIN USING => JOIN ON
   256  	hasStarInSelect bool
   257  
   258  	keyspace    string
   259  	selectLimit int
   260  }
   261  
   262  func newExpressionRewriter(keyspace string, selectLimit int) *expressionRewriter {
   263  	return &expressionRewriter{bindVars: &BindVarNeeds{}, keyspace: keyspace, selectLimit: selectLimit}
   264  }
   265  
   266  const (
   267  	// LastInsertIDName is a reserved bind var name for last_insert_id()
   268  	LastInsertIDName = "__lastInsertId"
   269  
   270  	// DBVarName is a reserved bind var name for database()
   271  	DBVarName = "__vtdbname"
   272  
   273  	// FoundRowsName is a reserved bind var name for found_rows()
   274  	FoundRowsName = "__vtfrows"
   275  
   276  	// RowCountName is a reserved bind var name for row_count()
   277  	RowCountName = "__vtrcount"
   278  
   279  	// UserDefinedVariableName is what we prepend bind var names for user defined variables
   280  	UserDefinedVariableName = "__vtudv"
   281  )
   282  
   283  func (er *expressionRewriter) rewriteAliasedExpr(node *AliasedExpr) (*BindVarNeeds, error) {
   284  	inner := newExpressionRewriter(er.keyspace, er.selectLimit)
   285  	inner.shouldRewriteDatabaseFunc = er.shouldRewriteDatabaseFunc
   286  	tmp := Rewrite(node.Expr, inner.rewrite, nil)
   287  	newExpr, ok := tmp.(Expr)
   288  	if !ok {
   289  		return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to rewrite AST. function expected to return Expr returned a %s", String(tmp))
   290  	}
   291  	node.Expr = newExpr
   292  	return inner.bindVars, nil
   293  }
   294  
   295  func (er *expressionRewriter) rewrite(cursor *Cursor) bool {
   296  	switch node := cursor.Node().(type) {
   297  	// select last_insert_id() -> select :__lastInsertId as `last_insert_id()`
   298  	case *Select:
   299  		for _, col := range node.SelectExprs {
   300  			_, hasStar := col.(*StarExpr)
   301  			if hasStar {
   302  				er.hasStarInSelect = true
   303  			}
   304  
   305  			aliasedExpr, ok := col.(*AliasedExpr)
   306  			if ok && aliasedExpr.As.IsEmpty() {
   307  				buf := NewTrackedBuffer(nil)
   308  				aliasedExpr.Expr.Format(buf)
   309  				innerBindVarNeeds, err := er.rewriteAliasedExpr(aliasedExpr)
   310  				if err != nil {
   311  					er.err = err
   312  					return false
   313  				}
   314  				if innerBindVarNeeds.HasRewrites() {
   315  					aliasedExpr.As = NewColIdent(buf.String())
   316  				}
   317  				er.bindVars.MergeWith(innerBindVarNeeds)
   318  			}
   319  		}
   320  		// set select limit if explicitly not set when sql_select_limit is set on the connection.
   321  		if er.selectLimit > 0 && node.Limit == nil {
   322  			node.Limit = &Limit{Rowcount: NewIntLiteral(strconv.Itoa(er.selectLimit))}
   323  		}
   324  	case *Union:
   325  		// set select limit if explicitly not set when sql_select_limit is set on the connection.
   326  		if er.selectLimit > 0 && node.Limit == nil {
   327  			node.Limit = &Limit{Rowcount: NewIntLiteral(strconv.Itoa(er.selectLimit))}
   328  		}
   329  	case *FuncExpr:
   330  		er.funcRewrite(cursor, node)
   331  	case *ColName:
   332  		switch node.Name.at {
   333  		case SingleAt:
   334  			er.udvRewrite(cursor, node)
   335  		case DoubleAt:
   336  			er.sysVarRewrite(cursor, node)
   337  		}
   338  	case *Subquery:
   339  		er.unnestSubQueries(cursor, node)
   340  	case *JoinCondition:
   341  		er.rewriteJoinCondition(cursor, node)
   342  	case *NotExpr:
   343  		switch inner := node.Expr.(type) {
   344  		case *ComparisonExpr:
   345  			// not col = 42 => col != 42
   346  			// not col > 42 => col <= 42
   347  			// etc
   348  			canChange, inverse := inverseOp(inner.Operator)
   349  			if canChange {
   350  				inner.Operator = inverse
   351  				cursor.Replace(inner)
   352  			}
   353  		case *NotExpr:
   354  			// not not true => true
   355  			cursor.Replace(inner.Expr)
   356  		case BoolVal:
   357  			// not true => false
   358  			inner = !inner
   359  			cursor.Replace(inner)
   360  		}
   361  	case *AliasedTableExpr:
   362  		if !SystemSchema(er.keyspace) {
   363  			break
   364  		}
   365  		aliasTableName, ok := node.Expr.(TableName)
   366  		if !ok {
   367  			return true
   368  		}
   369  		// Qualifier should not be added to dual table
   370  		if aliasTableName.Name.String() == "dual" {
   371  			break
   372  		}
   373  		if er.keyspace != "" && aliasTableName.Qualifier.IsEmpty() {
   374  			aliasTableName.Qualifier = NewTableIdent(er.keyspace)
   375  			node.Expr = aliasTableName
   376  			cursor.Replace(node)
   377  		}
   378  	case *ShowBasic:
   379  		if node.Command == VariableGlobal || node.Command == VariableSession {
   380  			varsToAdd := sysvars.GetInterestingVariables()
   381  			for _, sysVar := range varsToAdd {
   382  				er.bindVars.AddSysVar(sysVar)
   383  			}
   384  		}
   385  	}
   386  	return true
   387  }
   388  
   389  func inverseOp(i ComparisonExprOperator) (bool, ComparisonExprOperator) {
   390  	switch i {
   391  	case EqualOp:
   392  		return true, NotEqualOp
   393  	case LessThanOp:
   394  		return true, GreaterEqualOp
   395  	case GreaterThanOp:
   396  		return true, LessEqualOp
   397  	case LessEqualOp:
   398  		return true, GreaterThanOp
   399  	case GreaterEqualOp:
   400  		return true, LessThanOp
   401  	case NotEqualOp:
   402  		return true, EqualOp
   403  	case InOp:
   404  		return true, NotInOp
   405  	case NotInOp:
   406  		return true, InOp
   407  	case LikeOp:
   408  		return true, NotLikeOp
   409  	case NotLikeOp:
   410  		return true, LikeOp
   411  	case RegexpOp:
   412  		return true, NotRegexpOp
   413  	case NotRegexpOp:
   414  		return true, RegexpOp
   415  	}
   416  
   417  	return false, i
   418  }
   419  
   420  func (er *expressionRewriter) rewriteJoinCondition(cursor *Cursor, node *JoinCondition) {
   421  	if node.Using != nil && !er.hasStarInSelect {
   422  		joinTableExpr, ok := cursor.Parent().(*JoinTableExpr)
   423  		if !ok {
   424  			// this is not possible with the current AST
   425  			return
   426  		}
   427  		leftTable, leftOk := joinTableExpr.LeftExpr.(*AliasedTableExpr)
   428  		rightTable, rightOk := joinTableExpr.RightExpr.(*AliasedTableExpr)
   429  		if !(leftOk && rightOk) {
   430  			// we only deal with simple FROM A JOIN B USING queries at the moment
   431  			return
   432  		}
   433  		lft, err := leftTable.TableName()
   434  		if err != nil {
   435  			er.err = err
   436  			return
   437  		}
   438  		rgt, err := rightTable.TableName()
   439  		if err != nil {
   440  			er.err = err
   441  			return
   442  		}
   443  		newCondition := &JoinCondition{}
   444  		for _, colIdent := range node.Using {
   445  			lftCol := NewColNameWithQualifier(colIdent.String(), lft)
   446  			rgtCol := NewColNameWithQualifier(colIdent.String(), rgt)
   447  			cmp := &ComparisonExpr{
   448  				Operator: EqualOp,
   449  				Left:     lftCol,
   450  				Right:    rgtCol,
   451  			}
   452  			if newCondition.On == nil {
   453  				newCondition.On = cmp
   454  			} else {
   455  				newCondition.On = &AndExpr{Left: newCondition.On, Right: cmp}
   456  			}
   457  		}
   458  		cursor.Replace(newCondition)
   459  	}
   460  }
   461  
   462  func (er *expressionRewriter) sysVarRewrite(cursor *Cursor, node *ColName) {
   463  	lowered := node.Name.Lowered()
   464  	switch lowered {
   465  	case sysvars.Autocommit.Name,
   466  		sysvars.Charset.Name,
   467  		sysvars.ClientFoundRows.Name,
   468  		sysvars.DDLStrategy.Name,
   469  		sysvars.Names.Name,
   470  		sysvars.TransactionMode.Name,
   471  		sysvars.ReadAfterWriteGTID.Name,
   472  		sysvars.ReadAfterWriteTimeOut.Name,
   473  		sysvars.SessionEnableSystemSettings.Name,
   474  		sysvars.SessionTrackGTIDs.Name,
   475  		sysvars.SessionUUID.Name,
   476  		sysvars.SkipQueryPlanCache.Name,
   477  		sysvars.Socket.Name,
   478  		sysvars.SQLSelectLimit.Name,
   479  		sysvars.Version.Name,
   480  		sysvars.VersionComment.Name,
   481  		sysvars.Workload.Name:
   482  		cursor.Replace(bindVarExpression("__vt" + lowered))
   483  		er.bindVars.AddSysVar(lowered)
   484  	}
   485  }
   486  
   487  func (er *expressionRewriter) udvRewrite(cursor *Cursor, node *ColName) {
   488  	udv := strings.ToLower(node.Name.CompliantName())
   489  	cursor.Replace(bindVarExpression(UserDefinedVariableName + udv))
   490  	er.bindVars.AddUserDefVar(udv)
   491  }
   492  
   493  var funcRewrites = map[string]string{
   494  	"last_insert_id": LastInsertIDName,
   495  	"database":       DBVarName,
   496  	"schema":         DBVarName,
   497  	"found_rows":     FoundRowsName,
   498  	"row_count":      RowCountName,
   499  }
   500  
   501  func (er *expressionRewriter) funcRewrite(cursor *Cursor, node *FuncExpr) {
   502  	bindVar, found := funcRewrites[node.Name.Lowered()]
   503  	if found {
   504  		if bindVar == DBVarName && !er.shouldRewriteDatabaseFunc {
   505  			return
   506  		}
   507  		if len(node.Exprs) > 0 {
   508  			er.err = vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Argument to %s() not supported", node.Name.Lowered())
   509  			return
   510  		}
   511  		cursor.Replace(bindVarExpression(bindVar))
   512  		er.bindVars.AddFuncResult(bindVar)
   513  	}
   514  }
   515  
   516  func (er *expressionRewriter) unnestSubQueries(cursor *Cursor, subquery *Subquery) {
   517  	if _, isExists := cursor.Parent().(*ExistsExpr); isExists {
   518  		return
   519  	}
   520  	sel, isSimpleSelect := subquery.Select.(*Select)
   521  	if !isSimpleSelect {
   522  		return
   523  	}
   524  
   525  	if len(sel.SelectExprs) != 1 ||
   526  		len(sel.OrderBy) != 0 ||
   527  		len(sel.GroupBy) != 0 ||
   528  		len(sel.From) != 1 ||
   529  		sel.Where != nil ||
   530  		sel.Having != nil ||
   531  		sel.Limit != nil || sel.Lock != NoLock {
   532  		return
   533  	}
   534  
   535  	aliasedTable, ok := sel.From[0].(*AliasedTableExpr)
   536  	if !ok {
   537  		return
   538  	}
   539  	table, ok := aliasedTable.Expr.(TableName)
   540  	if !ok || table.Name.String() != "dual" {
   541  		return
   542  	}
   543  	expr, ok := sel.SelectExprs[0].(*AliasedExpr)
   544  	if !ok {
   545  		return
   546  	}
   547  	er.bindVars.NoteRewrite()
   548  	// we need to make sure that the inner expression also gets rewritten,
   549  	// so we fire off another rewriter traversal here
   550  	rewritten := Rewrite(expr.Expr, er.rewrite, nil)
   551  
   552  	// Here we need to handle the subquery rewrite in case in occurs in an IN clause
   553  	// For example, SELECT id FROM user WHERE id IN (SELECT 1 FROM DUAL)
   554  	// Here we cannot rewrite the query to SELECT id FROM user WHERE id IN 1, since that is syntactically wrong
   555  	// We must rewrite it to SELECT id FROM user WHERE id IN (1)
   556  	// Find more cases in the test file
   557  	rewrittenExpr, isExpr := rewritten.(Expr)
   558  	_, isColTuple := rewritten.(ColTuple)
   559  	comparisonExpr, isCompExpr := cursor.Parent().(*ComparisonExpr)
   560  	// Check that the parent is a comparison operator with IN or NOT IN operation.
   561  	// Also, if rewritten is already a ColTuple (like a subquery), then we do not need this
   562  	// We also need to check that rewritten is an Expr, if it is then we can rewrite it as a ValTuple
   563  	if isCompExpr && (comparisonExpr.Operator == InOp || comparisonExpr.Operator == NotInOp) && !isColTuple && isExpr {
   564  		cursor.Replace(ValTuple{rewrittenExpr})
   565  		return
   566  	}
   567  
   568  	cursor.Replace(rewritten)
   569  }
   570  
   571  func bindVarExpression(name string) Expr {
   572  	return NewArgument(name)
   573  }
   574  
   575  // SystemSchema returns true if the schema passed is system schema
   576  func SystemSchema(schema string) bool {
   577  	return strings.EqualFold(schema, "information_schema") ||
   578  		strings.EqualFold(schema, "performance_schema") ||
   579  		strings.EqualFold(schema, "sys") ||
   580  		strings.EqualFold(schema, "mysql")
   581  }
   582  
   583  // RewriteToCNF walks the input AST and rewrites any boolean logic into CNF
   584  // Note: In order to re-plan, we need to empty the accumulated metadata in the AST,
   585  // so ColName.Metadata will be nil:ed out as part of this rewrite
   586  func RewriteToCNF(ast SQLNode) SQLNode {
   587  	for {
   588  		finishedRewrite := true
   589  		ast = Rewrite(ast, func(cursor *Cursor) bool {
   590  			if e, isExpr := cursor.node.(Expr); isExpr {
   591  				rewritten, didRewrite := rewriteToCNFExpr(e)
   592  				if didRewrite {
   593  					finishedRewrite = false
   594  					cursor.Replace(rewritten)
   595  				}
   596  			}
   597  			if col, isCol := cursor.node.(*ColName); isCol {
   598  				col.Metadata = nil
   599  			}
   600  			return true
   601  		}, nil)
   602  
   603  		if finishedRewrite {
   604  			return ast
   605  		}
   606  	}
   607  }
   608  
   609  func distinctOr(in *OrExpr) (Expr, bool) {
   610  	todo := []*OrExpr{in}
   611  	var leaves []Expr
   612  	for len(todo) > 0 {
   613  		curr := todo[0]
   614  		todo = todo[1:]
   615  		addAnd := func(in Expr) {
   616  			and, ok := in.(*OrExpr)
   617  			if ok {
   618  				todo = append(todo, and)
   619  			} else {
   620  				leaves = append(leaves, in)
   621  			}
   622  		}
   623  		addAnd(curr.Left)
   624  		addAnd(curr.Right)
   625  	}
   626  	original := len(leaves)
   627  	var predicates []Expr
   628  
   629  outer1:
   630  	for len(leaves) > 0 {
   631  		curr := leaves[0]
   632  		leaves = leaves[1:]
   633  		for _, alreadyIn := range predicates {
   634  			if EqualsExpr(alreadyIn, curr) {
   635  				continue outer1
   636  			}
   637  		}
   638  		predicates = append(predicates, curr)
   639  	}
   640  	if original == len(predicates) {
   641  		return in, false
   642  	}
   643  	var result Expr
   644  	for i, curr := range predicates {
   645  		if i == 0 {
   646  			result = curr
   647  			continue
   648  		}
   649  		result = &OrExpr{Left: result, Right: curr}
   650  	}
   651  	return result, true
   652  }
   653  func distinctAnd(in *AndExpr) (Expr, bool) {
   654  	todo := []*AndExpr{in}
   655  	var leaves []Expr
   656  	for len(todo) > 0 {
   657  		curr := todo[0]
   658  		todo = todo[1:]
   659  		addAnd := func(in Expr) {
   660  			and, ok := in.(*AndExpr)
   661  			if ok {
   662  				todo = append(todo, and)
   663  			} else {
   664  				leaves = append(leaves, in)
   665  			}
   666  		}
   667  		addAnd(curr.Left)
   668  		addAnd(curr.Right)
   669  	}
   670  	original := len(leaves)
   671  	var predicates []Expr
   672  
   673  outer1:
   674  	for len(leaves) > 0 {
   675  		curr := leaves[0]
   676  		leaves = leaves[1:]
   677  		for _, alreadyIn := range predicates {
   678  			if EqualsExpr(alreadyIn, curr) {
   679  				continue outer1
   680  			}
   681  		}
   682  		predicates = append(predicates, curr)
   683  	}
   684  	if original == len(predicates) {
   685  		return in, false
   686  	}
   687  	var result Expr
   688  	for i, curr := range predicates {
   689  		if i == 0 {
   690  			result = curr
   691  			continue
   692  		}
   693  		result = &AndExpr{Left: result, Right: curr}
   694  	}
   695  	return result, true
   696  }
   697  
   698  func rewriteToCNFExpr(expr Expr) (Expr, bool) {
   699  	switch expr := expr.(type) {
   700  	case *NotExpr:
   701  		switch child := expr.Expr.(type) {
   702  		case *NotExpr:
   703  			// NOT NOT A => A
   704  			return child.Expr, true
   705  		case *OrExpr:
   706  			// DeMorgan Rewriter
   707  			// NOT (A OR B) => NOT A AND NOT B
   708  			return &AndExpr{Right: &NotExpr{Expr: child.Right}, Left: &NotExpr{Expr: child.Left}}, true
   709  		case *AndExpr:
   710  			// DeMorgan Rewriter
   711  			// NOT (A AND B) => NOT A OR NOT B
   712  			return &OrExpr{Right: &NotExpr{Expr: child.Right}, Left: &NotExpr{Expr: child.Left}}, true
   713  		}
   714  	case *OrExpr:
   715  		or := expr
   716  		if and, ok := or.Left.(*AndExpr); ok {
   717  			// Simplification
   718  			// (A AND B) OR A => A
   719  			if EqualsExpr(or.Right, and.Left) || EqualsExpr(or.Right, and.Right) {
   720  				return or.Right, true
   721  			}
   722  			// Distribution Law
   723  			// (A AND B) OR C => (A OR C) AND (B OR C)
   724  			return &AndExpr{Left: &OrExpr{Left: and.Left, Right: or.Right}, Right: &OrExpr{Left: and.Right, Right: or.Right}}, true
   725  		}
   726  		if and, ok := or.Right.(*AndExpr); ok {
   727  			// Simplification
   728  			// A OR (A AND B) => A
   729  			if EqualsExpr(or.Left, and.Left) || EqualsExpr(or.Left, and.Right) {
   730  				return or.Left, true
   731  			}
   732  			// Distribution Law
   733  			// C OR (A AND B) => (C OR A) AND (C OR B)
   734  			return &AndExpr{Left: &OrExpr{Left: or.Left, Right: and.Left}, Right: &OrExpr{Left: or.Left, Right: and.Right}}, true
   735  		}
   736  		// Try to make distinct
   737  		return distinctOr(expr)
   738  
   739  	case *XorExpr:
   740  		// DeMorgan Rewriter
   741  		// (A XOR B) => (A OR B) AND NOT (A AND B)
   742  		return &AndExpr{Left: &OrExpr{Left: expr.Left, Right: expr.Right}, Right: &NotExpr{Expr: &AndExpr{Left: expr.Left, Right: expr.Right}}}, true
   743  	case *AndExpr:
   744  		res, rewritten := distinctAnd(expr)
   745  		if rewritten {
   746  			return res, rewritten
   747  		}
   748  		and := expr
   749  		if or, ok := and.Left.(*OrExpr); ok {
   750  			// Simplification
   751  			// (A OR B) AND A => A
   752  			if EqualsExpr(or.Left, and.Right) || EqualsExpr(or.Right, and.Right) {
   753  				return and.Right, true
   754  			}
   755  		}
   756  		if or, ok := and.Right.(*OrExpr); ok {
   757  			// Simplification
   758  			// A OR (A AND B) => A
   759  			if EqualsExpr(or.Left, and.Left) || EqualsExpr(or.Right, and.Left) {
   760  				return or.Left, true
   761  			}
   762  		}
   763  
   764  	}
   765  	return expr, false
   766  }