github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/pingcap/tidb/optimizer/typeinferer.go (about) 1 // Copyright 2015 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package optimizer 15 16 import ( 17 "strings" 18 19 "github.com/insionng/yougam/libraries/pingcap/tidb/ast" 20 "github.com/insionng/yougam/libraries/pingcap/tidb/mysql" 21 "github.com/insionng/yougam/libraries/pingcap/tidb/parser/opcode" 22 "github.com/insionng/yougam/libraries/pingcap/tidb/util/charset" 23 "github.com/insionng/yougam/libraries/pingcap/tidb/util/types" 24 ) 25 26 // InferType infers result type for ast.ExprNode. 27 func InferType(node ast.Node) error { 28 var inferrer typeInferrer 29 // TODO: get the default charset from ctx 30 inferrer.defaultCharset = "utf8" 31 node.Accept(&inferrer) 32 return inferrer.err 33 } 34 35 type typeInferrer struct { 36 err error 37 defaultCharset string 38 } 39 40 func (v *typeInferrer) Enter(in ast.Node) (out ast.Node, skipChildren bool) { 41 return in, false 42 } 43 44 func (v *typeInferrer) Leave(in ast.Node) (out ast.Node, ok bool) { 45 switch x := in.(type) { 46 case *ast.AggregateFuncExpr: 47 v.aggregateFunc(x) 48 case *ast.BetweenExpr: 49 x.SetType(types.NewFieldType(mysql.TypeLonglong)) 50 x.Type.Charset = charset.CharsetBin 51 x.Type.Collate = charset.CollationBin 52 case *ast.BinaryOperationExpr: 53 v.binaryOperation(x) 54 case *ast.CaseExpr: 55 v.handleCaseExpr(x) 56 case *ast.ColumnNameExpr: 57 x.SetType(&x.Refer.Column.FieldType) 58 case *ast.CompareSubqueryExpr: 59 x.SetType(types.NewFieldType(mysql.TypeLonglong)) 60 x.Type.Charset = charset.CharsetBin 61 x.Type.Collate = charset.CollationBin 62 case *ast.ExistsSubqueryExpr: 63 x.SetType(types.NewFieldType(mysql.TypeLonglong)) 64 x.Type.Charset = charset.CharsetBin 65 x.Type.Collate = charset.CollationBin 66 case *ast.FuncCallExpr: 67 v.handleFuncCallExpr(x) 68 case *ast.FuncCastExpr: 69 // Copy a new field type. 70 tp := *x.Tp 71 x.SetType(&tp) 72 if len(x.Type.Charset) == 0 { 73 x.Type.Charset, x.Type.Collate = types.DefaultCharsetForType(x.Type.Tp) 74 } 75 case *ast.IsNullExpr: 76 x.SetType(types.NewFieldType(mysql.TypeLonglong)) 77 x.Type.Charset = charset.CharsetBin 78 x.Type.Collate = charset.CollationBin 79 case *ast.IsTruthExpr: 80 x.SetType(types.NewFieldType(mysql.TypeLonglong)) 81 x.Type.Charset = charset.CharsetBin 82 x.Type.Collate = charset.CollationBin 83 case *ast.ParamMarkerExpr: 84 x.SetType(types.DefaultTypeForValue(x.GetValue())) 85 case *ast.ParenthesesExpr: 86 x.SetType(x.Expr.GetType()) 87 case *ast.PatternInExpr: 88 x.SetType(types.NewFieldType(mysql.TypeLonglong)) 89 x.Type.Charset = charset.CharsetBin 90 x.Type.Collate = charset.CollationBin 91 case *ast.PatternLikeExpr: 92 x.SetType(types.NewFieldType(mysql.TypeLonglong)) 93 x.Type.Charset = charset.CharsetBin 94 x.Type.Collate = charset.CollationBin 95 case *ast.PatternRegexpExpr: 96 x.SetType(types.NewFieldType(mysql.TypeLonglong)) 97 x.Type.Charset = charset.CharsetBin 98 x.Type.Collate = charset.CollationBin 99 case *ast.SelectStmt: 100 v.selectStmt(x) 101 case *ast.UnaryOperationExpr: 102 v.unaryOperation(x) 103 case *ast.ValueExpr: 104 v.handleValueExpr(x) 105 case *ast.ValuesExpr: 106 v.handleValuesExpr(x) 107 case *ast.VariableExpr: 108 x.SetType(types.NewFieldType(mysql.TypeVarString)) 109 x.Type.Charset = v.defaultCharset 110 cln, err := charset.GetDefaultCollation(v.defaultCharset) 111 if err != nil { 112 v.err = err 113 } 114 x.Type.Collate = cln 115 // TODO: handle all expression types. 116 } 117 return in, true 118 } 119 120 func (v *typeInferrer) selectStmt(x *ast.SelectStmt) { 121 rf := x.GetResultFields() 122 for _, val := range rf { 123 // column ID is 0 means it is not a real column from table, but a temporary column, 124 // so its type is not pre-defined, we need to set it. 125 if val.Column.ID == 0 && val.Expr.GetType() != nil { 126 val.Column.FieldType = *(val.Expr.GetType()) 127 } 128 } 129 } 130 131 func (v *typeInferrer) aggregateFunc(x *ast.AggregateFuncExpr) { 132 name := strings.ToLower(x.F) 133 switch name { 134 case ast.AggFuncCount: 135 ft := types.NewFieldType(mysql.TypeLonglong) 136 ft.Flen = 21 137 ft.Charset = charset.CharsetBin 138 ft.Collate = charset.CollationBin 139 x.SetType(ft) 140 case ast.AggFuncMax, ast.AggFuncMin: 141 x.SetType(x.Args[0].GetType()) 142 case ast.AggFuncSum, ast.AggFuncAvg: 143 ft := types.NewFieldType(mysql.TypeNewDecimal) 144 ft.Charset = charset.CharsetBin 145 ft.Collate = charset.CollationBin 146 x.SetType(ft) 147 case ast.AggFuncGroupConcat: 148 ft := types.NewFieldType(mysql.TypeVarString) 149 ft.Charset = v.defaultCharset 150 cln, err := charset.GetDefaultCollation(v.defaultCharset) 151 if err != nil { 152 v.err = err 153 } 154 ft.Collate = cln 155 x.SetType(ft) 156 } 157 } 158 159 func (v *typeInferrer) binaryOperation(x *ast.BinaryOperationExpr) { 160 switch x.Op { 161 case opcode.AndAnd, opcode.OrOr, opcode.LogicXor: 162 x.Type = types.NewFieldType(mysql.TypeLonglong) 163 case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ: 164 x.Type = types.NewFieldType(mysql.TypeLonglong) 165 case opcode.RightShift, opcode.LeftShift, opcode.And, opcode.Or, opcode.Xor: 166 x.Type = types.NewFieldType(mysql.TypeLonglong) 167 x.Type.Flag |= mysql.UnsignedFlag 168 case opcode.IntDiv: 169 x.Type = types.NewFieldType(mysql.TypeLonglong) 170 case opcode.Plus, opcode.Minus, opcode.Mul, opcode.Mod: 171 if x.L.GetType() != nil && x.R.GetType() != nil { 172 xTp := mergeArithType(x.L.GetType().Tp, x.R.GetType().Tp) 173 x.Type = types.NewFieldType(xTp) 174 leftUnsigned := x.L.GetType().Flag & mysql.UnsignedFlag 175 rightUnsigned := x.R.GetType().Flag & mysql.UnsignedFlag 176 // If both operands are unsigned, result is unsigned. 177 x.Type.Flag |= (leftUnsigned & rightUnsigned) 178 } 179 case opcode.Div: 180 if x.L.GetType() != nil && x.R.GetType() != nil { 181 xTp := mergeArithType(x.L.GetType().Tp, x.R.GetType().Tp) 182 if xTp == mysql.TypeLonglong { 183 xTp = mysql.TypeDecimal 184 } 185 x.Type = types.NewFieldType(xTp) 186 } 187 } 188 x.Type.Charset = charset.CharsetBin 189 x.Type.Collate = charset.CollationBin 190 } 191 192 func mergeArithType(a, b byte) byte { 193 switch a { 194 case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeDouble, mysql.TypeFloat: 195 return mysql.TypeDouble 196 } 197 switch b { 198 case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeDouble, mysql.TypeFloat: 199 return mysql.TypeDouble 200 } 201 if a == mysql.TypeNewDecimal || b == mysql.TypeNewDecimal { 202 return mysql.TypeNewDecimal 203 } 204 return mysql.TypeLonglong 205 } 206 207 func (v *typeInferrer) unaryOperation(x *ast.UnaryOperationExpr) { 208 switch x.Op { 209 case opcode.Not: 210 x.Type = types.NewFieldType(mysql.TypeLonglong) 211 case opcode.BitNeg: 212 x.Type = types.NewFieldType(mysql.TypeLonglong) 213 x.Type.Flag |= mysql.UnsignedFlag 214 case opcode.Plus: 215 x.Type = x.V.GetType() 216 case opcode.Minus: 217 x.Type = types.NewFieldType(mysql.TypeLonglong) 218 if x.V.GetType() != nil { 219 switch x.V.GetType().Tp { 220 case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeDouble, mysql.TypeFloat: 221 x.Type.Tp = mysql.TypeDouble 222 case mysql.TypeNewDecimal: 223 x.Type.Tp = mysql.TypeNewDecimal 224 } 225 } 226 } 227 x.Type.Charset = charset.CharsetBin 228 x.Type.Collate = charset.CollationBin 229 } 230 231 func (v *typeInferrer) handleValueExpr(x *ast.ValueExpr) { 232 tp := types.DefaultTypeForValue(x.GetValue()) 233 // Set charset and collation 234 x.SetType(tp) 235 } 236 237 func (v *typeInferrer) handleValuesExpr(x *ast.ValuesExpr) { 238 x.SetType(x.Column.GetType()) 239 } 240 241 func (v *typeInferrer) getFsp(x *ast.FuncCallExpr) int { 242 if len(x.Args) == 1 { 243 a := x.Args[0].GetValue() 244 fsp, err := types.ToInt64(a) 245 if err != nil { 246 v.err = err 247 } 248 return int(fsp) 249 } 250 return 0 251 } 252 253 func (v *typeInferrer) handleFuncCallExpr(x *ast.FuncCallExpr) { 254 var ( 255 tp *types.FieldType 256 chs = charset.CharsetBin 257 ) 258 switch x.FnName.L { 259 case "abs", "ifnull", "nullif": 260 tp = x.Args[0].GetType() 261 // TODO: We should cover all types. 262 if x.FnName.L == "abs" && tp.Tp == mysql.TypeDatetime { 263 tp = types.NewFieldType(mysql.TypeDouble) 264 } 265 case "pow", "power", "rand": 266 tp = types.NewFieldType(mysql.TypeDouble) 267 case "curdate", "current_date", "date": 268 tp = types.NewFieldType(mysql.TypeDate) 269 case "curtime", "current_time": 270 tp = types.NewFieldType(mysql.TypeDuration) 271 tp.Decimal = v.getFsp(x) 272 case "current_timestamp", "date_arith": 273 tp = types.NewFieldType(mysql.TypeDatetime) 274 case "microsecond", "second", "minute", "hour", "day", "week", "month", "year", 275 "dayofweek", "dayofmonth", "dayofyear", "weekday", "weekofyear", "yearweek", 276 "found_rows", "length", "extract", "locate": 277 tp = types.NewFieldType(mysql.TypeLonglong) 278 case "now", "sysdate": 279 tp = types.NewFieldType(mysql.TypeDatetime) 280 tp.Decimal = v.getFsp(x) 281 case "dayname", "version", "database", "user", "current_user", 282 "concat", "concat_ws", "left", "lcase", "lower", "repeat", 283 "replace", "ucase", "upper", "convert", "substring", 284 "substring_index", "trim", "ltrim", "rtrim", "reverse": 285 tp = types.NewFieldType(mysql.TypeVarString) 286 chs = v.defaultCharset 287 case "strcmp", "isnull": 288 tp = types.NewFieldType(mysql.TypeLonglong) 289 case "connection_id": 290 tp = types.NewFieldType(mysql.TypeLonglong) 291 tp.Flag |= mysql.UnsignedFlag 292 case "if": 293 // TODO: fix this 294 // See: https://dev.mysql.com/doc/refman/5.5/en/control-flow-functions.html#function_if 295 // The default return type of IF() (which may matter when it is stored into a temporary table) is calculated as follows. 296 // Expression Return Value 297 // expr2 or expr3 returns a string string 298 // expr2 or expr3 returns a floating-point value floating-point 299 // expr2 or expr3 returns an integer integer 300 tp = x.Args[1].GetType() 301 default: 302 tp = types.NewFieldType(mysql.TypeUnspecified) 303 } 304 // If charset is unspecified. 305 if len(tp.Charset) == 0 { 306 tp.Charset = chs 307 cln := charset.CollationBin 308 if chs != charset.CharsetBin { 309 var err error 310 cln, err = charset.GetDefaultCollation(chs) 311 if err != nil { 312 v.err = err 313 } 314 } 315 tp.Collate = cln 316 } 317 x.SetType(tp) 318 } 319 320 // The return type of a CASE expression is the compatible aggregated type of all return values, 321 // but also depends on the context in which it is used. 322 // If used in a string context, the result is returned as a string. 323 // If used in a numeric context, the result is returned as a decimal, real, or integer value. 324 func (v *typeInferrer) handleCaseExpr(x *ast.CaseExpr) { 325 var currType types.FieldType 326 for _, w := range x.WhenClauses { 327 t := w.Result.GetType() 328 if currType.Tp == mysql.TypeUnspecified { 329 currType = *t 330 continue 331 } 332 mtp := types.MergeFieldType(currType.Tp, t.Tp) 333 if mtp == t.Tp && mtp != currType.Tp { 334 currType.Charset = t.Charset 335 currType.Collate = t.Collate 336 } 337 currType.Tp = mtp 338 339 } 340 if x.ElseClause != nil { 341 t := x.ElseClause.GetType() 342 if currType.Tp == mysql.TypeUnspecified { 343 currType = *t 344 } else { 345 mtp := types.MergeFieldType(currType.Tp, t.Tp) 346 if mtp == t.Tp && mtp != currType.Tp { 347 currType.Charset = t.Charset 348 currType.Collate = t.Collate 349 } 350 currType.Tp = mtp 351 } 352 } 353 x.SetType(&currType) 354 // TODO: We need a better way to set charset/collation 355 x.Type.Charset, x.Type.Collate = types.DefaultCharsetForType(x.Type.Tp) 356 }