github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/sql/sem/tree/overload.go (about) 1 // Copyright 2016 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package tree 12 13 import ( 14 "bytes" 15 "context" 16 "fmt" 17 "math" 18 "strings" 19 "sync" 20 21 "github.com/cockroachdb/cockroachdb-parser/pkg/sql/pgwire/pgcode" 22 "github.com/cockroachdb/cockroachdb-parser/pkg/sql/pgwire/pgerror" 23 "github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/volatility" 24 "github.com/cockroachdb/cockroachdb-parser/pkg/sql/types" 25 "github.com/cockroachdb/cockroachdb-parser/pkg/util/intsets" 26 "github.com/cockroachdb/errors" 27 "github.com/cockroachdb/redact" 28 "github.com/lib/pq/oid" 29 ) 30 31 // SpecializedVectorizedBuiltin is used to map overloads 32 // to the vectorized operator that is specific to 33 // that implementation of the builtin function. 34 type SpecializedVectorizedBuiltin int 35 36 // TODO (rohany): What is the best place to put this list? 37 // I want to put it in builtins or exec, but those create an import 38 // cycle with exec. tree is imported by both of them, so 39 // this package seems like a good place to do it. 40 41 // Keep this list alphabetized so that it is easy to manage. 42 const ( 43 _ SpecializedVectorizedBuiltin = iota 44 SubstringStringIntInt 45 CrdbInternalRangeStats 46 ) 47 48 // AggregateOverload is an opaque type which is used to box an eval.AggregateOverload. 49 type AggregateOverload interface { 50 // Aggregate is just a marker method so folks don't think you can just shove 51 // anything here. It ought to be an eval.AggregateOverload. 52 Aggregate() // marker interface 53 } 54 55 // WindowOverload is an opaque type which is used to box an eval.WindowOverload. 56 type WindowOverload interface { 57 // Window is just a marker method so folks don't think you can just shove 58 // anything here. It ought to be an eval.WindowOverload. 59 Window() 60 } 61 62 // FnWithExprsOverload is an opaque type used to box an 63 // eval.FnWithExprsOverload. 64 type FnWithExprsOverload interface { 65 FnWithExprs() 66 } 67 68 // FnOverload is an opaque type used to box an eval.FnOverload. 69 // 70 // TODO(ajwerner): Give this a marker method and convert all usages. 71 // This is onerous at time of writing because there are so many. 72 type FnOverload interface{} 73 74 // GeneratorOverload is an opaque type used to box an eval.GeneratorOverload. 75 type GeneratorOverload interface { 76 Generator() 77 } 78 79 // GeneratorWithExprsOverload is an opaque type used to box an eval.GeneratorWithExprsOverload. 80 type GeneratorWithExprsOverload interface { 81 GeneratorWithExprs() 82 } 83 84 // SQLFnOverload is an opaque type used to box an eval.SQLFnOverload. 85 type SQLFnOverload interface { 86 SQLFn() 87 } 88 89 // FunctionClass specifies the class of the builtin function. 90 type FunctionClass int 91 92 const ( 93 // NormalClass is a standard builtin function. 94 NormalClass FunctionClass = iota 95 // AggregateClass is a builtin aggregate function. 96 AggregateClass 97 // WindowClass is a builtin window function. 98 WindowClass 99 // GeneratorClass is a builtin generator function. 100 GeneratorClass 101 // SQLClass is a builtin function that executes a SQL statement as a side 102 // effect of the function call. 103 // 104 // For example, AddGeometryColumn is a SQLClass function that executes an 105 // ALTER TABLE ... ADD COLUMN statement to add a geometry column to an 106 // existing table. It returns metadata about the column added. 107 // 108 // All builtin functions of this class should include a definition for 109 // Overload.SQLFn, which returns the SQL statement to be executed. They 110 // should also include a definition for Overload.Fn, which is executed 111 // like a NormalClass function and returns a Datum. 112 SQLClass 113 ) 114 115 // String returns the string representation of the function class. 116 func (c FunctionClass) String() string { 117 switch c { 118 case NormalClass: 119 return "normal" 120 case AggregateClass: 121 return "aggregate" 122 case WindowClass: 123 return "window" 124 case GeneratorClass: 125 return "generator" 126 case SQLClass: 127 return "SQL" 128 default: 129 panic(errors.AssertionFailedf("unexpected class %d", c)) 130 } 131 } 132 133 // RoutineType specifies the type of routine represented by an overload. 134 type RoutineType uint8 135 136 const ( 137 // BuiltinRoutine is a builtin function. 138 BuiltinRoutine RoutineType = 1 << iota 139 // UDFRoutine is a user-defined function. 140 UDFRoutine 141 // ProcedureRoutine is a user-defined procedure. 142 ProcedureRoutine 143 ) 144 145 // String returns the string representation of the routine type. 146 func (t RoutineType) String() string { 147 switch t { 148 case BuiltinRoutine: 149 return "builtin" 150 case UDFRoutine: 151 return "udf" 152 case ProcedureRoutine: 153 return "procedure" 154 default: 155 panic(errors.AssertionFailedf("unexpected routine type %d", t)) 156 } 157 } 158 159 // Overload is one of the overloads of a built-in function. 160 // Each FunctionDefinition may contain one or more overloads. 161 type Overload struct { 162 Types TypeList 163 ReturnType ReturnTyper 164 Volatility volatility.V 165 166 // PreferredOverload determines overload resolution as follows. 167 // When multiple overloads are eligible based on types even after all of of 168 // the heuristics to pick one have been used, if one of the overloads is a 169 // Overload with the `PreferredOverload` flag set to true it can be selected 170 // rather than returning a no-such-method error. 171 // This should generally be avoided -- avoiding introducing ambiguous 172 // overloads in the first place is a much better solution -- and only done 173 // after consultation with @knz @nvanbenschoten. 174 PreferredOverload bool 175 176 // Info is a description of the function, which is surfaced on the CockroachDB 177 // docs site on the "Functions and Operators" page. Descriptions typically use 178 // third-person with the function as an implicit subject (e.g. "Calculates 179 // infinity"), but should focus more on ease of understanding so other structures 180 // might be more appropriate. 181 Info string 182 183 AggregateFunc AggregateOverload 184 WindowFunc WindowOverload 185 186 // Only one of the "Fn", "FnWithExprs", "Generate", "GeneratorWithExprs", 187 // "SQLFn" and "Body" attributes can be set. 188 189 // Class is the kind of built-in function (normal/aggregate/window/etc.) 190 Class FunctionClass 191 192 // Fn is the normal builtin implementation function. It's for functions that 193 // take in Datums and return a Datum. 194 // 195 // The opaque wrapper needs to be type asserted into eval.FnOverload. 196 Fn FnOverload 197 198 // FnWithExprs is for builtins that need access to their arguments as Exprs 199 // and not pre-evaluated Datums, but is otherwise identical to Fn. 200 FnWithExprs FnWithExprsOverload 201 202 // Generator is for SRFs. SRFs take Datums and return multiple rows of Datums. 203 Generator GeneratorOverload 204 205 // GeneratorWithExprs is for SRFs that need access to their arguments as Exprs 206 // and not pre-evaluated Datums, but is otherwise identical to Generator. 207 GeneratorWithExprs GeneratorWithExprsOverload 208 209 // SQLFn must be set for overloads of type SQLClass. It should return a SQL 210 // statement which will be executed as a common table expression in the query. 211 SQLFn SQLFnOverload 212 213 // OnTypeCheck is called every time this overload is type checked. 214 // This is a pointer so that it can be set in a builtinsregistry hook, which 215 // gets a copy of the overload struct. 216 OnTypeCheck *func() 217 218 // SpecializedVecBuiltin is used to let the vectorized engine 219 // know when an Overload has a specialized vectorized operator. 220 SpecializedVecBuiltin SpecializedVectorizedBuiltin 221 222 // IgnoreVolatilityCheck ignores checking the functions overload's 223 // Volatility against Postgres's Volatility at test time. 224 // This should be used with caution. 225 IgnoreVolatilityCheck bool 226 227 // Oid is the cached oidHasher.BuiltinOid result for this Overload. It's 228 // populated at init-time. 229 Oid oid.Oid 230 231 // DistsqlBlocklist is set to true when a function cannot be evaluated in 232 // DistSQL. One example is when the type information for function arguments 233 // cannot be recovered. 234 DistsqlBlocklist bool 235 236 // CalledOnNullInput is set to true when a function is called when any of 237 // its inputs are NULL. When true, the function implementation must be able 238 // to handle NULL arguments. 239 // 240 // When set to false, the function will directly result in NULL in the 241 // presence of any NULL arguments without evaluating the function's 242 // implementation. Therefore, if the function is expected to produce 243 // side-effects with a NULL argument, CalledOnNullInput must be true. Note 244 // that if this behavior changes so that CalledOnNullInput=false functions 245 // can produce side-effects, the FoldFunctionWithNullArg optimizer rule must 246 // be changed to avoid folding those functions. 247 // 248 // NOTE: when set, a function should be prepared for any of its arguments to 249 // be NULL and should act accordingly. 250 CalledOnNullInput bool 251 252 // FunctionProperties are the properties of this overload. 253 FunctionProperties 254 255 // Type indicates if the overload represents a built-in function, a 256 // user-defined function, or a user-define procedure. 257 Type RoutineType 258 // Body is the SQL string body of a function. It can be set even if Type is 259 // BuiltinRoutine if a builtin function is defined using a SQL string. 260 Body string 261 // UDFContainsOnlySignature is only set to true for Overload signatures cached 262 // in a Schema descriptor, which means that the full UDF descriptor need to be 263 // fetched to get more info, e.g. function Body. 264 UDFContainsOnlySignature bool 265 // ReturnSet is set to true when a user-defined function is defined to return 266 // a set of values. 267 ReturnSet bool 268 // Version is the descriptor version of the descriptor used to construct 269 // this version of the function overload. Only used for UDFs. 270 Version uint64 271 // Language is the function language that was used to define the UDF. 272 // This is currently either SQL or PL/pgSQL. 273 Language RoutineLanguage 274 } 275 276 // params implements the overloadImpl interface. 277 func (b Overload) params() TypeList { return b.Types } 278 279 // returnType implements the overloadImpl interface. 280 func (b Overload) returnType() ReturnTyper { return b.ReturnType } 281 282 // preferred implements the overloadImpl interface. 283 func (b Overload) preferred() bool { return b.PreferredOverload } 284 285 // FixedReturnType returns a fixed type that the function returns, returning Any 286 // if the return type is based on the function's arguments. 287 func (b Overload) FixedReturnType() *types.T { 288 if b.ReturnType == nil { 289 return nil 290 } 291 return returnTypeToFixedType(b.ReturnType, nil) 292 } 293 294 // InferReturnTypeFromInputArgTypes returns the type that the function returns, 295 // inferring the type based on the function's inputTypes if necessary. 296 func (b Overload) InferReturnTypeFromInputArgTypes(inputTypes []*types.T) *types.T { 297 retTyp := b.FixedReturnType() 298 // If the output type of the function depends on its inputs, then 299 // the output of FixedReturnType will be ambiguous. In the ambiguous 300 // cases, use the information about the input types to construct the 301 // appropriate output type. The tree.ReturnTyper interface is 302 // []tree.TypedExpr -> *types.T, so construct the []tree.TypedExpr 303 // from the types that we know are the inputs. Note that we don't 304 // try to create datums of each input type, and instead use this 305 // "TypedDummy" construct. This is because some types don't have resident 306 // members (like an ENUM with no values), and we shouldn't error out 307 // trying to infer the return type in those cases. 308 if retTyp.IsAmbiguous() { 309 args := make([]TypedExpr, len(inputTypes)) 310 for i, t := range inputTypes { 311 args[i] = &TypedDummy{Typ: t} 312 } 313 // Evaluate ReturnType with the fake input set of arguments. 314 retTyp = returnTypeToFixedType(b.ReturnType, args) 315 } 316 return retTyp 317 } 318 319 // IsGenerator returns true if the function is a set returning function (SRF). 320 func (b Overload) IsGenerator() bool { 321 return b.Generator != nil || b.GeneratorWithExprs != nil 322 } 323 324 // HasSQLBody returns true if the function was defined using a SQL string body. 325 // This is the case for user-defined functions and some builtins. 326 func (b Overload) HasSQLBody() bool { 327 return b.Type == UDFRoutine || b.Type == ProcedureRoutine || b.Body != "" 328 } 329 330 // Signature returns a human-readable signature. 331 // If simplify is bool, tuple-returning functions with just 332 // 1 tuple element unwrap the return type in the signature. 333 func (b Overload) Signature(simplify bool) string { 334 retType := b.FixedReturnType() 335 if simplify { 336 if retType.Family() == types.TupleFamily && len(retType.TupleContents()) == 1 { 337 retType = retType.TupleContents()[0] 338 } 339 } 340 return fmt.Sprintf("(%s) -> %s", b.Types.String(), retType) 341 } 342 343 // overloadImpl is an implementation of an overloaded function. It provides 344 // access to the parameter type list and the return type of the implementation. 345 // 346 // This is a more general type than Overload defined above, because it also 347 // works with the built-in binary and unary operators. 348 type overloadImpl interface { 349 params() TypeList 350 returnType() ReturnTyper 351 // allows manually resolving preference between multiple compatible overloads. 352 preferred() bool 353 } 354 355 var _ overloadImpl = &Overload{} 356 var _ overloadImpl = &UnaryOp{} 357 var _ overloadImpl = &BinOp{} 358 var _ overloadImpl = &CmpOp{} 359 360 // GetParamsAndReturnType gets the parameters and return type of an 361 // overloadImpl. 362 func GetParamsAndReturnType(impl overloadImpl) (TypeList, ReturnTyper) { 363 return impl.params(), impl.returnType() 364 } 365 366 // TypeList is a list of types representing a function parameter list. 367 type TypeList interface { 368 // Match checks if all types in the TypeList match the corresponding elements in types. 369 Match(types []*types.T) bool 370 // MatchIdentical is similar to match but checks that the types are identical matches, 371 //instead of equivalent matches. See types.T.Equivalent and types.T.Identical. 372 MatchIdentical(types []*types.T) bool 373 // MatchAt checks if the parameter type at index i of the TypeList matches type typ. 374 // In all implementations, types.Null will match with each parameter type, allowing 375 // NULL values to be used as arguments. 376 MatchAt(typ *types.T, i int) bool 377 // MatchAtIdentical is similar to MatchAt but checks that the type at index i of 378 // the Typelist is identical to typ. 379 MatchAtIdentical(typ *types.T, i int) bool 380 // MatchLen checks that the TypeList can support l parameters. 381 MatchLen(l int) bool 382 // GetAt returns the type at the given index in the TypeList, or nil if the TypeList 383 // cannot have a parameter at index i. 384 GetAt(i int) *types.T 385 // Length returns the number of types in the list 386 Length() int 387 // Types returns a realized copy of the list. variadic lists return a list of size one. 388 Types() []*types.T 389 // String returns a human readable signature 390 String() string 391 } 392 393 var _ TypeList = ParamTypes{} 394 var _ TypeList = HomogeneousType{} 395 var _ TypeList = VariadicType{} 396 397 // ParamTypes is a list of function parameter names and their types. 398 type ParamTypes []ParamType 399 400 // ParamType encapsulate a function parameter name and type. 401 type ParamType struct { 402 Name string 403 Typ *types.T 404 } 405 406 // Match is part of the TypeList interface. 407 func (p ParamTypes) Match(types []*types.T) bool { 408 if len(types) != len(p) { 409 return false 410 } 411 for i := range types { 412 if !p.MatchAt(types[i], i) { 413 return false 414 } 415 } 416 return true 417 } 418 419 // MatchIdentical is part of the TypeList interface. 420 func (p ParamTypes) MatchIdentical(types []*types.T) bool { 421 if len(types) != len(p) { 422 return false 423 } 424 for i := range types { 425 if !p.MatchAtIdentical(types[i], i) { 426 return false 427 } 428 } 429 return true 430 } 431 432 // MatchAt is part of the TypeList interface. 433 func (p ParamTypes) MatchAt(typ *types.T, i int) bool { 434 // The parameterized types for Tuples are checked in the type checking 435 // routines before getting here, so we only need to check if the parameter 436 // type is p types.TUPLE below. This allows us to avoid defining overloads 437 // for types.Tuple{}, types.Tuple{types.Any}, types.Tuple{types.Any, types.Any}, 438 // etc. for Tuple operators. 439 if typ.Family() == types.TupleFamily { 440 typ = types.AnyTuple 441 } 442 return i < len(p) && (typ.Family() == types.UnknownFamily || p[i].Typ.Equivalent(typ)) 443 } 444 445 // MatchAtIdentical is part of the TypeList interface. 446 func (p ParamTypes) MatchAtIdentical(typ *types.T, i int) bool { 447 return i < len(p) && (typ.Family() == types.UnknownFamily || p[i].Typ.Identical(typ)) 448 } 449 450 // MatchLen is part of the TypeList interface. 451 func (p ParamTypes) MatchLen(l int) bool { 452 return len(p) == l 453 } 454 455 // GetAt is part of the TypeList interface. 456 func (p ParamTypes) GetAt(i int) *types.T { 457 return p[i].Typ 458 } 459 460 // SetAt is part of the TypeList interface. 461 func (p ParamTypes) SetAt(i int, name string, t *types.T) { 462 p[i].Name = name 463 p[i].Typ = t 464 } 465 466 // Length is part of the TypeList interface. 467 func (p ParamTypes) Length() int { 468 return len(p) 469 } 470 471 // Types is part of the TypeList interface. 472 func (p ParamTypes) Types() []*types.T { 473 n := len(p) 474 ret := make([]*types.T, n) 475 for i, s := range p { 476 ret[i] = s.Typ 477 } 478 return ret 479 } 480 481 func (p ParamTypes) String() string { 482 var s strings.Builder 483 for i, param := range p { 484 if i > 0 { 485 s.WriteString(", ") 486 } 487 s.WriteString(param.Name) 488 s.WriteString(": ") 489 s.WriteString(param.Typ.String()) 490 } 491 return s.String() 492 } 493 494 // HomogeneousType is a TypeList implementation that accepts any arguments, as 495 // long as all are the same type or NULL. The homogeneous constraint is enforced 496 // in typeCheckOverloadedExprs. 497 type HomogeneousType struct{} 498 499 // Match is part of the TypeList interface. 500 func (HomogeneousType) Match(types []*types.T) bool { 501 return true 502 } 503 504 // MatchIdentical is part of the TypeList interface. 505 func (HomogeneousType) MatchIdentical(types []*types.T) bool { 506 return true 507 } 508 509 // MatchAt is part of the TypeList interface. 510 func (HomogeneousType) MatchAt(typ *types.T, i int) bool { 511 return true 512 } 513 514 // MatchAtIdentical is part of the TypeList interface. 515 func (HomogeneousType) MatchAtIdentical(typ *types.T, i int) bool { 516 return true 517 } 518 519 // MatchLen is part of the TypeList interface. 520 func (HomogeneousType) MatchLen(l int) bool { 521 return true 522 } 523 524 // GetAt is part of the TypeList interface. 525 func (HomogeneousType) GetAt(i int) *types.T { 526 return types.Any 527 } 528 529 // Length is part of the TypeList interface. 530 func (HomogeneousType) Length() int { 531 return 1 532 } 533 534 // Types is part of the TypeList interface. 535 func (HomogeneousType) Types() []*types.T { 536 return []*types.T{types.Any} 537 } 538 539 func (HomogeneousType) String() string { 540 return "anyelement..." 541 } 542 543 // VariadicType is a TypeList implementation which accepts a fixed number of 544 // arguments at the beginning and an arbitrary number of homogenous arguments 545 // at the end. 546 type VariadicType struct { 547 FixedTypes []*types.T 548 VarType *types.T 549 } 550 551 // Match is part of the TypeList interface. 552 func (v VariadicType) Match(types []*types.T) bool { 553 for i := range types { 554 if !v.MatchAt(types[i], i) { 555 return false 556 } 557 } 558 return true 559 } 560 561 // MatchIdentical is part of the TypeList interface. 562 func (VariadicType) MatchIdentical(types []*types.T) bool { 563 return true 564 } 565 566 // MatchAt is part of the TypeList interface. 567 func (v VariadicType) MatchAt(typ *types.T, i int) bool { 568 if i < len(v.FixedTypes) { 569 return typ.Family() == types.UnknownFamily || v.FixedTypes[i].Equivalent(typ) 570 } 571 return typ.Family() == types.UnknownFamily || v.VarType.Equivalent(typ) 572 } 573 574 // MatchAtIdentical is part of the TypeList interface. 575 // TODO(mgartner): This will be incorrect once we add support for variadic UDFs 576 func (VariadicType) MatchAtIdentical(typ *types.T, i int) bool { 577 return true 578 } 579 580 // MatchLen is part of the TypeList interface. 581 func (v VariadicType) MatchLen(l int) bool { 582 return l >= len(v.FixedTypes) 583 } 584 585 // GetAt is part of the TypeList interface. 586 func (v VariadicType) GetAt(i int) *types.T { 587 if i < len(v.FixedTypes) { 588 return v.FixedTypes[i] 589 } 590 return v.VarType 591 } 592 593 // Length is part of the TypeList interface. 594 func (v VariadicType) Length() int { 595 return len(v.FixedTypes) + 1 596 } 597 598 // Types is part of the TypeList interface. 599 func (v VariadicType) Types() []*types.T { 600 result := make([]*types.T, len(v.FixedTypes)+1) 601 copy(result, v.FixedTypes) 602 result[len(result)-1] = v.VarType 603 return result 604 } 605 606 func (v VariadicType) String() string { 607 var s bytes.Buffer 608 for i, t := range v.FixedTypes { 609 if i != 0 { 610 s.WriteString(", ") 611 } 612 s.WriteString(t.String()) 613 } 614 if len(v.FixedTypes) > 0 { 615 s.WriteString(", ") 616 } 617 fmt.Fprintf(&s, "%s...", v.VarType) 618 return s.String() 619 } 620 621 // UnknownReturnType is returned from ReturnTypers when the arguments provided are 622 // not sufficient to determine a return type. This is necessary for cases like overload 623 // resolution, where the argument types are not resolved yet so the type-level function 624 // will be called without argument types. If a ReturnTyper returns unknownReturnType, 625 // then the candidate function set cannot be refined. This means that only ReturnTypers 626 // that never return unknownReturnType, like those created with FixedReturnType, can 627 // help reduce overload ambiguity. 628 var UnknownReturnType *types.T 629 630 // ReturnTyper defines the type-level function in which a builtin function's return type 631 // is determined. ReturnTypers should make sure to return unknownReturnType when necessary. 632 type ReturnTyper func(args []TypedExpr) *types.T 633 634 // FixedReturnType functions simply return a fixed type, independent of argument types. 635 func FixedReturnType(typ *types.T) ReturnTyper { 636 return func(args []TypedExpr) *types.T { return typ } 637 } 638 639 // IdentityReturnType creates a returnType that is a projection of the idx'th 640 // argument type. 641 func IdentityReturnType(idx int) ReturnTyper { 642 return func(args []TypedExpr) *types.T { 643 if len(args) == 0 { 644 return UnknownReturnType 645 } 646 return args[idx].ResolvedType() 647 } 648 } 649 650 // ArrayOfFirstNonNullReturnType returns an array type from the first non-null 651 // type in the argument list. 652 func ArrayOfFirstNonNullReturnType() ReturnTyper { 653 return func(args []TypedExpr) *types.T { 654 if len(args) == 0 { 655 return UnknownReturnType 656 } 657 for _, arg := range args { 658 if t := arg.ResolvedType(); t.Family() != types.UnknownFamily { 659 return types.MakeArray(t) 660 } 661 } 662 return types.Unknown 663 } 664 } 665 666 // FirstNonNullReturnType returns the type of the first non-null argument, or 667 // types.Unknown if all arguments are null. There must be at least one argument, 668 // or else FirstNonNullReturnType returns UnknownReturnType. This method is used 669 // with HomogeneousType functions, in which all arguments have been checked to 670 // have the same type (or be null). 671 func FirstNonNullReturnType() ReturnTyper { 672 return func(args []TypedExpr) *types.T { 673 if len(args) == 0 { 674 return UnknownReturnType 675 } 676 for _, arg := range args { 677 if t := arg.ResolvedType(); t.Family() != types.UnknownFamily { 678 return t 679 } 680 } 681 return types.Unknown 682 } 683 } 684 685 func returnTypeToFixedType(s ReturnTyper, inputTyps []TypedExpr) *types.T { 686 if t := s(inputTyps); t != UnknownReturnType { 687 return t 688 } 689 return types.Any 690 } 691 692 type overloadTypeChecker struct { 693 overloads []overloadImpl 694 params []TypeList 695 overloadIdxs []uint8 // index into overloads 696 exprs []Expr 697 typedExprs []TypedExpr 698 resolvableIdxs intsets.Fast // index into exprs/typedExprs 699 constIdxs intsets.Fast // index into exprs/typedExprs 700 placeholderIdxs intsets.Fast // index into exprs/typedExprs 701 overloadsIdxArr [16]uint8 702 } 703 704 var overloadTypeCheckerPool = sync.Pool{ 705 New: func() interface{} { 706 var s overloadTypeChecker 707 s.overloadIdxs = s.overloadsIdxArr[:0] 708 return &s 709 }, 710 } 711 712 // getOverloadTypeChecker initialized an overloadTypeChecker from the pool. The 713 // returned object should be returned to the pool via its release method. 714 func getOverloadTypeChecker(o overloadSet, exprs ...Expr) *overloadTypeChecker { 715 s := overloadTypeCheckerPool.Get().(*overloadTypeChecker) 716 n := o.len() 717 if n > cap(s.overloads) { 718 s.overloads = make([]overloadImpl, n) 719 s.params = make([]TypeList, n) 720 } else { 721 s.overloads = s.overloads[:n] 722 s.params = s.params[:n] 723 } 724 for i := 0; i < n; i++ { 725 got := o.get(i) 726 s.overloads[i] = got 727 s.params[i] = got.params() 728 } 729 s.exprs = append(s.exprs, exprs...) 730 return s 731 } 732 733 func (s *overloadTypeChecker) release() { 734 for i := range s.overloads { 735 s.overloads[i] = nil 736 } 737 s.overloads = s.overloads[:0] 738 for i := range s.params { 739 s.params[i] = nil 740 } 741 s.params = s.params[:0] 742 for i := range s.exprs { 743 s.exprs[i] = nil 744 } 745 s.exprs = s.exprs[:0] 746 for i := range s.typedExprs { 747 s.typedExprs[i] = nil 748 } 749 s.typedExprs = s.typedExprs[:0] 750 s.overloadIdxs = s.overloadIdxs[:0] 751 s.resolvableIdxs = intsets.Fast{} 752 s.constIdxs = intsets.Fast{} 753 s.placeholderIdxs = intsets.Fast{} 754 overloadTypeCheckerPool.Put(s) 755 } 756 757 type overloadSet interface { 758 len() int 759 get(i int) overloadImpl 760 } 761 762 // typeCheckOverloadedExprs determines the correct overload to use for the given set of 763 // expression parameters, along with an optional desired return type. It returns the expression 764 // parameters after being type checked, along with a slice of candidate overloadImpls. The 765 // slice may have length: 766 // 767 // 0: overload resolution failed because no compatible overloads were found 768 // 1: overload resolution succeeded 769 // 2+: overload resolution failed because of ambiguity 770 // 771 // The inBinOp parameter denotes whether this type check is occurring within a binary operator, 772 // in which case we may need to make a guess that the two parameters are of the same type if one 773 // of them is NULL. 774 func (s *overloadTypeChecker) typeCheckOverloadedExprs( 775 ctx context.Context, semaCtx *SemaContext, desired *types.T, inBinOp bool, 776 ) (_ error) { 777 numOverloads := len(s.overloads) 778 if numOverloads > math.MaxUint8 { 779 return errors.AssertionFailedf("too many overloads (%d > 255)", numOverloads) 780 } 781 782 // Special-case the HomogeneousType overload. We determine its return type by checking that 783 // all parameters have the same type. 784 for i := range s.params { 785 // Only one overload can be provided if it has parameters with HomogeneousType. 786 if _, ok := s.params[i].(HomogeneousType); ok { 787 if numOverloads > 1 { 788 return errors.AssertionFailedf( 789 "only one overload can have HomogeneousType parameters") 790 } 791 typedExprs, _, err := typeCheckSameTypedExprs(ctx, semaCtx, desired, s.exprs...) 792 if err != nil { 793 return err 794 } 795 s.typedExprs = typedExprs 796 s.overloadIdxs = append(s.overloadIdxs[:0], uint8(i)) 797 return nil 798 } 799 } 800 801 // Hold the resolved type expressions of the provided exprs, in order. 802 if cap(s.typedExprs) >= len(s.exprs) { 803 s.typedExprs = s.typedExprs[:len(s.exprs)] 804 } else { 805 s.typedExprs = make([]TypedExpr, len(s.exprs)) 806 } 807 s.constIdxs, s.placeholderIdxs, s.resolvableIdxs = typeCheckSplitExprs(s.exprs) 808 809 // If no overloads are provided, just type check parameters and return. 810 if numOverloads == 0 { 811 for i, ok := s.resolvableIdxs.Next(0); ok; i, ok = s.resolvableIdxs.Next(i + 1) { 812 typ, err := s.exprs[i].TypeCheck(ctx, semaCtx, types.Any) 813 if err != nil { 814 return pgerror.Wrapf(err, pgcode.InvalidParameterValue, 815 "error type checking resolved expression:") 816 } 817 s.typedExprs[i] = typ 818 } 819 if err := defaultTypeCheck(ctx, semaCtx, s, false); err != nil { 820 return err 821 } 822 s.overloadIdxs = s.overloadIdxs[:0] 823 return nil 824 } 825 826 if cap(s.overloadIdxs) < numOverloads { 827 s.overloadIdxs = make([]uint8, 0, numOverloads) 828 } 829 s.overloadIdxs = s.overloadIdxs[:numOverloads] 830 for i := range s.overloadIdxs { 831 s.overloadIdxs[i] = uint8(i) 832 } 833 834 // Filter out incorrect parameter length overloads. 835 exprsLen := len(s.exprs) 836 matchLen := func(params TypeList) bool { return params.MatchLen(exprsLen) } 837 s.overloadIdxs = filterParams(s.overloadIdxs, s.params, matchLen) 838 839 // Filter out overloads which constants cannot become. 840 for i, ok := s.constIdxs.Next(0); ok; i, ok = s.constIdxs.Next(i + 1) { 841 constExpr := s.exprs[i].(Constant) 842 filter := func(params TypeList) bool { 843 return canConstantBecome(constExpr, params.GetAt(i)) 844 } 845 s.overloadIdxs = filterParams(s.overloadIdxs, s.params, filter) 846 } 847 848 // TODO(nvanbenschoten): We should add a filtering step here to filter 849 // out impossible candidates based on identical parameters. For instance, 850 // f(int, float) is not a possible candidate for the expression f($1, $1). 851 852 // Filter out overloads on resolved types. This includes resolved placeholders 853 // and any other resolvable exprs. 854 ambiguousCollatedTypes := false 855 var typeableIdxs intsets.Fast 856 for i, ok := s.resolvableIdxs.Next(0); ok; i, ok = s.resolvableIdxs.Next(i + 1) { 857 typeableIdxs.Add(i) 858 } 859 for i, ok := s.placeholderIdxs.Next(0); ok; i, ok = s.placeholderIdxs.Next(i + 1) { 860 if !semaCtx.isUnresolvedPlaceholder(s.exprs[i]) { 861 typeableIdxs.Add(i) 862 } 863 } 864 for i, ok := typeableIdxs.Next(0); ok; i, ok = typeableIdxs.Next(i + 1) { 865 paramDesired := types.Any 866 867 // If all remaining candidates require the same type for this parameter, 868 // begin desiring that type for the corresponding argument expression. 869 // Note that this is always the case when we have a single overload left. 870 var sameType *types.T 871 for _, ovIdx := range s.overloadIdxs { 872 ov := s.overloads[ovIdx] 873 typ := ov.params().GetAt(i) 874 if sameType == nil { 875 sameType = typ 876 } else if !typ.Identical(sameType) { 877 sameType = nil 878 break 879 } 880 } 881 // Don't allow ambiguous types to be desired, this prevents for instance 882 // AnyCollatedString from trumping the concrete collated type. 883 if sameType != nil { 884 paramDesired = sameType 885 } 886 typ, err := s.exprs[i].TypeCheck(ctx, semaCtx, paramDesired) 887 if err != nil { 888 return err 889 } 890 s.typedExprs[i] = typ 891 if typ.ResolvedType() == types.AnyCollatedString { 892 ambiguousCollatedTypes = true 893 } 894 rt := typ.ResolvedType() 895 s.overloadIdxs = filterParams(s.overloadIdxs, s.params, func( 896 params TypeList, 897 ) bool { 898 return params.MatchAt(rt, i) 899 }) 900 } 901 902 // If we typed any exprs as AnyCollatedString but have a concrete collated 903 // string type, redo the types using the contrete type. Note we're probably 904 // still lacking full compliance with PG on collation handling: 905 // https://www.postgresql.org/docs/current/collation.html#id-1.6.11.4.4 906 if ambiguousCollatedTypes { 907 var concreteType *types.T 908 for i, ok := typeableIdxs.Next(0); ok; i, ok = typeableIdxs.Next(i + 1) { 909 typ := s.typedExprs[i].ResolvedType() 910 if typ != types.AnyCollatedString { 911 concreteType = typ 912 break 913 } 914 } 915 if concreteType != nil { 916 for i, ok := typeableIdxs.Next(0); ok; i, ok = typeableIdxs.Next(i + 1) { 917 if s.typedExprs[i].ResolvedType() == types.AnyCollatedString { 918 typ, err := s.exprs[i].TypeCheck(ctx, semaCtx, concreteType) 919 if err != nil { 920 return err 921 } 922 s.typedExprs[i] = typ 923 } 924 } 925 } 926 } 927 928 // At this point, all remaining overload candidates accept the argument list, 929 // so we begin checking for a single remainig candidate implementation to choose. 930 // In case there is more than one candidate remaining, the following code uses 931 // heuristics to find a most preferable candidate. 932 if ok, err := checkReturn(ctx, semaCtx, s); ok { 933 return err 934 } 935 936 // The first heuristic is to prefer candidates that return the desired type, 937 // if a desired type was provided. 938 if desired.Family() != types.AnyFamily { 939 s.overloadIdxs = filterOverloads(s.overloadIdxs, s.overloads, func( 940 o overloadImpl, 941 ) bool { 942 // For now, we only filter on the return type for overloads with 943 // fixed return types. This could be improved, but is not currently 944 // critical because we have no cases of functions with multiple 945 // overloads that do not all expose FixedReturnTypes. 946 if t := o.returnType()(nil); t != UnknownReturnType { 947 return t.Equivalent(desired) 948 } 949 return true 950 }) 951 if ok, err := checkReturn(ctx, semaCtx, s); ok { 952 return err 953 } 954 } 955 956 var homogeneousTyp *types.T 957 if !typeableIdxs.Empty() { 958 idx, _ := typeableIdxs.Next(0) 959 homogeneousTyp = s.typedExprs[idx].ResolvedType() 960 for i, ok := typeableIdxs.Next(idx); ok; i, ok = typeableIdxs.Next(i + 1) { 961 if !homogeneousTyp.Equivalent(s.typedExprs[i].ResolvedType()) { 962 homogeneousTyp = nil 963 break 964 } 965 } 966 } 967 968 if !s.constIdxs.Empty() { 969 allConstantsAreHomogenous := false 970 if ok, err := filterAttempt(ctx, semaCtx, s, func() { 971 // The second heuristic is to prefer candidates where all constants can 972 // become a homogeneous type, if all resolvable expressions became one. 973 // This is only possible if resolvable expressions were resolved 974 // homogeneously up to this point. 975 if homogeneousTyp != nil { 976 allConstantsAreHomogenous = true 977 for i, ok := s.constIdxs.Next(0); ok; i, ok = s.constIdxs.Next(i + 1) { 978 if !canConstantBecome(s.exprs[i].(Constant), homogeneousTyp) { 979 allConstantsAreHomogenous = false 980 break 981 } 982 } 983 if allConstantsAreHomogenous { 984 for i, ok := s.constIdxs.Next(0); ok; i, ok = s.constIdxs.Next(i + 1) { 985 filter := func(params TypeList) bool { 986 return params.GetAt(i).Equivalent(homogeneousTyp) 987 } 988 s.overloadIdxs = filterParams(s.overloadIdxs, s.params, filter) 989 } 990 } 991 } 992 }); ok { 993 return err 994 } 995 996 if ok, err := filterAttempt(ctx, semaCtx, s, func() { 997 // The third heuristic is to prefer candidates where all constants can 998 // become their "natural" types. 999 for i, ok := s.constIdxs.Next(0); ok; i, ok = s.constIdxs.Next(i + 1) { 1000 natural := naturalConstantType(s.exprs[i].(Constant)) 1001 if natural != nil { 1002 filter := func(params TypeList) bool { 1003 return params.GetAt(i).Equivalent(natural) 1004 } 1005 s.overloadIdxs = filterParams(s.overloadIdxs, s.params, filter) 1006 } 1007 } 1008 }); ok { 1009 return err 1010 } 1011 1012 // At this point, it's worth seeing if we have constants that can't actually 1013 // parse as the type that canConstantBecome claims they can. For example, 1014 // every string literal will report that it can become an interval, but most 1015 // string literals do not encode valid intervals. This may uncover some 1016 // overloads with invalid type signatures. 1017 // 1018 // This parsing is sufficiently expensive (see the comment on 1019 // StrVal.AvailableTypes) that we wait until now, when we've eliminated most 1020 // overloads from consideration, so that we only need to check each constant 1021 // against a limited set of types. We can't hold off on this parsing any 1022 // longer, though: the remaining heuristics are overly aggressive and will 1023 // falsely reject the only valid overload in some cases. 1024 // 1025 // This case is broken into two parts. We first attempt to use the 1026 // information about the homogeneity of our constants collected by previous 1027 // heuristic passes. If: 1028 // * all our constants are homogeneous 1029 // * we only have a single overload left 1030 // * the constant overload parameters are homogeneous as well 1031 // then match this overload with the homogeneous constants. Otherwise, 1032 // continue to filter overloads by whether or not the constants can parse 1033 // into the desired types of the overloads. 1034 // This first case is important when resolving overloads for operations 1035 // between user-defined types, where we need to propagate the concrete 1036 // resolved type information over to the constants, rather than attempting 1037 // to resolve constants as the placeholder type for the user defined type 1038 // family (like `AnyEnum`). 1039 if len(s.overloadIdxs) == 1 && allConstantsAreHomogenous { 1040 overloadParamsAreHomogenous := true 1041 p := s.overloads[s.overloadIdxs[0]].params() 1042 for i, ok := s.constIdxs.Next(0); ok; i, ok = s.constIdxs.Next(i + 1) { 1043 if !p.GetAt(i).Equivalent(homogeneousTyp) { 1044 overloadParamsAreHomogenous = false 1045 break 1046 } 1047 } 1048 if overloadParamsAreHomogenous { 1049 // Type check our constants using the homogeneous type rather than 1050 // the type in overload parameter. This lets us type check user defined 1051 // types with a concrete type instance, rather than an ambiguous type. 1052 for i, ok := s.constIdxs.Next(0); ok; i, ok = s.constIdxs.Next(i + 1) { 1053 typ, err := s.exprs[i].TypeCheck(ctx, semaCtx, homogeneousTyp) 1054 if err != nil { 1055 return err 1056 } 1057 s.typedExprs[i] = typ 1058 } 1059 _, err := checkReturnPlaceholdersAtIdx(ctx, semaCtx, s, s.overloadIdxs[0]) 1060 return err 1061 } 1062 } 1063 for i, ok := s.constIdxs.Next(0); ok; i, ok = s.constIdxs.Next(i + 1) { 1064 constExpr := s.exprs[i].(Constant) 1065 filter := func(params TypeList) bool { 1066 semaCtx := MakeSemaContext() 1067 _, err := constExpr.ResolveAsType(ctx, &semaCtx, params.GetAt(i)) 1068 return err == nil 1069 } 1070 s.overloadIdxs = filterParams(s.overloadIdxs, s.params, filter) 1071 } 1072 if ok, err := checkReturn(ctx, semaCtx, s); ok { 1073 return err 1074 } 1075 1076 // The fourth heuristic is to prefer candidates that accept the "best" 1077 // mutual type in the resolvable type set of all constants. 1078 if bestConstType, ok := commonConstantType(s.exprs, s.constIdxs); ok { 1079 // In case all overloads are filtered out at this step, 1080 // keep track of previous overload indexes to return ambiguous error (>1 overloads) 1081 // instead of unsupported error (0 overloads) when applicable. 1082 prevOverloadIdxs := s.overloadIdxs 1083 for i, ok := s.constIdxs.Next(0); ok; i, ok = s.constIdxs.Next(i + 1) { 1084 filter := func(params TypeList) bool { 1085 return params.GetAt(i).Equivalent(bestConstType) 1086 } 1087 s.overloadIdxs = filterParams(s.overloadIdxs, s.params, filter) 1088 } 1089 if ok, err := checkReturn(ctx, semaCtx, s); ok { 1090 if len(s.overloadIdxs) == 0 { 1091 s.overloadIdxs = prevOverloadIdxs 1092 } 1093 return err 1094 } 1095 if homogeneousTyp != nil { 1096 if !homogeneousTyp.Equivalent(bestConstType) { 1097 homogeneousTyp = nil 1098 } 1099 } else { 1100 homogeneousTyp = bestConstType 1101 } 1102 } 1103 } 1104 1105 // The fifth heuristic is to defer to preferred candidates, if one has been 1106 // specified in the overload list. 1107 if ok, err := filterAttempt(ctx, semaCtx, s, func() { 1108 s.overloadIdxs = filterOverloads( 1109 s.overloadIdxs, s.overloads, overloadImpl.preferred, 1110 ) 1111 }); ok { 1112 return err 1113 } 1114 1115 // The sixth heuristic is to prefer candidates where all placeholders can be 1116 // given the same type as all constants and resolvable expressions. This is 1117 // only possible if all constants and resolvable expressions were resolved 1118 // homogeneously up to this point. 1119 if homogeneousTyp != nil && !s.placeholderIdxs.Empty() { 1120 // Before we continue, try to propagate the homogeneous type to the 1121 // placeholders. This might not have happened yet, if the overloads' 1122 // parameter types are ambiguous (like in the case of tuple-tuple binary 1123 // operators). 1124 for i, ok := s.placeholderIdxs.Next(0); ok; i, ok = s.placeholderIdxs.Next(i + 1) { 1125 if _, err := s.exprs[i].TypeCheck(ctx, semaCtx, homogeneousTyp); err != nil { 1126 return err 1127 } 1128 filter := func(params TypeList) bool { 1129 return params.GetAt(i).Equivalent(homogeneousTyp) 1130 } 1131 s.overloadIdxs = filterParams(s.overloadIdxs, s.params, filter) 1132 } 1133 if ok, err := checkReturn(ctx, semaCtx, s); ok { 1134 return err 1135 } 1136 } 1137 1138 // This is a total hack for AnyEnum whilst we don't have postgres type resolution. 1139 // This enables AnyEnum array ops to not need a cast, e.g. array['a']::enum[] = '{a}'. 1140 // If we have one remaining candidate containing AnyEnum, cast all remaining 1141 // arguments to a known enum and check that the rest match. This is a poor man's 1142 // implicit cast / postgres "same argument" resolution clone. 1143 if len(s.overloadIdxs) == 1 { 1144 params := s.overloads[s.overloadIdxs[0]].params() 1145 var knownEnum *types.T 1146 1147 // Check we have all "AnyEnum" (or "AnyEnum" array) arguments and that 1148 // one argument is typed with an enum. 1149 attemptAnyEnumCast := func() bool { 1150 for i := 0; i < params.Length(); i++ { 1151 typ := params.GetAt(i) 1152 // Note we are deliberately looking at whether the built-in takes in 1153 // AnyEnum as an argument, not the exprs given to the overload itself. 1154 if !(typ.Identical(types.AnyEnum) || typ.Identical(types.MakeArray(types.AnyEnum))) { 1155 return false 1156 } 1157 if s.typedExprs[i] != nil { 1158 // Assign the known enum if it was previously unassigned. 1159 // Otherwise, double check it matches a previously defined enum. 1160 posEnum := s.typedExprs[i].ResolvedType() 1161 if !posEnum.UserDefined() { 1162 return false 1163 } 1164 if posEnum.Family() == types.ArrayFamily { 1165 posEnum = posEnum.ArrayContents() 1166 } 1167 if knownEnum == nil { 1168 knownEnum = posEnum 1169 } else if !posEnum.Identical(knownEnum) { 1170 return false 1171 } 1172 } 1173 } 1174 return knownEnum != nil 1175 }() 1176 1177 // If we have all arguments as AnyEnum, and we know at least one of the 1178 // enum's actual type, try type cast the rest. 1179 if attemptAnyEnumCast { 1180 // Copy exprs to prevent any overwrites of underlying s.exprs array later. 1181 sCopy := *s 1182 sCopy.exprs = make([]Expr, len(s.exprs)) 1183 copy(sCopy.exprs, s.exprs) 1184 1185 if ok, err := filterAttempt(ctx, semaCtx, &sCopy, func() { 1186 work := func(idx int) { 1187 p := params.GetAt(idx) 1188 typCast := knownEnum 1189 if p.Family() == types.ArrayFamily { 1190 typCast = types.MakeArray(knownEnum) 1191 } 1192 sCopy.exprs[idx] = &CastExpr{Expr: sCopy.exprs[idx], Type: typCast, SyntaxMode: CastShort} 1193 } 1194 for i, ok := s.constIdxs.Next(0); ok; i, ok = s.constIdxs.Next(i + 1) { 1195 work(i) 1196 } 1197 for i, ok := s.placeholderIdxs.Next(0); ok; i, ok = s.placeholderIdxs.Next(i + 1) { 1198 work(i) 1199 } 1200 }); ok { 1201 s.exprs = sCopy.exprs 1202 s.typedExprs = sCopy.typedExprs 1203 s.overloadIdxs = append(s.overloadIdxs[:0], sCopy.overloadIdxs...) 1204 return err 1205 } 1206 } 1207 } 1208 1209 // In a binary expression, in the case of one of the arguments being untyped NULL, 1210 // we prefer overloads where we infer the type of the NULL to be the same as the 1211 // other argument. This is used to differentiate the behavior of 1212 // STRING[] || NULL and STRING || NULL. 1213 if inBinOp && len(s.exprs) == 2 { 1214 if ok, err := filterAttempt(ctx, semaCtx, s, func() { 1215 var err error 1216 left := s.typedExprs[0] 1217 if left == nil { 1218 left, err = s.exprs[0].TypeCheck(ctx, semaCtx, types.Any) 1219 if err != nil { 1220 return 1221 } 1222 } 1223 right := s.typedExprs[1] 1224 if right == nil { 1225 right, err = s.exprs[1].TypeCheck(ctx, semaCtx, types.Any) 1226 if err != nil { 1227 return 1228 } 1229 } 1230 leftType := left.ResolvedType() 1231 rightType := right.ResolvedType() 1232 leftIsNull := leftType.Family() == types.UnknownFamily 1233 rightIsNull := rightType.Family() == types.UnknownFamily 1234 oneIsNull := (leftIsNull || rightIsNull) && !(leftIsNull && rightIsNull) 1235 if oneIsNull { 1236 if leftIsNull { 1237 leftType = rightType 1238 } 1239 if rightIsNull { 1240 rightType = leftType 1241 } 1242 filter := func(params TypeList) bool { 1243 return params.GetAt(0).Equivalent(leftType) && 1244 params.GetAt(1).Equivalent(rightType) 1245 } 1246 s.overloadIdxs = filterParams(s.overloadIdxs, s.params, filter) 1247 } 1248 }); ok { 1249 return err 1250 } 1251 } 1252 1253 // After the previous heuristic, in a binary expression, in the case of one of the arguments being untyped 1254 // NULL, we prefer overloads where we infer the type of the NULL to be a STRING. This is used 1255 // to choose INT || NULL::STRING over INT || NULL::INT[]. 1256 if inBinOp && len(s.exprs) == 2 { 1257 if ok, err := filterAttempt(ctx, semaCtx, s, func() { 1258 var err error 1259 left := s.typedExprs[0] 1260 if left == nil { 1261 left, err = s.exprs[0].TypeCheck(ctx, semaCtx, types.Any) 1262 if err != nil { 1263 return 1264 } 1265 } 1266 right := s.typedExprs[1] 1267 if right == nil { 1268 right, err = s.exprs[1].TypeCheck(ctx, semaCtx, types.Any) 1269 if err != nil { 1270 return 1271 } 1272 } 1273 leftType := left.ResolvedType() 1274 rightType := right.ResolvedType() 1275 leftIsNull := leftType.Family() == types.UnknownFamily 1276 rightIsNull := rightType.Family() == types.UnknownFamily 1277 oneIsNull := (leftIsNull || rightIsNull) && !(leftIsNull && rightIsNull) 1278 if oneIsNull { 1279 if leftIsNull { 1280 leftType = types.String 1281 } 1282 if rightIsNull { 1283 rightType = types.String 1284 } 1285 filter := func(params TypeList) bool { 1286 return params.GetAt(0).Equivalent(leftType) && 1287 params.GetAt(1).Equivalent(rightType) 1288 } 1289 s.overloadIdxs = filterParams(s.overloadIdxs, s.params, filter) 1290 } 1291 }); ok { 1292 return err 1293 } 1294 } 1295 1296 if err := defaultTypeCheck(ctx, semaCtx, s, numOverloads > 0); err != nil { 1297 return err 1298 } 1299 1300 return nil 1301 } 1302 1303 // filterAttempt attempts to filter the overloads down to a single candidate. 1304 // If it succeeds, it will return true, along with the overload (in a slice for 1305 // convenience) and a possible error. If it fails, it will return false and 1306 // undo any filtering performed during the attempt. 1307 func filterAttempt( 1308 ctx context.Context, semaCtx *SemaContext, s *overloadTypeChecker, attempt func(), 1309 ) (ok bool, _ error) { 1310 before := s.overloadIdxs 1311 attempt() 1312 if len(s.overloadIdxs) == 1 { 1313 ok, err := checkReturn(ctx, semaCtx, s) 1314 if err != nil { 1315 return false, err 1316 } 1317 if ok { 1318 return true, err 1319 } 1320 } 1321 s.overloadIdxs = before 1322 return false, nil 1323 } 1324 1325 func filterParams(idxs []uint8, params []TypeList, fn func(i TypeList) bool) []uint8 { 1326 i, n := 0, len(idxs) 1327 for i < n { 1328 if fn(params[idxs[i]]) { 1329 i++ 1330 } else if n--; i != n { 1331 idxs[i], idxs[n] = idxs[n], idxs[i] 1332 } 1333 } 1334 return idxs[:n] 1335 } 1336 1337 func filterOverloads(idxs []uint8, params []overloadImpl, fn func(overloadImpl) bool) []uint8 { 1338 i, n := 0, len(idxs) 1339 for i < n { 1340 if fn(params[idxs[i]]) { 1341 i++ 1342 } else if n--; i != n { 1343 idxs[i], idxs[n] = idxs[n], idxs[i] 1344 } 1345 } 1346 return idxs[:n] 1347 } 1348 1349 // defaultTypeCheck type checks the constant and placeholder expressions without a preference 1350 // and adds them to the type checked slice. 1351 func defaultTypeCheck( 1352 ctx context.Context, semaCtx *SemaContext, s *overloadTypeChecker, errorOnPlaceholders bool, 1353 ) error { 1354 for i, ok := s.constIdxs.Next(0); ok; i, ok = s.constIdxs.Next(i + 1) { 1355 typ, err := s.exprs[i].TypeCheck(ctx, semaCtx, types.Any) 1356 if err != nil { 1357 return pgerror.Wrapf(err, pgcode.InvalidParameterValue, 1358 "error type checking constant value") 1359 } 1360 s.typedExprs[i] = typ 1361 } 1362 for i, ok := s.placeholderIdxs.Next(0); ok; i, ok = s.placeholderIdxs.Next(i + 1) { 1363 if errorOnPlaceholders { 1364 if _, err := s.exprs[i].TypeCheck(ctx, semaCtx, types.Any); err != nil { 1365 return err 1366 } 1367 } 1368 // If we dont want to error on args, avoid type checking them without a desired type. 1369 s.typedExprs[i] = StripParens(s.exprs[i]).(*Placeholder) 1370 } 1371 return nil 1372 } 1373 1374 // checkReturn checks the number of remaining overloaded function 1375 // implementations. 1376 // Returns ok=true if we should stop overload resolution, and returning either 1377 // 1. the chosen overload in a slice, or 1378 // 2. nil, 1379 // along with the typed arguments. 1380 // This modifies values within s as scratch slices, but only in the case where 1381 // it returns true, which signals to the calling function that it should 1382 // immediately return, so any mutations to s are irrelevant. 1383 func checkReturn( 1384 ctx context.Context, semaCtx *SemaContext, s *overloadTypeChecker, 1385 ) (ok bool, _ error) { 1386 switch len(s.overloadIdxs) { 1387 case 0: 1388 if err := defaultTypeCheck(ctx, semaCtx, s, false); err != nil { 1389 return false, err 1390 } 1391 return true, nil 1392 1393 case 1: 1394 idx := s.overloadIdxs[0] 1395 p := s.params[idx] 1396 for i, ok := s.constIdxs.Next(0); ok; i, ok = s.constIdxs.Next(i + 1) { 1397 des := p.GetAt(i) 1398 typ, err := s.exprs[i].TypeCheck(ctx, semaCtx, des) 1399 if err != nil { 1400 return false, pgerror.Wrapf( 1401 err, pgcode.InvalidParameterValue, 1402 "error type checking constant value", 1403 ) 1404 } 1405 if des != nil && !typ.ResolvedType().Equivalent(des) { 1406 return false, errors.AssertionFailedf( 1407 "desired constant value type %s but set type %s", 1408 redact.Safe(des), redact.Safe(typ.ResolvedType()), 1409 ) 1410 } 1411 s.typedExprs[i] = typ 1412 } 1413 1414 return checkReturnPlaceholdersAtIdx(ctx, semaCtx, s, idx) 1415 1416 default: 1417 return false, nil 1418 } 1419 } 1420 1421 // checkReturnPlaceholdersAtIdx checks that the placeholders for the 1422 // overload at the input index are valid. It has the same return values 1423 // as checkReturn. 1424 func checkReturnPlaceholdersAtIdx( 1425 ctx context.Context, semaCtx *SemaContext, s *overloadTypeChecker, idx uint8, 1426 ) (ok bool, _ error) { 1427 p := s.params[idx] 1428 for i, ok := s.placeholderIdxs.Next(0); ok; i, ok = s.placeholderIdxs.Next(i + 1) { 1429 des := p.GetAt(i) 1430 typ, err := s.exprs[i].TypeCheck(ctx, semaCtx, des) 1431 if err != nil { 1432 if des.IsAmbiguous() { 1433 return false, nil 1434 } 1435 return false, err 1436 } 1437 if typ.ResolvedType().IsAmbiguous() { 1438 return false, nil 1439 } 1440 s.typedExprs[i] = typ 1441 } 1442 s.overloadIdxs = append(s.overloadIdxs[:0], idx) 1443 return true, nil 1444 } 1445 1446 func formatCandidates(prefix string, candidates []overloadImpl, filter []uint8) string { 1447 var buf bytes.Buffer 1448 for _, idx := range filter { 1449 candidate := candidates[idx] 1450 buf.WriteString(prefix) 1451 buf.WriteByte('(') 1452 params := candidate.params() 1453 tLen := params.Length() 1454 inputTyps := make([]TypedExpr, tLen) 1455 for i := 0; i < tLen; i++ { 1456 t := params.GetAt(i) 1457 inputTyps[i] = &TypedDummy{Typ: t} 1458 if i > 0 { 1459 buf.WriteString(", ") 1460 } 1461 buf.WriteString(t.String()) 1462 } 1463 buf.WriteString(") -> ") 1464 buf.WriteString(returnTypeToFixedType(candidate.returnType(), inputTyps).String()) 1465 if candidate.preferred() { 1466 buf.WriteString(" [preferred]") 1467 } 1468 buf.WriteByte('\n') 1469 } 1470 return buf.String() 1471 } 1472 1473 // TODO(chengxiong): unify this method with Overload.Signature method if possible. 1474 func getFuncSig(expr *FuncExpr, typedInputExprs []TypedExpr, desiredType *types.T) string { 1475 typeNames := make([]string, 0, len(expr.Exprs)) 1476 for _, expr := range typedInputExprs { 1477 typeNames = append(typeNames, expr.ResolvedType().String()) 1478 } 1479 var desStr string 1480 if desiredType.Family() != types.AnyFamily { 1481 desStr = fmt.Sprintf(" (desired <%s>)", desiredType) 1482 } 1483 return fmt.Sprintf("%s(%s)%s", &expr.Func, strings.Join(typeNames, ", "), desStr) 1484 }