github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/pingcap/tidb/optimizer/validator.go (about)

     1  // Copyright 2015 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package optimizer
    15  
    16  import (
    17  	"math"
    18  
    19  	"github.com/insionng/yougam/libraries/juju/errors"
    20  	"github.com/insionng/yougam/libraries/pingcap/tidb/ast"
    21  	"github.com/insionng/yougam/libraries/pingcap/tidb/mysql"
    22  	"github.com/insionng/yougam/libraries/pingcap/tidb/parser"
    23  	"github.com/insionng/yougam/libraries/pingcap/tidb/parser/opcode"
    24  )
    25  
    26  // Validate checkes whether the node is valid.
    27  func Validate(node ast.Node, inPrepare bool) error {
    28  	v := validator{inPrepare: inPrepare}
    29  	node.Accept(&v)
    30  	return v.err
    31  }
    32  
    33  // validator is an ast.Visitor that validates
    34  // ast Nodes parsed from parser.
    35  type validator struct {
    36  	err           error
    37  	wildCardCount int
    38  	inPrepare     bool
    39  	inAggregate   bool
    40  }
    41  
    42  func (v *validator) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
    43  	switch in.(type) {
    44  	case *ast.AggregateFuncExpr:
    45  		if v.inAggregate {
    46  			// Aggregate function can not contain aggregate function.
    47  			v.err = ErrInvalidGroupFuncUse
    48  			return in, true
    49  		}
    50  		v.inAggregate = true
    51  	}
    52  	return in, false
    53  }
    54  
    55  func (v *validator) Leave(in ast.Node) (out ast.Node, ok bool) {
    56  	switch x := in.(type) {
    57  	case *ast.AggregateFuncExpr:
    58  		v.inAggregate = false
    59  	case *ast.BetweenExpr:
    60  		v.checkAllOneColumn(x.Expr, x.Left, x.Right)
    61  	case *ast.BinaryOperationExpr:
    62  		v.checkBinaryOperation(x)
    63  	case *ast.ByItem:
    64  		v.checkAllOneColumn(x.Expr)
    65  	case *ast.CreateTableStmt:
    66  		v.checkAutoIncrement(x)
    67  	case *ast.CompareSubqueryExpr:
    68  		v.checkSameColumns(x.L, x.R)
    69  	case *ast.FieldList:
    70  		v.checkFieldList(x)
    71  	case *ast.HavingClause:
    72  		v.checkAllOneColumn(x.Expr)
    73  	case *ast.IsNullExpr:
    74  		v.checkAllOneColumn(x.Expr)
    75  	case *ast.IsTruthExpr:
    76  		v.checkAllOneColumn(x.Expr)
    77  	case *ast.ParamMarkerExpr:
    78  		if !v.inPrepare {
    79  			v.err = parser.ErrSyntax.Gen("syntax error, unexpected '?'")
    80  		}
    81  	case *ast.PatternInExpr:
    82  		v.checkSameColumns(append(x.List, x.Expr)...)
    83  	case *ast.Limit:
    84  		if x.Count > math.MaxUint64-x.Offset {
    85  			x.Count = math.MaxUint64 - x.Offset
    86  		}
    87  	}
    88  
    89  	return in, v.err == nil
    90  }
    91  
    92  // checkAllOneColumn checks that all expressions have one column.
    93  // Expression may have more than one column when it is a rowExpr or
    94  // a Subquery with more than one result fields.
    95  func (v *validator) checkAllOneColumn(exprs ...ast.ExprNode) {
    96  	for _, expr := range exprs {
    97  		switch x := expr.(type) {
    98  		case *ast.RowExpr:
    99  			v.err = ErrOneColumn
   100  		case *ast.SubqueryExpr:
   101  			if len(x.Query.GetResultFields()) != 1 {
   102  				v.err = ErrOneColumn
   103  			}
   104  		}
   105  	}
   106  	return
   107  }
   108  
   109  func checkAutoIncrementOp(colDef *ast.ColumnDef, num int) (bool, error) {
   110  	var hasAutoIncrement bool
   111  
   112  	if colDef.Options[num].Tp == ast.ColumnOptionAutoIncrement {
   113  		hasAutoIncrement = true
   114  		if len(colDef.Options) == num+1 {
   115  			return hasAutoIncrement, nil
   116  		}
   117  		for _, op := range colDef.Options[num+1:] {
   118  			if op.Tp == ast.ColumnOptionDefaultValue {
   119  				return hasAutoIncrement, errors.Errorf("Invalid default value for '%s'", colDef.Name.Name.O)
   120  			}
   121  		}
   122  	}
   123  	if colDef.Options[num].Tp == ast.ColumnOptionDefaultValue && len(colDef.Options) != num+1 {
   124  		for _, op := range colDef.Options[num+1:] {
   125  			if op.Tp == ast.ColumnOptionAutoIncrement {
   126  				return hasAutoIncrement, errors.Errorf("Invalid default value for '%s'", colDef.Name.Name.O)
   127  			}
   128  		}
   129  	}
   130  
   131  	return hasAutoIncrement, nil
   132  }
   133  
   134  func isConstraintKeyTp(constraints []*ast.Constraint, colDef *ast.ColumnDef) bool {
   135  	for _, c := range constraints {
   136  		if len(c.Keys) < 1 {
   137  		}
   138  		// If the constraint as follows: primary key(c1, c2)
   139  		// we only support c1 column can be auto_increment.
   140  		if colDef.Name.Name.L != c.Keys[0].Column.Name.L {
   141  			continue
   142  		}
   143  		switch c.Tp {
   144  		case ast.ConstraintPrimaryKey, ast.ConstraintKey, ast.ConstraintIndex,
   145  			ast.ConstraintUniq, ast.ConstraintUniqIndex, ast.ConstraintUniqKey:
   146  			return true
   147  		}
   148  	}
   149  
   150  	return false
   151  }
   152  
   153  func (v *validator) checkAutoIncrement(stmt *ast.CreateTableStmt) {
   154  	var (
   155  		isKey            bool
   156  		count            int
   157  		autoIncrementCol *ast.ColumnDef
   158  	)
   159  
   160  	for _, colDef := range stmt.Cols {
   161  		var hasAutoIncrement bool
   162  		for i, op := range colDef.Options {
   163  			ok, err := checkAutoIncrementOp(colDef, i)
   164  			if err != nil {
   165  				v.err = err
   166  				return
   167  			}
   168  			if ok {
   169  				hasAutoIncrement = true
   170  			}
   171  			switch op.Tp {
   172  			case ast.ColumnOptionPrimaryKey, ast.ColumnOptionUniqKey, ast.ColumnOptionUniqIndex,
   173  				ast.ColumnOptionUniq, ast.ColumnOptionKey, ast.ColumnOptionIndex:
   174  				isKey = true
   175  			}
   176  		}
   177  		if hasAutoIncrement {
   178  			count++
   179  			autoIncrementCol = colDef
   180  		}
   181  	}
   182  
   183  	if count < 1 {
   184  		return
   185  	}
   186  
   187  	if !isKey {
   188  		isKey = isConstraintKeyTp(stmt.Constraints, autoIncrementCol)
   189  	}
   190  	if !isKey || count > 1 {
   191  		v.err = errors.New("Incorrect table definition; there can be only one auto column and it must be defined as a key")
   192  	}
   193  
   194  	switch autoIncrementCol.Tp.Tp {
   195  	case mysql.TypeTiny, mysql.TypeShort, mysql.TypeLong,
   196  		mysql.TypeFloat, mysql.TypeDouble, mysql.TypeLonglong, mysql.TypeInt24:
   197  	default:
   198  		v.err = errors.Errorf("Incorrect column specifier for column '%s'", autoIncrementCol.Name.Name.O)
   199  	}
   200  }
   201  
   202  func (v *validator) checkBinaryOperation(x *ast.BinaryOperationExpr) {
   203  	// row constructor only supports comparison operation.
   204  	switch x.Op {
   205  	case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ:
   206  		v.checkSameColumns(x.L, x.R)
   207  	default:
   208  		v.checkAllOneColumn(x.L, x.R)
   209  	}
   210  }
   211  
   212  func columnCount(ex ast.ExprNode) int {
   213  	switch x := ex.(type) {
   214  	case *ast.RowExpr:
   215  		return len(x.Values)
   216  	case *ast.SubqueryExpr:
   217  		return len(x.Query.GetResultFields())
   218  	default:
   219  		return 1
   220  	}
   221  }
   222  
   223  func (v *validator) checkSameColumns(exprs ...ast.ExprNode) {
   224  	if len(exprs) == 0 {
   225  		return
   226  	}
   227  	count := columnCount(exprs[0])
   228  	for i := 1; i < len(exprs); i++ {
   229  		if columnCount(exprs[i]) != count {
   230  			v.err = ErrSameColumns
   231  			return
   232  		}
   233  	}
   234  }
   235  
   236  // checkFieldList checks if there is only one '*' and each field has only one column.
   237  func (v *validator) checkFieldList(x *ast.FieldList) {
   238  	var hasWildCard bool
   239  	for _, val := range x.Fields {
   240  		if val.WildCard != nil && val.WildCard.Table.L == "" {
   241  			if hasWildCard {
   242  				v.err = ErrMultiWildCard
   243  				return
   244  			}
   245  			hasWildCard = true
   246  		}
   247  		v.checkAllOneColumn(val.Expr)
   248  		if v.err != nil {
   249  			return
   250  		}
   251  	}
   252  }