github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/sem/tree/overload.go (about)

     1  // Copyright 2016 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package tree
    12  
    13  import (
    14  	"bytes"
    15  	"context"
    16  	"fmt"
    17  	"math"
    18  
    19  	"github.com/cockroachdb/cockroach/pkg/server/telemetry"
    20  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
    21  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
    22  	"github.com/cockroachdb/cockroach/pkg/sql/types"
    23  	"github.com/cockroachdb/cockroach/pkg/util/log"
    24  	"github.com/cockroachdb/errors"
    25  )
    26  
    27  // SpecializedVectorizedBuiltin is used to map overloads
    28  // to the vectorized operator that is specific to
    29  // that implementation of the builtin function.
    30  type SpecializedVectorizedBuiltin int
    31  
    32  // TODO (rohany): What is the best place to put this list?
    33  // I want to put it in builtins or exec, but those create an import
    34  // cycle with exec. tree is imported by both of them, so
    35  // this package seems like a good place to do it.
    36  
    37  // Keep this list alphabetized so that it is easy to manage.
    38  const (
    39  	_ SpecializedVectorizedBuiltin = iota
    40  	SubstringStringIntInt
    41  )
    42  
    43  // Overload is one of the overloads of a built-in function.
    44  // Each FunctionDefinition may contain one or more overloads.
    45  type Overload struct {
    46  	Types      TypeList
    47  	ReturnType ReturnTyper
    48  	Volatility Volatility
    49  
    50  	// PreferredOverload determines overload resolution as follows.
    51  	// When multiple overloads are eligible based on types even after all of of
    52  	// the heuristics to pick one have been used, if one of the overloads is a
    53  	// Overload with the `PreferredOverload` flag set to true it can be selected
    54  	// rather than returning a no-such-method error.
    55  	// This should generally be avoided -- avoiding introducing ambiguous
    56  	// overloads in the first place is a much better solution -- and only done
    57  	// after consultation with @knz @nvanbenschoten.
    58  	PreferredOverload bool
    59  
    60  	// Info is a description of the function, which is surfaced on the CockroachDB
    61  	// docs site on the "Functions and Operators" page. Descriptions typically use
    62  	// third-person with the function as an implicit subject (e.g. "Calculates
    63  	// infinity"), but should focus more on ease of understanding so other structures
    64  	// might be more appropriate.
    65  	Info string
    66  
    67  	AggregateFunc func([]*types.T, *EvalContext, Datums) AggregateFunc
    68  	WindowFunc    func([]*types.T, *EvalContext) WindowFunc
    69  	Fn            func(*EvalContext, Datums) (Datum, error)
    70  	Generator     GeneratorFactory
    71  
    72  	// SQLFn must be set for overloads of type SQLClass. It should return a SQL
    73  	// statement which will be executed as a common table expression in the query.
    74  	SQLFn func(*EvalContext, Datums) (string, error)
    75  
    76  	// counter, if non-nil, should be incremented upon successful
    77  	// type check of expressions using this overload.
    78  	counter telemetry.Counter
    79  
    80  	// SpecializedVecBuiltin is used to let the vectorized engine
    81  	// know when an Overload has a specialized vectorized operator.
    82  	SpecializedVecBuiltin SpecializedVectorizedBuiltin
    83  }
    84  
    85  // params implements the overloadImpl interface.
    86  func (b Overload) params() TypeList { return b.Types }
    87  
    88  // returnType implements the overloadImpl interface.
    89  func (b Overload) returnType() ReturnTyper { return b.ReturnType }
    90  
    91  // preferred implements the overloadImpl interface.
    92  func (b Overload) preferred() bool { return b.PreferredOverload }
    93  
    94  // FixedReturnType returns a fixed type that the function returns, returning Any
    95  // if the return type is based on the function's arguments.
    96  func (b Overload) FixedReturnType() *types.T {
    97  	if b.ReturnType == nil {
    98  		return nil
    99  	}
   100  	return returnTypeToFixedType(b.ReturnType)
   101  }
   102  
   103  // Signature returns a human-readable signature.
   104  // If simplify is bool, tuple-returning functions with just
   105  // 1 tuple element unwrap the return type in the signature.
   106  func (b Overload) Signature(simplify bool) string {
   107  	retType := b.FixedReturnType()
   108  	if simplify {
   109  		if retType.Family() == types.TupleFamily && len(retType.TupleContents()) == 1 {
   110  			retType = retType.TupleContents()[0]
   111  		}
   112  	}
   113  	return fmt.Sprintf("(%s) -> %s", b.Types.String(), retType)
   114  }
   115  
   116  // overloadImpl is an implementation of an overloaded function. It provides
   117  // access to the parameter type list  and the return type of the implementation.
   118  //
   119  // This is a more general type than Overload defined above, because it also
   120  // works with the built-in binary and unary operators.
   121  type overloadImpl interface {
   122  	params() TypeList
   123  	returnType() ReturnTyper
   124  	// allows manually resolving preference between multiple compatible overloads
   125  	preferred() bool
   126  }
   127  
   128  var _ overloadImpl = &Overload{}
   129  var _ overloadImpl = &UnaryOp{}
   130  var _ overloadImpl = &BinOp{}
   131  
   132  // GetParamsAndReturnType gets the parameters and return type of an
   133  // overloadImpl.
   134  func GetParamsAndReturnType(impl overloadImpl) (TypeList, ReturnTyper) {
   135  	return impl.params(), impl.returnType()
   136  }
   137  
   138  // TypeList is a list of types representing a function parameter list.
   139  type TypeList interface {
   140  	// Match checks if all types in the TypeList match the corresponding elements in types.
   141  	Match(types []*types.T) bool
   142  	// MatchAt checks if the parameter type at index i of the TypeList matches type typ.
   143  	// In all implementations, types.Null will match with each parameter type, allowing
   144  	// NULL values to be used as arguments.
   145  	MatchAt(typ *types.T, i int) bool
   146  	// matchLen checks that the TypeList can support l parameters.
   147  	MatchLen(l int) bool
   148  	// getAt returns the type at the given index in the TypeList, or nil if the TypeList
   149  	// cannot have a parameter at index i.
   150  	GetAt(i int) *types.T
   151  	// Length returns the number of types in the list
   152  	Length() int
   153  	// Types returns a realized copy of the list. variadic lists return a list of size one.
   154  	Types() []*types.T
   155  	// String returns a human readable signature
   156  	String() string
   157  }
   158  
   159  var _ TypeList = ArgTypes{}
   160  var _ TypeList = HomogeneousType{}
   161  var _ TypeList = VariadicType{}
   162  
   163  // ArgTypes is very similar to ArgTypes except it allows keeping a string
   164  // name for each argument as well and using those when printing the
   165  // human-readable signature.
   166  type ArgTypes []struct {
   167  	Name string
   168  	Typ  *types.T
   169  }
   170  
   171  // Match is part of the TypeList interface.
   172  func (a ArgTypes) Match(types []*types.T) bool {
   173  	if len(types) != len(a) {
   174  		return false
   175  	}
   176  	for i := range types {
   177  		if !a.MatchAt(types[i], i) {
   178  			return false
   179  		}
   180  	}
   181  	return true
   182  }
   183  
   184  // MatchAt is part of the TypeList interface.
   185  func (a ArgTypes) MatchAt(typ *types.T, i int) bool {
   186  	// The parameterized types for Tuples are checked in the type checking
   187  	// routines before getting here, so we only need to check if the argument
   188  	// type is a types.TUPLE below. This allows us to avoid defining overloads
   189  	// for types.Tuple{}, types.Tuple{types.Any}, types.Tuple{types.Any, types.Any},
   190  	// etc. for Tuple operators.
   191  	if typ.Family() == types.TupleFamily {
   192  		typ = types.AnyTuple
   193  	}
   194  	return i < len(a) && (typ.Family() == types.UnknownFamily || a[i].Typ.Equivalent(typ))
   195  }
   196  
   197  // MatchLen is part of the TypeList interface.
   198  func (a ArgTypes) MatchLen(l int) bool {
   199  	return len(a) == l
   200  }
   201  
   202  // GetAt is part of the TypeList interface.
   203  func (a ArgTypes) GetAt(i int) *types.T {
   204  	return a[i].Typ
   205  }
   206  
   207  // Length is part of the TypeList interface.
   208  func (a ArgTypes) Length() int {
   209  	return len(a)
   210  }
   211  
   212  // Types is part of the TypeList interface.
   213  func (a ArgTypes) Types() []*types.T {
   214  	n := len(a)
   215  	ret := make([]*types.T, n)
   216  	for i, s := range a {
   217  		ret[i] = s.Typ
   218  	}
   219  	return ret
   220  }
   221  
   222  func (a ArgTypes) String() string {
   223  	var s bytes.Buffer
   224  	for i, arg := range a {
   225  		if i > 0 {
   226  			s.WriteString(", ")
   227  		}
   228  		s.WriteString(arg.Name)
   229  		s.WriteString(": ")
   230  		s.WriteString(arg.Typ.String())
   231  	}
   232  	return s.String()
   233  }
   234  
   235  // HomogeneousType is a TypeList implementation that accepts any arguments, as
   236  // long as all are the same type or NULL. The homogeneous constraint is enforced
   237  // in typeCheckOverloadedExprs.
   238  type HomogeneousType struct{}
   239  
   240  // Match is part of the TypeList interface.
   241  func (HomogeneousType) Match(types []*types.T) bool {
   242  	return true
   243  }
   244  
   245  // MatchAt is part of the TypeList interface.
   246  func (HomogeneousType) MatchAt(typ *types.T, i int) bool {
   247  	return true
   248  }
   249  
   250  // MatchLen is part of the TypeList interface.
   251  func (HomogeneousType) MatchLen(l int) bool {
   252  	return true
   253  }
   254  
   255  // GetAt is part of the TypeList interface.
   256  func (HomogeneousType) GetAt(i int) *types.T {
   257  	return types.Any
   258  }
   259  
   260  // Length is part of the TypeList interface.
   261  func (HomogeneousType) Length() int {
   262  	return 1
   263  }
   264  
   265  // Types is part of the TypeList interface.
   266  func (HomogeneousType) Types() []*types.T {
   267  	return []*types.T{types.Any}
   268  }
   269  
   270  func (HomogeneousType) String() string {
   271  	return "anyelement..."
   272  }
   273  
   274  // VariadicType is a TypeList implementation which accepts a fixed number of
   275  // arguments at the beginning and an arbitrary number of homogenous arguments
   276  // at the end.
   277  type VariadicType struct {
   278  	FixedTypes []*types.T
   279  	VarType    *types.T
   280  }
   281  
   282  // Match is part of the TypeList interface.
   283  func (v VariadicType) Match(types []*types.T) bool {
   284  	for i := range types {
   285  		if !v.MatchAt(types[i], i) {
   286  			return false
   287  		}
   288  	}
   289  	return true
   290  }
   291  
   292  // MatchAt is part of the TypeList interface.
   293  func (v VariadicType) MatchAt(typ *types.T, i int) bool {
   294  	if i < len(v.FixedTypes) {
   295  		return typ.Family() == types.UnknownFamily || v.FixedTypes[i].Equivalent(typ)
   296  	}
   297  	return typ.Family() == types.UnknownFamily || v.VarType.Equivalent(typ)
   298  }
   299  
   300  // MatchLen is part of the TypeList interface.
   301  func (v VariadicType) MatchLen(l int) bool {
   302  	return l >= len(v.FixedTypes)
   303  }
   304  
   305  // GetAt is part of the TypeList interface.
   306  func (v VariadicType) GetAt(i int) *types.T {
   307  	if i < len(v.FixedTypes) {
   308  		return v.FixedTypes[i]
   309  	}
   310  	return v.VarType
   311  }
   312  
   313  // Length is part of the TypeList interface.
   314  func (v VariadicType) Length() int {
   315  	return len(v.FixedTypes) + 1
   316  }
   317  
   318  // Types is part of the TypeList interface.
   319  func (v VariadicType) Types() []*types.T {
   320  	result := make([]*types.T, len(v.FixedTypes)+1)
   321  	for i := range v.FixedTypes {
   322  		result[i] = v.FixedTypes[i]
   323  	}
   324  	result[len(result)-1] = v.VarType
   325  	return result
   326  }
   327  
   328  func (v VariadicType) String() string {
   329  	var s bytes.Buffer
   330  	for i, t := range v.FixedTypes {
   331  		if i != 0 {
   332  			s.WriteString(", ")
   333  		}
   334  		s.WriteString(t.String())
   335  	}
   336  	if len(v.FixedTypes) > 0 {
   337  		s.WriteString(", ")
   338  	}
   339  	fmt.Fprintf(&s, "%s...", v.VarType)
   340  	return s.String()
   341  }
   342  
   343  // UnknownReturnType is returned from ReturnTypers when the arguments provided are
   344  // not sufficient to determine a return type. This is necessary for cases like overload
   345  // resolution, where the argument types are not resolved yet so the type-level function
   346  // will be called without argument types. If a ReturnTyper returns unknownReturnType,
   347  // then the candidate function set cannot be refined. This means that only ReturnTypers
   348  // that never return unknownReturnType, like those created with FixedReturnType, can
   349  // help reduce overload ambiguity.
   350  var UnknownReturnType *types.T
   351  
   352  // ReturnTyper defines the type-level function in which a builtin function's return type
   353  // is determined. ReturnTypers should make sure to return unknownReturnType when necessary.
   354  type ReturnTyper func(args []TypedExpr) *types.T
   355  
   356  // FixedReturnType functions simply return a fixed type, independent of argument types.
   357  func FixedReturnType(typ *types.T) ReturnTyper {
   358  	return func(args []TypedExpr) *types.T { return typ }
   359  }
   360  
   361  // IdentityReturnType creates a returnType that is a projection of the idx'th
   362  // argument type.
   363  func IdentityReturnType(idx int) ReturnTyper {
   364  	return func(args []TypedExpr) *types.T {
   365  		if len(args) == 0 {
   366  			return UnknownReturnType
   367  		}
   368  		return args[idx].ResolvedType()
   369  	}
   370  }
   371  
   372  // ArrayOfFirstNonNullReturnType returns an array type from the first non-null
   373  // type in the argument list.
   374  func ArrayOfFirstNonNullReturnType() ReturnTyper {
   375  	return func(args []TypedExpr) *types.T {
   376  		if len(args) == 0 {
   377  			return UnknownReturnType
   378  		}
   379  		for _, arg := range args {
   380  			if t := arg.ResolvedType(); t.Family() != types.UnknownFamily {
   381  				return types.MakeArray(t)
   382  			}
   383  		}
   384  		return types.Unknown
   385  	}
   386  }
   387  
   388  // FirstNonNullReturnType returns the type of the first non-null argument, or
   389  // types.Unknown if all arguments are null. There must be at least one argument,
   390  // or else FirstNonNullReturnType returns UnknownReturnType. This method is used
   391  // with HomogeneousType functions, in which all arguments have been checked to
   392  // have the same type (or be null).
   393  func FirstNonNullReturnType() ReturnTyper {
   394  	return func(args []TypedExpr) *types.T {
   395  		if len(args) == 0 {
   396  			return UnknownReturnType
   397  		}
   398  		for _, arg := range args {
   399  			if t := arg.ResolvedType(); t.Family() != types.UnknownFamily {
   400  				return t
   401  			}
   402  		}
   403  		return types.Unknown
   404  	}
   405  }
   406  
   407  func returnTypeToFixedType(s ReturnTyper) *types.T {
   408  	if t := s(nil); t != UnknownReturnType {
   409  		return t
   410  	}
   411  	return types.Any
   412  }
   413  
   414  type typeCheckOverloadState struct {
   415  	overloads       []overloadImpl
   416  	overloadIdxs    []uint8 // index into overloads
   417  	exprs           []Expr
   418  	typedExprs      []TypedExpr
   419  	resolvableIdxs  []int // index into exprs/typedExprs
   420  	constIdxs       []int // index into exprs/typedExprs
   421  	placeholderIdxs []int // index into exprs/typedExprs
   422  }
   423  
   424  // typeCheckOverloadedExprs determines the correct overload to use for the given set of
   425  // expression parameters, along with an optional desired return type. It returns the expression
   426  // parameters after being type checked, along with a slice of candidate overloadImpls. The
   427  // slice may have length:
   428  //   0: overload resolution failed because no compatible overloads were found
   429  //   1: overload resolution succeeded
   430  //  2+: overload resolution failed because of ambiguity
   431  // The inBinOp parameter denotes whether this type check is occurring within a binary operator,
   432  // in which case we may need to make a guess that the two parameters are of the same type if one
   433  // of them is NULL.
   434  func typeCheckOverloadedExprs(
   435  	ctx context.Context,
   436  	semaCtx *SemaContext,
   437  	desired *types.T,
   438  	overloads []overloadImpl,
   439  	inBinOp bool,
   440  	exprs ...Expr,
   441  ) ([]TypedExpr, []overloadImpl, error) {
   442  	if len(overloads) > math.MaxUint8 {
   443  		return nil, nil, errors.AssertionFailedf("too many overloads (%d > 255)", len(overloads))
   444  	}
   445  
   446  	var s typeCheckOverloadState
   447  	s.exprs = exprs
   448  	s.overloads = overloads
   449  
   450  	// Special-case the HomogeneousType overload. We determine its return type by checking that
   451  	// all parameters have the same type.
   452  	for i, overload := range overloads {
   453  		// Only one overload can be provided if it has parameters with HomogeneousType.
   454  		if _, ok := overload.params().(HomogeneousType); ok {
   455  			if len(overloads) > 1 {
   456  				return nil, nil, errors.AssertionFailedf(
   457  					"only one overload can have HomogeneousType parameters")
   458  			}
   459  			typedExprs, _, err := TypeCheckSameTypedExprs(ctx, semaCtx, desired, exprs...)
   460  			if err != nil {
   461  				return nil, nil, err
   462  			}
   463  			return typedExprs, overloads[i : i+1], nil
   464  		}
   465  	}
   466  
   467  	// Hold the resolved type expressions of the provided exprs, in order.
   468  	s.typedExprs = make([]TypedExpr, len(exprs))
   469  	s.constIdxs, s.placeholderIdxs, s.resolvableIdxs = typeCheckSplitExprs(ctx, semaCtx, exprs)
   470  
   471  	// If no overloads are provided, just type check parameters and return.
   472  	if len(overloads) == 0 {
   473  		for _, i := range s.resolvableIdxs {
   474  			typ, err := exprs[i].TypeCheck(ctx, semaCtx, types.Any)
   475  			if err != nil {
   476  				return nil, nil, pgerror.Wrapf(err, pgcode.InvalidParameterValue,
   477  					"error type checking resolved expression:")
   478  			}
   479  			s.typedExprs[i] = typ
   480  		}
   481  		if err := defaultTypeCheck(ctx, semaCtx, &s, false); err != nil {
   482  			return nil, nil, err
   483  		}
   484  		return s.typedExprs, nil, nil
   485  	}
   486  
   487  	s.overloadIdxs = make([]uint8, len(overloads))
   488  	for i := 0; i < len(overloads); i++ {
   489  		s.overloadIdxs[i] = uint8(i)
   490  	}
   491  
   492  	// Filter out incorrect parameter length overloads.
   493  	s.overloadIdxs = filterOverloads(s.overloads, s.overloadIdxs,
   494  		func(o overloadImpl) bool {
   495  			return o.params().MatchLen(len(exprs))
   496  		})
   497  
   498  	// Filter out overloads which constants cannot become.
   499  	for _, i := range s.constIdxs {
   500  		constExpr := exprs[i].(Constant)
   501  		s.overloadIdxs = filterOverloads(s.overloads, s.overloadIdxs,
   502  			func(o overloadImpl) bool {
   503  				return canConstantBecome(constExpr, o.params().GetAt(i))
   504  			})
   505  	}
   506  
   507  	// TODO(nvanbenschoten): We should add a filtering step here to filter
   508  	// out impossible candidates based on identical parameters. For instance,
   509  	// f(int, float) is not a possible candidate for the expression f($1, $1).
   510  
   511  	// Filter out overloads on resolved types.
   512  	for _, i := range s.resolvableIdxs {
   513  		paramDesired := types.Any
   514  		if len(s.overloadIdxs) == 1 {
   515  			// Once we get down to a single overload candidate, begin desiring its
   516  			// parameter types for the corresponding argument expressions.
   517  			paramDesired = s.overloads[s.overloadIdxs[0]].params().GetAt(i)
   518  		}
   519  		typ, err := exprs[i].TypeCheck(ctx, semaCtx, paramDesired)
   520  		if err != nil {
   521  			return nil, nil, err
   522  		}
   523  		s.typedExprs[i] = typ
   524  		s.overloadIdxs = filterOverloads(s.overloads, s.overloadIdxs,
   525  			func(o overloadImpl) bool {
   526  				return o.params().MatchAt(typ.ResolvedType(), i)
   527  			})
   528  	}
   529  
   530  	// At this point, all remaining overload candidates accept the argument list,
   531  	// so we begin checking for a single remaining candidate implementation to choose.
   532  	// In case there is more than one candidate remaining, the following code uses
   533  	// heuristics to find a most preferable candidate.
   534  	if ok, typedExprs, fns, err := checkReturn(ctx, semaCtx, &s); ok {
   535  		return typedExprs, fns, err
   536  	}
   537  
   538  	// The first heuristic is to prefer candidates that return the desired type.
   539  	if desired.Family() != types.AnyFamily {
   540  		s.overloadIdxs = filterOverloads(s.overloads, s.overloadIdxs,
   541  			func(o overloadImpl) bool {
   542  				// For now, we only filter on the return type for overloads with
   543  				// fixed return types. This could be improved, but is not currently
   544  				// critical because we have no cases of functions with multiple
   545  				// overloads that do not all expose FixedReturnTypes.
   546  				if t := o.returnType()(nil); t != UnknownReturnType {
   547  					return t.Equivalent(desired)
   548  				}
   549  				return true
   550  			})
   551  		if ok, typedExprs, fns, err := checkReturn(ctx, semaCtx, &s); ok {
   552  			return typedExprs, fns, err
   553  		}
   554  	}
   555  
   556  	var homogeneousTyp *types.T
   557  	if len(s.resolvableIdxs) > 0 {
   558  		homogeneousTyp = s.typedExprs[s.resolvableIdxs[0]].ResolvedType()
   559  		for _, i := range s.resolvableIdxs[1:] {
   560  			if !homogeneousTyp.Equivalent(s.typedExprs[i].ResolvedType()) {
   561  				homogeneousTyp = nil
   562  				break
   563  			}
   564  		}
   565  	}
   566  
   567  	if len(s.constIdxs) > 0 {
   568  		allConstantsAreHomogenous := false
   569  		if ok, typedExprs, fns, err := filterAttempt(ctx, semaCtx, &s, func() {
   570  			// The second heuristic is to prefer candidates where all constants can
   571  			// become a homogeneous type, if all resolvable expressions became one.
   572  			// This is only possible if resolvable expressions were resolved
   573  			// homogeneously up to this point.
   574  			if homogeneousTyp != nil {
   575  				allConstantsAreHomogenous = true
   576  				for _, i := range s.constIdxs {
   577  					if !canConstantBecome(exprs[i].(Constant), homogeneousTyp) {
   578  						allConstantsAreHomogenous = false
   579  						break
   580  					}
   581  				}
   582  				if allConstantsAreHomogenous {
   583  					for _, i := range s.constIdxs {
   584  						s.overloadIdxs = filterOverloads(s.overloads, s.overloadIdxs,
   585  							func(o overloadImpl) bool {
   586  								return o.params().GetAt(i).Equivalent(homogeneousTyp)
   587  							})
   588  					}
   589  				}
   590  			}
   591  		}); ok {
   592  			return typedExprs, fns, err
   593  		}
   594  
   595  		if ok, typedExprs, fns, err := filterAttempt(ctx, semaCtx, &s, func() {
   596  			// The third heuristic is to prefer candidates where all constants can
   597  			// become their "natural" types.
   598  			for _, i := range s.constIdxs {
   599  				natural := naturalConstantType(exprs[i].(Constant))
   600  				if natural != nil {
   601  					s.overloadIdxs = filterOverloads(s.overloads, s.overloadIdxs,
   602  						func(o overloadImpl) bool {
   603  							return o.params().GetAt(i).Equivalent(natural)
   604  						})
   605  				}
   606  			}
   607  		}); ok {
   608  			return typedExprs, fns, err
   609  		}
   610  
   611  		// At this point, it's worth seeing if we have constants that can't actually
   612  		// parse as the type that canConstantBecome claims they can. For example,
   613  		// every string literal will report that it can become an interval, but most
   614  		// string literals do not encode valid intervals. This may uncover some
   615  		// overloads with invalid type signatures.
   616  		//
   617  		// This parsing is sufficiently expensive (see the comment on
   618  		// StrVal.AvailableTypes) that we wait until now, when we've eliminated most
   619  		// overloads from consideration, so that we only need to check each constant
   620  		// against a limited set of types. We can't hold off on this parsing any
   621  		// longer, though: the remaining heuristics are overly aggressive and will
   622  		// falsely reject the only valid overload in some cases.
   623  		//
   624  		// This case is broken into two parts. We first attempt to use the
   625  		// information about the homogeneity of our constants collected by previous
   626  		// heuristic passes. If:
   627  		// * all our constants are homogeneous
   628  		// * we only have a single overload left
   629  		// * the constant overload parameters are homogeneous as well
   630  		// then match this overload with the homogeneous constants. Otherwise,
   631  		// continue to filter overloads by whether or not the constants can parse
   632  		// into the desired types of the overloads.
   633  		// This first case is important when resolving overloads for operations
   634  		// between user-defined types, where we need to propagate the concrete
   635  		// resolved type information over to the constants, rather than attempting
   636  		// to resolve constants as the placeholder type for the user defined type
   637  		// family (like `AnyEnum`).
   638  		if len(s.overloadIdxs) == 1 && allConstantsAreHomogenous {
   639  			overloadParamsAreHomogenous := true
   640  			p := s.overloads[s.overloadIdxs[0]].params()
   641  			for _, i := range s.constIdxs {
   642  				if !p.GetAt(i).Equivalent(homogeneousTyp) {
   643  					overloadParamsAreHomogenous = false
   644  					break
   645  				}
   646  			}
   647  			if overloadParamsAreHomogenous {
   648  				// Type check our constants using the homogeneous type rather than
   649  				// the type in overload parameter. This lets us type check user defined
   650  				// types with a concrete type instance, rather than an ambiguous type.
   651  				for _, i := range s.constIdxs {
   652  					typ, err := s.exprs[i].TypeCheck(ctx, semaCtx, homogeneousTyp)
   653  					if err != nil {
   654  						return nil, nil, err
   655  					}
   656  					s.typedExprs[i] = typ
   657  				}
   658  				_, typedExprs, fn, err := checkReturnPlaceholdersAtIdx(ctx, semaCtx, &s, int(s.overloadIdxs[0]))
   659  				return typedExprs, fn, err
   660  			}
   661  		}
   662  		for _, i := range s.constIdxs {
   663  			constExpr := exprs[i].(Constant)
   664  			s.overloadIdxs = filterOverloads(s.overloads, s.overloadIdxs,
   665  				func(o overloadImpl) bool {
   666  					semaCtx := MakeSemaContext()
   667  					_, err := constExpr.ResolveAsType(&semaCtx, o.params().GetAt(i))
   668  					return err == nil
   669  				})
   670  		}
   671  		if ok, typedExprs, fn, err := checkReturn(ctx, semaCtx, &s); ok {
   672  			return typedExprs, fn, err
   673  		}
   674  
   675  		// The fourth heuristic is to prefer candidates that accepts the "best"
   676  		// mutual type in the resolvable type set of all constants.
   677  		if bestConstType, ok := commonConstantType(s.exprs, s.constIdxs); ok {
   678  			for _, i := range s.constIdxs {
   679  				s.overloadIdxs = filterOverloads(s.overloads, s.overloadIdxs,
   680  					func(o overloadImpl) bool {
   681  						return o.params().GetAt(i).Equivalent(bestConstType)
   682  					})
   683  			}
   684  			if ok, typedExprs, fns, err := checkReturn(ctx, semaCtx, &s); ok {
   685  				return typedExprs, fns, err
   686  			}
   687  			if homogeneousTyp != nil {
   688  				if !homogeneousTyp.Equivalent(bestConstType) {
   689  					homogeneousTyp = nil
   690  				}
   691  			} else {
   692  				homogeneousTyp = bestConstType
   693  			}
   694  		}
   695  	}
   696  
   697  	// The fifth heuristic is to prefer candidates where all placeholders can be
   698  	// given the same type as all constants and resolvable expressions. This is
   699  	// only possible if all constants and resolvable expressions were resolved
   700  	// homogeneously up to this point.
   701  	if homogeneousTyp != nil && len(s.placeholderIdxs) > 0 {
   702  		// Before we continue, try to propagate the homogeneous type to the
   703  		// placeholders. This might not have happened yet, if the overloads'
   704  		// parameter types are ambiguous (like in the case of tuple-tuple binary
   705  		// operators).
   706  		for _, i := range s.placeholderIdxs {
   707  			if _, err := exprs[i].TypeCheck(ctx, semaCtx, homogeneousTyp); err != nil {
   708  				return nil, nil, err
   709  			}
   710  			s.overloadIdxs = filterOverloads(s.overloads, s.overloadIdxs,
   711  				func(o overloadImpl) bool {
   712  					return o.params().GetAt(i).Equivalent(homogeneousTyp)
   713  				})
   714  		}
   715  		if ok, typedExprs, fns, err := checkReturn(ctx, semaCtx, &s); ok {
   716  			return typedExprs, fns, err
   717  		}
   718  	}
   719  
   720  	// In a binary expression, in the case of one of the arguments being untyped NULL,
   721  	// we prefer overloads where we infer the type of the NULL to be the same as the
   722  	// other argument. This is used to differentiate the behavior of
   723  	// STRING[] || NULL and STRING || NULL.
   724  	if inBinOp && len(s.exprs) == 2 {
   725  		if ok, typedExprs, fns, err := filterAttempt(ctx, semaCtx, &s, func() {
   726  			var err error
   727  			left := s.typedExprs[0]
   728  			if left == nil {
   729  				left, err = s.exprs[0].TypeCheck(ctx, semaCtx, types.Any)
   730  				if err != nil {
   731  					return
   732  				}
   733  			}
   734  			right := s.typedExprs[1]
   735  			if right == nil {
   736  				right, err = s.exprs[1].TypeCheck(ctx, semaCtx, types.Any)
   737  				if err != nil {
   738  					return
   739  				}
   740  			}
   741  			leftType := left.ResolvedType()
   742  			rightType := right.ResolvedType()
   743  			leftIsNull := leftType.Family() == types.UnknownFamily
   744  			rightIsNull := rightType.Family() == types.UnknownFamily
   745  			oneIsNull := (leftIsNull || rightIsNull) && !(leftIsNull && rightIsNull)
   746  			if oneIsNull {
   747  				if leftIsNull {
   748  					leftType = rightType
   749  				}
   750  				if rightIsNull {
   751  					rightType = leftType
   752  				}
   753  				s.overloadIdxs = filterOverloads(s.overloads, s.overloadIdxs,
   754  					func(o overloadImpl) bool {
   755  						return o.params().GetAt(0).Equivalent(leftType) &&
   756  							o.params().GetAt(1).Equivalent(rightType)
   757  					})
   758  			}
   759  		}); ok {
   760  			return typedExprs, fns, err
   761  		}
   762  	}
   763  
   764  	// The final heuristic is to defer to preferred candidates, if available.
   765  	if ok, typedExprs, fns, err := filterAttempt(ctx, semaCtx, &s, func() {
   766  		s.overloadIdxs = filterOverloads(s.overloads, s.overloadIdxs, func(o overloadImpl) bool {
   767  			return o.preferred()
   768  		})
   769  	}); ok {
   770  		return typedExprs, fns, err
   771  	}
   772  
   773  	if err := defaultTypeCheck(ctx, semaCtx, &s, len(s.overloads) > 0); err != nil {
   774  		return nil, nil, err
   775  	}
   776  
   777  	possibleOverloads := make([]overloadImpl, len(s.overloadIdxs))
   778  	for i, o := range s.overloadIdxs {
   779  		possibleOverloads[i] = s.overloads[o]
   780  	}
   781  	return s.typedExprs, possibleOverloads, nil
   782  }
   783  
   784  // filterAttempt attempts to filter the overloads down to a single candidate.
   785  // If it succeeds, it will return true, along with the overload (in a slice for
   786  // convenience) and a possible error. If it fails, it will return false and
   787  // undo any filtering performed during the attempt.
   788  func filterAttempt(
   789  	ctx context.Context, semaCtx *SemaContext, s *typeCheckOverloadState, attempt func(),
   790  ) (ok bool, _ []TypedExpr, _ []overloadImpl, _ error) {
   791  	before := s.overloadIdxs
   792  	attempt()
   793  	if len(s.overloadIdxs) == 1 {
   794  		ok, typedExprs, fns, err := checkReturn(ctx, semaCtx, s)
   795  		if err != nil {
   796  			return false, nil, nil, err
   797  		}
   798  		if ok {
   799  			return true, typedExprs, fns, err
   800  		}
   801  	}
   802  	s.overloadIdxs = before
   803  	return false, nil, nil, nil
   804  }
   805  
   806  // filterOverloads filters overloads which do not satisfy the predicate.
   807  func filterOverloads(
   808  	overloads []overloadImpl, overloadIdxs []uint8, fn func(overloadImpl) bool,
   809  ) []uint8 {
   810  	for i := 0; i < len(overloadIdxs); {
   811  		if fn(overloads[overloadIdxs[i]]) {
   812  			i++
   813  		} else {
   814  			overloadIdxs[i], overloadIdxs[len(overloadIdxs)-1] = overloadIdxs[len(overloadIdxs)-1], overloadIdxs[i]
   815  			overloadIdxs = overloadIdxs[:len(overloadIdxs)-1]
   816  		}
   817  	}
   818  	return overloadIdxs
   819  }
   820  
   821  // defaultTypeCheck type checks the constant and placeholder expressions without a preference
   822  // and adds them to the type checked slice.
   823  func defaultTypeCheck(
   824  	ctx context.Context, semaCtx *SemaContext, s *typeCheckOverloadState, errorOnPlaceholders bool,
   825  ) error {
   826  	for _, i := range s.constIdxs {
   827  		typ, err := s.exprs[i].TypeCheck(ctx, semaCtx, types.Any)
   828  		if err != nil {
   829  			return pgerror.Wrapf(err, pgcode.InvalidParameterValue,
   830  				"error type checking constant value")
   831  		}
   832  		s.typedExprs[i] = typ
   833  	}
   834  	for _, i := range s.placeholderIdxs {
   835  		if errorOnPlaceholders {
   836  			_, err := s.exprs[i].TypeCheck(ctx, semaCtx, types.Any)
   837  			return err
   838  		}
   839  		// If we dont want to error on args, avoid type checking them without a desired type.
   840  		s.typedExprs[i] = StripParens(s.exprs[i]).(*Placeholder)
   841  	}
   842  	return nil
   843  }
   844  
   845  // checkReturn checks the number of remaining overloaded function
   846  // implementations.
   847  // Returns ok=true if we should stop overload resolution, and returning either
   848  // 1. the chosen overload in a slice, or
   849  // 2. nil,
   850  // along with the typed arguments.
   851  // This modifies values within s as scratch slices, but only in the case where
   852  // it returns true, which signals to the calling function that it should
   853  // immediately return, so any mutations to s are irrelevant.
   854  func checkReturn(
   855  	ctx context.Context, semaCtx *SemaContext, s *typeCheckOverloadState,
   856  ) (ok bool, _ []TypedExpr, _ []overloadImpl, _ error) {
   857  	switch len(s.overloadIdxs) {
   858  	case 0:
   859  		if err := defaultTypeCheck(ctx, semaCtx, s, false); err != nil {
   860  			return false, nil, nil, err
   861  		}
   862  		return true, s.typedExprs, nil, nil
   863  
   864  	case 1:
   865  		idx := s.overloadIdxs[0]
   866  		o := s.overloads[idx]
   867  		p := o.params()
   868  		for _, i := range s.constIdxs {
   869  			des := p.GetAt(i)
   870  			typ, err := s.exprs[i].TypeCheck(ctx, semaCtx, des)
   871  			if err != nil {
   872  				return false, s.typedExprs, nil, pgerror.Wrapf(
   873  					err, pgcode.InvalidParameterValue,
   874  					"error type checking constant value",
   875  				)
   876  			}
   877  			if des != nil && !typ.ResolvedType().Equivalent(des) {
   878  				return false, nil, nil, errors.AssertionFailedf(
   879  					"desired constant value type %s but set type %s",
   880  					log.Safe(des), log.Safe(typ.ResolvedType()),
   881  				)
   882  			}
   883  			s.typedExprs[i] = typ
   884  		}
   885  
   886  		return checkReturnPlaceholdersAtIdx(ctx, semaCtx, s, int(idx))
   887  
   888  	default:
   889  		return false, nil, nil, nil
   890  	}
   891  }
   892  
   893  // checkReturnPlaceholdersAtIdx checks that the placeholders for the
   894  // overload at the input index are valid. It has the same return values
   895  // as checkReturn.
   896  func checkReturnPlaceholdersAtIdx(
   897  	ctx context.Context, semaCtx *SemaContext, s *typeCheckOverloadState, idx int,
   898  ) (bool, []TypedExpr, []overloadImpl, error) {
   899  	o := s.overloads[idx]
   900  	p := o.params()
   901  	for _, i := range s.placeholderIdxs {
   902  		des := p.GetAt(i)
   903  		typ, err := s.exprs[i].TypeCheck(ctx, semaCtx, des)
   904  		if err != nil {
   905  			if des.IsAmbiguous() {
   906  				return false, nil, nil, nil
   907  			}
   908  			return false, nil, nil, err
   909  		}
   910  		s.typedExprs[i] = typ
   911  	}
   912  	return true, s.typedExprs, s.overloads[idx : idx+1], nil
   913  }
   914  
   915  func formatCandidates(prefix string, candidates []overloadImpl) string {
   916  	var buf bytes.Buffer
   917  	for _, candidate := range candidates {
   918  		buf.WriteString(prefix)
   919  		buf.WriteByte('(')
   920  		params := candidate.params()
   921  		tLen := params.Length()
   922  		for i := 0; i < tLen; i++ {
   923  			t := params.GetAt(i)
   924  			if i > 0 {
   925  				buf.WriteString(", ")
   926  			}
   927  			buf.WriteString(t.String())
   928  		}
   929  		buf.WriteString(") -> ")
   930  		buf.WriteString(returnTypeToFixedType(candidate.returnType()).String())
   931  		if candidate.preferred() {
   932  			buf.WriteString(" [preferred]")
   933  		}
   934  		buf.WriteByte('\n')
   935  	}
   936  	return buf.String()
   937  }