github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/plan/function/function.go (about)

     1  // Copyright 2021 - 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package function
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  
    21  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    22  	"github.com/matrixorigin/matrixone/pkg/container/types"
    23  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    24  	"github.com/matrixorigin/matrixone/pkg/pb/plan"
    25  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    26  )
    27  
    28  var allSupportedFunctions [1000]FuncNew
    29  
    30  // register all supported functions.
    31  func initAllSupportedFunctions() {
    32  	for _, fn := range supportedOperators {
    33  		allSupportedFunctions[fn.functionId] = fn
    34  	}
    35  	for _, fn := range supportedStringBuiltIns {
    36  		allSupportedFunctions[fn.functionId] = fn
    37  	}
    38  	for _, fn := range supportedDateAndTimeBuiltIns {
    39  		allSupportedFunctions[fn.functionId] = fn
    40  	}
    41  	for _, fn := range supportedMathBuiltIns {
    42  		allSupportedFunctions[fn.functionId] = fn
    43  	}
    44  	for _, fn := range supportedArrayOperations {
    45  		allSupportedFunctions[fn.functionId] = fn
    46  	}
    47  	for _, fn := range supportedControlBuiltIns {
    48  		allSupportedFunctions[fn.functionId] = fn
    49  	}
    50  	for _, fn := range supportedOthersBuiltIns {
    51  		allSupportedFunctions[fn.functionId] = fn
    52  	}
    53  
    54  	for _, fn := range supportedWindowInNewFramework {
    55  		for _, ov := range fn.Overloads {
    56  			ov.aggFramework.aggRegister(encodeOverloadID(int32(fn.functionId), int32(ov.overloadId)))
    57  		}
    58  		allSupportedFunctions[fn.functionId] = fn
    59  	}
    60  	for _, fn := range supportedAggInNewFramework {
    61  		for _, ov := range fn.Overloads {
    62  			ov.aggFramework.aggRegister(encodeOverloadID(int32(fn.functionId), int32(ov.overloadId)))
    63  		}
    64  		allSupportedFunctions[fn.functionId] = fn
    65  	}
    66  }
    67  
    68  func GetFunctionIsAggregateByName(name string) bool {
    69  	fid, exists := getFunctionIdByNameWithoutErr(name)
    70  	if !exists {
    71  		return false
    72  	}
    73  	f := allSupportedFunctions[fid]
    74  	return f.isAggregate()
    75  }
    76  
    77  func GetFunctionIsWinFunByName(name string) bool {
    78  	fid, exists := getFunctionIdByNameWithoutErr(name)
    79  	if !exists {
    80  		return false
    81  	}
    82  	f := allSupportedFunctions[fid]
    83  	return f.isWindow()
    84  }
    85  
    86  func GetFunctionIsWinOrderFunByName(name string) bool {
    87  	fid, exists := getFunctionIdByNameWithoutErr(name)
    88  	if !exists {
    89  		return false
    90  	}
    91  	f := allSupportedFunctions[fid]
    92  	return f.isWindowOrder()
    93  }
    94  
    95  func GetFunctionIsWinOrderFunById(overloadID int64) bool {
    96  	fid, _ := DecodeOverloadID(overloadID)
    97  	return allSupportedFunctions[fid].isWindowOrder()
    98  }
    99  
   100  func GetFunctionIsZonemappableById(ctx context.Context, overloadID int64) (bool, error) {
   101  	fid, oIndex := DecodeOverloadID(overloadID)
   102  	if int(fid) >= len(allSupportedFunctions) || int(fid) != allSupportedFunctions[fid].functionId {
   103  		return false, moerr.NewInvalidInput(ctx, "function overload id not found")
   104  	}
   105  	f := allSupportedFunctions[fid]
   106  	if f.Overloads[oIndex].volatile {
   107  		return false, nil
   108  	}
   109  	return f.testFlag(plan.Function_ZONEMAPPABLE), nil
   110  }
   111  
   112  func GetFunctionById(ctx context.Context, overloadID int64) (f overload, err error) {
   113  	fid, oIndex := DecodeOverloadID(overloadID)
   114  	if int(fid) >= len(allSupportedFunctions) || int(fid) != allSupportedFunctions[fid].functionId {
   115  		return overload{}, moerr.NewInvalidInput(ctx, "function overload id not found")
   116  	}
   117  	return allSupportedFunctions[fid].Overloads[oIndex], nil
   118  }
   119  
   120  func GetLayoutById(ctx context.Context, overloadID int64) (FuncExplainLayout, error) {
   121  	fid, _ := DecodeOverloadID(overloadID)
   122  	if int(fid) >= len(allSupportedFunctions) || int(fid) != allSupportedFunctions[fid].functionId {
   123  		return 0, moerr.NewInvalidInput(ctx, "function overload id not found")
   124  	}
   125  	return allSupportedFunctions[fid].layout, nil
   126  }
   127  
   128  func GetFunctionByIdWithoutError(overloadID int64) (f overload, exists bool) {
   129  	fid, oIndex := DecodeOverloadID(overloadID)
   130  	if int(fid) >= len(allSupportedFunctions) || int(fid) != allSupportedFunctions[fid].functionId {
   131  		return overload{}, false
   132  	}
   133  	return allSupportedFunctions[fid].Overloads[oIndex], true
   134  }
   135  
   136  func GetFunctionByName(ctx context.Context, name string, args []types.Type) (r FuncGetResult, err error) {
   137  	r.fid, err = getFunctionIdByName(ctx, name)
   138  	if err != nil {
   139  		return r, err
   140  	}
   141  	f := allSupportedFunctions[r.fid]
   142  	if len(f.Overloads) == 0 || f.checkFn == nil {
   143  		return r, moerr.NewNYI(ctx, "should implement the function %s", name)
   144  	}
   145  
   146  	check := f.checkFn(f.Overloads, args)
   147  	switch check.status {
   148  	case succeedMatched:
   149  		r.overloadId = int32(check.idx)
   150  		r.retType = f.Overloads[r.overloadId].retType(args)
   151  		r.cannotRunInParallel = f.Overloads[r.overloadId].cannotParallel
   152  
   153  	case succeedWithCast:
   154  		r.overloadId = int32(check.idx)
   155  		r.needCast = true
   156  		r.targetTypes = check.finalType
   157  		r.retType = f.Overloads[r.overloadId].retType(r.targetTypes)
   158  		r.cannotRunInParallel = f.Overloads[r.overloadId].cannotParallel
   159  
   160  	case failedFunctionParametersWrong:
   161  		if f.isFunction() {
   162  			err = moerr.NewInvalidArg(ctx, fmt.Sprintf("function %s", name), args)
   163  		} else {
   164  			err = moerr.NewInvalidArg(ctx, fmt.Sprintf("operator %s", name), args)
   165  		}
   166  
   167  	case failedAggParametersWrong:
   168  		err = moerr.NewInvalidArg(ctx, fmt.Sprintf("aggregate function %s", name), args)
   169  
   170  	case failedTooManyFunctionMatched:
   171  		err = moerr.NewInvalidArg(ctx, fmt.Sprintf("too many overloads matched %s", name), args)
   172  	}
   173  
   174  	return r, err
   175  }
   176  
   177  // RunFunctionDirectly runs a function directly without any protections.
   178  // It is dangerous and should be used only when you are sure that the overloadID is correct and the inputs are valid.
   179  func RunFunctionDirectly(proc *process.Process, overloadID int64, inputs []*vector.Vector, length int) (*vector.Vector, error) {
   180  	f, err := GetFunctionById(proc.Ctx, overloadID)
   181  	if err != nil {
   182  		return nil, err
   183  	}
   184  
   185  	mp := proc.Mp()
   186  	inputTypes := make([]types.Type, len(inputs))
   187  	for i := range inputTypes {
   188  		inputTypes[i] = *inputs[i].GetType()
   189  	}
   190  
   191  	result := vector.NewFunctionResultWrapper(proc.GetVector, proc.PutVector, f.retType(inputTypes), mp)
   192  
   193  	fold := true
   194  	evaluateLength := length
   195  	if !f.CannotFold() && !f.IsRealTimeRelated() {
   196  		for _, param := range inputs {
   197  			if !param.IsConst() {
   198  				fold = false
   199  			}
   200  		}
   201  		if fold {
   202  			evaluateLength = 1
   203  		}
   204  	}
   205  
   206  	if err = result.PreExtendAndReset(evaluateLength); err != nil {
   207  		result.Free()
   208  		return nil, err
   209  	}
   210  	exec, execFree := f.GetExecuteMethod()
   211  	if err = exec(inputs, result, proc, evaluateLength); err != nil {
   212  		result.Free()
   213  		if execFree != nil {
   214  			// NOTE: execFree is only applicable for serial and serial_full.
   215  			// if execFree is not nil, then make sure to call it after exec() is done.
   216  			_ = execFree()
   217  		}
   218  		return nil, err
   219  	}
   220  	if execFree != nil {
   221  		// NOTE: execFree is only applicable for serial and serial_full.
   222  		// if execFree is not nil, then make sure to call it after exec() is done.
   223  		_ = execFree()
   224  	}
   225  
   226  	vec := result.GetResultVector()
   227  	if fold {
   228  		// ToConst is a confused method. it just returns a new pointer to the same memory.
   229  		// so we need to duplicate it.
   230  		cvec, er := vec.ToConst(0, length, mp).Dup(mp)
   231  		result.Free()
   232  		if er != nil {
   233  			return nil, er
   234  		}
   235  		return cvec, nil
   236  	}
   237  	return vec, nil
   238  }
   239  
   240  func GetAggFunctionNameByID(overloadID int64) string {
   241  	f, exist := GetFunctionByIdWithoutError(overloadID)
   242  	if !exist {
   243  		return "unknown function"
   244  	}
   245  	return f.aggFramework.str
   246  }
   247  
   248  // DeduceNotNullable helps optimization sometimes.
   249  // deduce notNullable for function
   250  // for example, create table t1(c1 int not null, c2 int, c3 int not null ,c4 int);
   251  // sql select c1+1, abs(c2), cast(c3 as varchar(10)) from t1 where c1=c3;
   252  // we can deduce that c1+1, cast c3 and c1=c3 is notNullable, abs(c2) is nullable.
   253  func DeduceNotNullable(overloadID int64, args []*plan.Expr) bool {
   254  	fid, _ := DecodeOverloadID(overloadID)
   255  	if allSupportedFunctions[fid].testFlag(plan.Function_PRODUCE_NO_NULL) {
   256  		return true
   257  	}
   258  
   259  	for _, arg := range args {
   260  		if !arg.Typ.NotNullable {
   261  			return false
   262  		}
   263  	}
   264  	return true
   265  }
   266  
   267  type FuncGetResult struct {
   268  	fid        int32
   269  	overloadId int32
   270  	retType    types.Type
   271  
   272  	cannotRunInParallel bool
   273  
   274  	needCast    bool
   275  	targetTypes []types.Type
   276  }
   277  
   278  func (fr *FuncGetResult) GetEncodedOverloadID() (overloadID int64) {
   279  	return encodeOverloadID(fr.fid, fr.overloadId)
   280  }
   281  
   282  func (fr *FuncGetResult) ShouldDoImplicitTypeCast() (typs []types.Type, should bool) {
   283  	return fr.targetTypes, fr.needCast
   284  }
   285  
   286  func (fr *FuncGetResult) GetReturnType() types.Type {
   287  	return fr.retType
   288  }
   289  
   290  func (fr *FuncGetResult) CannotRunInParallel() bool {
   291  	return fr.cannotRunInParallel
   292  }
   293  
   294  func encodeOverloadID(fid, overloadId int32) (overloadID int64) {
   295  	overloadID = int64(fid)
   296  	overloadID = overloadID << 32
   297  	overloadID |= int64(overloadId)
   298  	return
   299  }
   300  
   301  func DecodeOverloadID(overloadID int64) (fid int32, oIndex int32) {
   302  	base := overloadID
   303  	oIndex = int32(overloadID)
   304  	fid = int32(base >> 32)
   305  	return fid, oIndex
   306  }
   307  
   308  func getFunctionIdByName(ctx context.Context, name string) (int32, error) {
   309  	if fid, ok := functionIdRegister[name]; ok {
   310  		return fid, nil
   311  	}
   312  	return -1, moerr.NewNotSupported(ctx, "function or operator '%s'", name)
   313  }
   314  
   315  func getFunctionIdByNameWithoutErr(name string) (int32, bool) {
   316  	fid, exist := functionIdRegister[name]
   317  	return fid, exist
   318  }
   319  
   320  // FuncNew stores all information about a function.
   321  // including the unique id that marks the function, the class which the function belongs to,
   322  // and all overloads of the function.
   323  type FuncNew struct {
   324  	// unique id of function.
   325  	functionId int
   326  
   327  	// function type.
   328  	class plan.Function_FuncFlag
   329  
   330  	// All overloads of the function.
   331  	Overloads []overload
   332  
   333  	// checkFn was used to check whether the input type can match the requirement of the function.
   334  	// if matched, return the corresponding id of overload. If type conversion was required,
   335  	// the required type should be returned at the same time.
   336  	checkFn func(overloads []overload, inputs []types.Type) checkResult
   337  
   338  	// layout was used for `explain SQL`.
   339  	layout FuncExplainLayout
   340  }
   341  
   342  type executeLogicOfOverload func(parameters []*vector.Vector,
   343  	result vector.FunctionResultWrapper,
   344  	proc *process.Process, length int) error
   345  
   346  // executeFreeOfOverload is used to free the resources allocated by the execution logic.
   347  // It is mainly used in SERIAL and SERIAL_FULL.
   348  // NOTE: right now, we are not throwing an error when the free logic failed. However, it is still included
   349  // in case we need it in the future.
   350  type executeFreeOfOverload func() error
   351  
   352  type aggregationLogicOfOverload struct {
   353  	// agg related string for error message.
   354  	str string
   355  
   356  	// how to register the aggregation.
   357  	aggRegister func(overloadID int64)
   358  }
   359  
   360  // an overload of a function.
   361  // stores all information about execution logic.
   362  type overload struct {
   363  	overloadId int
   364  
   365  	// args records some type information about this overload.
   366  	// in most case, it records, in order, which parameter types the overload required.
   367  	// For example,
   368  	//		args can be `{int64, int64}` of one overload for the `pow` function.
   369  	//		this means the overload can accept {int64, int64} as its input.
   370  	// but it was not necessarily the type directly required by the overload.
   371  	// what it is depends on the logic of function's checkFn.
   372  	args []types.T
   373  
   374  	// return type of the overload.
   375  	// parameters are the params actually received when the overload is executed.
   376  	retType func(parameters []types.Type) types.Type
   377  
   378  	// the execution logic.
   379  	newOp func() executeLogicOfOverload
   380  
   381  	// the execution logic and free logic.
   382  	// NOTE: use either newOp or newOpWithFree.
   383  	newOpWithFree func() (executeLogicOfOverload, executeFreeOfOverload)
   384  
   385  	// in fact, the function framework does not directly run aggregate functions and window functions.
   386  	// we use two flags to mark whether function is one of them.
   387  	isAgg        bool
   388  	isWin        bool
   389  	aggFramework aggregationLogicOfOverload
   390  
   391  	// if true, overload was unable to run in parallel.
   392  	// For example,
   393  	//		rand(1) cannot run in parallel because it should use the same rand seed.
   394  	//
   395  	// TODO: there is not a good place to use that in plan now. the attribute is not effective.
   396  	cannotParallel bool
   397  
   398  	// if true, overload cannot be folded
   399  	volatile bool
   400  	// if realTimeRelated, overload cannot be folded when `Prepare`.
   401  	realTimeRelated bool
   402  }
   403  
   404  func (ov *overload) CannotFold() bool {
   405  	return ov.volatile
   406  }
   407  
   408  func (ov *overload) IsRealTimeRelated() bool {
   409  	return ov.realTimeRelated
   410  }
   411  
   412  func (ov *overload) IsAgg() bool {
   413  	return ov.isAgg
   414  }
   415  
   416  func (ov *overload) CannotExecuteInParallel() bool {
   417  	return ov.cannotParallel
   418  }
   419  
   420  func (ov *overload) GetExecuteMethod() (executeLogicOfOverload, executeFreeOfOverload) {
   421  	if ov.newOpWithFree != nil {
   422  		fn, fnFree := ov.newOpWithFree()
   423  		return fn, fnFree
   424  	}
   425  
   426  	fn := ov.newOp()
   427  	return fn, nil
   428  }
   429  
   430  func (ov *overload) GetReturnTypeMethod() func(parameters []types.Type) types.Type {
   431  	return ov.retType
   432  }
   433  
   434  func (ov *overload) IsWin() bool {
   435  	return ov.isWin
   436  }
   437  
   438  func (fn *FuncNew) isFunction() bool {
   439  	return fn.layout == STANDARD_FUNCTION || fn.layout >= NOPARAMETER_FUNCTION
   440  }
   441  
   442  func (fn *FuncNew) isAggregate() bool {
   443  	return fn.testFlag(plan.Function_AGG)
   444  }
   445  
   446  func (fn *FuncNew) isWindow() bool {
   447  	return fn.testFlag(plan.Function_WIN_ORDER) || fn.testFlag(plan.Function_WIN_VALUE) || fn.testFlag(plan.Function_AGG)
   448  }
   449  
   450  func (fn *FuncNew) isWindowOrder() bool {
   451  	return fn.testFlag(plan.Function_WIN_ORDER)
   452  }
   453  
   454  func (fn *FuncNew) testFlag(funcFlag plan.Function_FuncFlag) bool {
   455  	return fn.class&funcFlag != 0
   456  }
   457  
   458  type overloadCheckSituation int
   459  
   460  const (
   461  	succeedMatched                overloadCheckSituation = 0
   462  	succeedWithCast               overloadCheckSituation = -1
   463  	failedFunctionParametersWrong overloadCheckSituation = -2
   464  	failedAggParametersWrong      overloadCheckSituation = -3
   465  	failedTooManyFunctionMatched  overloadCheckSituation = -4
   466  )
   467  
   468  type checkResult struct {
   469  	status overloadCheckSituation
   470  
   471  	// if matched
   472  	idx       int
   473  	finalType []types.Type
   474  }
   475  
   476  func newCheckResultWithSuccess(overloadId int) checkResult {
   477  	return checkResult{status: succeedMatched, idx: overloadId}
   478  }
   479  
   480  func newCheckResultWithFailure(status overloadCheckSituation) checkResult {
   481  	return checkResult{status: status}
   482  }
   483  
   484  func newCheckResultWithCast(overloadId int, castType []types.Type) checkResult {
   485  	return checkResult{
   486  		status:    succeedWithCast,
   487  		idx:       overloadId,
   488  		finalType: castType,
   489  	}
   490  }