github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/plan/function/function.go (about)

     1  // Copyright 2021 - 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package function
    16  
    17  import (
    18  	"context"
    19  	"math"
    20  
    21  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    22  	"github.com/matrixorigin/matrixone/pkg/container/types"
    23  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    24  	"github.com/matrixorigin/matrixone/pkg/pb/plan"
    25  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    26  )
    27  
    28  const (
    29  	// ScalarNull means of scalar NULL
    30  	// which can meet each required type.
    31  	// e.g.
    32  	// if we input a SQL `select built_in_function(columnA, NULL);`, and columnA is int64 column.
    33  	// it will use [types.T_int64, ScalarNull] to match function when we were building the query plan.
    34  	ScalarNull = types.T_any
    35  )
    36  
    37  var (
    38  	// an empty type structure just for return when we couldn't meet any function.
    39  	emptyType = types.Type{}
    40  
    41  	// AndFunctionEncodedID is the encoded overload id of And(bool, bool)
    42  	// used to make an AndExpr
    43  	AndFunctionEncodedID = EncodeOverloadID(AND, 0)
    44  	AndFunctionName      = "and"
    45  )
    46  
    47  // Functions records all overloads of the same function name
    48  // and its function-id and type-check-function
    49  type Functions struct {
    50  	Id int
    51  
    52  	Flag plan.Function_FuncFlag
    53  
    54  	// Layout adapt to plan/function.go, used for `explain SQL`.
    55  	Layout FuncExplainLayout
    56  
    57  	// TypeCheckFn checks if the input parameters can satisfy one of the overloads
    58  	// and returns its index id.
    59  	// if type convert should happen, return the target-types at the same time.
    60  	TypeCheckFn func(overloads []Function, inputs []types.T) (overloadIndex int32, ts []types.T)
    61  
    62  	Overloads []Function
    63  }
    64  
    65  // TypeCheck do type check work for a function,
    66  // if the input params matched one of function's overloads.
    67  // returns overload-index-number, target-type
    68  // just set target-type nil if there is no need to do implicit-type-conversion for parameters
    69  func (fs *Functions) TypeCheck(args []types.T) (int32, []types.T) {
    70  	if fs.TypeCheckFn == nil {
    71  		return normalTypeCheck(fs.Overloads, args)
    72  	}
    73  	return fs.TypeCheckFn(fs.Overloads, args)
    74  }
    75  
    76  func normalTypeCheck(overloads []Function, inputs []types.T) (overloadIndex int32, ts []types.T) {
    77  	matched := make([]int32, 0, 4)   // function overload which can be matched directly
    78  	byCast := make([]int32, 0, 4)    // function overload which can be matched according to type cast
    79  	convertCost := make([]int, 0, 4) // records the cost of conversion for byCast
    80  	for i, f := range overloads {
    81  		c, cost := tryToMatch(inputs, f.Args)
    82  		switch c {
    83  		case matchedDirectly:
    84  			matched = append(matched, int32(i))
    85  		case matchedByConvert:
    86  			byCast = append(byCast, int32(i))
    87  			convertCost = append(convertCost, cost)
    88  		case matchedFailed:
    89  			continue
    90  		}
    91  	}
    92  	if len(matched) == 1 {
    93  		return matched[0], nil
    94  	} else if len(matched) == 0 && len(byCast) > 0 {
    95  		// choose the overload with the least number of conversions
    96  		min, index := math.MaxInt32, 0
    97  		for j := range convertCost {
    98  			if convertCost[j] < min {
    99  				index = j
   100  				min = convertCost[j]
   101  			}
   102  		}
   103  		return byCast[index], overloads[byCast[index]].Args
   104  	} else if len(matched) > 1 {
   105  		// if contains any scalar null as param, just return the first matched.
   106  		for j := range inputs {
   107  			if inputs[j] == ScalarNull {
   108  				return matched[0], nil
   109  			}
   110  		}
   111  		return tooManyFunctionsMatched, nil
   112  	}
   113  	return wrongFunctionParameters, nil
   114  }
   115  
   116  // Function is an overload of
   117  // a built-in function or an aggregate function or an operator
   118  type Function struct {
   119  	// Index is the function's location number of all the overloads with the same functionName.
   120  	Index int32
   121  
   122  	// Volatile function cannot be fold
   123  	Volatile bool
   124  
   125  	// RealTimeRelate function cannot be folded when in prepare statement
   126  	RealTimeRelated bool
   127  
   128  	// whether the function needs to append a hidden parameter, such as 'uuid'
   129  	AppendHideArg bool
   130  
   131  	Args      []types.T
   132  	ReturnTyp types.T
   133  
   134  	// Fn is implementation of built-in function and operator
   135  	// it received vector list, and return result vector.
   136  	Fn func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error)
   137  
   138  	// AggregateInfo is related information about aggregate function.
   139  	AggregateInfo int
   140  
   141  	// Info records information about the function overload used to print
   142  	Info string
   143  
   144  	flag plan.Function_FuncFlag
   145  
   146  	// Layout adapt to plan/function.go, used for `explain SQL`.
   147  	layout FuncExplainLayout
   148  
   149  	UseNewFramework     bool
   150  	ResultWillNotNull   bool
   151  	FlexibleReturnType  func(parameters []types.Type) types.Type
   152  	ParameterMustScalar []bool
   153  	NewFn               func(parameters []*vector.Vector, result vector.FunctionResultWrapper, proc *process.Process, length int) error
   154  }
   155  
   156  func (f *Function) TestFlag(funcFlag plan.Function_FuncFlag) bool {
   157  	return f.flag&funcFlag != 0
   158  }
   159  
   160  func (f *Function) GetLayout() FuncExplainLayout {
   161  	return f.layout
   162  }
   163  
   164  // ReturnType return result-type of function, and the result is nullable
   165  // if nullable is false, function won't return a vector with null value.
   166  func (f Function) ReturnType(args []types.Type) (typ types.Type, nullable bool) {
   167  	if f.FlexibleReturnType != nil {
   168  		return f.FlexibleReturnType(args), !f.ResultWillNotNull
   169  	}
   170  	return f.ReturnTyp.ToType(), !f.ResultWillNotNull
   171  }
   172  
   173  func (f Function) VecFn(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) {
   174  	if f.Fn == nil {
   175  		return nil, moerr.NewInternalError(proc.Ctx, "no function")
   176  	}
   177  	return f.Fn(vs, proc)
   178  }
   179  
   180  func (f Function) IsAggregate() bool {
   181  	return f.TestFlag(plan.Function_AGG)
   182  }
   183  
   184  func (f Function) isFunction() bool {
   185  	return f.GetLayout() == STANDARD_FUNCTION || f.GetLayout() >= NOPARAMETER_FUNCTION
   186  }
   187  
   188  // functionRegister records the information about
   189  // all the operator, built-function and aggregate function.
   190  //
   191  // For use in other packages, see GetFunctionByID and GetFunctionByName
   192  var functionRegister []Functions
   193  
   194  // get function id from map functionIdRegister, see functionIds.go
   195  func fromNameToFunctionIdWithoutError(name string) (int32, bool) {
   196  	if fid, ok := functionIdRegister[name]; ok {
   197  		return fid, true
   198  	}
   199  	return -1, false
   200  }
   201  
   202  // get function id from map functionIdRegister, see functionIds.go
   203  func fromNameToFunctionId(ctx context.Context, name string) (int32, error) {
   204  	if fid, ok := functionIdRegister[name]; ok {
   205  		return fid, nil
   206  	}
   207  	return -1, moerr.NewNotSupported(ctx, "function or operator '%s'", name)
   208  }
   209  
   210  // EncodeOverloadID convert function-id and overload-index to be an overloadID
   211  // the high 32-bit is function-id, the low 32-bit is overload-index
   212  func EncodeOverloadID(fid int32, index int32) (overloadID int64) {
   213  	overloadID = int64(fid)
   214  	overloadID = overloadID << 32
   215  	overloadID |= int64(index)
   216  	return overloadID
   217  }
   218  
   219  // DecodeOverloadID convert overload id to be function-id and overload-index
   220  func DecodeOverloadID(overloadID int64) (fid int32, index int32) {
   221  	base := overloadID
   222  	index = int32(overloadID)
   223  	fid = int32(base >> 32)
   224  	return fid, index
   225  }
   226  
   227  // GetFunctionByIDWithoutError get function structure by its index id.
   228  func GetFunctionByIDWithoutError(overloadID int64) (*Function, bool) {
   229  	fid, overloadIndex := DecodeOverloadID(overloadID)
   230  	if int(fid) < len(functionRegister) {
   231  		fs := functionRegister[fid].Overloads
   232  		return &fs[overloadIndex], true
   233  	} else {
   234  		return nil, false
   235  	}
   236  }
   237  
   238  // GetFunctionByID get function structure by its index id.
   239  func GetFunctionByID(ctx context.Context, overloadID int64) (*Function, error) {
   240  	fid, overloadIndex := DecodeOverloadID(overloadID)
   241  	if int(fid) < len(functionRegister) {
   242  		fs := functionRegister[fid].Overloads
   243  		return &fs[overloadIndex], nil
   244  	} else {
   245  		return nil, moerr.NewInvalidInput(ctx, "function overload id not found")
   246  	}
   247  }
   248  
   249  // deduce notNullable for function
   250  // for example, create table t1(c1 int not null, c2 int, c3 int not null ,c4 int);
   251  // sql select c1+1, abs(c2), cast(c3 as varchar(10)) from t1 where c1=c3;
   252  // we can deduce that c1+1, cast c3 and c1=c3 is notNullable, abs(c2) is nullable
   253  // this message helps optimization sometimes
   254  func DeduceNotNullable(overloadID int64, args []*plan.Expr) bool {
   255  	function, _ := GetFunctionByIDWithoutError(overloadID)
   256  	if function.TestFlag(plan.Function_PRODUCE_NO_NULL) {
   257  		return true
   258  	}
   259  
   260  	for _, arg := range args {
   261  		if !arg.Typ.NotNullable {
   262  			return false
   263  		}
   264  	}
   265  	return true
   266  }
   267  
   268  func GetFunctionIsAggregateByName(name string) bool {
   269  	fid, exists := fromNameToFunctionIdWithoutError(name)
   270  	if !exists {
   271  		return false
   272  	}
   273  	fs := functionRegister[fid].Overloads
   274  	return len(fs) > 0 && fs[0].IsAggregate()
   275  }
   276  
   277  // Check whether the function needs to append a hidden parameter
   278  func GetFunctionAppendHideArgByID(overloadID int64) bool {
   279  	function, exists := GetFunctionByIDWithoutError(overloadID)
   280  	if !exists {
   281  		return false
   282  	}
   283  	return function.AppendHideArg
   284  }
   285  
   286  func GetFunctionIsMonotonicById(ctx context.Context, overloadID int64) (bool, error) {
   287  	function, err := GetFunctionByID(ctx, overloadID)
   288  	if err != nil {
   289  		return false, err
   290  	}
   291  	// if function cann't be fold, we think that will be not monotonic
   292  	if function.Volatile {
   293  		return false, nil
   294  	}
   295  	isMonotonic := function.TestFlag(plan.Function_MONOTONIC)
   296  	return isMonotonic, nil
   297  }
   298  
   299  // GetFunctionByName check a function exist or not according to input function name and arg types,
   300  // if matches,
   301  // return the encoded overload id and the overload's return type
   302  // and final converted argument types( it will be nil if there's no need to do type level-up work).
   303  func GetFunctionByName(ctx context.Context, name string, args []types.Type) (int64, types.Type, []types.Type, error) {
   304  	fid, err := fromNameToFunctionId(ctx, name)
   305  	if err != nil {
   306  		return -1, emptyType, nil, err
   307  	}
   308  	fs := functionRegister[fid]
   309  
   310  	argTs := getOidSlice(args)
   311  	index, targetTs := fs.TypeCheck(argTs)
   312  
   313  	// if implicit type conversion happens, set the right precision for target types.
   314  	targetTypes := getTypeSlice(targetTs)
   315  	rewriteTypesIfNecessary(targetTypes, args)
   316  
   317  	var finalTypes []types.Type
   318  	if targetTs != nil {
   319  		finalTypes = targetTypes
   320  	} else {
   321  		finalTypes = args
   322  	}
   323  
   324  	// deal the failed situations
   325  	switch index {
   326  	case wrongFunctionParameters:
   327  		ArgsToPrint := getOidSlice(finalTypes) // arg information to print for error message
   328  		if len(fs.Overloads) > 0 && fs.Overloads[0].isFunction() {
   329  			return -1, emptyType, nil, moerr.NewInvalidArg(ctx, "function "+name, ArgsToPrint)
   330  		}
   331  		return -1, emptyType, nil, moerr.NewInvalidArg(ctx, "operator "+name, ArgsToPrint)
   332  	case tooManyFunctionsMatched:
   333  		return -1, emptyType, nil, moerr.NewInvalidArg(ctx, "too many overloads matched "+name, args)
   334  	case wrongFuncParamForAgg:
   335  		ArgsToPrint := getOidSlice(finalTypes)
   336  		return -1, emptyType, nil, moerr.NewInvalidArg(ctx, "aggregate function "+name, ArgsToPrint)
   337  	}
   338  
   339  	// make the real return type of function overload.
   340  	rt := getRealReturnType(fid, fs.Overloads[index], finalTypes)
   341  
   342  	return EncodeOverloadID(fid, index), rt, targetTypes, nil
   343  }
   344  
   345  func ensureBinaryOperatorWithSamePrecision(targets []types.Type, hasSet []bool) {
   346  	if len(targets) == 2 && targets[0].Oid == targets[1].Oid {
   347  		if hasSet[0] && !hasSet[1] { // precision follow the left-part
   348  			copyType(&targets[1], &targets[0])
   349  			hasSet[1] = true
   350  		} else if !hasSet[0] && hasSet[1] { // precision follow the right-part
   351  			copyType(&targets[0], &targets[1])
   352  			hasSet[0] = true
   353  		}
   354  	}
   355  }
   356  
   357  func rewriteTypesIfNecessary(targets []types.Type, sources []types.Type) {
   358  	if len(targets) != 0 {
   359  		hasSet := make([]bool, len(sources))
   360  
   361  		//ensure that we will not lost the origin scale
   362  		maxScale := int32(0)
   363  		for i := range sources {
   364  			if sources[i].Oid == types.T_decimal64 || sources[i].Oid == types.T_decimal128 {
   365  				if sources[i].Scale > maxScale {
   366  					maxScale = sources[i].Scale
   367  				}
   368  			}
   369  		}
   370  		for i := range sources {
   371  			if targets[i].Oid == types.T_decimal64 || targets[i].Oid == types.T_decimal128 {
   372  				if sources[i].Scale < maxScale {
   373  					sources[i].Scale = maxScale
   374  				}
   375  			}
   376  		}
   377  
   378  		for i := range targets {
   379  			oid1, oid2 := sources[i].Oid, targets[i].Oid
   380  			// ensure that we will not lose the original precision.
   381  			if oid2 == types.T_decimal64 || oid2 == types.T_decimal128 || oid2 == types.T_timestamp || oid2 == types.T_time {
   382  				if oid1 != types.T_char && oid1 != types.T_varchar && oid1 != types.T_blob && oid1 != types.T_text {
   383  					copyType(&targets[i], &sources[i])
   384  					hasSet[i] = true
   385  				}
   386  			}
   387  		}
   388  		ensureBinaryOperatorWithSamePrecision(targets, hasSet)
   389  		for i := range targets {
   390  			if !hasSet[i] && targets[i].Oid != ScalarNull {
   391  				setDefaultPrecision(&targets[i])
   392  			}
   393  		}
   394  	}
   395  }
   396  
   397  // set default precision / scalar / width for a type
   398  func setDefaultPrecision(typ *types.Type) {
   399  	if typ.Oid == types.T_decimal64 {
   400  		typ.Scale = 0
   401  		typ.Width = 18
   402  	} else if typ.Oid == types.T_decimal128 {
   403  		typ.Scale = 0
   404  		typ.Width = 38
   405  	} else if typ.Oid == types.T_timestamp {
   406  		typ.Precision = 6
   407  	} else if typ.Oid == types.T_datetime {
   408  		typ.Precision = 6
   409  	} else if typ.Oid == types.T_time {
   410  		typ.Precision = 6
   411  	}
   412  	typ.Size = int32(typ.Oid.TypeLen())
   413  }
   414  
   415  func getRealReturnType(fid int32, f Function, realArgs []types.Type) types.Type {
   416  	if f.IsAggregate() {
   417  		switch fid {
   418  		case MIN, MAX:
   419  			if realArgs[0].Oid != ScalarNull {
   420  				return realArgs[0]
   421  			}
   422  		}
   423  	}
   424  	if f.FlexibleReturnType != nil {
   425  		return f.FlexibleReturnType(realArgs)
   426  	}
   427  	rt := f.ReturnTyp.ToType()
   428  	for i := range realArgs {
   429  		if realArgs[i].Oid == rt.Oid {
   430  			copyType(&rt, &realArgs[i])
   431  			checkTypeWidth(realArgs, &rt)
   432  			break
   433  		}
   434  		if types.T(rt.Oid) == types.T_decimal128 && types.T(realArgs[i].Oid) == types.T_decimal64 {
   435  			copyType(&rt, &realArgs[i])
   436  		}
   437  	}
   438  	return rt
   439  }
   440  
   441  func checkTypeWidth(realArgs []types.Type, rt *types.Type) {
   442  	for i := range realArgs {
   443  		if realArgs[i].Oid == rt.Oid && rt.Width < realArgs[i].Width {
   444  			rt.Width = realArgs[i].Width
   445  		}
   446  	}
   447  }
   448  
   449  func copyType(dst, src *types.Type) {
   450  	dst.Width = src.Width
   451  	dst.Scale = src.Scale
   452  	dst.Precision = src.Precision
   453  }
   454  
   455  func getOidSlice(ts []types.Type) []types.T {
   456  	ret := make([]types.T, len(ts))
   457  	for i := range ts {
   458  		ret[i] = ts[i].Oid
   459  	}
   460  	return ret
   461  }
   462  
   463  func getTypeSlice(ts []types.T) []types.Type {
   464  	ret := make([]types.Type, len(ts))
   465  	for i := range ts {
   466  		ret[i] = ts[i].ToType()
   467  	}
   468  	return ret
   469  }