github.com/XiaoMi/Gaea@v1.2.5/parser/ast/functions.go (about)

     1  // Copyright 2015 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package ast
    15  
    16  import (
    17  	"fmt"
    18  	"io"
    19  	"strings"
    20  
    21  	"github.com/pingcap/errors"
    22  
    23  	"github.com/XiaoMi/Gaea/parser/format"
    24  	"github.com/XiaoMi/Gaea/parser/model"
    25  	"github.com/XiaoMi/Gaea/parser/types"
    26  )
    27  
    28  var (
    29  	_ FuncNode = &AggregateFuncExpr{}
    30  	_ FuncNode = &FuncCallExpr{}
    31  	_ FuncNode = &FuncCastExpr{}
    32  	_ FuncNode = &WindowFuncExpr{}
    33  )
    34  
    35  // List scalar function names.
    36  const (
    37  	LogicAnd   = "and"
    38  	Cast       = "cast"
    39  	LeftShift  = "leftshift"
    40  	RightShift = "rightshift"
    41  	LogicOr    = "or"
    42  	GE         = "ge"
    43  	LE         = "le"
    44  	EQ         = "eq"
    45  	NE         = "ne"
    46  	LT         = "lt"
    47  	GT         = "gt"
    48  	Plus       = "plus"
    49  	Minus      = "minus"
    50  	And        = "bitand"
    51  	Or         = "bitor"
    52  	Mod        = "mod"
    53  	Xor        = "bitxor"
    54  	Div        = "div"
    55  	Mul        = "mul"
    56  	UnaryNot   = "not" // Avoid name conflict with Not in github/pingcap/check.
    57  	BitNeg     = "bitneg"
    58  	IntDiv     = "intdiv"
    59  	LogicXor   = "xor"
    60  	NullEQ     = "nulleq"
    61  	UnaryPlus  = "unaryplus"
    62  	UnaryMinus = "unaryminus"
    63  	In         = "in"
    64  	Like       = "like"
    65  	Case       = "case"
    66  	Regexp     = "regexp"
    67  	IsNull     = "isnull"
    68  	IsTruth    = "istrue"  // Avoid name conflict with IsTrue in github/pingcap/check.
    69  	IsFalsity  = "isfalse" // Avoid name conflict with IsFalse in github/pingcap/check.
    70  	RowFunc    = "row"
    71  	SetVar     = "setvar"
    72  	GetVar     = "getvar"
    73  	Values     = "values"
    74  	BitCount   = "bit_count"
    75  	GetParam   = "getparam"
    76  
    77  	// common functions
    78  	Coalesce = "coalesce"
    79  	Greatest = "greatest"
    80  	Least    = "least"
    81  	Interval = "interval"
    82  
    83  	// math functions
    84  	Abs      = "abs"
    85  	Acos     = "acos"
    86  	Asin     = "asin"
    87  	Atan     = "atan"
    88  	Atan2    = "atan2"
    89  	Ceil     = "ceil"
    90  	Ceiling  = "ceiling"
    91  	Conv     = "conv"
    92  	Cos      = "cos"
    93  	Cot      = "cot"
    94  	CRC32    = "crc32"
    95  	Degrees  = "degrees"
    96  	Exp      = "exp"
    97  	Floor    = "floor"
    98  	Ln       = "ln"
    99  	Log      = "log"
   100  	Log2     = "log2"
   101  	Log10    = "log10"
   102  	PI       = "pi"
   103  	Pow      = "pow"
   104  	Power    = "power"
   105  	Radians  = "radians"
   106  	Rand     = "rand"
   107  	Round    = "round"
   108  	Sign     = "sign"
   109  	Sin      = "sin"
   110  	Sqrt     = "sqrt"
   111  	Tan      = "tan"
   112  	Truncate = "truncate"
   113  
   114  	// time functions
   115  	AddDate          = "adddate"
   116  	AddTime          = "addtime"
   117  	ConvertTz        = "convert_tz"
   118  	Curdate          = "curdate"
   119  	CurrentDate      = "current_date"
   120  	CurrentTime      = "current_time"
   121  	CurrentTimestamp = "current_timestamp"
   122  	Curtime          = "curtime"
   123  	Date             = "date"
   124  	DateLiteral      = "dateliteral"
   125  	DateAdd          = "date_add"
   126  	DateFormat       = "date_format"
   127  	DateSub          = "date_sub"
   128  	DateDiff         = "datediff"
   129  	Day              = "day"
   130  	DayName          = "dayname"
   131  	DayOfMonth       = "dayofmonth"
   132  	DayOfWeek        = "dayofweek"
   133  	DayOfYear        = "dayofyear"
   134  	Extract          = "extract"
   135  	FromDays         = "from_days"
   136  	FromUnixTime     = "from_unixtime"
   137  	GetFormat        = "get_format"
   138  	Hour             = "hour"
   139  	LocalTime        = "localtime"
   140  	LocalTimestamp   = "localtimestamp"
   141  	MakeDate         = "makedate"
   142  	MakeTime         = "maketime"
   143  	MicroSecond      = "microsecond"
   144  	Minute           = "minute"
   145  	Month            = "month"
   146  	MonthName        = "monthname"
   147  	Now              = "now"
   148  	PeriodAdd        = "period_add"
   149  	PeriodDiff       = "period_diff"
   150  	Quarter          = "quarter"
   151  	SecToTime        = "sec_to_time"
   152  	Second           = "second"
   153  	StrToDate        = "str_to_date"
   154  	SubDate          = "subdate"
   155  	SubTime          = "subtime"
   156  	Sysdate          = "sysdate"
   157  	Time             = "time"
   158  	TimeLiteral      = "timeliteral"
   159  	TimeFormat       = "time_format"
   160  	TimeToSec        = "time_to_sec"
   161  	TimeDiff         = "timediff"
   162  	Timestamp        = "timestamp"
   163  	TimestampLiteral = "timestampliteral"
   164  	TimestampAdd     = "timestampadd"
   165  	TimestampDiff    = "timestampdiff"
   166  	ToDays           = "to_days"
   167  	ToSeconds        = "to_seconds"
   168  	UnixTimestamp    = "unix_timestamp"
   169  	UTCDate          = "utc_date"
   170  	UTCTime          = "utc_time"
   171  	UTCTimestamp     = "utc_timestamp"
   172  	Week             = "week"
   173  	Weekday          = "weekday"
   174  	WeekOfYear       = "weekofyear"
   175  	Year             = "year"
   176  	YearWeek         = "yearweek"
   177  	LastDay          = "last_day"
   178  	TiDBParseTso     = "tidb_parse_tso"
   179  
   180  	// string functions
   181  	ASCII           = "ascii"
   182  	Bin             = "bin"
   183  	Concat          = "concat"
   184  	ConcatWS        = "concat_ws"
   185  	Convert         = "convert"
   186  	Elt             = "elt"
   187  	ExportSet       = "export_set"
   188  	Field           = "field"
   189  	Format          = "format"
   190  	FromBase64      = "from_base64"
   191  	InsertFunc      = "insert_func"
   192  	Instr           = "instr"
   193  	Lcase           = "lcase"
   194  	Left            = "left"
   195  	Length          = "length"
   196  	LoadFile        = "load_file"
   197  	Locate          = "locate"
   198  	Lower           = "lower"
   199  	Lpad            = "lpad"
   200  	LTrim           = "ltrim"
   201  	MakeSet         = "make_set"
   202  	Mid             = "mid"
   203  	Oct             = "oct"
   204  	Ord             = "ord"
   205  	Position        = "position"
   206  	Quote           = "quote"
   207  	Repeat          = "repeat"
   208  	Replace         = "replace"
   209  	Reverse         = "reverse"
   210  	Right           = "right"
   211  	RTrim           = "rtrim"
   212  	Space           = "space"
   213  	Strcmp          = "strcmp"
   214  	Substring       = "substring"
   215  	Substr          = "substr"
   216  	SubstringIndex  = "substring_index"
   217  	ToBase64        = "to_base64"
   218  	Trim            = "trim"
   219  	Upper           = "upper"
   220  	Ucase           = "ucase"
   221  	Hex             = "hex"
   222  	Unhex           = "unhex"
   223  	Rpad            = "rpad"
   224  	BitLength       = "bit_length"
   225  	CharFunc        = "char_func"
   226  	CharLength      = "char_length"
   227  	CharacterLength = "character_length"
   228  	FindInSet       = "find_in_set"
   229  
   230  	// information functions
   231  	Benchmark      = "benchmark"
   232  	Charset        = "charset"
   233  	Coercibility   = "coercibility"
   234  	Collation      = "collation"
   235  	ConnectionID   = "connection_id"
   236  	CurrentUser    = "current_user"
   237  	Database       = "database"
   238  	FoundRows      = "found_rows"
   239  	LastInsertId   = "last_insert_id"
   240  	RowCount       = "row_count"
   241  	Schema         = "schema"
   242  	SessionUser    = "session_user"
   243  	SystemUser     = "system_user"
   244  	User           = "user"
   245  	Version        = "version"
   246  	TiDBVersion    = "tidb_version"
   247  	TiDBIsDDLOwner = "tidb_is_ddl_owner"
   248  
   249  	// control functions
   250  	If     = "if"
   251  	Ifnull = "ifnull"
   252  	Nullif = "nullif"
   253  
   254  	// miscellaneous functions
   255  	AnyValue        = "any_value"
   256  	DefaultFunc     = "default_func"
   257  	InetAton        = "inet_aton"
   258  	InetNtoa        = "inet_ntoa"
   259  	Inet6Aton       = "inet6_aton"
   260  	Inet6Ntoa       = "inet6_ntoa"
   261  	IsFreeLock      = "is_free_lock"
   262  	IsIPv4          = "is_ipv4"
   263  	IsIPv4Compat    = "is_ipv4_compat"
   264  	IsIPv4Mapped    = "is_ipv4_mapped"
   265  	IsIPv6          = "is_ipv6"
   266  	IsUsedLock      = "is_used_lock"
   267  	MasterPosWait   = "master_pos_wait"
   268  	NameConst       = "name_const"
   269  	ReleaseAllLocks = "release_all_locks"
   270  	Sleep           = "sleep"
   271  	UUID            = "uuid"
   272  	UUIDShort       = "uuid_short"
   273  	// get_lock() and release_lock() is parsed but do nothing.
   274  	// It is used for preventing error in Ruby's activerecord migrations.
   275  	GetLock     = "get_lock"
   276  	ReleaseLock = "release_lock"
   277  
   278  	// encryption and compression functions
   279  	AesDecrypt               = "aes_decrypt"
   280  	AesEncrypt               = "aes_encrypt"
   281  	Compress                 = "compress"
   282  	Decode                   = "decode"
   283  	DesDecrypt               = "des_decrypt"
   284  	DesEncrypt               = "des_encrypt"
   285  	Encode                   = "encode"
   286  	Encrypt                  = "encrypt"
   287  	MD5                      = "md5"
   288  	OldPassword              = "old_password"
   289  	PasswordFunc             = "password_func"
   290  	RandomBytes              = "random_bytes"
   291  	SHA1                     = "sha1"
   292  	SHA                      = "sha"
   293  	SHA2                     = "sha2"
   294  	Uncompress               = "uncompress"
   295  	UncompressedLength       = "uncompressed_length"
   296  	ValidatePasswordStrength = "validate_password_strength"
   297  
   298  	// json functions
   299  	JSONType          = "json_type"
   300  	JSONExtract       = "json_extract"
   301  	JSONUnquote       = "json_unquote"
   302  	JSONArray         = "json_array"
   303  	JSONObject        = "json_object"
   304  	JSONMerge         = "json_merge"
   305  	JSONSet           = "json_set"
   306  	JSONInsert        = "json_insert"
   307  	JSONReplace       = "json_replace"
   308  	JSONRemove        = "json_remove"
   309  	JSONContains      = "json_contains"
   310  	JSONContainsPath  = "json_contains_path"
   311  	JSONValid         = "json_valid"
   312  	JSONArrayAppend   = "json_array_append"
   313  	JSONArrayInsert   = "json_array_insert"
   314  	JSONMergePatch    = "json_merge_patch"
   315  	JSONMergePreserve = "json_merge_preserve"
   316  	JSONPretty        = "json_pretty"
   317  	JSONQuote         = "json_quote"
   318  	JSONSearch        = "json_search"
   319  	JSONStorageSize   = "json_storage_size"
   320  	JSONDepth         = "json_depth"
   321  	JSONKeys          = "json_keys"
   322  	JSONLength        = "json_length"
   323  )
   324  
   325  // FuncCallExpr is for function expression.
   326  type FuncCallExpr struct {
   327  	funcNode
   328  	// FnName is the function name.
   329  	FnName model.CIStr
   330  	// Args is the function args.
   331  	Args []ExprNode
   332  }
   333  
   334  // Restore implements Node interface.
   335  func (n *FuncCallExpr) Restore(ctx *format.RestoreCtx) error {
   336  	ctx.WriteKeyWord(n.FnName.O)
   337  	ctx.WritePlain("(")
   338  	switch n.FnName.L {
   339  	case "convert":
   340  		if err := n.Args[0].Restore(ctx); err != nil {
   341  			return errors.Annotatef(err, "An error occurred while restore FuncCastExpr.Expr")
   342  		}
   343  		ctx.WriteKeyWord(" USING ")
   344  		ctx.WriteKeyWord(n.Args[1].GetType().Charset)
   345  	case "adddate", "subdate", "date_add", "date_sub":
   346  		if err := n.Args[0].Restore(ctx); err != nil {
   347  			return errors.Annotatef(err, "An error occurred while restore FuncCallExpr")
   348  		}
   349  		ctx.WritePlain(", ")
   350  		ctx.WriteKeyWord("INTERVAL ")
   351  		if err := n.Args[1].Restore(ctx); err != nil {
   352  			return errors.Annotatef(err, "An error occurred while restore FuncCallExpr")
   353  		}
   354  		ctx.WritePlain(" ")
   355  		ctx.WriteKeyWord(n.Args[2].(ValueExpr).GetString())
   356  	case "extract":
   357  		ctx.WriteKeyWord(n.Args[0].(ValueExpr).GetString())
   358  		ctx.WriteKeyWord(" FROM ")
   359  		if err := n.Args[1].Restore(ctx); err != nil {
   360  			return errors.Annotatef(err, "An error occurred while restore FuncCallExpr")
   361  		}
   362  	case "get_format":
   363  		ctx.WriteKeyWord(n.Args[0].(ValueExpr).GetString())
   364  		ctx.WritePlain(", ")
   365  		if err := n.Args[1].Restore(ctx); err != nil {
   366  			return errors.Annotatef(err, "An error occurred while restore FuncCallExpr")
   367  		}
   368  	case "position":
   369  		if err := n.Args[0].Restore(ctx); err != nil {
   370  			return errors.Annotatef(err, "An error occurred while restore FuncCallExpr")
   371  		}
   372  		ctx.WriteKeyWord(" IN ")
   373  		if err := n.Args[1].Restore(ctx); err != nil {
   374  			return errors.Annotatef(err, "An error occurred while restore FuncCallExpr")
   375  		}
   376  	case "trim":
   377  		switch len(n.Args) {
   378  		case 1:
   379  			if err := n.Args[0].Restore(ctx); err != nil {
   380  				return errors.Annotatef(err, "An error occurred while restore FuncCallExpr")
   381  			}
   382  		case 2:
   383  			if err := n.Args[1].Restore(ctx); err != nil {
   384  				return errors.Annotatef(err, "An error occurred while restore FuncCallExpr")
   385  			}
   386  			ctx.WriteKeyWord(" FROM ")
   387  			if err := n.Args[0].Restore(ctx); err != nil {
   388  				return errors.Annotatef(err, "An error occurred while restore FuncCallExpr")
   389  			}
   390  		case 3:
   391  			switch fmt.Sprint(n.Args[2].(ValueExpr).GetValue()) {
   392  			case "3":
   393  				ctx.WriteKeyWord("TRAILING ")
   394  			case "2":
   395  				ctx.WriteKeyWord("LEADING ")
   396  			case "0", "1":
   397  				ctx.WriteKeyWord("BOTH ")
   398  			}
   399  			if n.Args[1].(ValueExpr).GetValue() != nil {
   400  				if err := n.Args[1].Restore(ctx); err != nil {
   401  					return errors.Annotatef(err, "An error occurred while restore FuncCallExpr")
   402  				}
   403  				ctx.WritePlain(" ")
   404  			}
   405  			ctx.WriteKeyWord("FROM ")
   406  			if err := n.Args[0].Restore(ctx); err != nil {
   407  				return errors.Annotatef(err, "An error occurred while restore FuncCallExpr")
   408  			}
   409  		}
   410  	case "timestampdiff", "timestampadd":
   411  		ctx.WriteKeyWord(n.Args[0].(ValueExpr).GetString())
   412  		for i := 1; i < len(n.Args); {
   413  			ctx.WritePlain(", ")
   414  			if err := n.Args[i].Restore(ctx); err != nil {
   415  				return errors.Annotatef(err, "An error occurred while restore FuncCallExpr")
   416  			}
   417  			i++
   418  		}
   419  	default:
   420  		for i, argv := range n.Args {
   421  			if i != 0 {
   422  				ctx.WritePlain(", ")
   423  			}
   424  			if err := argv.Restore(ctx); err != nil {
   425  				return errors.Annotatef(err, "An error occurred while restore FuncCallExpr.Args %d", i)
   426  			}
   427  		}
   428  	}
   429  	ctx.WritePlain(")")
   430  	return nil
   431  }
   432  
   433  // Format the ExprNode into a Writer.
   434  func (n *FuncCallExpr) Format(w io.Writer) {
   435  	fmt.Fprintf(w, "%s(", n.FnName.L)
   436  	if !n.specialFormatArgs(w) {
   437  		for i, arg := range n.Args {
   438  			arg.Format(w)
   439  			if i != len(n.Args)-1 {
   440  				fmt.Fprint(w, ", ")
   441  			}
   442  		}
   443  	}
   444  	fmt.Fprint(w, ")")
   445  }
   446  
   447  // specialFormatArgs formats argument list for some special functions.
   448  func (n *FuncCallExpr) specialFormatArgs(w io.Writer) bool {
   449  	switch n.FnName.L {
   450  	case DateAdd, DateSub, AddDate, SubDate:
   451  		n.Args[0].Format(w)
   452  		fmt.Fprint(w, ", INTERVAL ")
   453  		n.Args[1].Format(w)
   454  		fmt.Fprintf(w, " %s", n.Args[2].(ValueExpr).GetDatumString())
   455  		return true
   456  	case TimestampAdd, TimestampDiff:
   457  		fmt.Fprintf(w, "%s, ", n.Args[0].(ValueExpr).GetDatumString())
   458  		n.Args[1].Format(w)
   459  		fmt.Fprint(w, ", ")
   460  		n.Args[2].Format(w)
   461  		return true
   462  	}
   463  	return false
   464  }
   465  
   466  // Accept implements Node interface.
   467  func (n *FuncCallExpr) Accept(v Visitor) (Node, bool) {
   468  	newNode, skipChildren := v.Enter(n)
   469  	if skipChildren {
   470  		return v.Leave(newNode)
   471  	}
   472  	n = newNode.(*FuncCallExpr)
   473  	for i, val := range n.Args {
   474  		node, ok := val.Accept(v)
   475  		if !ok {
   476  			return n, false
   477  		}
   478  		n.Args[i] = node.(ExprNode)
   479  	}
   480  	return v.Leave(n)
   481  }
   482  
   483  // CastFunctionType is the type for cast function.
   484  type CastFunctionType int
   485  
   486  // CastFunction types
   487  const (
   488  	CastFunction CastFunctionType = iota + 1
   489  	CastConvertFunction
   490  	CastBinaryOperator
   491  )
   492  
   493  // FuncCastExpr is the cast function converting value to another type, e.g, cast(expr AS signed).
   494  // See https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html
   495  type FuncCastExpr struct {
   496  	funcNode
   497  	// Expr is the expression to be converted.
   498  	Expr ExprNode
   499  	// Tp is the conversion type.
   500  	Tp *types.FieldType
   501  	// FunctionType is either Cast, Convert or Binary.
   502  	FunctionType CastFunctionType
   503  }
   504  
   505  // Restore implements Node interface.
   506  func (n *FuncCastExpr) Restore(ctx *format.RestoreCtx) error {
   507  	switch n.FunctionType {
   508  	case CastFunction:
   509  		ctx.WriteKeyWord("CAST")
   510  		ctx.WritePlain("(")
   511  		if err := n.Expr.Restore(ctx); err != nil {
   512  			return errors.Annotatef(err, "An error occurred while restore FuncCastExpr.Expr")
   513  		}
   514  		ctx.WriteKeyWord(" AS ")
   515  		n.Tp.FormatAsCastType(ctx.In)
   516  		ctx.WritePlain(")")
   517  	case CastConvertFunction:
   518  		ctx.WriteKeyWord("CONVERT")
   519  		ctx.WritePlain("(")
   520  		if err := n.Expr.Restore(ctx); err != nil {
   521  			return errors.Annotatef(err, "An error occurred while restore FuncCastExpr.Expr")
   522  		}
   523  		ctx.WritePlain(", ")
   524  		n.Tp.FormatAsCastType(ctx.In)
   525  		ctx.WritePlain(")")
   526  	case CastBinaryOperator:
   527  		ctx.WriteKeyWord("BINARY ")
   528  		if err := n.Expr.Restore(ctx); err != nil {
   529  			return errors.Annotatef(err, "An error occurred while restore FuncCastExpr.Expr")
   530  		}
   531  	}
   532  	return nil
   533  }
   534  
   535  // Format the ExprNode into a Writer.
   536  func (n *FuncCastExpr) Format(w io.Writer) {
   537  	switch n.FunctionType {
   538  	case CastFunction:
   539  		fmt.Fprint(w, "CAST(")
   540  		n.Expr.Format(w)
   541  		fmt.Fprint(w, " AS ")
   542  		n.Tp.FormatAsCastType(w)
   543  		fmt.Fprint(w, ")")
   544  	case CastConvertFunction:
   545  		fmt.Fprint(w, "CONVERT(")
   546  		n.Expr.Format(w)
   547  		fmt.Fprint(w, ", ")
   548  		n.Tp.FormatAsCastType(w)
   549  		fmt.Fprint(w, ")")
   550  	case CastBinaryOperator:
   551  		fmt.Fprint(w, "BINARY ")
   552  		n.Expr.Format(w)
   553  	}
   554  }
   555  
   556  // Accept implements Node Accept interface.
   557  func (n *FuncCastExpr) Accept(v Visitor) (Node, bool) {
   558  	newNode, skipChildren := v.Enter(n)
   559  	if skipChildren {
   560  		return v.Leave(newNode)
   561  	}
   562  	n = newNode.(*FuncCastExpr)
   563  	node, ok := n.Expr.Accept(v)
   564  	if !ok {
   565  		return n, false
   566  	}
   567  	n.Expr = node.(ExprNode)
   568  	return v.Leave(n)
   569  }
   570  
   571  // TrimDirectionType is the type for trim direction.
   572  type TrimDirectionType int
   573  
   574  const (
   575  	// TrimBothDefault trims from both direction by default.
   576  	TrimBothDefault TrimDirectionType = iota
   577  	// TrimBoth trims from both direction with explicit notation.
   578  	TrimBoth
   579  	// TrimLeading trims from left.
   580  	TrimLeading
   581  	// TrimTrailing trims from right.
   582  	TrimTrailing
   583  )
   584  
   585  // DateArithType is type for DateArith type.
   586  type DateArithType byte
   587  
   588  const (
   589  	// DateArithAdd is to run adddate or date_add function option.
   590  	// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_adddate
   591  	// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add
   592  	DateArithAdd DateArithType = iota + 1
   593  	// DateArithSub is to run subdate or date_sub function option.
   594  	// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subdate
   595  	// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-sub
   596  	DateArithSub
   597  )
   598  
   599  const (
   600  	// AggFuncCount is the name of Count function.
   601  	AggFuncCount = "count"
   602  	// AggFuncSum is the name of Sum function.
   603  	AggFuncSum = "sum"
   604  	// AggFuncAvg is the name of Avg function.
   605  	AggFuncAvg = "avg"
   606  	// AggFuncFirstRow is the name of FirstRowColumn function.
   607  	AggFuncFirstRow = "firstrow"
   608  	// AggFuncMax is the name of max function.
   609  	AggFuncMax = "max"
   610  	// AggFuncMin is the name of min function.
   611  	AggFuncMin = "min"
   612  	// AggFuncGroupConcat is the name of group_concat function.
   613  	AggFuncGroupConcat = "group_concat"
   614  	// AggFuncBitOr is the name of bit_or function.
   615  	AggFuncBitOr = "bit_or"
   616  	// AggFuncBitXor is the name of bit_xor function.
   617  	AggFuncBitXor = "bit_xor"
   618  	// AggFuncBitAnd is the name of bit_and function.
   619  	AggFuncBitAnd = "bit_and"
   620  	// AggFuncVarPop is the name of var_pop function
   621  	AggFuncVarPop = "var_pop"
   622  	// AggFuncVarSamp is the name of var_samp function
   623  	AggFuncVarSamp = "var_samp"
   624  	// AggFuncStddevPop is the name of stddev_pop function
   625  	AggFuncStddevPop = "stddev_pop"
   626  	// AggFuncStddevSamp is the name of stddev_samp function
   627  	AggFuncStddevSamp = "stddev_samp"
   628  )
   629  
   630  // AggregateFuncExpr represents aggregate function expression.
   631  type AggregateFuncExpr struct {
   632  	funcNode
   633  	// F is the function name.
   634  	F string
   635  	// Args is the function args.
   636  	Args []ExprNode
   637  	// Distinct is true, function hence only aggregate distinct values.
   638  	// For example, column c1 values are "1", "2", "2",  "sum(c1)" is "5",
   639  	// but "sum(distinct c1)" is "3".
   640  	Distinct bool
   641  }
   642  
   643  // Restore implements Node interface.
   644  func (n *AggregateFuncExpr) Restore(ctx *format.RestoreCtx) error {
   645  	ctx.WriteKeyWord(n.F)
   646  	ctx.WritePlain("(")
   647  	if n.Distinct {
   648  		ctx.WriteKeyWord("DISTINCT ")
   649  	}
   650  	switch strings.ToLower(n.F) {
   651  	case "group_concat":
   652  		for i := 0; i < len(n.Args)-1; i++ {
   653  			if i != 0 {
   654  				ctx.WritePlain(", ")
   655  			}
   656  			if err := n.Args[i].Restore(ctx); err != nil {
   657  				return errors.Annotatef(err, "An error occurred while restore AggregateFuncExpr.Args[%d]", i)
   658  			}
   659  		}
   660  		ctx.WriteKeyWord(" SEPARATOR ")
   661  		if err := n.Args[len(n.Args)-1].Restore(ctx); err != nil {
   662  			return errors.Annotate(err, "An error occurred while restore AggregateFuncExpr.Args SEPARATOR")
   663  		}
   664  	default:
   665  		for i, argv := range n.Args {
   666  			if i != 0 {
   667  				ctx.WritePlain(", ")
   668  			}
   669  			if err := argv.Restore(ctx); err != nil {
   670  				return errors.Annotatef(err, "An error occurred while restore AggregateFuncExpr.Args[%d]", i)
   671  			}
   672  		}
   673  	}
   674  	ctx.WritePlain(")")
   675  	return nil
   676  }
   677  
   678  // Format the ExprNode into a Writer.
   679  func (n *AggregateFuncExpr) Format(w io.Writer) {
   680  	panic("Not implemented")
   681  }
   682  
   683  // Accept implements Node Accept interface.
   684  func (n *AggregateFuncExpr) Accept(v Visitor) (Node, bool) {
   685  	newNode, skipChildren := v.Enter(n)
   686  	if skipChildren {
   687  		return v.Leave(newNode)
   688  	}
   689  	n = newNode.(*AggregateFuncExpr)
   690  	for i, val := range n.Args {
   691  		node, ok := val.Accept(v)
   692  		if !ok {
   693  			return n, false
   694  		}
   695  		n.Args[i] = node.(ExprNode)
   696  	}
   697  	return v.Leave(n)
   698  }
   699  
   700  const (
   701  	// WindowFuncRowNumber is the name of row_number function.
   702  	WindowFuncRowNumber = "row_number"
   703  	// WindowFuncRank is the name of rank function.
   704  	WindowFuncRank = "rank"
   705  	// WindowFuncDenseRank is the name of dense_rank function.
   706  	WindowFuncDenseRank = "dense_rank"
   707  	// WindowFuncCumeDist is the name of cume_dist function.
   708  	WindowFuncCumeDist = "cume_dist"
   709  	// WindowFuncPercentRank is the name of percent_rank function.
   710  	WindowFuncPercentRank = "percent_rank"
   711  	// WindowFuncNtile is the name of ntile function.
   712  	WindowFuncNtile = "ntile"
   713  	// WindowFuncLead is the name of lead function.
   714  	WindowFuncLead = "lead"
   715  	// WindowFuncLag is the name of lag function.
   716  	WindowFuncLag = "lag"
   717  	// WindowFuncFirstValue is the name of first_value function.
   718  	WindowFuncFirstValue = "first_value"
   719  	// WindowFuncLastValue is the name of last_value function.
   720  	WindowFuncLastValue = "last_value"
   721  	// WindowFuncNthValue is the name of nth_value function.
   722  	WindowFuncNthValue = "nth_value"
   723  )
   724  
   725  // WindowFuncExpr represents window function expression.
   726  type WindowFuncExpr struct {
   727  	funcNode
   728  
   729  	// F is the function name.
   730  	F string
   731  	// Args is the function args.
   732  	Args []ExprNode
   733  	// Distinct cannot be true for most window functions, except `max` and `min`.
   734  	// We need to raise error if it is not allowed to be true.
   735  	Distinct bool
   736  	// IgnoreNull indicates how to handle null value.
   737  	// MySQL only supports `RESPECT NULLS`, so we need to raise error if it is true.
   738  	IgnoreNull bool
   739  	// FromLast indicates the calculation direction of this window function.
   740  	// MySQL only supports calculation from first, so we need to raise error if it is true.
   741  	FromLast bool
   742  	// Spec is the specification of this window.
   743  	Spec WindowSpec
   744  }
   745  
   746  // Restore implements Node interface.
   747  func (n *WindowFuncExpr) Restore(ctx *format.RestoreCtx) error {
   748  	ctx.WriteKeyWord(n.F)
   749  	ctx.WritePlain("(")
   750  	for i, v := range n.Args {
   751  		if i != 0 {
   752  			ctx.WritePlain(", ")
   753  		} else if n.Distinct {
   754  			ctx.WriteKeyWord("DISTINCT ")
   755  		}
   756  		if err := v.Restore(ctx); err != nil {
   757  			return errors.Annotatef(err, "An error occurred while restore WindowFuncExpr.Args[%d]", i)
   758  		}
   759  	}
   760  	ctx.WritePlain(")")
   761  	if n.FromLast {
   762  		ctx.WriteKeyWord(" FROM LAST")
   763  	}
   764  	if n.IgnoreNull {
   765  		ctx.WriteKeyWord(" IGNORE NULLS")
   766  	}
   767  	ctx.WriteKeyWord(" OVER ")
   768  	if err := n.Spec.Restore(ctx); err != nil {
   769  		return errors.Annotate(err, "An error occurred while restore WindowFuncExpr.Spec")
   770  	}
   771  
   772  	return nil
   773  }
   774  
   775  // Format formats the window function expression into a Writer.
   776  func (n *WindowFuncExpr) Format(w io.Writer) {
   777  	panic("Not implemented")
   778  }
   779  
   780  // Accept implements Node Accept interface.
   781  func (n *WindowFuncExpr) Accept(v Visitor) (Node, bool) {
   782  	newNode, skipChildren := v.Enter(n)
   783  	if skipChildren {
   784  		return v.Leave(newNode)
   785  	}
   786  	n = newNode.(*WindowFuncExpr)
   787  	for i, val := range n.Args {
   788  		node, ok := val.Accept(v)
   789  		if !ok {
   790  			return n, false
   791  		}
   792  		n.Args[i] = node.(ExprNode)
   793  	}
   794  	node, ok := n.Spec.Accept(v)
   795  	if !ok {
   796  		return n, false
   797  	}
   798  	n.Spec = *node.(*WindowSpec)
   799  	return v.Leave(n)
   800  }