github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/plan/function/function.go (about) 1 // Copyright 2021 - 2022 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package function 16 17 import ( 18 "context" 19 "math" 20 21 "github.com/matrixorigin/matrixone/pkg/common/moerr" 22 "github.com/matrixorigin/matrixone/pkg/container/types" 23 "github.com/matrixorigin/matrixone/pkg/container/vector" 24 "github.com/matrixorigin/matrixone/pkg/pb/plan" 25 "github.com/matrixorigin/matrixone/pkg/vm/process" 26 ) 27 28 const ( 29 // ScalarNull means of scalar NULL 30 // which can meet each required type. 31 // e.g. 32 // if we input a SQL `select built_in_function(columnA, NULL);`, and columnA is int64 column. 33 // it will use [types.T_int64, ScalarNull] to match function when we were building the query plan. 34 ScalarNull = types.T_any 35 ) 36 37 var ( 38 // an empty type structure just for return when we couldn't meet any function. 39 emptyType = types.Type{} 40 41 // AndFunctionEncodedID is the encoded overload id of And(bool, bool) 42 // used to make an AndExpr 43 AndFunctionEncodedID = EncodeOverloadID(AND, 0) 44 AndFunctionName = "and" 45 ) 46 47 // Functions records all overloads of the same function name 48 // and its function-id and type-check-function 49 type Functions struct { 50 Id int 51 52 Flag plan.Function_FuncFlag 53 54 // Layout adapt to plan/function.go, used for `explain SQL`. 55 Layout FuncExplainLayout 56 57 // TypeCheckFn checks if the input parameters can satisfy one of the overloads 58 // and returns its index id. 59 // if type convert should happen, return the target-types at the same time. 60 TypeCheckFn func(overloads []Function, inputs []types.T) (overloadIndex int32, ts []types.T) 61 62 Overloads []Function 63 } 64 65 // TypeCheck do type check work for a function, 66 // if the input params matched one of function's overloads. 67 // returns overload-index-number, target-type 68 // just set target-type nil if there is no need to do implicit-type-conversion for parameters 69 func (fs *Functions) TypeCheck(args []types.T) (int32, []types.T) { 70 if fs.TypeCheckFn == nil { 71 return normalTypeCheck(fs.Overloads, args) 72 } 73 return fs.TypeCheckFn(fs.Overloads, args) 74 } 75 76 func normalTypeCheck(overloads []Function, inputs []types.T) (overloadIndex int32, ts []types.T) { 77 matched := make([]int32, 0, 4) // function overload which can be matched directly 78 byCast := make([]int32, 0, 4) // function overload which can be matched according to type cast 79 convertCost := make([]int, 0, 4) // records the cost of conversion for byCast 80 for i, f := range overloads { 81 c, cost := tryToMatch(inputs, f.Args) 82 switch c { 83 case matchedDirectly: 84 matched = append(matched, int32(i)) 85 case matchedByConvert: 86 byCast = append(byCast, int32(i)) 87 convertCost = append(convertCost, cost) 88 case matchedFailed: 89 continue 90 } 91 } 92 if len(matched) == 1 { 93 return matched[0], nil 94 } else if len(matched) == 0 && len(byCast) > 0 { 95 // choose the overload with the least number of conversions 96 min, index := math.MaxInt32, 0 97 for j := range convertCost { 98 if convertCost[j] < min { 99 index = j 100 min = convertCost[j] 101 } 102 } 103 return byCast[index], overloads[byCast[index]].Args 104 } else if len(matched) > 1 { 105 // if contains any scalar null as param, just return the first matched. 106 for j := range inputs { 107 if inputs[j] == ScalarNull { 108 return matched[0], nil 109 } 110 } 111 return tooManyFunctionsMatched, nil 112 } 113 return wrongFunctionParameters, nil 114 } 115 116 // Function is an overload of 117 // a built-in function or an aggregate function or an operator 118 type Function struct { 119 // Index is the function's location number of all the overloads with the same functionName. 120 Index int32 121 122 // Volatile function cannot be fold 123 Volatile bool 124 125 // RealTimeRelate function cannot be folded when in prepare statement 126 RealTimeRelated bool 127 128 // whether the function needs to append a hidden parameter, such as 'uuid' 129 AppendHideArg bool 130 131 Args []types.T 132 ReturnTyp types.T 133 134 // Fn is implementation of built-in function and operator 135 // it received vector list, and return result vector. 136 Fn func(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) 137 138 // AggregateInfo is related information about aggregate function. 139 AggregateInfo int 140 141 // Info records information about the function overload used to print 142 Info string 143 144 flag plan.Function_FuncFlag 145 146 // Layout adapt to plan/function.go, used for `explain SQL`. 147 layout FuncExplainLayout 148 149 UseNewFramework bool 150 ResultWillNotNull bool 151 FlexibleReturnType func(parameters []types.Type) types.Type 152 ParameterMustScalar []bool 153 NewFn func(parameters []*vector.Vector, result vector.FunctionResultWrapper, proc *process.Process, length int) error 154 } 155 156 func (f *Function) TestFlag(funcFlag plan.Function_FuncFlag) bool { 157 return f.flag&funcFlag != 0 158 } 159 160 func (f *Function) GetLayout() FuncExplainLayout { 161 return f.layout 162 } 163 164 // ReturnType return result-type of function, and the result is nullable 165 // if nullable is false, function won't return a vector with null value. 166 func (f Function) ReturnType(args []types.Type) (typ types.Type, nullable bool) { 167 if f.FlexibleReturnType != nil { 168 return f.FlexibleReturnType(args), !f.ResultWillNotNull 169 } 170 return f.ReturnTyp.ToType(), !f.ResultWillNotNull 171 } 172 173 func (f Function) VecFn(vs []*vector.Vector, proc *process.Process) (*vector.Vector, error) { 174 if f.Fn == nil { 175 return nil, moerr.NewInternalError(proc.Ctx, "no function") 176 } 177 return f.Fn(vs, proc) 178 } 179 180 func (f Function) IsAggregate() bool { 181 return f.TestFlag(plan.Function_AGG) 182 } 183 184 func (f Function) isFunction() bool { 185 return f.GetLayout() == STANDARD_FUNCTION || f.GetLayout() >= NOPARAMETER_FUNCTION 186 } 187 188 // functionRegister records the information about 189 // all the operator, built-function and aggregate function. 190 // 191 // For use in other packages, see GetFunctionByID and GetFunctionByName 192 var functionRegister []Functions 193 194 // get function id from map functionIdRegister, see functionIds.go 195 func fromNameToFunctionIdWithoutError(name string) (int32, bool) { 196 if fid, ok := functionIdRegister[name]; ok { 197 return fid, true 198 } 199 return -1, false 200 } 201 202 // get function id from map functionIdRegister, see functionIds.go 203 func fromNameToFunctionId(ctx context.Context, name string) (int32, error) { 204 if fid, ok := functionIdRegister[name]; ok { 205 return fid, nil 206 } 207 return -1, moerr.NewNotSupported(ctx, "function or operator '%s'", name) 208 } 209 210 // EncodeOverloadID convert function-id and overload-index to be an overloadID 211 // the high 32-bit is function-id, the low 32-bit is overload-index 212 func EncodeOverloadID(fid int32, index int32) (overloadID int64) { 213 overloadID = int64(fid) 214 overloadID = overloadID << 32 215 overloadID |= int64(index) 216 return overloadID 217 } 218 219 // DecodeOverloadID convert overload id to be function-id and overload-index 220 func DecodeOverloadID(overloadID int64) (fid int32, index int32) { 221 base := overloadID 222 index = int32(overloadID) 223 fid = int32(base >> 32) 224 return fid, index 225 } 226 227 // GetFunctionByIDWithoutError get function structure by its index id. 228 func GetFunctionByIDWithoutError(overloadID int64) (*Function, bool) { 229 fid, overloadIndex := DecodeOverloadID(overloadID) 230 if int(fid) < len(functionRegister) { 231 fs := functionRegister[fid].Overloads 232 return &fs[overloadIndex], true 233 } else { 234 return nil, false 235 } 236 } 237 238 // GetFunctionByID get function structure by its index id. 239 func GetFunctionByID(ctx context.Context, overloadID int64) (*Function, error) { 240 fid, overloadIndex := DecodeOverloadID(overloadID) 241 if int(fid) < len(functionRegister) { 242 fs := functionRegister[fid].Overloads 243 return &fs[overloadIndex], nil 244 } else { 245 return nil, moerr.NewInvalidInput(ctx, "function overload id not found") 246 } 247 } 248 249 // deduce notNullable for function 250 // for example, create table t1(c1 int not null, c2 int, c3 int not null ,c4 int); 251 // sql select c1+1, abs(c2), cast(c3 as varchar(10)) from t1 where c1=c3; 252 // we can deduce that c1+1, cast c3 and c1=c3 is notNullable, abs(c2) is nullable 253 // this message helps optimization sometimes 254 func DeduceNotNullable(overloadID int64, args []*plan.Expr) bool { 255 function, _ := GetFunctionByIDWithoutError(overloadID) 256 if function.TestFlag(plan.Function_PRODUCE_NO_NULL) { 257 return true 258 } 259 260 for _, arg := range args { 261 if !arg.Typ.NotNullable { 262 return false 263 } 264 } 265 return true 266 } 267 268 func GetFunctionIsAggregateByName(name string) bool { 269 fid, exists := fromNameToFunctionIdWithoutError(name) 270 if !exists { 271 return false 272 } 273 fs := functionRegister[fid].Overloads 274 return len(fs) > 0 && fs[0].IsAggregate() 275 } 276 277 // Check whether the function needs to append a hidden parameter 278 func GetFunctionAppendHideArgByID(overloadID int64) bool { 279 function, exists := GetFunctionByIDWithoutError(overloadID) 280 if !exists { 281 return false 282 } 283 return function.AppendHideArg 284 } 285 286 func GetFunctionIsMonotonicById(ctx context.Context, overloadID int64) (bool, error) { 287 function, err := GetFunctionByID(ctx, overloadID) 288 if err != nil { 289 return false, err 290 } 291 // if function cann't be fold, we think that will be not monotonic 292 if function.Volatile { 293 return false, nil 294 } 295 isMonotonic := function.TestFlag(plan.Function_MONOTONIC) 296 return isMonotonic, nil 297 } 298 299 // GetFunctionByName check a function exist or not according to input function name and arg types, 300 // if matches, 301 // return the encoded overload id and the overload's return type 302 // and final converted argument types( it will be nil if there's no need to do type level-up work). 303 func GetFunctionByName(ctx context.Context, name string, args []types.Type) (int64, types.Type, []types.Type, error) { 304 fid, err := fromNameToFunctionId(ctx, name) 305 if err != nil { 306 return -1, emptyType, nil, err 307 } 308 fs := functionRegister[fid] 309 310 argTs := getOidSlice(args) 311 index, targetTs := fs.TypeCheck(argTs) 312 313 // if implicit type conversion happens, set the right precision for target types. 314 targetTypes := getTypeSlice(targetTs) 315 rewriteTypesIfNecessary(targetTypes, args) 316 317 var finalTypes []types.Type 318 if targetTs != nil { 319 finalTypes = targetTypes 320 } else { 321 finalTypes = args 322 } 323 324 // deal the failed situations 325 switch index { 326 case wrongFunctionParameters: 327 ArgsToPrint := getOidSlice(finalTypes) // arg information to print for error message 328 if len(fs.Overloads) > 0 && fs.Overloads[0].isFunction() { 329 return -1, emptyType, nil, moerr.NewInvalidArg(ctx, "function "+name, ArgsToPrint) 330 } 331 return -1, emptyType, nil, moerr.NewInvalidArg(ctx, "operator "+name, ArgsToPrint) 332 case tooManyFunctionsMatched: 333 return -1, emptyType, nil, moerr.NewInvalidArg(ctx, "too many overloads matched "+name, args) 334 case wrongFuncParamForAgg: 335 ArgsToPrint := getOidSlice(finalTypes) 336 return -1, emptyType, nil, moerr.NewInvalidArg(ctx, "aggregate function "+name, ArgsToPrint) 337 } 338 339 // make the real return type of function overload. 340 rt := getRealReturnType(fid, fs.Overloads[index], finalTypes) 341 342 return EncodeOverloadID(fid, index), rt, targetTypes, nil 343 } 344 345 func ensureBinaryOperatorWithSamePrecision(targets []types.Type, hasSet []bool) { 346 if len(targets) == 2 && targets[0].Oid == targets[1].Oid { 347 if hasSet[0] && !hasSet[1] { // precision follow the left-part 348 copyType(&targets[1], &targets[0]) 349 hasSet[1] = true 350 } else if !hasSet[0] && hasSet[1] { // precision follow the right-part 351 copyType(&targets[0], &targets[1]) 352 hasSet[0] = true 353 } 354 } 355 } 356 357 func rewriteTypesIfNecessary(targets []types.Type, sources []types.Type) { 358 if len(targets) != 0 { 359 hasSet := make([]bool, len(sources)) 360 361 //ensure that we will not lost the origin scale 362 maxScale := int32(0) 363 for i := range sources { 364 if sources[i].Oid == types.T_decimal64 || sources[i].Oid == types.T_decimal128 { 365 if sources[i].Scale > maxScale { 366 maxScale = sources[i].Scale 367 } 368 } 369 } 370 for i := range sources { 371 if targets[i].Oid == types.T_decimal64 || targets[i].Oid == types.T_decimal128 { 372 if sources[i].Scale < maxScale { 373 sources[i].Scale = maxScale 374 } 375 } 376 } 377 378 for i := range targets { 379 oid1, oid2 := sources[i].Oid, targets[i].Oid 380 // ensure that we will not lose the original precision. 381 if oid2 == types.T_decimal64 || oid2 == types.T_decimal128 || oid2 == types.T_timestamp || oid2 == types.T_time { 382 if oid1 != types.T_char && oid1 != types.T_varchar && oid1 != types.T_blob && oid1 != types.T_text { 383 copyType(&targets[i], &sources[i]) 384 hasSet[i] = true 385 } 386 } 387 } 388 ensureBinaryOperatorWithSamePrecision(targets, hasSet) 389 for i := range targets { 390 if !hasSet[i] && targets[i].Oid != ScalarNull { 391 setDefaultPrecision(&targets[i]) 392 } 393 } 394 } 395 } 396 397 // set default precision / scalar / width for a type 398 func setDefaultPrecision(typ *types.Type) { 399 if typ.Oid == types.T_decimal64 { 400 typ.Scale = 0 401 typ.Width = 18 402 } else if typ.Oid == types.T_decimal128 { 403 typ.Scale = 0 404 typ.Width = 38 405 } else if typ.Oid == types.T_timestamp { 406 typ.Precision = 6 407 } else if typ.Oid == types.T_datetime { 408 typ.Precision = 6 409 } else if typ.Oid == types.T_time { 410 typ.Precision = 6 411 } 412 typ.Size = int32(typ.Oid.TypeLen()) 413 } 414 415 func getRealReturnType(fid int32, f Function, realArgs []types.Type) types.Type { 416 if f.IsAggregate() { 417 switch fid { 418 case MIN, MAX: 419 if realArgs[0].Oid != ScalarNull { 420 return realArgs[0] 421 } 422 } 423 } 424 if f.FlexibleReturnType != nil { 425 return f.FlexibleReturnType(realArgs) 426 } 427 rt := f.ReturnTyp.ToType() 428 for i := range realArgs { 429 if realArgs[i].Oid == rt.Oid { 430 copyType(&rt, &realArgs[i]) 431 checkTypeWidth(realArgs, &rt) 432 break 433 } 434 if types.T(rt.Oid) == types.T_decimal128 && types.T(realArgs[i].Oid) == types.T_decimal64 { 435 copyType(&rt, &realArgs[i]) 436 } 437 } 438 return rt 439 } 440 441 func checkTypeWidth(realArgs []types.Type, rt *types.Type) { 442 for i := range realArgs { 443 if realArgs[i].Oid == rt.Oid && rt.Width < realArgs[i].Width { 444 rt.Width = realArgs[i].Width 445 } 446 } 447 } 448 449 func copyType(dst, src *types.Type) { 450 dst.Width = src.Width 451 dst.Scale = src.Scale 452 dst.Precision = src.Precision 453 } 454 455 func getOidSlice(ts []types.Type) []types.T { 456 ret := make([]types.T, len(ts)) 457 for i := range ts { 458 ret[i] = ts[i].Oid 459 } 460 return ret 461 } 462 463 func getTypeSlice(ts []types.T) []types.Type { 464 ret := make([]types.Type, len(ts)) 465 for i := range ts { 466 ret[i] = ts[i].ToType() 467 } 468 return ret 469 }