github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/opt/memo/typing.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package memo
    12  
    13  import (
    14  	"github.com/cockroachdb/cockroach/pkg/sql/opt"
    15  	"github.com/cockroachdb/cockroach/pkg/sql/sem/builtins"
    16  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    17  	"github.com/cockroachdb/cockroach/pkg/sql/types"
    18  	"github.com/cockroachdb/cockroach/pkg/util/log"
    19  	"github.com/cockroachdb/errors"
    20  )
    21  
    22  // InferType derives the type of the given scalar expression and stores it in
    23  // the expression's Type field. Depending upon the operator, the type may be
    24  // fixed, or it may be dependent upon the expression children.
    25  func InferType(mem *Memo, e opt.ScalarExpr) *types.T {
    26  	// Special-case Variable, since it's the only expression that needs the memo.
    27  	if e.Op() == opt.VariableOp {
    28  		return typeVariable(mem, e)
    29  	}
    30  
    31  	fn := typingFuncMap[e.Op()]
    32  	if fn == nil {
    33  		panic(errors.AssertionFailedf("type inference for %v is not yet implemented", log.Safe(e.Op())))
    34  	}
    35  	return fn(e)
    36  }
    37  
    38  // InferUnaryType infers the return type of a unary operator, given the type of
    39  // its input.
    40  func InferUnaryType(op opt.Operator, inputType *types.T) *types.T {
    41  	unaryOp := opt.UnaryOpReverseMap[op]
    42  
    43  	// Find the unary op that matches the type of the expression's child.
    44  	for _, op := range tree.UnaryOps[unaryOp] {
    45  		o := op.(*tree.UnaryOp)
    46  		if inputType.Equivalent(o.Typ) {
    47  			return o.ReturnType
    48  		}
    49  	}
    50  	panic(errors.AssertionFailedf("could not find type for unary expression %s", log.Safe(op)))
    51  }
    52  
    53  // InferBinaryType infers the return type of a binary expression, given the type
    54  // of its inputs.
    55  func InferBinaryType(op opt.Operator, leftType, rightType *types.T) *types.T {
    56  	o, ok := FindBinaryOverload(op, leftType, rightType)
    57  	if !ok {
    58  		panic(errors.AssertionFailedf("could not find type for binary expression %s", log.Safe(op)))
    59  	}
    60  	return o.ReturnType
    61  }
    62  
    63  // InferWhensType returns the type of a CASE expression, which is
    64  // of the form:
    65  //   CASE [ <cond> ]
    66  //       WHEN <condval1> THEN <expr1>
    67  //     [ WHEN <condval2> THEN <expr2> ] ...
    68  //     [ ELSE <expr> ]
    69  //   END
    70  // The type is equal to the type of the WHEN <condval> THEN <expr> clauses, or
    71  // the type of the ELSE <expr> value if all the previous types are unknown.
    72  func InferWhensType(whens ScalarListExpr, orElse opt.ScalarExpr) *types.T {
    73  	for _, when := range whens {
    74  		childType := when.DataType()
    75  		if childType.Family() != types.UnknownFamily {
    76  			return childType
    77  		}
    78  	}
    79  	return orElse.DataType()
    80  }
    81  
    82  // BinaryOverloadExists returns true if the given binary operator exists with the
    83  // given arguments.
    84  func BinaryOverloadExists(op opt.Operator, leftType, rightType *types.T) bool {
    85  	_, ok := FindBinaryOverload(op, leftType, rightType)
    86  	return ok
    87  }
    88  
    89  // BinaryAllowsNullArgs returns true if the given binary operator allows null
    90  // arguments, and cannot therefore be folded away to null.
    91  func BinaryAllowsNullArgs(op opt.Operator, leftType, rightType *types.T) bool {
    92  	o, ok := FindBinaryOverload(op, leftType, rightType)
    93  	if !ok {
    94  		panic(errors.AssertionFailedf("could not find overload for binary expression %s", log.Safe(op)))
    95  	}
    96  	return o.NullableArgs
    97  }
    98  
    99  // AggregateOverloadExists returns whether or not the given operator has a
   100  // unary overload which takes the given type as input.
   101  func AggregateOverloadExists(agg opt.Operator, typ *types.T) bool {
   102  	name := opt.AggregateOpReverseMap[agg]
   103  	_, overloads := builtins.GetBuiltinProperties(name)
   104  	for _, o := range overloads {
   105  		if o.Types.MatchAt(typ, 0) {
   106  			return true
   107  		}
   108  	}
   109  	return false
   110  }
   111  
   112  func findOverload(e opt.ScalarExpr, name string) (overload *tree.Overload, ok bool) {
   113  	_, overloads := builtins.GetBuiltinProperties(name)
   114  	for o := range overloads {
   115  		overload = &overloads[o]
   116  		matches := true
   117  		for i, n := 0, e.ChildCount(); i < n; i++ {
   118  			typ := e.Child(i).(opt.ScalarExpr).DataType()
   119  			if !overload.Types.MatchAt(typ, i) {
   120  				matches = false
   121  				break
   122  			}
   123  		}
   124  		if matches {
   125  			return overload, true
   126  		}
   127  	}
   128  	return nil, false
   129  }
   130  
   131  // FindWindowOverload finds a window function overload that matches the
   132  // given window function expression. It panics if no match can be found.
   133  func FindWindowOverload(e opt.ScalarExpr) (name string, overload *tree.Overload) {
   134  	name = opt.WindowOpReverseMap[e.Op()]
   135  	overload, ok := findOverload(e, name)
   136  	if ok {
   137  		return name, overload
   138  	}
   139  	// NB: all aggregate functions can be used as window functions.
   140  	return FindAggregateOverload(e)
   141  }
   142  
   143  // FindAggregateOverload finds an aggregate function overload that matches the
   144  // given aggregate function expression. It panics if no match can be found.
   145  func FindAggregateOverload(e opt.ScalarExpr) (name string, overload *tree.Overload) {
   146  	name = opt.AggregateOpReverseMap[e.Op()]
   147  	overload, ok := findOverload(e, name)
   148  	if ok {
   149  		return name, overload
   150  	}
   151  	panic(errors.AssertionFailedf("could not find overload for %s aggregate", name))
   152  }
   153  
   154  type typingFunc func(e opt.ScalarExpr) *types.T
   155  
   156  // typingFuncMap is a lookup table from scalar operator type to a function
   157  // which returns the data type of an instance of that operator.
   158  var typingFuncMap map[opt.Operator]typingFunc
   159  
   160  func init() {
   161  	typingFuncMap = make(map[opt.Operator]typingFunc)
   162  	typingFuncMap[opt.PlaceholderOp] = typeAsTypedExpr
   163  	typingFuncMap[opt.UnsupportedExprOp] = typeAsTypedExpr
   164  	typingFuncMap[opt.CoalesceOp] = typeCoalesce
   165  	typingFuncMap[opt.CaseOp] = typeCase
   166  	typingFuncMap[opt.WhenOp] = typeWhen
   167  	typingFuncMap[opt.CastOp] = typeCast
   168  	typingFuncMap[opt.SubqueryOp] = typeSubquery
   169  	typingFuncMap[opt.ColumnAccessOp] = typeColumnAccess
   170  	typingFuncMap[opt.IndirectionOp] = typeIndirection
   171  	typingFuncMap[opt.CollateOp] = typeCollate
   172  	typingFuncMap[opt.ArrayFlattenOp] = typeArrayFlatten
   173  	typingFuncMap[opt.IfErrOp] = typeIfErr
   174  
   175  	// Override default typeAsAggregate behavior for aggregate functions with
   176  	// a large number of possible overloads or where ReturnType depends on
   177  	// argument types.
   178  	typingFuncMap[opt.ArrayAggOp] = typeArrayAgg
   179  	typingFuncMap[opt.MaxOp] = typeAsFirstArg
   180  	typingFuncMap[opt.MinOp] = typeAsFirstArg
   181  	typingFuncMap[opt.ConstAggOp] = typeAsFirstArg
   182  	typingFuncMap[opt.ConstNotNullAggOp] = typeAsFirstArg
   183  	typingFuncMap[opt.AnyNotNullAggOp] = typeAsFirstArg
   184  	typingFuncMap[opt.FirstAggOp] = typeAsFirstArg
   185  
   186  	typingFuncMap[opt.LagOp] = typeAsFirstArg
   187  	typingFuncMap[opt.LeadOp] = typeAsFirstArg
   188  	typingFuncMap[opt.NthValueOp] = typeAsFirstArg
   189  
   190  	// Modifiers for aggregations pass through their argument.
   191  	typingFuncMap[opt.AggDistinctOp] = typeAsFirstArg
   192  	typingFuncMap[opt.AggFilterOp] = typeAsFirstArg
   193  	typingFuncMap[opt.WindowFromOffsetOp] = typeAsFirstArg
   194  	typingFuncMap[opt.WindowToOffsetOp] = typeAsFirstArg
   195  
   196  	for _, op := range opt.BinaryOperators {
   197  		typingFuncMap[op] = typeAsBinary
   198  	}
   199  
   200  	for _, op := range opt.UnaryOperators {
   201  		typingFuncMap[op] = typeAsUnary
   202  	}
   203  
   204  	for _, op := range opt.AggregateOperators {
   205  		// Fill in any that are not already added to the typingFuncMap above.
   206  		if typingFuncMap[op] == nil {
   207  			typingFuncMap[op] = typeAsAggregate
   208  		}
   209  	}
   210  
   211  	for _, op := range opt.WindowOperators {
   212  		if typingFuncMap[op] == nil {
   213  			typingFuncMap[op] = typeAsWindow
   214  		}
   215  	}
   216  }
   217  
   218  // typeVariable returns the type of a variable expression, which is stored in
   219  // the query metadata and accessed by column id.
   220  func typeVariable(mem *Memo, e opt.ScalarExpr) *types.T {
   221  	variable := e.(*VariableExpr)
   222  	typ := mem.Metadata().ColumnMeta(variable.Col).Type
   223  	if typ == nil {
   224  		panic(errors.AssertionFailedf("column %d does not have type", log.Safe(variable.Col)))
   225  	}
   226  	return typ
   227  }
   228  
   229  // typeArrayAgg returns an array type with element type equal to the type of the
   230  // aggregate expression's first (and only) argument.
   231  func typeArrayAgg(e opt.ScalarExpr) *types.T {
   232  	arrayAgg := e.(*ArrayAggExpr)
   233  	typ := arrayAgg.Input.DataType()
   234  	return types.MakeArray(typ)
   235  }
   236  
   237  // typeIndirection returns the type of the element of the array.
   238  func typeIndirection(e opt.ScalarExpr) *types.T {
   239  	return e.Child(0).(opt.ScalarExpr).DataType().ArrayContents()
   240  }
   241  
   242  // typeCollate returns the collated string typed with the given locale.
   243  func typeCollate(e opt.ScalarExpr) *types.T {
   244  	locale := e.(*CollateExpr).Locale
   245  	return types.MakeCollatedString(types.String, locale)
   246  }
   247  
   248  // typeArrayFlatten returns the type of the subquery as an array.
   249  func typeArrayFlatten(e opt.ScalarExpr) *types.T {
   250  	input := e.Child(0).(RelExpr)
   251  	colID := e.(*ArrayFlattenExpr).RequestedCol
   252  	return types.MakeArray(input.Memo().Metadata().ColumnMeta(colID).Type)
   253  }
   254  
   255  // typeIfErr returns the type of the IfErrExpr. The type is boolean if
   256  // there is no OrElse, and the type of Cond/OrElse otherwise.
   257  func typeIfErr(e opt.ScalarExpr) *types.T {
   258  	if e.(*IfErrExpr).OrElse.ChildCount() == 0 {
   259  		return types.Bool
   260  	}
   261  	return e.(*IfErrExpr).Cond.DataType()
   262  }
   263  
   264  // typeAsFirstArg returns the type of the expression's 0th argument.
   265  func typeAsFirstArg(e opt.ScalarExpr) *types.T {
   266  	return e.Child(0).(opt.ScalarExpr).DataType()
   267  }
   268  
   269  // typeAsTypedExpr returns the resolved type of the private field, with the
   270  // assumption that it is a tree.TypedExpr.
   271  func typeAsTypedExpr(e opt.ScalarExpr) *types.T {
   272  	return e.Private().(tree.TypedExpr).ResolvedType()
   273  }
   274  
   275  // typeAsUnary returns the type of a unary expression by hooking into the sql
   276  // semantics code that searches for unary operator overloads.
   277  func typeAsUnary(e opt.ScalarExpr) *types.T {
   278  	return InferUnaryType(e.Op(), e.Child(0).(opt.ScalarExpr).DataType())
   279  }
   280  
   281  // typeAsBinary returns the type of a binary expression by hooking into the sql
   282  // semantics code that searches for binary operator overloads.
   283  func typeAsBinary(e opt.ScalarExpr) *types.T {
   284  	leftType := e.Child(0).(opt.ScalarExpr).DataType()
   285  	rightType := e.Child(1).(opt.ScalarExpr).DataType()
   286  	return InferBinaryType(e.Op(), leftType, rightType)
   287  }
   288  
   289  // typeAsAggregate returns the type of an aggregate expression by hooking into
   290  // the sql semantics code that searches for aggregate operator overloads.
   291  func typeAsAggregate(e opt.ScalarExpr) *types.T {
   292  	// Only handle cases where the return type is not dependent on argument
   293  	// types (i.e. pass nil to the ReturnTyper). Aggregates with return types
   294  	// that depend on argument types are handled separately.
   295  	_, overload := FindAggregateOverload(e)
   296  	t := overload.ReturnType(nil)
   297  	if t == tree.UnknownReturnType {
   298  		panic(errors.AssertionFailedf("unknown aggregate return type. e:\n%s", e))
   299  	}
   300  	return t
   301  }
   302  
   303  // typeAsWindow returns the type of a window function expression similar to
   304  // typeAsAggregate.
   305  func typeAsWindow(e opt.ScalarExpr) *types.T {
   306  	_, overload := FindWindowOverload(e)
   307  	t := overload.ReturnType(nil)
   308  	if t == tree.UnknownReturnType {
   309  		panic(errors.AssertionFailedf("unknown window return type. e:\n%s", e))
   310  	}
   311  
   312  	return t
   313  }
   314  
   315  // typeCoalesce returns the type of a coalesce expression, which is equal to
   316  // the type of its first non-null child.
   317  func typeCoalesce(e opt.ScalarExpr) *types.T {
   318  	for _, arg := range e.(*CoalesceExpr).Args {
   319  		childType := arg.DataType()
   320  		if childType.Family() != types.UnknownFamily {
   321  			return childType
   322  		}
   323  	}
   324  	return types.Unknown
   325  }
   326  
   327  // typeCase returns the type of a CASE expression, which is
   328  // of the form:
   329  //   CASE [ <cond> ]
   330  //       WHEN <condval1> THEN <expr1>
   331  //     [ WHEN <condval2> THEN <expr2> ] ...
   332  //     [ ELSE <expr> ]
   333  //   END
   334  // The type is equal to the type of the WHEN <condval> THEN <expr> clauses, or
   335  // the type of the ELSE <expr> value if all the previous types are unknown.
   336  func typeCase(e opt.ScalarExpr) *types.T {
   337  	caseExpr := e.(*CaseExpr)
   338  	return InferWhensType(caseExpr.Whens, caseExpr.OrElse)
   339  }
   340  
   341  // typeWhen returns the type of a WHEN <condval> THEN <expr> clause inside a
   342  // CASE statement.
   343  func typeWhen(e opt.ScalarExpr) *types.T {
   344  	return e.(*WhenExpr).Value.DataType()
   345  }
   346  
   347  // typeCast returns the type of a CAST operator.
   348  func typeCast(e opt.ScalarExpr) *types.T {
   349  	return e.(*CastExpr).Typ
   350  }
   351  
   352  // typeSubquery returns the type of a subquery, which is equal to the type of
   353  // its first (and only) column.
   354  func typeSubquery(e opt.ScalarExpr) *types.T {
   355  	input := e.Child(0).(RelExpr)
   356  	colID := input.Relational().OutputCols.SingleColumn()
   357  	return input.Memo().Metadata().ColumnMeta(colID).Type
   358  }
   359  
   360  func typeColumnAccess(e opt.ScalarExpr) *types.T {
   361  	colAccess := e.(*ColumnAccessExpr)
   362  	typ := colAccess.Input.DataType()
   363  	return typ.TupleContents()[colAccess.Idx]
   364  }
   365  
   366  // FindBinaryOverload finds the correct type signature overload for the
   367  // specified binary operator, given the types of its inputs. If an overload is
   368  // found, FindBinaryOverload returns true, plus a pointer to the overload.
   369  // If an overload is not found, FindBinaryOverload returns false.
   370  func FindBinaryOverload(op opt.Operator, leftType, rightType *types.T) (_ *tree.BinOp, ok bool) {
   371  	bin := opt.BinaryOpReverseMap[op]
   372  
   373  	// Find the binary op that matches the type of the expression's left and
   374  	// right children. No more than one match should ever be found. The
   375  	// TestTypingBinaryAssumptions test ensures this will be the case even if
   376  	// new operators or overloads are added.
   377  	for _, binOverloads := range tree.BinOps[bin] {
   378  		o := binOverloads.(*tree.BinOp)
   379  
   380  		if leftType.Family() == types.UnknownFamily {
   381  			if rightType.Equivalent(o.RightType) {
   382  				return o, true
   383  			}
   384  		} else if rightType.Family() == types.UnknownFamily {
   385  			if leftType.Equivalent(o.LeftType) {
   386  				return o, true
   387  			}
   388  		} else {
   389  			if leftType.Equivalent(o.LeftType) && rightType.Equivalent(o.RightType) {
   390  				return o, true
   391  			}
   392  		}
   393  	}
   394  	return nil, false
   395  }
   396  
   397  // FindUnaryOverload finds the correct type signature overload for the
   398  // specified unary operator, given the type of its input. If an overload is
   399  // found, FindUnaryOverload returns true, plus a pointer to the overload.
   400  // If an overload is not found, FindUnaryOverload returns false.
   401  func FindUnaryOverload(op opt.Operator, typ *types.T) (_ *tree.UnaryOp, ok bool) {
   402  	unary := opt.UnaryOpReverseMap[op]
   403  
   404  	for _, unaryOverloads := range tree.UnaryOps[unary] {
   405  		o := unaryOverloads.(*tree.UnaryOp)
   406  		if o.Typ.Equivalent(typ) {
   407  			return o, true
   408  		}
   409  	}
   410  	return nil, false
   411  }
   412  
   413  // FindComparisonOverload finds the correct type signature overload for the
   414  // specified comparison operator, given the types of its inputs. If an overload
   415  // is found, FindComparisonOverload returns a pointer to the overload and
   416  // ok=true. It also returns "flipped" and "not" flags. The "flipped" flag
   417  // indicates whether the original left and right operands should be flipped
   418  // with the returned overload. The "not" flag indicates whether the result of
   419  // the comparison operation should be negated. If an overload is not found,
   420  // FindComparisonOverload returns ok=false.
   421  func FindComparisonOverload(
   422  	op opt.Operator, leftType, rightType *types.T,
   423  ) (_ *tree.CmpOp, flipped, not, ok bool) {
   424  	op, flipped, not = NormalizeComparison(op)
   425  	comp := opt.ComparisonOpReverseMap[op]
   426  
   427  	if flipped {
   428  		leftType, rightType = rightType, leftType
   429  	}
   430  
   431  	// Find the comparison op that matches the type of the expression's left and
   432  	// right children. No more than one match should ever be found. The
   433  	// TestTypingComparisonAssumptions test ensures this will be the case even if
   434  	// new operators or overloads are added.
   435  	for _, cmpOverloads := range tree.CmpOps[comp] {
   436  		o := cmpOverloads.(*tree.CmpOp)
   437  
   438  		if leftType.Family() == types.UnknownFamily {
   439  			if rightType.Equivalent(o.RightType) {
   440  				return o, flipped, not, true
   441  			}
   442  		} else if rightType.Family() == types.UnknownFamily {
   443  			if leftType.Equivalent(o.LeftType) {
   444  				return o, flipped, not, true
   445  			}
   446  		} else {
   447  			if leftType.Equivalent(o.LeftType) && rightType.Equivalent(o.RightType) {
   448  				return o, flipped, not, true
   449  			}
   450  		}
   451  	}
   452  	return nil, false, false, false
   453  }
   454  
   455  // NormalizeComparison maps a given comparison operator into an equivalent
   456  // operator that exists in the tree.CmpOps map, returning this new operator,
   457  // along with "flipped" and "not" flags. The "flipped" flag indicates whether
   458  // the left and right operands should be flipped with the new operator. The
   459  // "not" flag indicates whether the result of the comparison operation should
   460  // be negated.
   461  func NormalizeComparison(op opt.Operator) (newOp opt.Operator, flipped, not bool) {
   462  	switch op {
   463  	case opt.NeOp:
   464  		// Ne(left, right) is implemented as !Eq(left, right).
   465  		return opt.EqOp, false, true
   466  	case opt.GtOp:
   467  		// Gt(left, right) is implemented as Lt(right, left)
   468  		return opt.LtOp, true, false
   469  	case opt.GeOp:
   470  		// Ge(left, right) is implemented as Le(right, left)
   471  		return opt.LeOp, true, false
   472  	case opt.NotInOp:
   473  		// NotIn(left, right) is implemented as !In(left, right)
   474  		return opt.InOp, false, true
   475  	case opt.NotLikeOp:
   476  		// NotLike(left, right) is implemented as !Like(left, right)
   477  		return opt.LikeOp, false, true
   478  	case opt.NotILikeOp:
   479  		// NotILike(left, right) is implemented as !ILike(left, right)
   480  		return opt.ILikeOp, false, true
   481  	case opt.NotSimilarToOp:
   482  		// NotSimilarTo(left, right) is implemented as !SimilarTo(left, right)
   483  		return opt.SimilarToOp, false, true
   484  	case opt.NotRegMatchOp:
   485  		// NotRegMatch(left, right) is implemented as !RegMatch(left, right)
   486  		return opt.RegMatchOp, false, true
   487  	case opt.NotRegIMatchOp:
   488  		// NotRegIMatch(left, right) is implemented as !RegIMatch(left, right)
   489  		return opt.RegIMatchOp, false, true
   490  	case opt.IsNotOp:
   491  		// IsNot(left, right) is implemented as !Is(left, right)
   492  		return opt.IsOp, false, true
   493  	}
   494  	return op, false, false
   495  }