github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/sql/sem/tree/function_definition.go (about)

     1  // Copyright 2016 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package tree
    12  
    13  import (
    14  	"strings"
    15  
    16  	"github.com/cockroachdb/cockroachdb-parser/pkg/sql/pgwire/pgcode"
    17  	"github.com/cockroachdb/cockroachdb-parser/pkg/sql/pgwire/pgerror"
    18  	"github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/catconstants"
    19  	"github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/catid"
    20  	"github.com/cockroachdb/cockroachdb-parser/pkg/sql/types"
    21  	"github.com/cockroachdb/errors"
    22  	"github.com/lib/pq/oid"
    23  )
    24  
    25  // FunctionDefinition implements a reference to the (possibly several)
    26  // overloads for a built-in function.
    27  // TODO(Chengxiong): Remove this struct entirely. Instead, use overloads from
    28  // function resolution or use "GetBuiltinProperties" if the need is to only look
    29  // at builtin functions(there are such existing use cases). Also change "Name"
    30  // of ResolvedFunctionDefinition to Name type.
    31  type FunctionDefinition struct {
    32  	// Name is the short name of the function.
    33  	Name string
    34  
    35  	// Definition is the set of overloads for this function name.
    36  	Definition []*Overload
    37  
    38  	// FunctionProperties are the properties common to all overloads.
    39  	FunctionProperties
    40  }
    41  
    42  // ResolvedFunctionDefinition is similar to FunctionDefinition but with all the
    43  // overloads qualified with schema name.
    44  type ResolvedFunctionDefinition struct {
    45  	// Name is the name of the function and not the name of the schema. And, it's
    46  	// not qualified.
    47  	Name string
    48  
    49  	Overloads []QualifiedOverload
    50  }
    51  
    52  type qualifiedOverloads []QualifiedOverload
    53  
    54  func (qo qualifiedOverloads) len() int {
    55  	return len(qo)
    56  }
    57  
    58  func (qo qualifiedOverloads) get(i int) overloadImpl {
    59  	return qo[i].Overload
    60  }
    61  
    62  // QualifiedOverload is a wrapper of Overload prefixed with a schema name.
    63  // It indicates that the overload is defined with the specified schema.
    64  type QualifiedOverload struct {
    65  	Schema string
    66  	*Overload
    67  }
    68  
    69  // MakeQualifiedOverload creates a new QualifiedOverload.
    70  func MakeQualifiedOverload(schema string, overload *Overload) QualifiedOverload {
    71  	return QualifiedOverload{Schema: schema, Overload: overload}
    72  }
    73  
    74  // FunctionProperties defines the properties of the built-in
    75  // functions that are common across all overloads.
    76  type FunctionProperties struct {
    77  	// UnsupportedWithIssue, if non-zero indicates the built-in is not
    78  	// really supported; the name is a placeholder. Value -1 just says
    79  	// "not supported" without an issue to link; values > 0 provide an
    80  	// issue number to link.
    81  	UnsupportedWithIssue int
    82  
    83  	// Undocumented, when set to true, indicates that the built-in function is
    84  	// hidden from documentation. This is currently used to hide experimental
    85  	// functionality as it is being developed.
    86  	Undocumented bool
    87  
    88  	// Private, when set to true, indicates the built-in function is not
    89  	// available for use by user queries. This is currently used by some
    90  	// aggregates due to issue #10495. Private functions are implicitly
    91  	// considered undocumented.
    92  	Private bool
    93  
    94  	// DistsqlBlocklist is set to true when a function depends on
    95  	// members of the EvalContext that are not marshaled by DistSQL
    96  	// (e.g. planner). Currently used for DistSQL to determine if
    97  	// expressions can be evaluated on a different node without sending
    98  	// over the EvalContext.
    99  	//
   100  	// TODO(andrei): Get rid of the planner from the EvalContext and then we can
   101  	// get rid of this blocklist.
   102  	DistsqlBlocklist bool
   103  
   104  	// Category is used to generate documentation strings.
   105  	Category string
   106  
   107  	// AvailableOnPublicSchema indicates whether the function can be resolved
   108  	// if it is found on the public schema.
   109  	AvailableOnPublicSchema bool
   110  
   111  	// ReturnLabels can be used to override the return column name of a
   112  	// function in a FROM clause.
   113  	// This satisfies a Postgres quirk where some json functions have
   114  	// different return labels when used in SELECT or FROM clause.
   115  	ReturnLabels []string
   116  
   117  	// AmbiguousReturnType is true if the builtin's return type can't be
   118  	// determined without extra context. This is used for formatting builtins
   119  	// with the FmtParsable directive.
   120  	AmbiguousReturnType bool
   121  
   122  	// HasSequenceArguments is true if the builtin function takes in a sequence
   123  	// name (string) and can be used in a scalar expression.
   124  	// TODO(richardjcai): When implicit casting is supported, these builtins
   125  	// should take RegClass as the arg type for the sequence name instead of
   126  	// string, we will add a dependency on all RegClass types used in a view.
   127  	HasSequenceArguments bool
   128  
   129  	// CompositeInsensitive indicates that this function returns equal results
   130  	// when evaluated on equal inputs. This is a non-trivial property for
   131  	// composite types which can be equal but not identical
   132  	// (e.g. decimals 1.0 and 1.00). For example, converting a decimal to string
   133  	// is not CompositeInsensitive.
   134  	//
   135  	// See memo.CanBeCompositeSensitive.
   136  	CompositeInsensitive bool
   137  
   138  	// VectorizeStreaming indicates that the function is of "streaming" nature
   139  	// from the perspective of the vectorized execution engine.
   140  	VectorizeStreaming bool
   141  
   142  	// ReturnsRecordType indicates that this function is a record-returning
   143  	// function, which implies that it's unusable without a corresponding type
   144  	// alias.
   145  	//
   146  	// For example, consider the case of json_to_record('{"a":"b", "c":"d"}').
   147  	// This function returns an error unless it as an `AS t(a,b,c)` declaration,
   148  	// since its definition is to pick out the JSON attributes within the input
   149  	// that match, by name, to the columns in the aliased record type.
   150  	ReturnsRecordType bool
   151  }
   152  
   153  // ShouldDocument returns whether the built-in function should be included in
   154  // external-facing documentation.
   155  func (fp *FunctionProperties) ShouldDocument() bool {
   156  	return !(fp.Undocumented || fp.Private)
   157  }
   158  
   159  // NewFunctionDefinition allocates a function definition corresponding
   160  // to the given built-in definition.
   161  func NewFunctionDefinition(
   162  	name string, props *FunctionProperties, def []Overload,
   163  ) *FunctionDefinition {
   164  	overloads := make([]*Overload, len(def))
   165  
   166  	for i := range def {
   167  		if def[i].PreferredOverload {
   168  			// Builtins with a preferred overload are always ambiguous.
   169  			props.AmbiguousReturnType = true
   170  			break
   171  		}
   172  	}
   173  
   174  	for i := range def {
   175  		def[i].FunctionProperties = *props
   176  		overloads[i] = &def[i]
   177  	}
   178  	return &FunctionDefinition{
   179  		Name:               name,
   180  		Definition:         overloads,
   181  		FunctionProperties: *props,
   182  	}
   183  }
   184  
   185  // FunDefs holds pre-allocated FunctionDefinition instances
   186  // for every builtin function. Initialized by builtins.init().
   187  //
   188  // Note that this is extremely similar to the set stored in builtinsregistry.
   189  // The hope is to remove this map at some point in the future as we delegate
   190  // function definition resolution to interfaces defined in the SemaContext.
   191  var FunDefs map[string]*FunctionDefinition
   192  
   193  // ResolvedBuiltinFuncDefs holds pre-allocated ResolvedFunctionDefinition
   194  // instances. Keys of the map is schema qualified function names.
   195  var ResolvedBuiltinFuncDefs map[string]*ResolvedFunctionDefinition
   196  
   197  // OidToBuiltinName contains a map from the hashed OID of all builtin functions
   198  // to their name.
   199  var OidToBuiltinName map[oid.Oid]string
   200  
   201  // OidToQualifiedBuiltinOverload is a map from builtin function OID to an
   202  // qualified overload.
   203  var OidToQualifiedBuiltinOverload map[oid.Oid]QualifiedOverload
   204  
   205  // Format implements the NodeFormatter interface.
   206  // FunctionDefinitions should always be builtin functions, so we do not need to
   207  // anonymize them, even if the flag is set.
   208  func (fd *FunctionDefinition) Format(ctx *FmtCtx) {
   209  	ctx.WriteString(fd.Name)
   210  }
   211  
   212  // String implements the Stringer interface.
   213  func (fd *FunctionDefinition) String() string { return AsString(fd) }
   214  
   215  // Format implements the NodeFormatter interface.
   216  // ResolvedFunctionDefinitions should always be builtin functions, so we do not
   217  // need to anonymize them, even if the flag is set.
   218  func (fd *ResolvedFunctionDefinition) Format(ctx *FmtCtx) {
   219  	// This is necessary when deserializing function expressions for SHOW CREATE
   220  	// statements. When deserializing a function expression with function OID
   221  	// references, it's guaranteed that there'll be always one overload resolved.
   222  	// There is no need to show prefix for builtin functions since we don't
   223  	// serialize them.
   224  	if len(fd.Overloads) == 1 && catid.IsOIDUserDefined(fd.Overloads[0].Oid) {
   225  		ctx.WriteString(fd.Overloads[0].Schema)
   226  		ctx.WriteString(".")
   227  	}
   228  	ctx.WriteString(fd.Name)
   229  }
   230  
   231  // String implements the Stringer interface.
   232  func (fd *ResolvedFunctionDefinition) String() string { return AsString(fd) }
   233  
   234  // MergeWith is used to merge two UDF definitions with same name.
   235  func (fd *ResolvedFunctionDefinition) MergeWith(
   236  	another *ResolvedFunctionDefinition,
   237  ) (*ResolvedFunctionDefinition, error) {
   238  	if fd == nil {
   239  		return another, nil
   240  	}
   241  	if another == nil {
   242  		return fd, nil
   243  	}
   244  
   245  	if fd.Name != another.Name {
   246  		return nil, errors.Newf("cannot merge function definition of %q with %q", fd.Name, another.Name)
   247  	}
   248  
   249  	return &ResolvedFunctionDefinition{
   250  		Name:      fd.Name,
   251  		Overloads: combineOverloads(fd.Overloads, another.Overloads),
   252  	}, nil
   253  }
   254  
   255  // MatchOverload searches an overload which has exactly the same parameter
   256  // types. The overload from the most significant schema is returned. If
   257  // paramTypes==nil, an error is returned if the function name is not unique in
   258  // the most significant schema. If paramTypes is not nil, an error with
   259  // ErrRoutineUndefined cause is returned if not matched found. Overloads that
   260  // don't match the types in routineType are ignored.
   261  func (fd *ResolvedFunctionDefinition) MatchOverload(
   262  	paramTypes []*types.T, explicitSchema string, searchPath SearchPath, routineType RoutineType,
   263  ) (QualifiedOverload, error) {
   264  	matched := func(ol QualifiedOverload, schema string) bool {
   265  		if ol.Type == UDFRoutine || ol.Type == ProcedureRoutine {
   266  			return schema == ol.Schema && (paramTypes == nil || ol.params().MatchIdentical(paramTypes))
   267  		}
   268  		return schema == ol.Schema && (paramTypes == nil || ol.params().Match(paramTypes))
   269  	}
   270  	typeNames := func() string {
   271  		ns := make([]string, len(paramTypes))
   272  		for i, t := range paramTypes {
   273  			ns[i] = t.Name()
   274  		}
   275  		return strings.Join(ns, ",")
   276  	}
   277  
   278  	found := false
   279  	ret := make([]QualifiedOverload, 0, len(fd.Overloads))
   280  
   281  	findMatches := func(schema string) {
   282  		for i := range fd.Overloads {
   283  			if matched(fd.Overloads[i], schema) {
   284  				found = true
   285  				ret = append(ret, fd.Overloads[i])
   286  			}
   287  		}
   288  	}
   289  
   290  	if explicitSchema != "" {
   291  		findMatches(explicitSchema)
   292  	} else {
   293  		for i, n := 0, searchPath.NumElements(); i < n; i++ {
   294  			if findMatches(searchPath.GetSchema(i)); found {
   295  				break
   296  			}
   297  		}
   298  	}
   299  
   300  	if len(ret) == 1 && ret[0].Type&routineType == 0 {
   301  		if routineType == ProcedureRoutine {
   302  			return QualifiedOverload{}, pgerror.Newf(
   303  				pgcode.WrongObjectType, "%s(%s) is not a procedure", fd.Name, typeNames())
   304  		} else {
   305  			return QualifiedOverload{}, pgerror.Newf(
   306  				pgcode.WrongObjectType, "%s(%s) is not a function", fd.Name, typeNames())
   307  		}
   308  	}
   309  
   310  	// Filter out overloads that don't match the requested type.
   311  	i := 0
   312  	for _, o := range ret {
   313  		if ret[i].Type&routineType != 0 {
   314  			ret[i] = o
   315  			i++
   316  		}
   317  	}
   318  	// Clear non-matching overloads.
   319  	for j := i; j < len(ret); j++ {
   320  		ret[j] = QualifiedOverload{}
   321  	}
   322  	// Truncate the slice.
   323  	ret = ret[:i]
   324  
   325  	if len(ret) == 0 {
   326  		if routineType == ProcedureRoutine {
   327  			return QualifiedOverload{}, errors.Mark(
   328  				pgerror.Newf(pgcode.UndefinedFunction, "procedure %s(%s) does not exist", fd.Name, typeNames()),
   329  				ErrRoutineUndefined,
   330  			)
   331  		} else {
   332  			return QualifiedOverload{}, errors.Mark(
   333  				pgerror.Newf(pgcode.UndefinedFunction, "function %s(%s) does not exist", fd.Name, typeNames()),
   334  				ErrRoutineUndefined,
   335  			)
   336  		}
   337  	}
   338  
   339  	if len(ret) > 1 {
   340  		if routineType == ProcedureRoutine {
   341  			return QualifiedOverload{}, pgerror.Newf(pgcode.AmbiguousFunction, "procedure name %q is not unique", fd.Name)
   342  		} else {
   343  			return QualifiedOverload{}, pgerror.Newf(pgcode.AmbiguousFunction, "function name %q is not unique", fd.Name)
   344  		}
   345  	}
   346  	return ret[0], nil
   347  }
   348  
   349  func combineOverloads(a, b []QualifiedOverload) []QualifiedOverload {
   350  	return append(append(make([]QualifiedOverload, 0, len(a)+len(b)), a...), b...)
   351  }
   352  
   353  // GetClass returns function class by checking each overload's Class and returns
   354  // the homogeneous Class value if all overloads are the same Class. Ambiguous
   355  // error is returned if there is any overload with different Class.
   356  //
   357  // TODO(chengxiong,mgartner): make sure that, at places of the use cases of this
   358  // method, function is resolved to one overload, so that we can get rid of this
   359  // function and similar methods below.
   360  func (fd *ResolvedFunctionDefinition) GetClass() (FunctionClass, error) {
   361  	if len(fd.Overloads) < 1 {
   362  		return 0, errors.AssertionFailedf("no overloads found for function %s", fd.Name)
   363  	}
   364  	ret := fd.Overloads[0].Class
   365  	for i := range fd.Overloads {
   366  		if fd.Overloads[i].Class != ret {
   367  			return 0, pgerror.Newf(pgcode.AmbiguousFunction, "ambiguous function class on %s", fd.Name)
   368  		}
   369  	}
   370  	return ret, nil
   371  }
   372  
   373  // GetReturnLabel returns function ReturnLabel by checking each overload and
   374  // returns a ReturnLabel if all overloads have a ReturnLabel of the same length.
   375  // Ambiguous error is returned if there is any overload has ReturnLabel of a
   376  // different length. This is good enough since we don't create UDF with
   377  // ReturnLabel.
   378  func (fd *ResolvedFunctionDefinition) GetReturnLabel() ([]string, error) {
   379  	if len(fd.Overloads) < 1 {
   380  		return nil, errors.AssertionFailedf("no overloads found for function %s", fd.Name)
   381  	}
   382  	ret := fd.Overloads[0].ReturnLabels
   383  	for i := range fd.Overloads {
   384  		if len(ret) != len(fd.Overloads[i].ReturnLabels) {
   385  			return nil, pgerror.Newf(pgcode.AmbiguousFunction, "ambiguous function return label on %s", fd.Name)
   386  		}
   387  	}
   388  	return ret, nil
   389  }
   390  
   391  // GetHasSequenceArguments returns function's HasSequenceArguments flag by
   392  // checking each overload's HasSequenceArguments flag. Ambiguous error is
   393  // returned if there is any overload has a different flag.
   394  func (fd *ResolvedFunctionDefinition) GetHasSequenceArguments() (bool, error) {
   395  	if len(fd.Overloads) < 1 {
   396  		return false, errors.AssertionFailedf("no overloads found for function %s", fd.Name)
   397  	}
   398  	ret := fd.Overloads[0].HasSequenceArguments
   399  	for i := range fd.Overloads {
   400  		if ret != fd.Overloads[i].HasSequenceArguments {
   401  			return false, pgerror.Newf(pgcode.AmbiguousFunction, "ambiguous function sequence argument on %s", fd.Name)
   402  		}
   403  	}
   404  	return ret, nil
   405  }
   406  
   407  // QualifyBuiltinFunctionDefinition qualified all overloads in a function
   408  // definition with a schema name. Note that this function can only be used for
   409  // builtin function.
   410  func QualifyBuiltinFunctionDefinition(
   411  	def *FunctionDefinition, schema string,
   412  ) *ResolvedFunctionDefinition {
   413  	ret := &ResolvedFunctionDefinition{
   414  		Name:      def.Name,
   415  		Overloads: make([]QualifiedOverload, 0, len(def.Definition)),
   416  	}
   417  	for _, o := range def.Definition {
   418  		ret.Overloads = append(
   419  			ret.Overloads,
   420  			MakeQualifiedOverload(schema, o),
   421  		)
   422  	}
   423  	return ret
   424  }
   425  
   426  // GetBuiltinFuncDefinitionOrFail is similar to GetBuiltinFuncDefinition but
   427  // returns an error if function is not found.
   428  func GetBuiltinFuncDefinitionOrFail(
   429  	fName RoutineName, searchPath SearchPath,
   430  ) (*ResolvedFunctionDefinition, error) {
   431  	def, err := GetBuiltinFuncDefinition(fName, searchPath)
   432  	if err != nil {
   433  		return nil, err
   434  	}
   435  	if def == nil {
   436  		forError := fName // prevent fName from escaping
   437  		return nil, errors.Mark(
   438  			pgerror.Newf(pgcode.UndefinedFunction, "unknown function: %s()", ErrString(&forError)),
   439  			ErrRoutineUndefined,
   440  		)
   441  	}
   442  	return def, nil
   443  }
   444  
   445  // GetBuiltinFunctionByOIDOrFail retrieves a builtin function by OID.
   446  func GetBuiltinFunctionByOIDOrFail(oid oid.Oid) (*ResolvedFunctionDefinition, error) {
   447  	ol, ok := OidToQualifiedBuiltinOverload[oid]
   448  	if !ok {
   449  		return nil, errors.Mark(
   450  			pgerror.Newf(pgcode.UndefinedFunction, "function %d not found", oid),
   451  			ErrRoutineUndefined,
   452  		)
   453  	}
   454  	fd := &ResolvedFunctionDefinition{
   455  		Name:      OidToBuiltinName[oid],
   456  		Overloads: []QualifiedOverload{ol},
   457  	}
   458  	return fd, nil
   459  }
   460  
   461  // GetBuiltinFuncDefinition search for a builtin function given a function name
   462  // and a search path. If function name is prefixed, only the builtin functions
   463  // in the specific schema are searched. Otherwise, all schemas on the given
   464  // searchPath are searched. A nil is returned if no function is found. It's
   465  // caller's choice to error out if function not found.
   466  //
   467  // In theory, this function returns an error only when the search path iterator
   468  // errors which won't happen since the iterating function never errors out. But
   469  // error is still checked and return from the function signature just in case
   470  // we change the iterating function in the future.
   471  func GetBuiltinFuncDefinition(
   472  	fName RoutineName, searchPath SearchPath,
   473  ) (*ResolvedFunctionDefinition, error) {
   474  	if fName.ExplicitSchema {
   475  		return ResolvedBuiltinFuncDefs[fName.Schema()+"."+fName.Object()], nil
   476  	}
   477  
   478  	// First try that if we can get function directly with the function name.
   479  	// There is a case where the part[0] of the name is a qualified string when
   480  	// the qualified name is double quoted as a single name like "schema.fn".
   481  	if def, ok := ResolvedBuiltinFuncDefs[fName.Object()]; ok {
   482  		return def, nil
   483  	}
   484  
   485  	// Then try if it's in pg_catalog.
   486  	if def, ok := ResolvedBuiltinFuncDefs[catconstants.PgCatalogName+"."+fName.Object()]; ok {
   487  		return def, nil
   488  	}
   489  
   490  	// If not in pg_catalog, go through search path.
   491  	var resolvedDef *ResolvedFunctionDefinition
   492  	for i, n := 0, searchPath.NumElements(); i < n; i++ {
   493  		schema := searchPath.GetSchema(i)
   494  		fullName := schema + "." + fName.Object()
   495  		if def, ok := ResolvedBuiltinFuncDefs[fullName]; ok {
   496  			resolvedDef = def
   497  			break
   498  		}
   499  	}
   500  
   501  	return resolvedDef, nil
   502  }