github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/pingcap/tidb/optimizer/typeinferer.go (about)

     1  // Copyright 2015 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package optimizer
    15  
    16  import (
    17  	"strings"
    18  
    19  	"github.com/insionng/yougam/libraries/pingcap/tidb/ast"
    20  	"github.com/insionng/yougam/libraries/pingcap/tidb/mysql"
    21  	"github.com/insionng/yougam/libraries/pingcap/tidb/parser/opcode"
    22  	"github.com/insionng/yougam/libraries/pingcap/tidb/util/charset"
    23  	"github.com/insionng/yougam/libraries/pingcap/tidb/util/types"
    24  )
    25  
    26  // InferType infers result type for ast.ExprNode.
    27  func InferType(node ast.Node) error {
    28  	var inferrer typeInferrer
    29  	// TODO: get the default charset from ctx
    30  	inferrer.defaultCharset = "utf8"
    31  	node.Accept(&inferrer)
    32  	return inferrer.err
    33  }
    34  
    35  type typeInferrer struct {
    36  	err            error
    37  	defaultCharset string
    38  }
    39  
    40  func (v *typeInferrer) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
    41  	return in, false
    42  }
    43  
    44  func (v *typeInferrer) Leave(in ast.Node) (out ast.Node, ok bool) {
    45  	switch x := in.(type) {
    46  	case *ast.AggregateFuncExpr:
    47  		v.aggregateFunc(x)
    48  	case *ast.BetweenExpr:
    49  		x.SetType(types.NewFieldType(mysql.TypeLonglong))
    50  		x.Type.Charset = charset.CharsetBin
    51  		x.Type.Collate = charset.CollationBin
    52  	case *ast.BinaryOperationExpr:
    53  		v.binaryOperation(x)
    54  	case *ast.CaseExpr:
    55  		v.handleCaseExpr(x)
    56  	case *ast.ColumnNameExpr:
    57  		x.SetType(&x.Refer.Column.FieldType)
    58  	case *ast.CompareSubqueryExpr:
    59  		x.SetType(types.NewFieldType(mysql.TypeLonglong))
    60  		x.Type.Charset = charset.CharsetBin
    61  		x.Type.Collate = charset.CollationBin
    62  	case *ast.ExistsSubqueryExpr:
    63  		x.SetType(types.NewFieldType(mysql.TypeLonglong))
    64  		x.Type.Charset = charset.CharsetBin
    65  		x.Type.Collate = charset.CollationBin
    66  	case *ast.FuncCallExpr:
    67  		v.handleFuncCallExpr(x)
    68  	case *ast.FuncCastExpr:
    69  		// Copy a new field type.
    70  		tp := *x.Tp
    71  		x.SetType(&tp)
    72  		if len(x.Type.Charset) == 0 {
    73  			x.Type.Charset, x.Type.Collate = types.DefaultCharsetForType(x.Type.Tp)
    74  		}
    75  	case *ast.IsNullExpr:
    76  		x.SetType(types.NewFieldType(mysql.TypeLonglong))
    77  		x.Type.Charset = charset.CharsetBin
    78  		x.Type.Collate = charset.CollationBin
    79  	case *ast.IsTruthExpr:
    80  		x.SetType(types.NewFieldType(mysql.TypeLonglong))
    81  		x.Type.Charset = charset.CharsetBin
    82  		x.Type.Collate = charset.CollationBin
    83  	case *ast.ParamMarkerExpr:
    84  		x.SetType(types.DefaultTypeForValue(x.GetValue()))
    85  	case *ast.ParenthesesExpr:
    86  		x.SetType(x.Expr.GetType())
    87  	case *ast.PatternInExpr:
    88  		x.SetType(types.NewFieldType(mysql.TypeLonglong))
    89  		x.Type.Charset = charset.CharsetBin
    90  		x.Type.Collate = charset.CollationBin
    91  	case *ast.PatternLikeExpr:
    92  		x.SetType(types.NewFieldType(mysql.TypeLonglong))
    93  		x.Type.Charset = charset.CharsetBin
    94  		x.Type.Collate = charset.CollationBin
    95  	case *ast.PatternRegexpExpr:
    96  		x.SetType(types.NewFieldType(mysql.TypeLonglong))
    97  		x.Type.Charset = charset.CharsetBin
    98  		x.Type.Collate = charset.CollationBin
    99  	case *ast.SelectStmt:
   100  		v.selectStmt(x)
   101  	case *ast.UnaryOperationExpr:
   102  		v.unaryOperation(x)
   103  	case *ast.ValueExpr:
   104  		v.handleValueExpr(x)
   105  	case *ast.ValuesExpr:
   106  		v.handleValuesExpr(x)
   107  	case *ast.VariableExpr:
   108  		x.SetType(types.NewFieldType(mysql.TypeVarString))
   109  		x.Type.Charset = v.defaultCharset
   110  		cln, err := charset.GetDefaultCollation(v.defaultCharset)
   111  		if err != nil {
   112  			v.err = err
   113  		}
   114  		x.Type.Collate = cln
   115  		// TODO: handle all expression types.
   116  	}
   117  	return in, true
   118  }
   119  
   120  func (v *typeInferrer) selectStmt(x *ast.SelectStmt) {
   121  	rf := x.GetResultFields()
   122  	for _, val := range rf {
   123  		// column ID is 0 means it is not a real column from table, but a temporary column,
   124  		// so its type is not pre-defined, we need to set it.
   125  		if val.Column.ID == 0 && val.Expr.GetType() != nil {
   126  			val.Column.FieldType = *(val.Expr.GetType())
   127  		}
   128  	}
   129  }
   130  
   131  func (v *typeInferrer) aggregateFunc(x *ast.AggregateFuncExpr) {
   132  	name := strings.ToLower(x.F)
   133  	switch name {
   134  	case ast.AggFuncCount:
   135  		ft := types.NewFieldType(mysql.TypeLonglong)
   136  		ft.Flen = 21
   137  		ft.Charset = charset.CharsetBin
   138  		ft.Collate = charset.CollationBin
   139  		x.SetType(ft)
   140  	case ast.AggFuncMax, ast.AggFuncMin:
   141  		x.SetType(x.Args[0].GetType())
   142  	case ast.AggFuncSum, ast.AggFuncAvg:
   143  		ft := types.NewFieldType(mysql.TypeNewDecimal)
   144  		ft.Charset = charset.CharsetBin
   145  		ft.Collate = charset.CollationBin
   146  		x.SetType(ft)
   147  	case ast.AggFuncGroupConcat:
   148  		ft := types.NewFieldType(mysql.TypeVarString)
   149  		ft.Charset = v.defaultCharset
   150  		cln, err := charset.GetDefaultCollation(v.defaultCharset)
   151  		if err != nil {
   152  			v.err = err
   153  		}
   154  		ft.Collate = cln
   155  		x.SetType(ft)
   156  	}
   157  }
   158  
   159  func (v *typeInferrer) binaryOperation(x *ast.BinaryOperationExpr) {
   160  	switch x.Op {
   161  	case opcode.AndAnd, opcode.OrOr, opcode.LogicXor:
   162  		x.Type = types.NewFieldType(mysql.TypeLonglong)
   163  	case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ:
   164  		x.Type = types.NewFieldType(mysql.TypeLonglong)
   165  	case opcode.RightShift, opcode.LeftShift, opcode.And, opcode.Or, opcode.Xor:
   166  		x.Type = types.NewFieldType(mysql.TypeLonglong)
   167  		x.Type.Flag |= mysql.UnsignedFlag
   168  	case opcode.IntDiv:
   169  		x.Type = types.NewFieldType(mysql.TypeLonglong)
   170  	case opcode.Plus, opcode.Minus, opcode.Mul, opcode.Mod:
   171  		if x.L.GetType() != nil && x.R.GetType() != nil {
   172  			xTp := mergeArithType(x.L.GetType().Tp, x.R.GetType().Tp)
   173  			x.Type = types.NewFieldType(xTp)
   174  			leftUnsigned := x.L.GetType().Flag & mysql.UnsignedFlag
   175  			rightUnsigned := x.R.GetType().Flag & mysql.UnsignedFlag
   176  			// If both operands are unsigned, result is unsigned.
   177  			x.Type.Flag |= (leftUnsigned & rightUnsigned)
   178  		}
   179  	case opcode.Div:
   180  		if x.L.GetType() != nil && x.R.GetType() != nil {
   181  			xTp := mergeArithType(x.L.GetType().Tp, x.R.GetType().Tp)
   182  			if xTp == mysql.TypeLonglong {
   183  				xTp = mysql.TypeDecimal
   184  			}
   185  			x.Type = types.NewFieldType(xTp)
   186  		}
   187  	}
   188  	x.Type.Charset = charset.CharsetBin
   189  	x.Type.Collate = charset.CollationBin
   190  }
   191  
   192  func mergeArithType(a, b byte) byte {
   193  	switch a {
   194  	case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeDouble, mysql.TypeFloat:
   195  		return mysql.TypeDouble
   196  	}
   197  	switch b {
   198  	case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeDouble, mysql.TypeFloat:
   199  		return mysql.TypeDouble
   200  	}
   201  	if a == mysql.TypeNewDecimal || b == mysql.TypeNewDecimal {
   202  		return mysql.TypeNewDecimal
   203  	}
   204  	return mysql.TypeLonglong
   205  }
   206  
   207  func (v *typeInferrer) unaryOperation(x *ast.UnaryOperationExpr) {
   208  	switch x.Op {
   209  	case opcode.Not:
   210  		x.Type = types.NewFieldType(mysql.TypeLonglong)
   211  	case opcode.BitNeg:
   212  		x.Type = types.NewFieldType(mysql.TypeLonglong)
   213  		x.Type.Flag |= mysql.UnsignedFlag
   214  	case opcode.Plus:
   215  		x.Type = x.V.GetType()
   216  	case opcode.Minus:
   217  		x.Type = types.NewFieldType(mysql.TypeLonglong)
   218  		if x.V.GetType() != nil {
   219  			switch x.V.GetType().Tp {
   220  			case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeDouble, mysql.TypeFloat:
   221  				x.Type.Tp = mysql.TypeDouble
   222  			case mysql.TypeNewDecimal:
   223  				x.Type.Tp = mysql.TypeNewDecimal
   224  			}
   225  		}
   226  	}
   227  	x.Type.Charset = charset.CharsetBin
   228  	x.Type.Collate = charset.CollationBin
   229  }
   230  
   231  func (v *typeInferrer) handleValueExpr(x *ast.ValueExpr) {
   232  	tp := types.DefaultTypeForValue(x.GetValue())
   233  	// Set charset and collation
   234  	x.SetType(tp)
   235  }
   236  
   237  func (v *typeInferrer) handleValuesExpr(x *ast.ValuesExpr) {
   238  	x.SetType(x.Column.GetType())
   239  }
   240  
   241  func (v *typeInferrer) getFsp(x *ast.FuncCallExpr) int {
   242  	if len(x.Args) == 1 {
   243  		a := x.Args[0].GetValue()
   244  		fsp, err := types.ToInt64(a)
   245  		if err != nil {
   246  			v.err = err
   247  		}
   248  		return int(fsp)
   249  	}
   250  	return 0
   251  }
   252  
   253  func (v *typeInferrer) handleFuncCallExpr(x *ast.FuncCallExpr) {
   254  	var (
   255  		tp  *types.FieldType
   256  		chs = charset.CharsetBin
   257  	)
   258  	switch x.FnName.L {
   259  	case "abs", "ifnull", "nullif":
   260  		tp = x.Args[0].GetType()
   261  		// TODO: We should cover all types.
   262  		if x.FnName.L == "abs" && tp.Tp == mysql.TypeDatetime {
   263  			tp = types.NewFieldType(mysql.TypeDouble)
   264  		}
   265  	case "pow", "power", "rand":
   266  		tp = types.NewFieldType(mysql.TypeDouble)
   267  	case "curdate", "current_date", "date":
   268  		tp = types.NewFieldType(mysql.TypeDate)
   269  	case "curtime", "current_time":
   270  		tp = types.NewFieldType(mysql.TypeDuration)
   271  		tp.Decimal = v.getFsp(x)
   272  	case "current_timestamp", "date_arith":
   273  		tp = types.NewFieldType(mysql.TypeDatetime)
   274  	case "microsecond", "second", "minute", "hour", "day", "week", "month", "year",
   275  		"dayofweek", "dayofmonth", "dayofyear", "weekday", "weekofyear", "yearweek",
   276  		"found_rows", "length", "extract", "locate":
   277  		tp = types.NewFieldType(mysql.TypeLonglong)
   278  	case "now", "sysdate":
   279  		tp = types.NewFieldType(mysql.TypeDatetime)
   280  		tp.Decimal = v.getFsp(x)
   281  	case "dayname", "version", "database", "user", "current_user",
   282  		"concat", "concat_ws", "left", "lcase", "lower", "repeat",
   283  		"replace", "ucase", "upper", "convert", "substring",
   284  		"substring_index", "trim", "ltrim", "rtrim", "reverse":
   285  		tp = types.NewFieldType(mysql.TypeVarString)
   286  		chs = v.defaultCharset
   287  	case "strcmp", "isnull":
   288  		tp = types.NewFieldType(mysql.TypeLonglong)
   289  	case "connection_id":
   290  		tp = types.NewFieldType(mysql.TypeLonglong)
   291  		tp.Flag |= mysql.UnsignedFlag
   292  	case "if":
   293  		// TODO: fix this
   294  		// See: https://dev.mysql.com/doc/refman/5.5/en/control-flow-functions.html#function_if
   295  		// The default return type of IF() (which may matter when it is stored into a temporary table) is calculated as follows.
   296  		// Expression	Return Value
   297  		// expr2 or expr3 returns a string	string
   298  		// expr2 or expr3 returns a floating-point value	floating-point
   299  		// expr2 or expr3 returns an integer	integer
   300  		tp = x.Args[1].GetType()
   301  	default:
   302  		tp = types.NewFieldType(mysql.TypeUnspecified)
   303  	}
   304  	// If charset is unspecified.
   305  	if len(tp.Charset) == 0 {
   306  		tp.Charset = chs
   307  		cln := charset.CollationBin
   308  		if chs != charset.CharsetBin {
   309  			var err error
   310  			cln, err = charset.GetDefaultCollation(chs)
   311  			if err != nil {
   312  				v.err = err
   313  			}
   314  		}
   315  		tp.Collate = cln
   316  	}
   317  	x.SetType(tp)
   318  }
   319  
   320  // The return type of a CASE expression is the compatible aggregated type of all return values,
   321  // but also depends on the context in which it is used.
   322  // If used in a string context, the result is returned as a string.
   323  // If used in a numeric context, the result is returned as a decimal, real, or integer value.
   324  func (v *typeInferrer) handleCaseExpr(x *ast.CaseExpr) {
   325  	var currType types.FieldType
   326  	for _, w := range x.WhenClauses {
   327  		t := w.Result.GetType()
   328  		if currType.Tp == mysql.TypeUnspecified {
   329  			currType = *t
   330  			continue
   331  		}
   332  		mtp := types.MergeFieldType(currType.Tp, t.Tp)
   333  		if mtp == t.Tp && mtp != currType.Tp {
   334  			currType.Charset = t.Charset
   335  			currType.Collate = t.Collate
   336  		}
   337  		currType.Tp = mtp
   338  
   339  	}
   340  	if x.ElseClause != nil {
   341  		t := x.ElseClause.GetType()
   342  		if currType.Tp == mysql.TypeUnspecified {
   343  			currType = *t
   344  		} else {
   345  			mtp := types.MergeFieldType(currType.Tp, t.Tp)
   346  			if mtp == t.Tp && mtp != currType.Tp {
   347  				currType.Charset = t.Charset
   348  				currType.Collate = t.Collate
   349  			}
   350  			currType.Tp = mtp
   351  		}
   352  	}
   353  	x.SetType(&currType)
   354  	// TODO: We need a better way to set charset/collation
   355  	x.Type.Charset, x.Type.Collate = types.DefaultCharsetForType(x.Type.Tp)
   356  }