github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/pingcap/tidb/ast/functions.go (about) 1 // Copyright 2015 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package ast 15 16 import ( 17 "bytes" 18 "fmt" 19 "strings" 20 21 "github.com/insionng/yougam/libraries/juju/errors" 22 "github.com/insionng/yougam/libraries/pingcap/tidb/model" 23 "github.com/insionng/yougam/libraries/pingcap/tidb/util/distinct" 24 "github.com/insionng/yougam/libraries/pingcap/tidb/util/types" 25 ) 26 27 var ( 28 _ FuncNode = &AggregateFuncExpr{} 29 _ FuncNode = &FuncCallExpr{} 30 _ FuncNode = &FuncCastExpr{} 31 ) 32 33 // UnquoteString is not quoted when printed. 34 type UnquoteString string 35 36 // FuncCallExpr is for function expression. 37 type FuncCallExpr struct { 38 funcNode 39 // FnName is the function name. 40 FnName model.CIStr 41 // Args is the function args. 42 Args []ExprNode 43 } 44 45 // Accept implements Node interface. 46 func (n *FuncCallExpr) Accept(v Visitor) (Node, bool) { 47 newNode, skipChildren := v.Enter(n) 48 if skipChildren { 49 return v.Leave(newNode) 50 } 51 n = newNode.(*FuncCallExpr) 52 for i, val := range n.Args { 53 node, ok := val.Accept(v) 54 if !ok { 55 return n, false 56 } 57 n.Args[i] = node.(ExprNode) 58 } 59 return v.Leave(n) 60 } 61 62 // CastFunctionType is the type for cast function. 63 type CastFunctionType int 64 65 // CastFunction types 66 const ( 67 CastFunction CastFunctionType = iota + 1 68 CastConvertFunction 69 CastBinaryOperator 70 ) 71 72 // FuncCastExpr is the cast function converting value to another type, e.g, cast(expr AS signed). 73 // See https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html 74 type FuncCastExpr struct { 75 funcNode 76 // Expr is the expression to be converted. 77 Expr ExprNode 78 // Tp is the conversion type. 79 Tp *types.FieldType 80 // Cast, Convert and Binary share this struct. 81 FunctionType CastFunctionType 82 } 83 84 // Accept implements Node Accept interface. 85 func (n *FuncCastExpr) Accept(v Visitor) (Node, bool) { 86 newNode, skipChildren := v.Enter(n) 87 if skipChildren { 88 return v.Leave(newNode) 89 } 90 n = newNode.(*FuncCastExpr) 91 node, ok := n.Expr.Accept(v) 92 if !ok { 93 return n, false 94 } 95 n.Expr = node.(ExprNode) 96 return v.Leave(n) 97 } 98 99 // TrimDirectionType is the type for trim direction. 100 type TrimDirectionType int 101 102 const ( 103 // TrimBothDefault trims from both direction by default. 104 TrimBothDefault TrimDirectionType = iota 105 // TrimBoth trims from both direction with explicit notation. 106 TrimBoth 107 // TrimLeading trims from left. 108 TrimLeading 109 // TrimTrailing trims from right. 110 TrimTrailing 111 ) 112 113 // DateArithType is type for DateArith type. 114 type DateArithType byte 115 116 const ( 117 // DateAdd is to run adddate or date_add function option. 118 // See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_adddate 119 // See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add 120 DateAdd DateArithType = iota + 1 121 // DateSub is to run subdate or date_sub function option. 122 // See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subdate 123 // See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-sub 124 DateSub 125 ) 126 127 // DateArithInterval is the struct of DateArith interval part. 128 type DateArithInterval struct { 129 Unit string 130 Interval ExprNode 131 } 132 133 const ( 134 // AggFuncCount is the name of Count function. 135 AggFuncCount = "count" 136 // AggFuncSum is the name of Sum function. 137 AggFuncSum = "sum" 138 // AggFuncAvg is the name of Avg function. 139 AggFuncAvg = "avg" 140 // AggFuncFirstRow is the name of FirstRowColumn function. 141 AggFuncFirstRow = "firstrow" 142 // AggFuncMax is the name of max function. 143 AggFuncMax = "max" 144 // AggFuncMin is the name of min function. 145 AggFuncMin = "min" 146 // AggFuncGroupConcat is the name of group_concat function. 147 AggFuncGroupConcat = "group_concat" 148 ) 149 150 // AggregateFuncExpr represents aggregate function expression. 151 type AggregateFuncExpr struct { 152 funcNode 153 // F is the function name. 154 F string 155 // Args is the function args. 156 Args []ExprNode 157 // If distinct is true, the function only aggregate distinct values. 158 // For example, column c1 values are "1", "2", "2", "sum(c1)" is "5", 159 // but "sum(distinct c1)" is "3". 160 Distinct bool 161 162 CurrentGroup string 163 // contextPerGroupMap is used to store aggregate evaluation context. 164 // Each entry for a group. 165 contextPerGroupMap map[string](*AggEvaluateContext) 166 } 167 168 // Accept implements Node Accept interface. 169 func (n *AggregateFuncExpr) Accept(v Visitor) (Node, bool) { 170 newNode, skipChildren := v.Enter(n) 171 if skipChildren { 172 return v.Leave(newNode) 173 } 174 n = newNode.(*AggregateFuncExpr) 175 for i, val := range n.Args { 176 node, ok := val.Accept(v) 177 if !ok { 178 return n, false 179 } 180 n.Args[i] = node.(ExprNode) 181 } 182 return v.Leave(n) 183 } 184 185 // Clear clears aggregate computing context. 186 func (n *AggregateFuncExpr) Clear() { 187 n.CurrentGroup = "" 188 n.contextPerGroupMap = nil 189 } 190 191 // Update is used for update aggregate context. 192 func (n *AggregateFuncExpr) Update() error { 193 name := strings.ToLower(n.F) 194 switch name { 195 case AggFuncCount: 196 return n.updateCount() 197 case AggFuncFirstRow: 198 return n.updateFirstRow() 199 case AggFuncGroupConcat: 200 return n.updateGroupConcat() 201 case AggFuncMax: 202 return n.updateMaxMin(true) 203 case AggFuncMin: 204 return n.updateMaxMin(false) 205 case AggFuncSum, AggFuncAvg: 206 return n.updateSum() 207 } 208 return nil 209 } 210 211 // GetContext gets aggregate evaluation context for the current group. 212 // If it is nil, add a new context into contextPerGroupMap. 213 func (n *AggregateFuncExpr) GetContext() *AggEvaluateContext { 214 if n.contextPerGroupMap == nil { 215 n.contextPerGroupMap = make(map[string](*AggEvaluateContext)) 216 } 217 if _, ok := n.contextPerGroupMap[n.CurrentGroup]; !ok { 218 c := &AggEvaluateContext{} 219 if n.Distinct { 220 c.distinctChecker = distinct.CreateDistinctChecker() 221 } 222 n.contextPerGroupMap[n.CurrentGroup] = c 223 } 224 return n.contextPerGroupMap[n.CurrentGroup] 225 } 226 227 func (n *AggregateFuncExpr) updateCount() error { 228 ctx := n.GetContext() 229 vals := make([]interface{}, 0, len(n.Args)) 230 for _, a := range n.Args { 231 value := a.GetValue() 232 if value == nil { 233 return nil 234 } 235 vals = append(vals, value) 236 } 237 if n.Distinct { 238 d, err := ctx.distinctChecker.Check(vals) 239 if err != nil { 240 return errors.Trace(err) 241 } 242 if !d { 243 return nil 244 } 245 } 246 ctx.Count++ 247 return nil 248 } 249 250 func (n *AggregateFuncExpr) updateFirstRow() error { 251 ctx := n.GetContext() 252 if ctx.evaluated { 253 return nil 254 } 255 if len(n.Args) != 1 { 256 return errors.New("Wrong number of args for AggFuncFirstRow") 257 } 258 ctx.Value = n.Args[0].GetValue() 259 ctx.evaluated = true 260 return nil 261 } 262 263 func (n *AggregateFuncExpr) updateMaxMin(max bool) error { 264 ctx := n.GetContext() 265 if len(n.Args) != 1 { 266 return errors.New("Wrong number of args for AggFuncFirstRow") 267 } 268 v := n.Args[0].GetValue() 269 if !ctx.evaluated { 270 ctx.Value = v 271 ctx.evaluated = true 272 return nil 273 } 274 c, err := types.Compare(ctx.Value, v) 275 if err != nil { 276 return errors.Trace(err) 277 } 278 if max { 279 if c == -1 { 280 ctx.Value = v 281 } 282 } else { 283 if c == 1 { 284 ctx.Value = v 285 } 286 287 } 288 return nil 289 } 290 291 func (n *AggregateFuncExpr) updateSum() error { 292 ctx := n.GetContext() 293 a := n.Args[0] 294 value := a.GetValue() 295 if value == nil { 296 return nil 297 } 298 if n.Distinct { 299 d, err := ctx.distinctChecker.Check([]interface{}{value}) 300 if err != nil { 301 return errors.Trace(err) 302 } 303 if !d { 304 return nil 305 } 306 } 307 var err error 308 ctx.Value, err = types.CalculateSum(ctx.Value, value) 309 if err != nil { 310 return errors.Trace(err) 311 } 312 ctx.Count++ 313 return nil 314 } 315 316 func (n *AggregateFuncExpr) updateGroupConcat() error { 317 ctx := n.GetContext() 318 vals := make([]interface{}, 0, len(n.Args)) 319 for _, a := range n.Args { 320 value := a.GetValue() 321 if value == nil { 322 return nil 323 } 324 vals = append(vals, value) 325 } 326 if n.Distinct { 327 d, err := ctx.distinctChecker.Check(vals) 328 if err != nil { 329 return errors.Trace(err) 330 } 331 if !d { 332 return nil 333 } 334 } 335 if ctx.Buffer == nil { 336 ctx.Buffer = &bytes.Buffer{} 337 } else { 338 // now use comma separator 339 ctx.Buffer.WriteString(",") 340 } 341 for _, val := range vals { 342 ctx.Buffer.WriteString(fmt.Sprintf("%v", val)) 343 } 344 // TODO: if total length is greater than global var group_concat_max_len, truncate it. 345 return nil 346 } 347 348 // AggregateFuncExtractor visits Expr tree. 349 // It converts ColunmNameExpr to AggregateFuncExpr and collects AggregateFuncExpr. 350 type AggregateFuncExtractor struct { 351 inAggregateFuncExpr bool 352 // AggFuncs is the collected AggregateFuncExprs. 353 AggFuncs []*AggregateFuncExpr 354 extracting bool 355 } 356 357 // Enter implements Visitor interface. 358 func (a *AggregateFuncExtractor) Enter(n Node) (node Node, skipChildren bool) { 359 switch n.(type) { 360 case *AggregateFuncExpr: 361 a.inAggregateFuncExpr = true 362 case *SelectStmt, *InsertStmt, *DeleteStmt, *UpdateStmt: 363 // Enter a new context, skip it. 364 // For example: select sum(c) + c + exists(select c from t) from t; 365 if a.extracting { 366 return n, true 367 } 368 } 369 a.extracting = true 370 return n, false 371 } 372 373 // Leave implements Visitor interface. 374 func (a *AggregateFuncExtractor) Leave(n Node) (node Node, ok bool) { 375 switch v := n.(type) { 376 case *AggregateFuncExpr: 377 a.inAggregateFuncExpr = false 378 a.AggFuncs = append(a.AggFuncs, v) 379 case *ColumnNameExpr: 380 // compose new AggregateFuncExpr 381 if !a.inAggregateFuncExpr { 382 // For example: select sum(c) + c from t; 383 // The c in sum() should be evaluated for each row. 384 // The c after plus should be evaluated only once. 385 agg := &AggregateFuncExpr{ 386 F: AggFuncFirstRow, 387 Args: []ExprNode{v}, 388 } 389 agg.SetFlag((v.GetFlag() | FlagHasAggregateFunc)) 390 a.AggFuncs = append(a.AggFuncs, agg) 391 return agg, true 392 } 393 } 394 return n, true 395 } 396 397 // AggEvaluateContext is used to store intermediate result when caculation aggregate functions. 398 type AggEvaluateContext struct { 399 distinctChecker *distinct.Checker 400 Count int64 401 Value interface{} 402 Buffer *bytes.Buffer // Buffer is used for group_concat. 403 evaluated bool 404 }