github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/pingcap/tidb/ast/functions.go (about)

     1  // Copyright 2015 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package ast
    15  
    16  import (
    17  	"bytes"
    18  	"fmt"
    19  	"strings"
    20  
    21  	"github.com/insionng/yougam/libraries/juju/errors"
    22  	"github.com/insionng/yougam/libraries/pingcap/tidb/model"
    23  	"github.com/insionng/yougam/libraries/pingcap/tidb/util/distinct"
    24  	"github.com/insionng/yougam/libraries/pingcap/tidb/util/types"
    25  )
    26  
    27  var (
    28  	_ FuncNode = &AggregateFuncExpr{}
    29  	_ FuncNode = &FuncCallExpr{}
    30  	_ FuncNode = &FuncCastExpr{}
    31  )
    32  
    33  // UnquoteString is not quoted when printed.
    34  type UnquoteString string
    35  
    36  // FuncCallExpr is for function expression.
    37  type FuncCallExpr struct {
    38  	funcNode
    39  	// FnName is the function name.
    40  	FnName model.CIStr
    41  	// Args is the function args.
    42  	Args []ExprNode
    43  }
    44  
    45  // Accept implements Node interface.
    46  func (n *FuncCallExpr) Accept(v Visitor) (Node, bool) {
    47  	newNode, skipChildren := v.Enter(n)
    48  	if skipChildren {
    49  		return v.Leave(newNode)
    50  	}
    51  	n = newNode.(*FuncCallExpr)
    52  	for i, val := range n.Args {
    53  		node, ok := val.Accept(v)
    54  		if !ok {
    55  			return n, false
    56  		}
    57  		n.Args[i] = node.(ExprNode)
    58  	}
    59  	return v.Leave(n)
    60  }
    61  
    62  // CastFunctionType is the type for cast function.
    63  type CastFunctionType int
    64  
    65  // CastFunction types
    66  const (
    67  	CastFunction CastFunctionType = iota + 1
    68  	CastConvertFunction
    69  	CastBinaryOperator
    70  )
    71  
    72  // FuncCastExpr is the cast function converting value to another type, e.g, cast(expr AS signed).
    73  // See https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html
    74  type FuncCastExpr struct {
    75  	funcNode
    76  	// Expr is the expression to be converted.
    77  	Expr ExprNode
    78  	// Tp is the conversion type.
    79  	Tp *types.FieldType
    80  	// Cast, Convert and Binary share this struct.
    81  	FunctionType CastFunctionType
    82  }
    83  
    84  // Accept implements Node Accept interface.
    85  func (n *FuncCastExpr) Accept(v Visitor) (Node, bool) {
    86  	newNode, skipChildren := v.Enter(n)
    87  	if skipChildren {
    88  		return v.Leave(newNode)
    89  	}
    90  	n = newNode.(*FuncCastExpr)
    91  	node, ok := n.Expr.Accept(v)
    92  	if !ok {
    93  		return n, false
    94  	}
    95  	n.Expr = node.(ExprNode)
    96  	return v.Leave(n)
    97  }
    98  
    99  // TrimDirectionType is the type for trim direction.
   100  type TrimDirectionType int
   101  
   102  const (
   103  	// TrimBothDefault trims from both direction by default.
   104  	TrimBothDefault TrimDirectionType = iota
   105  	// TrimBoth trims from both direction with explicit notation.
   106  	TrimBoth
   107  	// TrimLeading trims from left.
   108  	TrimLeading
   109  	// TrimTrailing trims from right.
   110  	TrimTrailing
   111  )
   112  
   113  // DateArithType is type for DateArith type.
   114  type DateArithType byte
   115  
   116  const (
   117  	// DateAdd is to run adddate or date_add function option.
   118  	// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_adddate
   119  	// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add
   120  	DateAdd DateArithType = iota + 1
   121  	// DateSub is to run subdate or date_sub function option.
   122  	// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subdate
   123  	// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-sub
   124  	DateSub
   125  )
   126  
   127  // DateArithInterval is the struct of DateArith interval part.
   128  type DateArithInterval struct {
   129  	Unit     string
   130  	Interval ExprNode
   131  }
   132  
   133  const (
   134  	// AggFuncCount is the name of Count function.
   135  	AggFuncCount = "count"
   136  	// AggFuncSum is the name of Sum function.
   137  	AggFuncSum = "sum"
   138  	// AggFuncAvg is the name of Avg function.
   139  	AggFuncAvg = "avg"
   140  	// AggFuncFirstRow is the name of FirstRowColumn function.
   141  	AggFuncFirstRow = "firstrow"
   142  	// AggFuncMax is the name of max function.
   143  	AggFuncMax = "max"
   144  	// AggFuncMin is the name of min function.
   145  	AggFuncMin = "min"
   146  	// AggFuncGroupConcat is the name of group_concat function.
   147  	AggFuncGroupConcat = "group_concat"
   148  )
   149  
   150  // AggregateFuncExpr represents aggregate function expression.
   151  type AggregateFuncExpr struct {
   152  	funcNode
   153  	// F is the function name.
   154  	F string
   155  	// Args is the function args.
   156  	Args []ExprNode
   157  	// If distinct is true, the function only aggregate distinct values.
   158  	// For example, column c1 values are "1", "2", "2",  "sum(c1)" is "5",
   159  	// but "sum(distinct c1)" is "3".
   160  	Distinct bool
   161  
   162  	CurrentGroup string
   163  	// contextPerGroupMap is used to store aggregate evaluation context.
   164  	// Each entry for a group.
   165  	contextPerGroupMap map[string](*AggEvaluateContext)
   166  }
   167  
   168  // Accept implements Node Accept interface.
   169  func (n *AggregateFuncExpr) Accept(v Visitor) (Node, bool) {
   170  	newNode, skipChildren := v.Enter(n)
   171  	if skipChildren {
   172  		return v.Leave(newNode)
   173  	}
   174  	n = newNode.(*AggregateFuncExpr)
   175  	for i, val := range n.Args {
   176  		node, ok := val.Accept(v)
   177  		if !ok {
   178  			return n, false
   179  		}
   180  		n.Args[i] = node.(ExprNode)
   181  	}
   182  	return v.Leave(n)
   183  }
   184  
   185  // Clear clears aggregate computing context.
   186  func (n *AggregateFuncExpr) Clear() {
   187  	n.CurrentGroup = ""
   188  	n.contextPerGroupMap = nil
   189  }
   190  
   191  // Update is used for update aggregate context.
   192  func (n *AggregateFuncExpr) Update() error {
   193  	name := strings.ToLower(n.F)
   194  	switch name {
   195  	case AggFuncCount:
   196  		return n.updateCount()
   197  	case AggFuncFirstRow:
   198  		return n.updateFirstRow()
   199  	case AggFuncGroupConcat:
   200  		return n.updateGroupConcat()
   201  	case AggFuncMax:
   202  		return n.updateMaxMin(true)
   203  	case AggFuncMin:
   204  		return n.updateMaxMin(false)
   205  	case AggFuncSum, AggFuncAvg:
   206  		return n.updateSum()
   207  	}
   208  	return nil
   209  }
   210  
   211  // GetContext gets aggregate evaluation context for the current group.
   212  // If it is nil, add a new context into contextPerGroupMap.
   213  func (n *AggregateFuncExpr) GetContext() *AggEvaluateContext {
   214  	if n.contextPerGroupMap == nil {
   215  		n.contextPerGroupMap = make(map[string](*AggEvaluateContext))
   216  	}
   217  	if _, ok := n.contextPerGroupMap[n.CurrentGroup]; !ok {
   218  		c := &AggEvaluateContext{}
   219  		if n.Distinct {
   220  			c.distinctChecker = distinct.CreateDistinctChecker()
   221  		}
   222  		n.contextPerGroupMap[n.CurrentGroup] = c
   223  	}
   224  	return n.contextPerGroupMap[n.CurrentGroup]
   225  }
   226  
   227  func (n *AggregateFuncExpr) updateCount() error {
   228  	ctx := n.GetContext()
   229  	vals := make([]interface{}, 0, len(n.Args))
   230  	for _, a := range n.Args {
   231  		value := a.GetValue()
   232  		if value == nil {
   233  			return nil
   234  		}
   235  		vals = append(vals, value)
   236  	}
   237  	if n.Distinct {
   238  		d, err := ctx.distinctChecker.Check(vals)
   239  		if err != nil {
   240  			return errors.Trace(err)
   241  		}
   242  		if !d {
   243  			return nil
   244  		}
   245  	}
   246  	ctx.Count++
   247  	return nil
   248  }
   249  
   250  func (n *AggregateFuncExpr) updateFirstRow() error {
   251  	ctx := n.GetContext()
   252  	if ctx.evaluated {
   253  		return nil
   254  	}
   255  	if len(n.Args) != 1 {
   256  		return errors.New("Wrong number of args for AggFuncFirstRow")
   257  	}
   258  	ctx.Value = n.Args[0].GetValue()
   259  	ctx.evaluated = true
   260  	return nil
   261  }
   262  
   263  func (n *AggregateFuncExpr) updateMaxMin(max bool) error {
   264  	ctx := n.GetContext()
   265  	if len(n.Args) != 1 {
   266  		return errors.New("Wrong number of args for AggFuncFirstRow")
   267  	}
   268  	v := n.Args[0].GetValue()
   269  	if !ctx.evaluated {
   270  		ctx.Value = v
   271  		ctx.evaluated = true
   272  		return nil
   273  	}
   274  	c, err := types.Compare(ctx.Value, v)
   275  	if err != nil {
   276  		return errors.Trace(err)
   277  	}
   278  	if max {
   279  		if c == -1 {
   280  			ctx.Value = v
   281  		}
   282  	} else {
   283  		if c == 1 {
   284  			ctx.Value = v
   285  		}
   286  
   287  	}
   288  	return nil
   289  }
   290  
   291  func (n *AggregateFuncExpr) updateSum() error {
   292  	ctx := n.GetContext()
   293  	a := n.Args[0]
   294  	value := a.GetValue()
   295  	if value == nil {
   296  		return nil
   297  	}
   298  	if n.Distinct {
   299  		d, err := ctx.distinctChecker.Check([]interface{}{value})
   300  		if err != nil {
   301  			return errors.Trace(err)
   302  		}
   303  		if !d {
   304  			return nil
   305  		}
   306  	}
   307  	var err error
   308  	ctx.Value, err = types.CalculateSum(ctx.Value, value)
   309  	if err != nil {
   310  		return errors.Trace(err)
   311  	}
   312  	ctx.Count++
   313  	return nil
   314  }
   315  
   316  func (n *AggregateFuncExpr) updateGroupConcat() error {
   317  	ctx := n.GetContext()
   318  	vals := make([]interface{}, 0, len(n.Args))
   319  	for _, a := range n.Args {
   320  		value := a.GetValue()
   321  		if value == nil {
   322  			return nil
   323  		}
   324  		vals = append(vals, value)
   325  	}
   326  	if n.Distinct {
   327  		d, err := ctx.distinctChecker.Check(vals)
   328  		if err != nil {
   329  			return errors.Trace(err)
   330  		}
   331  		if !d {
   332  			return nil
   333  		}
   334  	}
   335  	if ctx.Buffer == nil {
   336  		ctx.Buffer = &bytes.Buffer{}
   337  	} else {
   338  		// now use comma separator
   339  		ctx.Buffer.WriteString(",")
   340  	}
   341  	for _, val := range vals {
   342  		ctx.Buffer.WriteString(fmt.Sprintf("%v", val))
   343  	}
   344  	// TODO: if total length is greater than global var group_concat_max_len, truncate it.
   345  	return nil
   346  }
   347  
   348  // AggregateFuncExtractor visits Expr tree.
   349  // It converts ColunmNameExpr to AggregateFuncExpr and collects AggregateFuncExpr.
   350  type AggregateFuncExtractor struct {
   351  	inAggregateFuncExpr bool
   352  	// AggFuncs is the collected AggregateFuncExprs.
   353  	AggFuncs   []*AggregateFuncExpr
   354  	extracting bool
   355  }
   356  
   357  // Enter implements Visitor interface.
   358  func (a *AggregateFuncExtractor) Enter(n Node) (node Node, skipChildren bool) {
   359  	switch n.(type) {
   360  	case *AggregateFuncExpr:
   361  		a.inAggregateFuncExpr = true
   362  	case *SelectStmt, *InsertStmt, *DeleteStmt, *UpdateStmt:
   363  		// Enter a new context, skip it.
   364  		// For example: select sum(c) + c + exists(select c from t) from t;
   365  		if a.extracting {
   366  			return n, true
   367  		}
   368  	}
   369  	a.extracting = true
   370  	return n, false
   371  }
   372  
   373  // Leave implements Visitor interface.
   374  func (a *AggregateFuncExtractor) Leave(n Node) (node Node, ok bool) {
   375  	switch v := n.(type) {
   376  	case *AggregateFuncExpr:
   377  		a.inAggregateFuncExpr = false
   378  		a.AggFuncs = append(a.AggFuncs, v)
   379  	case *ColumnNameExpr:
   380  		// compose new AggregateFuncExpr
   381  		if !a.inAggregateFuncExpr {
   382  			// For example: select sum(c) + c from t;
   383  			// The c in sum() should be evaluated for each row.
   384  			// The c after plus should be evaluated only once.
   385  			agg := &AggregateFuncExpr{
   386  				F:    AggFuncFirstRow,
   387  				Args: []ExprNode{v},
   388  			}
   389  			agg.SetFlag((v.GetFlag() | FlagHasAggregateFunc))
   390  			a.AggFuncs = append(a.AggFuncs, agg)
   391  			return agg, true
   392  		}
   393  	}
   394  	return n, true
   395  }
   396  
   397  // AggEvaluateContext is used to store intermediate result when caculation aggregate functions.
   398  type AggEvaluateContext struct {
   399  	distinctChecker *distinct.Checker
   400  	Count           int64
   401  	Value           interface{}
   402  	Buffer          *bytes.Buffer // Buffer is used for group_concat.
   403  	evaluated       bool
   404  }