github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/sem/tree/overload.go (about) 1 // Copyright 2016 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package tree 12 13 import ( 14 "bytes" 15 "context" 16 "fmt" 17 "math" 18 19 "github.com/cockroachdb/cockroach/pkg/server/telemetry" 20 "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" 21 "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" 22 "github.com/cockroachdb/cockroach/pkg/sql/types" 23 "github.com/cockroachdb/cockroach/pkg/util/log" 24 "github.com/cockroachdb/errors" 25 ) 26 27 // SpecializedVectorizedBuiltin is used to map overloads 28 // to the vectorized operator that is specific to 29 // that implementation of the builtin function. 30 type SpecializedVectorizedBuiltin int 31 32 // TODO (rohany): What is the best place to put this list? 33 // I want to put it in builtins or exec, but those create an import 34 // cycle with exec. tree is imported by both of them, so 35 // this package seems like a good place to do it. 36 37 // Keep this list alphabetized so that it is easy to manage. 38 const ( 39 _ SpecializedVectorizedBuiltin = iota 40 SubstringStringIntInt 41 ) 42 43 // Overload is one of the overloads of a built-in function. 44 // Each FunctionDefinition may contain one or more overloads. 45 type Overload struct { 46 Types TypeList 47 ReturnType ReturnTyper 48 Volatility Volatility 49 50 // PreferredOverload determines overload resolution as follows. 51 // When multiple overloads are eligible based on types even after all of of 52 // the heuristics to pick one have been used, if one of the overloads is a 53 // Overload with the `PreferredOverload` flag set to true it can be selected 54 // rather than returning a no-such-method error. 55 // This should generally be avoided -- avoiding introducing ambiguous 56 // overloads in the first place is a much better solution -- and only done 57 // after consultation with @knz @nvanbenschoten. 58 PreferredOverload bool 59 60 // Info is a description of the function, which is surfaced on the CockroachDB 61 // docs site on the "Functions and Operators" page. Descriptions typically use 62 // third-person with the function as an implicit subject (e.g. "Calculates 63 // infinity"), but should focus more on ease of understanding so other structures 64 // might be more appropriate. 65 Info string 66 67 AggregateFunc func([]*types.T, *EvalContext, Datums) AggregateFunc 68 WindowFunc func([]*types.T, *EvalContext) WindowFunc 69 Fn func(*EvalContext, Datums) (Datum, error) 70 Generator GeneratorFactory 71 72 // SQLFn must be set for overloads of type SQLClass. It should return a SQL 73 // statement which will be executed as a common table expression in the query. 74 SQLFn func(*EvalContext, Datums) (string, error) 75 76 // counter, if non-nil, should be incremented upon successful 77 // type check of expressions using this overload. 78 counter telemetry.Counter 79 80 // SpecializedVecBuiltin is used to let the vectorized engine 81 // know when an Overload has a specialized vectorized operator. 82 SpecializedVecBuiltin SpecializedVectorizedBuiltin 83 } 84 85 // params implements the overloadImpl interface. 86 func (b Overload) params() TypeList { return b.Types } 87 88 // returnType implements the overloadImpl interface. 89 func (b Overload) returnType() ReturnTyper { return b.ReturnType } 90 91 // preferred implements the overloadImpl interface. 92 func (b Overload) preferred() bool { return b.PreferredOverload } 93 94 // FixedReturnType returns a fixed type that the function returns, returning Any 95 // if the return type is based on the function's arguments. 96 func (b Overload) FixedReturnType() *types.T { 97 if b.ReturnType == nil { 98 return nil 99 } 100 return returnTypeToFixedType(b.ReturnType) 101 } 102 103 // Signature returns a human-readable signature. 104 // If simplify is bool, tuple-returning functions with just 105 // 1 tuple element unwrap the return type in the signature. 106 func (b Overload) Signature(simplify bool) string { 107 retType := b.FixedReturnType() 108 if simplify { 109 if retType.Family() == types.TupleFamily && len(retType.TupleContents()) == 1 { 110 retType = retType.TupleContents()[0] 111 } 112 } 113 return fmt.Sprintf("(%s) -> %s", b.Types.String(), retType) 114 } 115 116 // overloadImpl is an implementation of an overloaded function. It provides 117 // access to the parameter type list and the return type of the implementation. 118 // 119 // This is a more general type than Overload defined above, because it also 120 // works with the built-in binary and unary operators. 121 type overloadImpl interface { 122 params() TypeList 123 returnType() ReturnTyper 124 // allows manually resolving preference between multiple compatible overloads 125 preferred() bool 126 } 127 128 var _ overloadImpl = &Overload{} 129 var _ overloadImpl = &UnaryOp{} 130 var _ overloadImpl = &BinOp{} 131 132 // GetParamsAndReturnType gets the parameters and return type of an 133 // overloadImpl. 134 func GetParamsAndReturnType(impl overloadImpl) (TypeList, ReturnTyper) { 135 return impl.params(), impl.returnType() 136 } 137 138 // TypeList is a list of types representing a function parameter list. 139 type TypeList interface { 140 // Match checks if all types in the TypeList match the corresponding elements in types. 141 Match(types []*types.T) bool 142 // MatchAt checks if the parameter type at index i of the TypeList matches type typ. 143 // In all implementations, types.Null will match with each parameter type, allowing 144 // NULL values to be used as arguments. 145 MatchAt(typ *types.T, i int) bool 146 // matchLen checks that the TypeList can support l parameters. 147 MatchLen(l int) bool 148 // getAt returns the type at the given index in the TypeList, or nil if the TypeList 149 // cannot have a parameter at index i. 150 GetAt(i int) *types.T 151 // Length returns the number of types in the list 152 Length() int 153 // Types returns a realized copy of the list. variadic lists return a list of size one. 154 Types() []*types.T 155 // String returns a human readable signature 156 String() string 157 } 158 159 var _ TypeList = ArgTypes{} 160 var _ TypeList = HomogeneousType{} 161 var _ TypeList = VariadicType{} 162 163 // ArgTypes is very similar to ArgTypes except it allows keeping a string 164 // name for each argument as well and using those when printing the 165 // human-readable signature. 166 type ArgTypes []struct { 167 Name string 168 Typ *types.T 169 } 170 171 // Match is part of the TypeList interface. 172 func (a ArgTypes) Match(types []*types.T) bool { 173 if len(types) != len(a) { 174 return false 175 } 176 for i := range types { 177 if !a.MatchAt(types[i], i) { 178 return false 179 } 180 } 181 return true 182 } 183 184 // MatchAt is part of the TypeList interface. 185 func (a ArgTypes) MatchAt(typ *types.T, i int) bool { 186 // The parameterized types for Tuples are checked in the type checking 187 // routines before getting here, so we only need to check if the argument 188 // type is a types.TUPLE below. This allows us to avoid defining overloads 189 // for types.Tuple{}, types.Tuple{types.Any}, types.Tuple{types.Any, types.Any}, 190 // etc. for Tuple operators. 191 if typ.Family() == types.TupleFamily { 192 typ = types.AnyTuple 193 } 194 return i < len(a) && (typ.Family() == types.UnknownFamily || a[i].Typ.Equivalent(typ)) 195 } 196 197 // MatchLen is part of the TypeList interface. 198 func (a ArgTypes) MatchLen(l int) bool { 199 return len(a) == l 200 } 201 202 // GetAt is part of the TypeList interface. 203 func (a ArgTypes) GetAt(i int) *types.T { 204 return a[i].Typ 205 } 206 207 // Length is part of the TypeList interface. 208 func (a ArgTypes) Length() int { 209 return len(a) 210 } 211 212 // Types is part of the TypeList interface. 213 func (a ArgTypes) Types() []*types.T { 214 n := len(a) 215 ret := make([]*types.T, n) 216 for i, s := range a { 217 ret[i] = s.Typ 218 } 219 return ret 220 } 221 222 func (a ArgTypes) String() string { 223 var s bytes.Buffer 224 for i, arg := range a { 225 if i > 0 { 226 s.WriteString(", ") 227 } 228 s.WriteString(arg.Name) 229 s.WriteString(": ") 230 s.WriteString(arg.Typ.String()) 231 } 232 return s.String() 233 } 234 235 // HomogeneousType is a TypeList implementation that accepts any arguments, as 236 // long as all are the same type or NULL. The homogeneous constraint is enforced 237 // in typeCheckOverloadedExprs. 238 type HomogeneousType struct{} 239 240 // Match is part of the TypeList interface. 241 func (HomogeneousType) Match(types []*types.T) bool { 242 return true 243 } 244 245 // MatchAt is part of the TypeList interface. 246 func (HomogeneousType) MatchAt(typ *types.T, i int) bool { 247 return true 248 } 249 250 // MatchLen is part of the TypeList interface. 251 func (HomogeneousType) MatchLen(l int) bool { 252 return true 253 } 254 255 // GetAt is part of the TypeList interface. 256 func (HomogeneousType) GetAt(i int) *types.T { 257 return types.Any 258 } 259 260 // Length is part of the TypeList interface. 261 func (HomogeneousType) Length() int { 262 return 1 263 } 264 265 // Types is part of the TypeList interface. 266 func (HomogeneousType) Types() []*types.T { 267 return []*types.T{types.Any} 268 } 269 270 func (HomogeneousType) String() string { 271 return "anyelement..." 272 } 273 274 // VariadicType is a TypeList implementation which accepts a fixed number of 275 // arguments at the beginning and an arbitrary number of homogenous arguments 276 // at the end. 277 type VariadicType struct { 278 FixedTypes []*types.T 279 VarType *types.T 280 } 281 282 // Match is part of the TypeList interface. 283 func (v VariadicType) Match(types []*types.T) bool { 284 for i := range types { 285 if !v.MatchAt(types[i], i) { 286 return false 287 } 288 } 289 return true 290 } 291 292 // MatchAt is part of the TypeList interface. 293 func (v VariadicType) MatchAt(typ *types.T, i int) bool { 294 if i < len(v.FixedTypes) { 295 return typ.Family() == types.UnknownFamily || v.FixedTypes[i].Equivalent(typ) 296 } 297 return typ.Family() == types.UnknownFamily || v.VarType.Equivalent(typ) 298 } 299 300 // MatchLen is part of the TypeList interface. 301 func (v VariadicType) MatchLen(l int) bool { 302 return l >= len(v.FixedTypes) 303 } 304 305 // GetAt is part of the TypeList interface. 306 func (v VariadicType) GetAt(i int) *types.T { 307 if i < len(v.FixedTypes) { 308 return v.FixedTypes[i] 309 } 310 return v.VarType 311 } 312 313 // Length is part of the TypeList interface. 314 func (v VariadicType) Length() int { 315 return len(v.FixedTypes) + 1 316 } 317 318 // Types is part of the TypeList interface. 319 func (v VariadicType) Types() []*types.T { 320 result := make([]*types.T, len(v.FixedTypes)+1) 321 for i := range v.FixedTypes { 322 result[i] = v.FixedTypes[i] 323 } 324 result[len(result)-1] = v.VarType 325 return result 326 } 327 328 func (v VariadicType) String() string { 329 var s bytes.Buffer 330 for i, t := range v.FixedTypes { 331 if i != 0 { 332 s.WriteString(", ") 333 } 334 s.WriteString(t.String()) 335 } 336 if len(v.FixedTypes) > 0 { 337 s.WriteString(", ") 338 } 339 fmt.Fprintf(&s, "%s...", v.VarType) 340 return s.String() 341 } 342 343 // UnknownReturnType is returned from ReturnTypers when the arguments provided are 344 // not sufficient to determine a return type. This is necessary for cases like overload 345 // resolution, where the argument types are not resolved yet so the type-level function 346 // will be called without argument types. If a ReturnTyper returns unknownReturnType, 347 // then the candidate function set cannot be refined. This means that only ReturnTypers 348 // that never return unknownReturnType, like those created with FixedReturnType, can 349 // help reduce overload ambiguity. 350 var UnknownReturnType *types.T 351 352 // ReturnTyper defines the type-level function in which a builtin function's return type 353 // is determined. ReturnTypers should make sure to return unknownReturnType when necessary. 354 type ReturnTyper func(args []TypedExpr) *types.T 355 356 // FixedReturnType functions simply return a fixed type, independent of argument types. 357 func FixedReturnType(typ *types.T) ReturnTyper { 358 return func(args []TypedExpr) *types.T { return typ } 359 } 360 361 // IdentityReturnType creates a returnType that is a projection of the idx'th 362 // argument type. 363 func IdentityReturnType(idx int) ReturnTyper { 364 return func(args []TypedExpr) *types.T { 365 if len(args) == 0 { 366 return UnknownReturnType 367 } 368 return args[idx].ResolvedType() 369 } 370 } 371 372 // ArrayOfFirstNonNullReturnType returns an array type from the first non-null 373 // type in the argument list. 374 func ArrayOfFirstNonNullReturnType() ReturnTyper { 375 return func(args []TypedExpr) *types.T { 376 if len(args) == 0 { 377 return UnknownReturnType 378 } 379 for _, arg := range args { 380 if t := arg.ResolvedType(); t.Family() != types.UnknownFamily { 381 return types.MakeArray(t) 382 } 383 } 384 return types.Unknown 385 } 386 } 387 388 // FirstNonNullReturnType returns the type of the first non-null argument, or 389 // types.Unknown if all arguments are null. There must be at least one argument, 390 // or else FirstNonNullReturnType returns UnknownReturnType. This method is used 391 // with HomogeneousType functions, in which all arguments have been checked to 392 // have the same type (or be null). 393 func FirstNonNullReturnType() ReturnTyper { 394 return func(args []TypedExpr) *types.T { 395 if len(args) == 0 { 396 return UnknownReturnType 397 } 398 for _, arg := range args { 399 if t := arg.ResolvedType(); t.Family() != types.UnknownFamily { 400 return t 401 } 402 } 403 return types.Unknown 404 } 405 } 406 407 func returnTypeToFixedType(s ReturnTyper) *types.T { 408 if t := s(nil); t != UnknownReturnType { 409 return t 410 } 411 return types.Any 412 } 413 414 type typeCheckOverloadState struct { 415 overloads []overloadImpl 416 overloadIdxs []uint8 // index into overloads 417 exprs []Expr 418 typedExprs []TypedExpr 419 resolvableIdxs []int // index into exprs/typedExprs 420 constIdxs []int // index into exprs/typedExprs 421 placeholderIdxs []int // index into exprs/typedExprs 422 } 423 424 // typeCheckOverloadedExprs determines the correct overload to use for the given set of 425 // expression parameters, along with an optional desired return type. It returns the expression 426 // parameters after being type checked, along with a slice of candidate overloadImpls. The 427 // slice may have length: 428 // 0: overload resolution failed because no compatible overloads were found 429 // 1: overload resolution succeeded 430 // 2+: overload resolution failed because of ambiguity 431 // The inBinOp parameter denotes whether this type check is occurring within a binary operator, 432 // in which case we may need to make a guess that the two parameters are of the same type if one 433 // of them is NULL. 434 func typeCheckOverloadedExprs( 435 ctx context.Context, 436 semaCtx *SemaContext, 437 desired *types.T, 438 overloads []overloadImpl, 439 inBinOp bool, 440 exprs ...Expr, 441 ) ([]TypedExpr, []overloadImpl, error) { 442 if len(overloads) > math.MaxUint8 { 443 return nil, nil, errors.AssertionFailedf("too many overloads (%d > 255)", len(overloads)) 444 } 445 446 var s typeCheckOverloadState 447 s.exprs = exprs 448 s.overloads = overloads 449 450 // Special-case the HomogeneousType overload. We determine its return type by checking that 451 // all parameters have the same type. 452 for i, overload := range overloads { 453 // Only one overload can be provided if it has parameters with HomogeneousType. 454 if _, ok := overload.params().(HomogeneousType); ok { 455 if len(overloads) > 1 { 456 return nil, nil, errors.AssertionFailedf( 457 "only one overload can have HomogeneousType parameters") 458 } 459 typedExprs, _, err := TypeCheckSameTypedExprs(ctx, semaCtx, desired, exprs...) 460 if err != nil { 461 return nil, nil, err 462 } 463 return typedExprs, overloads[i : i+1], nil 464 } 465 } 466 467 // Hold the resolved type expressions of the provided exprs, in order. 468 s.typedExprs = make([]TypedExpr, len(exprs)) 469 s.constIdxs, s.placeholderIdxs, s.resolvableIdxs = typeCheckSplitExprs(ctx, semaCtx, exprs) 470 471 // If no overloads are provided, just type check parameters and return. 472 if len(overloads) == 0 { 473 for _, i := range s.resolvableIdxs { 474 typ, err := exprs[i].TypeCheck(ctx, semaCtx, types.Any) 475 if err != nil { 476 return nil, nil, pgerror.Wrapf(err, pgcode.InvalidParameterValue, 477 "error type checking resolved expression:") 478 } 479 s.typedExprs[i] = typ 480 } 481 if err := defaultTypeCheck(ctx, semaCtx, &s, false); err != nil { 482 return nil, nil, err 483 } 484 return s.typedExprs, nil, nil 485 } 486 487 s.overloadIdxs = make([]uint8, len(overloads)) 488 for i := 0; i < len(overloads); i++ { 489 s.overloadIdxs[i] = uint8(i) 490 } 491 492 // Filter out incorrect parameter length overloads. 493 s.overloadIdxs = filterOverloads(s.overloads, s.overloadIdxs, 494 func(o overloadImpl) bool { 495 return o.params().MatchLen(len(exprs)) 496 }) 497 498 // Filter out overloads which constants cannot become. 499 for _, i := range s.constIdxs { 500 constExpr := exprs[i].(Constant) 501 s.overloadIdxs = filterOverloads(s.overloads, s.overloadIdxs, 502 func(o overloadImpl) bool { 503 return canConstantBecome(constExpr, o.params().GetAt(i)) 504 }) 505 } 506 507 // TODO(nvanbenschoten): We should add a filtering step here to filter 508 // out impossible candidates based on identical parameters. For instance, 509 // f(int, float) is not a possible candidate for the expression f($1, $1). 510 511 // Filter out overloads on resolved types. 512 for _, i := range s.resolvableIdxs { 513 paramDesired := types.Any 514 if len(s.overloadIdxs) == 1 { 515 // Once we get down to a single overload candidate, begin desiring its 516 // parameter types for the corresponding argument expressions. 517 paramDesired = s.overloads[s.overloadIdxs[0]].params().GetAt(i) 518 } 519 typ, err := exprs[i].TypeCheck(ctx, semaCtx, paramDesired) 520 if err != nil { 521 return nil, nil, err 522 } 523 s.typedExprs[i] = typ 524 s.overloadIdxs = filterOverloads(s.overloads, s.overloadIdxs, 525 func(o overloadImpl) bool { 526 return o.params().MatchAt(typ.ResolvedType(), i) 527 }) 528 } 529 530 // At this point, all remaining overload candidates accept the argument list, 531 // so we begin checking for a single remaining candidate implementation to choose. 532 // In case there is more than one candidate remaining, the following code uses 533 // heuristics to find a most preferable candidate. 534 if ok, typedExprs, fns, err := checkReturn(ctx, semaCtx, &s); ok { 535 return typedExprs, fns, err 536 } 537 538 // The first heuristic is to prefer candidates that return the desired type. 539 if desired.Family() != types.AnyFamily { 540 s.overloadIdxs = filterOverloads(s.overloads, s.overloadIdxs, 541 func(o overloadImpl) bool { 542 // For now, we only filter on the return type for overloads with 543 // fixed return types. This could be improved, but is not currently 544 // critical because we have no cases of functions with multiple 545 // overloads that do not all expose FixedReturnTypes. 546 if t := o.returnType()(nil); t != UnknownReturnType { 547 return t.Equivalent(desired) 548 } 549 return true 550 }) 551 if ok, typedExprs, fns, err := checkReturn(ctx, semaCtx, &s); ok { 552 return typedExprs, fns, err 553 } 554 } 555 556 var homogeneousTyp *types.T 557 if len(s.resolvableIdxs) > 0 { 558 homogeneousTyp = s.typedExprs[s.resolvableIdxs[0]].ResolvedType() 559 for _, i := range s.resolvableIdxs[1:] { 560 if !homogeneousTyp.Equivalent(s.typedExprs[i].ResolvedType()) { 561 homogeneousTyp = nil 562 break 563 } 564 } 565 } 566 567 if len(s.constIdxs) > 0 { 568 allConstantsAreHomogenous := false 569 if ok, typedExprs, fns, err := filterAttempt(ctx, semaCtx, &s, func() { 570 // The second heuristic is to prefer candidates where all constants can 571 // become a homogeneous type, if all resolvable expressions became one. 572 // This is only possible if resolvable expressions were resolved 573 // homogeneously up to this point. 574 if homogeneousTyp != nil { 575 allConstantsAreHomogenous = true 576 for _, i := range s.constIdxs { 577 if !canConstantBecome(exprs[i].(Constant), homogeneousTyp) { 578 allConstantsAreHomogenous = false 579 break 580 } 581 } 582 if allConstantsAreHomogenous { 583 for _, i := range s.constIdxs { 584 s.overloadIdxs = filterOverloads(s.overloads, s.overloadIdxs, 585 func(o overloadImpl) bool { 586 return o.params().GetAt(i).Equivalent(homogeneousTyp) 587 }) 588 } 589 } 590 } 591 }); ok { 592 return typedExprs, fns, err 593 } 594 595 if ok, typedExprs, fns, err := filterAttempt(ctx, semaCtx, &s, func() { 596 // The third heuristic is to prefer candidates where all constants can 597 // become their "natural" types. 598 for _, i := range s.constIdxs { 599 natural := naturalConstantType(exprs[i].(Constant)) 600 if natural != nil { 601 s.overloadIdxs = filterOverloads(s.overloads, s.overloadIdxs, 602 func(o overloadImpl) bool { 603 return o.params().GetAt(i).Equivalent(natural) 604 }) 605 } 606 } 607 }); ok { 608 return typedExprs, fns, err 609 } 610 611 // At this point, it's worth seeing if we have constants that can't actually 612 // parse as the type that canConstantBecome claims they can. For example, 613 // every string literal will report that it can become an interval, but most 614 // string literals do not encode valid intervals. This may uncover some 615 // overloads with invalid type signatures. 616 // 617 // This parsing is sufficiently expensive (see the comment on 618 // StrVal.AvailableTypes) that we wait until now, when we've eliminated most 619 // overloads from consideration, so that we only need to check each constant 620 // against a limited set of types. We can't hold off on this parsing any 621 // longer, though: the remaining heuristics are overly aggressive and will 622 // falsely reject the only valid overload in some cases. 623 // 624 // This case is broken into two parts. We first attempt to use the 625 // information about the homogeneity of our constants collected by previous 626 // heuristic passes. If: 627 // * all our constants are homogeneous 628 // * we only have a single overload left 629 // * the constant overload parameters are homogeneous as well 630 // then match this overload with the homogeneous constants. Otherwise, 631 // continue to filter overloads by whether or not the constants can parse 632 // into the desired types of the overloads. 633 // This first case is important when resolving overloads for operations 634 // between user-defined types, where we need to propagate the concrete 635 // resolved type information over to the constants, rather than attempting 636 // to resolve constants as the placeholder type for the user defined type 637 // family (like `AnyEnum`). 638 if len(s.overloadIdxs) == 1 && allConstantsAreHomogenous { 639 overloadParamsAreHomogenous := true 640 p := s.overloads[s.overloadIdxs[0]].params() 641 for _, i := range s.constIdxs { 642 if !p.GetAt(i).Equivalent(homogeneousTyp) { 643 overloadParamsAreHomogenous = false 644 break 645 } 646 } 647 if overloadParamsAreHomogenous { 648 // Type check our constants using the homogeneous type rather than 649 // the type in overload parameter. This lets us type check user defined 650 // types with a concrete type instance, rather than an ambiguous type. 651 for _, i := range s.constIdxs { 652 typ, err := s.exprs[i].TypeCheck(ctx, semaCtx, homogeneousTyp) 653 if err != nil { 654 return nil, nil, err 655 } 656 s.typedExprs[i] = typ 657 } 658 _, typedExprs, fn, err := checkReturnPlaceholdersAtIdx(ctx, semaCtx, &s, int(s.overloadIdxs[0])) 659 return typedExprs, fn, err 660 } 661 } 662 for _, i := range s.constIdxs { 663 constExpr := exprs[i].(Constant) 664 s.overloadIdxs = filterOverloads(s.overloads, s.overloadIdxs, 665 func(o overloadImpl) bool { 666 semaCtx := MakeSemaContext() 667 _, err := constExpr.ResolveAsType(&semaCtx, o.params().GetAt(i)) 668 return err == nil 669 }) 670 } 671 if ok, typedExprs, fn, err := checkReturn(ctx, semaCtx, &s); ok { 672 return typedExprs, fn, err 673 } 674 675 // The fourth heuristic is to prefer candidates that accepts the "best" 676 // mutual type in the resolvable type set of all constants. 677 if bestConstType, ok := commonConstantType(s.exprs, s.constIdxs); ok { 678 for _, i := range s.constIdxs { 679 s.overloadIdxs = filterOverloads(s.overloads, s.overloadIdxs, 680 func(o overloadImpl) bool { 681 return o.params().GetAt(i).Equivalent(bestConstType) 682 }) 683 } 684 if ok, typedExprs, fns, err := checkReturn(ctx, semaCtx, &s); ok { 685 return typedExprs, fns, err 686 } 687 if homogeneousTyp != nil { 688 if !homogeneousTyp.Equivalent(bestConstType) { 689 homogeneousTyp = nil 690 } 691 } else { 692 homogeneousTyp = bestConstType 693 } 694 } 695 } 696 697 // The fifth heuristic is to prefer candidates where all placeholders can be 698 // given the same type as all constants and resolvable expressions. This is 699 // only possible if all constants and resolvable expressions were resolved 700 // homogeneously up to this point. 701 if homogeneousTyp != nil && len(s.placeholderIdxs) > 0 { 702 // Before we continue, try to propagate the homogeneous type to the 703 // placeholders. This might not have happened yet, if the overloads' 704 // parameter types are ambiguous (like in the case of tuple-tuple binary 705 // operators). 706 for _, i := range s.placeholderIdxs { 707 if _, err := exprs[i].TypeCheck(ctx, semaCtx, homogeneousTyp); err != nil { 708 return nil, nil, err 709 } 710 s.overloadIdxs = filterOverloads(s.overloads, s.overloadIdxs, 711 func(o overloadImpl) bool { 712 return o.params().GetAt(i).Equivalent(homogeneousTyp) 713 }) 714 } 715 if ok, typedExprs, fns, err := checkReturn(ctx, semaCtx, &s); ok { 716 return typedExprs, fns, err 717 } 718 } 719 720 // In a binary expression, in the case of one of the arguments being untyped NULL, 721 // we prefer overloads where we infer the type of the NULL to be the same as the 722 // other argument. This is used to differentiate the behavior of 723 // STRING[] || NULL and STRING || NULL. 724 if inBinOp && len(s.exprs) == 2 { 725 if ok, typedExprs, fns, err := filterAttempt(ctx, semaCtx, &s, func() { 726 var err error 727 left := s.typedExprs[0] 728 if left == nil { 729 left, err = s.exprs[0].TypeCheck(ctx, semaCtx, types.Any) 730 if err != nil { 731 return 732 } 733 } 734 right := s.typedExprs[1] 735 if right == nil { 736 right, err = s.exprs[1].TypeCheck(ctx, semaCtx, types.Any) 737 if err != nil { 738 return 739 } 740 } 741 leftType := left.ResolvedType() 742 rightType := right.ResolvedType() 743 leftIsNull := leftType.Family() == types.UnknownFamily 744 rightIsNull := rightType.Family() == types.UnknownFamily 745 oneIsNull := (leftIsNull || rightIsNull) && !(leftIsNull && rightIsNull) 746 if oneIsNull { 747 if leftIsNull { 748 leftType = rightType 749 } 750 if rightIsNull { 751 rightType = leftType 752 } 753 s.overloadIdxs = filterOverloads(s.overloads, s.overloadIdxs, 754 func(o overloadImpl) bool { 755 return o.params().GetAt(0).Equivalent(leftType) && 756 o.params().GetAt(1).Equivalent(rightType) 757 }) 758 } 759 }); ok { 760 return typedExprs, fns, err 761 } 762 } 763 764 // The final heuristic is to defer to preferred candidates, if available. 765 if ok, typedExprs, fns, err := filterAttempt(ctx, semaCtx, &s, func() { 766 s.overloadIdxs = filterOverloads(s.overloads, s.overloadIdxs, func(o overloadImpl) bool { 767 return o.preferred() 768 }) 769 }); ok { 770 return typedExprs, fns, err 771 } 772 773 if err := defaultTypeCheck(ctx, semaCtx, &s, len(s.overloads) > 0); err != nil { 774 return nil, nil, err 775 } 776 777 possibleOverloads := make([]overloadImpl, len(s.overloadIdxs)) 778 for i, o := range s.overloadIdxs { 779 possibleOverloads[i] = s.overloads[o] 780 } 781 return s.typedExprs, possibleOverloads, nil 782 } 783 784 // filterAttempt attempts to filter the overloads down to a single candidate. 785 // If it succeeds, it will return true, along with the overload (in a slice for 786 // convenience) and a possible error. If it fails, it will return false and 787 // undo any filtering performed during the attempt. 788 func filterAttempt( 789 ctx context.Context, semaCtx *SemaContext, s *typeCheckOverloadState, attempt func(), 790 ) (ok bool, _ []TypedExpr, _ []overloadImpl, _ error) { 791 before := s.overloadIdxs 792 attempt() 793 if len(s.overloadIdxs) == 1 { 794 ok, typedExprs, fns, err := checkReturn(ctx, semaCtx, s) 795 if err != nil { 796 return false, nil, nil, err 797 } 798 if ok { 799 return true, typedExprs, fns, err 800 } 801 } 802 s.overloadIdxs = before 803 return false, nil, nil, nil 804 } 805 806 // filterOverloads filters overloads which do not satisfy the predicate. 807 func filterOverloads( 808 overloads []overloadImpl, overloadIdxs []uint8, fn func(overloadImpl) bool, 809 ) []uint8 { 810 for i := 0; i < len(overloadIdxs); { 811 if fn(overloads[overloadIdxs[i]]) { 812 i++ 813 } else { 814 overloadIdxs[i], overloadIdxs[len(overloadIdxs)-1] = overloadIdxs[len(overloadIdxs)-1], overloadIdxs[i] 815 overloadIdxs = overloadIdxs[:len(overloadIdxs)-1] 816 } 817 } 818 return overloadIdxs 819 } 820 821 // defaultTypeCheck type checks the constant and placeholder expressions without a preference 822 // and adds them to the type checked slice. 823 func defaultTypeCheck( 824 ctx context.Context, semaCtx *SemaContext, s *typeCheckOverloadState, errorOnPlaceholders bool, 825 ) error { 826 for _, i := range s.constIdxs { 827 typ, err := s.exprs[i].TypeCheck(ctx, semaCtx, types.Any) 828 if err != nil { 829 return pgerror.Wrapf(err, pgcode.InvalidParameterValue, 830 "error type checking constant value") 831 } 832 s.typedExprs[i] = typ 833 } 834 for _, i := range s.placeholderIdxs { 835 if errorOnPlaceholders { 836 _, err := s.exprs[i].TypeCheck(ctx, semaCtx, types.Any) 837 return err 838 } 839 // If we dont want to error on args, avoid type checking them without a desired type. 840 s.typedExprs[i] = StripParens(s.exprs[i]).(*Placeholder) 841 } 842 return nil 843 } 844 845 // checkReturn checks the number of remaining overloaded function 846 // implementations. 847 // Returns ok=true if we should stop overload resolution, and returning either 848 // 1. the chosen overload in a slice, or 849 // 2. nil, 850 // along with the typed arguments. 851 // This modifies values within s as scratch slices, but only in the case where 852 // it returns true, which signals to the calling function that it should 853 // immediately return, so any mutations to s are irrelevant. 854 func checkReturn( 855 ctx context.Context, semaCtx *SemaContext, s *typeCheckOverloadState, 856 ) (ok bool, _ []TypedExpr, _ []overloadImpl, _ error) { 857 switch len(s.overloadIdxs) { 858 case 0: 859 if err := defaultTypeCheck(ctx, semaCtx, s, false); err != nil { 860 return false, nil, nil, err 861 } 862 return true, s.typedExprs, nil, nil 863 864 case 1: 865 idx := s.overloadIdxs[0] 866 o := s.overloads[idx] 867 p := o.params() 868 for _, i := range s.constIdxs { 869 des := p.GetAt(i) 870 typ, err := s.exprs[i].TypeCheck(ctx, semaCtx, des) 871 if err != nil { 872 return false, s.typedExprs, nil, pgerror.Wrapf( 873 err, pgcode.InvalidParameterValue, 874 "error type checking constant value", 875 ) 876 } 877 if des != nil && !typ.ResolvedType().Equivalent(des) { 878 return false, nil, nil, errors.AssertionFailedf( 879 "desired constant value type %s but set type %s", 880 log.Safe(des), log.Safe(typ.ResolvedType()), 881 ) 882 } 883 s.typedExprs[i] = typ 884 } 885 886 return checkReturnPlaceholdersAtIdx(ctx, semaCtx, s, int(idx)) 887 888 default: 889 return false, nil, nil, nil 890 } 891 } 892 893 // checkReturnPlaceholdersAtIdx checks that the placeholders for the 894 // overload at the input index are valid. It has the same return values 895 // as checkReturn. 896 func checkReturnPlaceholdersAtIdx( 897 ctx context.Context, semaCtx *SemaContext, s *typeCheckOverloadState, idx int, 898 ) (bool, []TypedExpr, []overloadImpl, error) { 899 o := s.overloads[idx] 900 p := o.params() 901 for _, i := range s.placeholderIdxs { 902 des := p.GetAt(i) 903 typ, err := s.exprs[i].TypeCheck(ctx, semaCtx, des) 904 if err != nil { 905 if des.IsAmbiguous() { 906 return false, nil, nil, nil 907 } 908 return false, nil, nil, err 909 } 910 s.typedExprs[i] = typ 911 } 912 return true, s.typedExprs, s.overloads[idx : idx+1], nil 913 } 914 915 func formatCandidates(prefix string, candidates []overloadImpl) string { 916 var buf bytes.Buffer 917 for _, candidate := range candidates { 918 buf.WriteString(prefix) 919 buf.WriteByte('(') 920 params := candidate.params() 921 tLen := params.Length() 922 for i := 0; i < tLen; i++ { 923 t := params.GetAt(i) 924 if i > 0 { 925 buf.WriteString(", ") 926 } 927 buf.WriteString(t.String()) 928 } 929 buf.WriteString(") -> ") 930 buf.WriteString(returnTypeToFixedType(candidate.returnType()).String()) 931 if candidate.preferred() { 932 buf.WriteString(" [preferred]") 933 } 934 buf.WriteByte('\n') 935 } 936 return buf.String() 937 }