github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqlparse/tidbparser/ast/functions.go (about)

     1  // Copyright 2015 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package ast
    15  
    16  import (
    17  	"fmt"
    18  	"io"
    19  
    20  	"github.com/bingoohuang/gg/pkg/sqlparse/tidbparser/dependency/model"
    21  	"github.com/bingoohuang/gg/pkg/sqlparse/tidbparser/dependency/types"
    22  )
    23  
    24  var (
    25  	_ FuncNode = &AggregateFuncExpr{}
    26  	_ FuncNode = &FuncCallExpr{}
    27  	_ FuncNode = &FuncCastExpr{}
    28  )
    29  
    30  // List scalar function names.
    31  const (
    32  	LogicAnd   = "and"
    33  	Cast       = "cast"
    34  	LeftShift  = "leftshift"
    35  	RightShift = "rightshift"
    36  	LogicOr    = "or"
    37  	GE         = "ge"
    38  	LE         = "le"
    39  	EQ         = "eq"
    40  	NE         = "ne"
    41  	LT         = "lt"
    42  	GT         = "gt"
    43  	Plus       = "plus"
    44  	Minus      = "minus"
    45  	And        = "bitand"
    46  	Or         = "bitor"
    47  	Mod        = "mod"
    48  	Xor        = "bitxor"
    49  	Div        = "div"
    50  	Mul        = "mul"
    51  	UnaryNot   = "not" // Avoid name conflict with Not in github/pingcap/check.
    52  	BitNeg     = "bitneg"
    53  	IntDiv     = "intdiv"
    54  	LogicXor   = "xor"
    55  	NullEQ     = "nulleq"
    56  	UnaryPlus  = "unaryplus"
    57  	UnaryMinus = "unaryminus"
    58  	In         = "in"
    59  	Like       = "like"
    60  	Case       = "case"
    61  	Regexp     = "regexp"
    62  	IsNull     = "isnull"
    63  	IsTruth    = "istrue"  // Avoid name conflict with IsTrue in github/pingcap/check.
    64  	IsFalsity  = "isfalse" // Avoid name conflict with IsFalse in github/pingcap/check.
    65  	RowFunc    = "row"
    66  	SetVar     = "setvar"
    67  	GetVar     = "getvar"
    68  	Values     = "values"
    69  	BitCount   = "bit_count"
    70  	GetParam   = "getparam"
    71  
    72  	// common functions
    73  	Coalesce = "coalesce"
    74  	Greatest = "greatest"
    75  	Least    = "least"
    76  	Interval = "interval"
    77  
    78  	// math functions
    79  	Abs      = "abs"
    80  	Acos     = "acos"
    81  	Asin     = "asin"
    82  	Atan     = "atan"
    83  	Atan2    = "atan2"
    84  	Ceil     = "ceil"
    85  	Ceiling  = "ceiling"
    86  	Conv     = "conv"
    87  	Cos      = "cos"
    88  	Cot      = "cot"
    89  	CRC32    = "crc32"
    90  	Degrees  = "degrees"
    91  	Exp      = "exp"
    92  	Floor    = "floor"
    93  	Ln       = "ln"
    94  	Log      = "log"
    95  	Log2     = "log2"
    96  	Log10    = "log10"
    97  	PI       = "pi"
    98  	Pow      = "pow"
    99  	Power    = "power"
   100  	Radians  = "radians"
   101  	Rand     = "rand"
   102  	Round    = "round"
   103  	Sign     = "sign"
   104  	Sin      = "sin"
   105  	Sqrt     = "sqrt"
   106  	Tan      = "tan"
   107  	Truncate = "truncate"
   108  
   109  	// time functions
   110  	AddDate          = "adddate"
   111  	AddTime          = "addtime"
   112  	ConvertTz        = "convert_tz"
   113  	Curdate          = "curdate"
   114  	CurrentDate      = "current_date"
   115  	CurrentTime      = "current_time"
   116  	CurrentTimestamp = "current_timestamp"
   117  	Curtime          = "curtime"
   118  	Date             = "date"
   119  	DateLiteral      = "dateliteral"
   120  	DateAdd          = "date_add"
   121  	DateFormat       = "date_format"
   122  	DateSub          = "date_sub"
   123  	DateDiff         = "datediff"
   124  	Day              = "day"
   125  	DayName          = "dayname"
   126  	DayOfMonth       = "dayofmonth"
   127  	DayOfWeek        = "dayofweek"
   128  	DayOfYear        = "dayofyear"
   129  	Extract          = "extract"
   130  	FromDays         = "from_days"
   131  	FromUnixTime     = "from_unixtime"
   132  	GetFormat        = "get_format"
   133  	Hour             = "hour"
   134  	LocalTime        = "localtime"
   135  	LocalTimestamp   = "localtimestamp"
   136  	MakeDate         = "makedate"
   137  	MakeTime         = "maketime"
   138  	MicroSecond      = "microsecond"
   139  	Minute           = "minute"
   140  	Month            = "month"
   141  	MonthName        = "monthname"
   142  	Now              = "now"
   143  	PeriodAdd        = "period_add"
   144  	PeriodDiff       = "period_diff"
   145  	Quarter          = "quarter"
   146  	SecToTime        = "sec_to_time"
   147  	Second           = "second"
   148  	StrToDate        = "str_to_date"
   149  	SubDate          = "subdate"
   150  	SubTime          = "subtime"
   151  	Sysdate          = "sysdate"
   152  	Time             = "time"
   153  	TimeLiteral      = "timeliteral"
   154  	TimeFormat       = "time_format"
   155  	TimeToSec        = "time_to_sec"
   156  	TimeDiff         = "timediff"
   157  	Timestamp        = "timestamp"
   158  	TimestampLiteral = "timestampliteral"
   159  	TimestampAdd     = "timestampadd"
   160  	TimestampDiff    = "timestampdiff"
   161  	ToDays           = "to_days"
   162  	ToSeconds        = "to_seconds"
   163  	UnixTimestamp    = "unix_timestamp"
   164  	UTCDate          = "utc_date"
   165  	UTCTime          = "utc_time"
   166  	UTCTimestamp     = "utc_timestamp"
   167  	Week             = "week"
   168  	Weekday          = "weekday"
   169  	WeekOfYear       = "weekofyear"
   170  	Year             = "year"
   171  	YearWeek         = "yearweek"
   172  	LastDay          = "last_day"
   173  
   174  	// string functions
   175  	ASCII           = "ascii"
   176  	Bin             = "bin"
   177  	Concat          = "concat"
   178  	ConcatWS        = "concat_ws"
   179  	Convert         = "convert"
   180  	Elt             = "elt"
   181  	ExportSet       = "export_set"
   182  	Field           = "field"
   183  	Format          = "format"
   184  	FromBase64      = "from_base64"
   185  	InsertFunc      = "insert_func"
   186  	Instr           = "instr"
   187  	Lcase           = "lcase"
   188  	Left            = "left"
   189  	Length          = "length"
   190  	LoadFile        = "load_file"
   191  	Locate          = "locate"
   192  	Lower           = "lower"
   193  	Lpad            = "lpad"
   194  	LTrim           = "ltrim"
   195  	MakeSet         = "make_set"
   196  	Mid             = "mid"
   197  	Oct             = "oct"
   198  	Ord             = "ord"
   199  	Position        = "position"
   200  	Quote           = "quote"
   201  	Repeat          = "repeat"
   202  	Replace         = "replace"
   203  	Reverse         = "reverse"
   204  	Right           = "right"
   205  	RTrim           = "rtrim"
   206  	Space           = "space"
   207  	Strcmp          = "strcmp"
   208  	Substring       = "substring"
   209  	Substr          = "substr"
   210  	SubstringIndex  = "substring_index"
   211  	ToBase64        = "to_base64"
   212  	Trim            = "trim"
   213  	Upper           = "upper"
   214  	Ucase           = "ucase"
   215  	Hex             = "hex"
   216  	Unhex           = "unhex"
   217  	Rpad            = "rpad"
   218  	BitLength       = "bit_length"
   219  	CharFunc        = "char_func"
   220  	CharLength      = "char_length"
   221  	CharacterLength = "character_length"
   222  	FindInSet       = "find_in_set"
   223  
   224  	// information functions
   225  	Benchmark    = "benchmark"
   226  	Charset      = "charset"
   227  	Coercibility = "coercibility"
   228  	Collation    = "collation"
   229  	ConnectionID = "connection_id"
   230  	CurrentUser  = "current_user"
   231  	Database     = "database"
   232  	FoundRows    = "found_rows"
   233  	LastInsertId = "last_insert_id"
   234  	RowCount     = "row_count"
   235  	Schema       = "schema"
   236  	SessionUser  = "session_user"
   237  	SystemUser   = "system_user"
   238  	User         = "user"
   239  	Version      = "version"
   240  	TiDBVersion  = "tidb_version"
   241  
   242  	// control functions
   243  	If     = "if"
   244  	Ifnull = "ifnull"
   245  	Nullif = "nullif"
   246  
   247  	// miscellaneous functions
   248  	AnyValue        = "any_value"
   249  	DefaultFunc     = "default_func"
   250  	InetAton        = "inet_aton"
   251  	InetNtoa        = "inet_ntoa"
   252  	Inet6Aton       = "inet6_aton"
   253  	Inet6Ntoa       = "inet6_ntoa"
   254  	IsFreeLock      = "is_free_lock"
   255  	IsIPv4          = "is_ipv4"
   256  	IsIPv4Compat    = "is_ipv4_compat"
   257  	IsIPv4Mapped    = "is_ipv4_mapped"
   258  	IsIPv6          = "is_ipv6"
   259  	IsUsedLock      = "is_used_lock"
   260  	MasterPosWait   = "master_pos_wait"
   261  	NameConst       = "name_const"
   262  	ReleaseAllLocks = "release_all_locks"
   263  	Sleep           = "sleep"
   264  	UUID            = "uuid"
   265  	UUIDShort       = "uuid_short"
   266  	// get_lock() and release_lock() is parsed but do nothing.
   267  	// It is used for preventing error in Ruby's activerecord migrations.
   268  	GetLock     = "get_lock"
   269  	ReleaseLock = "release_lock"
   270  
   271  	// encryption and compression functions
   272  	AesDecrypt               = "aes_decrypt"
   273  	AesEncrypt               = "aes_encrypt"
   274  	Compress                 = "compress"
   275  	Decode                   = "decode"
   276  	DesDecrypt               = "des_decrypt"
   277  	DesEncrypt               = "des_encrypt"
   278  	Encode                   = "encode"
   279  	Encrypt                  = "encrypt"
   280  	MD5                      = "md5"
   281  	OldPassword              = "old_password"
   282  	PasswordFunc             = "password_func"
   283  	RandomBytes              = "random_bytes"
   284  	SHA1                     = "sha1"
   285  	SHA                      = "sha"
   286  	SHA2                     = "sha2"
   287  	Uncompress               = "uncompress"
   288  	UncompressedLength       = "uncompressed_length"
   289  	ValidatePasswordStrength = "validate_password_strength"
   290  
   291  	// json functions
   292  	JSONType     = "json_type"
   293  	JSONExtract  = "json_extract"
   294  	JSONUnquote  = "json_unquote"
   295  	JSONArray    = "json_array"
   296  	JSONObject   = "json_object"
   297  	JSONMerge    = "json_merge"
   298  	JSONValid    = "json_valid"
   299  	JSONSet      = "json_set"
   300  	JSONInsert   = "json_insert"
   301  	JSONReplace  = "json_replace"
   302  	JSONRemove   = "json_remove"
   303  	JSONContains = "json_contains"
   304  )
   305  
   306  // FuncCallExpr is for function expression.
   307  type FuncCallExpr struct {
   308  	funcNode
   309  	// FnName is the function name.
   310  	FnName model.CIStr
   311  	// Args is the function args.
   312  	Args []ExprNode
   313  }
   314  
   315  // Format the ExprNode into a Writer.
   316  func (n *FuncCallExpr) Format(w io.Writer) {
   317  	fmt.Fprintf(w, "%s(", n.FnName.L)
   318  	if !n.specialFormatArgs(w) {
   319  		for i, arg := range n.Args {
   320  			arg.Format(w)
   321  			if i != len(n.Args)-1 {
   322  				fmt.Fprint(w, ", ")
   323  			}
   324  		}
   325  	}
   326  	fmt.Fprint(w, ")")
   327  }
   328  
   329  // specialFormatArgs formats argument list for some special functions.
   330  func (n *FuncCallExpr) specialFormatArgs(w io.Writer) bool {
   331  	switch n.FnName.L {
   332  	case DateAdd, DateSub, AddDate, SubDate:
   333  		n.Args[0].Format(w)
   334  		fmt.Fprint(w, ", INTERVAL ")
   335  		n.Args[1].Format(w)
   336  		fmt.Fprintf(w, " %s", n.Args[2].GetDatum().GetString())
   337  		return true
   338  	case TimestampAdd, TimestampDiff:
   339  		fmt.Fprintf(w, "%s, ", n.Args[0].GetDatum().GetString())
   340  		n.Args[1].Format(w)
   341  		fmt.Fprint(w, ", ")
   342  		n.Args[2].Format(w)
   343  		return true
   344  	}
   345  	return false
   346  }
   347  
   348  // Accept implements Node interface.
   349  func (n *FuncCallExpr) Accept(v Visitor) (Node, bool) {
   350  	newNode, skipChildren := v.Enter(n)
   351  	if skipChildren {
   352  		return v.Leave(newNode)
   353  	}
   354  	n = newNode.(*FuncCallExpr)
   355  	for i, val := range n.Args {
   356  		node, ok := val.Accept(v)
   357  		if !ok {
   358  			return n, false
   359  		}
   360  		n.Args[i] = node.(ExprNode)
   361  	}
   362  	return v.Leave(n)
   363  }
   364  
   365  // CastFunctionType is the type for cast function.
   366  type CastFunctionType int
   367  
   368  // CastFunction types
   369  const (
   370  	CastFunction CastFunctionType = iota + 1
   371  	CastConvertFunction
   372  	CastBinaryOperator
   373  )
   374  
   375  // FuncCastExpr is the cast function converting value to another type, e.g, cast(expr AS signed).
   376  // See https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html
   377  type FuncCastExpr struct {
   378  	funcNode
   379  	// Expr is the expression to be converted.
   380  	Expr ExprNode
   381  	// Tp is the conversion type.
   382  	Tp *types.FieldType
   383  	// FunctionType is either Cast, Convert or Binary.
   384  	FunctionType CastFunctionType
   385  }
   386  
   387  // Format the ExprNode into a Writer.
   388  func (n *FuncCastExpr) Format(w io.Writer) {
   389  	switch n.FunctionType {
   390  	case CastFunction:
   391  		fmt.Fprint(w, "CAST(")
   392  		n.Expr.Format(w)
   393  		fmt.Fprint(w, " AS ")
   394  		n.Tp.FormatAsCastType(w)
   395  		fmt.Fprint(w, ")")
   396  	case CastConvertFunction:
   397  		fmt.Fprint(w, "CONVERT(")
   398  		n.Expr.Format(w)
   399  		fmt.Fprint(w, ", ")
   400  		n.Tp.FormatAsCastType(w)
   401  		fmt.Fprint(w, ")")
   402  	case CastBinaryOperator:
   403  		fmt.Fprint(w, "BINARY ")
   404  		n.Expr.Format(w)
   405  	}
   406  }
   407  
   408  // Accept implements Node Accept interface.
   409  func (n *FuncCastExpr) Accept(v Visitor) (Node, bool) {
   410  	newNode, skipChildren := v.Enter(n)
   411  	if skipChildren {
   412  		return v.Leave(newNode)
   413  	}
   414  	n = newNode.(*FuncCastExpr)
   415  	node, ok := n.Expr.Accept(v)
   416  	if !ok {
   417  		return n, false
   418  	}
   419  	n.Expr = node.(ExprNode)
   420  	return v.Leave(n)
   421  }
   422  
   423  // TrimDirectionType is the type for trim direction.
   424  type TrimDirectionType int
   425  
   426  const (
   427  	// TrimBothDefault trims from both direction by default.
   428  	TrimBothDefault TrimDirectionType = iota
   429  	// TrimBoth trims from both direction with explicit notation.
   430  	TrimBoth
   431  	// TrimLeading trims from left.
   432  	TrimLeading
   433  	// TrimTrailing trims from right.
   434  	TrimTrailing
   435  )
   436  
   437  // DateArithType is type for DateArith type.
   438  type DateArithType byte
   439  
   440  const (
   441  	// DateArithAdd is to run adddate or date_add function option.
   442  	// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_adddate
   443  	// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add
   444  	DateArithAdd DateArithType = iota + 1
   445  	// DateArithSub is to run subdate or date_sub function option.
   446  	// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subdate
   447  	// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-sub
   448  	DateArithSub
   449  )
   450  
   451  const (
   452  	// AggFuncCount is the name of Count function.
   453  	AggFuncCount = "count"
   454  	// AggFuncSum is the name of Sum function.
   455  	AggFuncSum = "sum"
   456  	// AggFuncAvg is the name of Avg function.
   457  	AggFuncAvg = "avg"
   458  	// AggFuncFirstRow is the name of FirstRowColumn function.
   459  	AggFuncFirstRow = "firstrow"
   460  	// AggFuncMax is the name of max function.
   461  	AggFuncMax = "max"
   462  	// AggFuncMin is the name of min function.
   463  	AggFuncMin = "min"
   464  	// AggFuncGroupConcat is the name of group_concat function.
   465  	AggFuncGroupConcat = "group_concat"
   466  	// AggFuncBitOr is the name of bit_or function.
   467  	AggFuncBitOr = "bit_or"
   468  	// AggFuncBitXor is the name of bit_xor function.
   469  	AggFuncBitXor = "bit_xor"
   470  	// AggFuncBitAnd is the name of bit_and function.
   471  	AggFuncBitAnd = "bit_and"
   472  )
   473  
   474  // AggregateFuncExpr represents aggregate function expression.
   475  type AggregateFuncExpr struct {
   476  	funcNode
   477  	// F is the function name.
   478  	F string
   479  	// Args is the function args.
   480  	Args []ExprNode
   481  	// Distinct is true, function hence only aggregate distinct values.
   482  	// For example, column c1 values are "1", "2", "2",  "sum(c1)" is "5",
   483  	// but "sum(distinct c1)" is "3".
   484  	Distinct bool
   485  }
   486  
   487  // Format the ExprNode into a Writer.
   488  func (n *AggregateFuncExpr) Format(w io.Writer) {
   489  	panic("Not implemented")
   490  }
   491  
   492  // Accept implements Node Accept interface.
   493  func (n *AggregateFuncExpr) Accept(v Visitor) (Node, bool) {
   494  	newNode, skipChildren := v.Enter(n)
   495  	if skipChildren {
   496  		return v.Leave(newNode)
   497  	}
   498  	n = newNode.(*AggregateFuncExpr)
   499  	for i, val := range n.Args {
   500  		node, ok := val.Accept(v)
   501  		if !ok {
   502  			return n, false
   503  		}
   504  		n.Args[i] = node.(ExprNode)
   505  	}
   506  	return v.Leave(n)
   507  }