github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/opt/memo/typing.go (about) 1 // Copyright 2018 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package memo 12 13 import ( 14 "github.com/cockroachdb/cockroach/pkg/sql/opt" 15 "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins" 16 "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" 17 "github.com/cockroachdb/cockroach/pkg/sql/types" 18 "github.com/cockroachdb/cockroach/pkg/util/log" 19 "github.com/cockroachdb/errors" 20 ) 21 22 // InferType derives the type of the given scalar expression and stores it in 23 // the expression's Type field. Depending upon the operator, the type may be 24 // fixed, or it may be dependent upon the expression children. 25 func InferType(mem *Memo, e opt.ScalarExpr) *types.T { 26 // Special-case Variable, since it's the only expression that needs the memo. 27 if e.Op() == opt.VariableOp { 28 return typeVariable(mem, e) 29 } 30 31 fn := typingFuncMap[e.Op()] 32 if fn == nil { 33 panic(errors.AssertionFailedf("type inference for %v is not yet implemented", log.Safe(e.Op()))) 34 } 35 return fn(e) 36 } 37 38 // InferUnaryType infers the return type of a unary operator, given the type of 39 // its input. 40 func InferUnaryType(op opt.Operator, inputType *types.T) *types.T { 41 unaryOp := opt.UnaryOpReverseMap[op] 42 43 // Find the unary op that matches the type of the expression's child. 44 for _, op := range tree.UnaryOps[unaryOp] { 45 o := op.(*tree.UnaryOp) 46 if inputType.Equivalent(o.Typ) { 47 return o.ReturnType 48 } 49 } 50 panic(errors.AssertionFailedf("could not find type for unary expression %s", log.Safe(op))) 51 } 52 53 // InferBinaryType infers the return type of a binary expression, given the type 54 // of its inputs. 55 func InferBinaryType(op opt.Operator, leftType, rightType *types.T) *types.T { 56 o, ok := FindBinaryOverload(op, leftType, rightType) 57 if !ok { 58 panic(errors.AssertionFailedf("could not find type for binary expression %s", log.Safe(op))) 59 } 60 return o.ReturnType 61 } 62 63 // InferWhensType returns the type of a CASE expression, which is 64 // of the form: 65 // CASE [ <cond> ] 66 // WHEN <condval1> THEN <expr1> 67 // [ WHEN <condval2> THEN <expr2> ] ... 68 // [ ELSE <expr> ] 69 // END 70 // The type is equal to the type of the WHEN <condval> THEN <expr> clauses, or 71 // the type of the ELSE <expr> value if all the previous types are unknown. 72 func InferWhensType(whens ScalarListExpr, orElse opt.ScalarExpr) *types.T { 73 for _, when := range whens { 74 childType := when.DataType() 75 if childType.Family() != types.UnknownFamily { 76 return childType 77 } 78 } 79 return orElse.DataType() 80 } 81 82 // BinaryOverloadExists returns true if the given binary operator exists with the 83 // given arguments. 84 func BinaryOverloadExists(op opt.Operator, leftType, rightType *types.T) bool { 85 _, ok := FindBinaryOverload(op, leftType, rightType) 86 return ok 87 } 88 89 // BinaryAllowsNullArgs returns true if the given binary operator allows null 90 // arguments, and cannot therefore be folded away to null. 91 func BinaryAllowsNullArgs(op opt.Operator, leftType, rightType *types.T) bool { 92 o, ok := FindBinaryOverload(op, leftType, rightType) 93 if !ok { 94 panic(errors.AssertionFailedf("could not find overload for binary expression %s", log.Safe(op))) 95 } 96 return o.NullableArgs 97 } 98 99 // AggregateOverloadExists returns whether or not the given operator has a 100 // unary overload which takes the given type as input. 101 func AggregateOverloadExists(agg opt.Operator, typ *types.T) bool { 102 name := opt.AggregateOpReverseMap[agg] 103 _, overloads := builtins.GetBuiltinProperties(name) 104 for _, o := range overloads { 105 if o.Types.MatchAt(typ, 0) { 106 return true 107 } 108 } 109 return false 110 } 111 112 func findOverload(e opt.ScalarExpr, name string) (overload *tree.Overload, ok bool) { 113 _, overloads := builtins.GetBuiltinProperties(name) 114 for o := range overloads { 115 overload = &overloads[o] 116 matches := true 117 for i, n := 0, e.ChildCount(); i < n; i++ { 118 typ := e.Child(i).(opt.ScalarExpr).DataType() 119 if !overload.Types.MatchAt(typ, i) { 120 matches = false 121 break 122 } 123 } 124 if matches { 125 return overload, true 126 } 127 } 128 return nil, false 129 } 130 131 // FindWindowOverload finds a window function overload that matches the 132 // given window function expression. It panics if no match can be found. 133 func FindWindowOverload(e opt.ScalarExpr) (name string, overload *tree.Overload) { 134 name = opt.WindowOpReverseMap[e.Op()] 135 overload, ok := findOverload(e, name) 136 if ok { 137 return name, overload 138 } 139 // NB: all aggregate functions can be used as window functions. 140 return FindAggregateOverload(e) 141 } 142 143 // FindAggregateOverload finds an aggregate function overload that matches the 144 // given aggregate function expression. It panics if no match can be found. 145 func FindAggregateOverload(e opt.ScalarExpr) (name string, overload *tree.Overload) { 146 name = opt.AggregateOpReverseMap[e.Op()] 147 overload, ok := findOverload(e, name) 148 if ok { 149 return name, overload 150 } 151 panic(errors.AssertionFailedf("could not find overload for %s aggregate", name)) 152 } 153 154 type typingFunc func(e opt.ScalarExpr) *types.T 155 156 // typingFuncMap is a lookup table from scalar operator type to a function 157 // which returns the data type of an instance of that operator. 158 var typingFuncMap map[opt.Operator]typingFunc 159 160 func init() { 161 typingFuncMap = make(map[opt.Operator]typingFunc) 162 typingFuncMap[opt.PlaceholderOp] = typeAsTypedExpr 163 typingFuncMap[opt.UnsupportedExprOp] = typeAsTypedExpr 164 typingFuncMap[opt.CoalesceOp] = typeCoalesce 165 typingFuncMap[opt.CaseOp] = typeCase 166 typingFuncMap[opt.WhenOp] = typeWhen 167 typingFuncMap[opt.CastOp] = typeCast 168 typingFuncMap[opt.SubqueryOp] = typeSubquery 169 typingFuncMap[opt.ColumnAccessOp] = typeColumnAccess 170 typingFuncMap[opt.IndirectionOp] = typeIndirection 171 typingFuncMap[opt.CollateOp] = typeCollate 172 typingFuncMap[opt.ArrayFlattenOp] = typeArrayFlatten 173 typingFuncMap[opt.IfErrOp] = typeIfErr 174 175 // Override default typeAsAggregate behavior for aggregate functions with 176 // a large number of possible overloads or where ReturnType depends on 177 // argument types. 178 typingFuncMap[opt.ArrayAggOp] = typeArrayAgg 179 typingFuncMap[opt.MaxOp] = typeAsFirstArg 180 typingFuncMap[opt.MinOp] = typeAsFirstArg 181 typingFuncMap[opt.ConstAggOp] = typeAsFirstArg 182 typingFuncMap[opt.ConstNotNullAggOp] = typeAsFirstArg 183 typingFuncMap[opt.AnyNotNullAggOp] = typeAsFirstArg 184 typingFuncMap[opt.FirstAggOp] = typeAsFirstArg 185 186 typingFuncMap[opt.LagOp] = typeAsFirstArg 187 typingFuncMap[opt.LeadOp] = typeAsFirstArg 188 typingFuncMap[opt.NthValueOp] = typeAsFirstArg 189 190 // Modifiers for aggregations pass through their argument. 191 typingFuncMap[opt.AggDistinctOp] = typeAsFirstArg 192 typingFuncMap[opt.AggFilterOp] = typeAsFirstArg 193 typingFuncMap[opt.WindowFromOffsetOp] = typeAsFirstArg 194 typingFuncMap[opt.WindowToOffsetOp] = typeAsFirstArg 195 196 for _, op := range opt.BinaryOperators { 197 typingFuncMap[op] = typeAsBinary 198 } 199 200 for _, op := range opt.UnaryOperators { 201 typingFuncMap[op] = typeAsUnary 202 } 203 204 for _, op := range opt.AggregateOperators { 205 // Fill in any that are not already added to the typingFuncMap above. 206 if typingFuncMap[op] == nil { 207 typingFuncMap[op] = typeAsAggregate 208 } 209 } 210 211 for _, op := range opt.WindowOperators { 212 if typingFuncMap[op] == nil { 213 typingFuncMap[op] = typeAsWindow 214 } 215 } 216 } 217 218 // typeVariable returns the type of a variable expression, which is stored in 219 // the query metadata and accessed by column id. 220 func typeVariable(mem *Memo, e opt.ScalarExpr) *types.T { 221 variable := e.(*VariableExpr) 222 typ := mem.Metadata().ColumnMeta(variable.Col).Type 223 if typ == nil { 224 panic(errors.AssertionFailedf("column %d does not have type", log.Safe(variable.Col))) 225 } 226 return typ 227 } 228 229 // typeArrayAgg returns an array type with element type equal to the type of the 230 // aggregate expression's first (and only) argument. 231 func typeArrayAgg(e opt.ScalarExpr) *types.T { 232 arrayAgg := e.(*ArrayAggExpr) 233 typ := arrayAgg.Input.DataType() 234 return types.MakeArray(typ) 235 } 236 237 // typeIndirection returns the type of the element of the array. 238 func typeIndirection(e opt.ScalarExpr) *types.T { 239 return e.Child(0).(opt.ScalarExpr).DataType().ArrayContents() 240 } 241 242 // typeCollate returns the collated string typed with the given locale. 243 func typeCollate(e opt.ScalarExpr) *types.T { 244 locale := e.(*CollateExpr).Locale 245 return types.MakeCollatedString(types.String, locale) 246 } 247 248 // typeArrayFlatten returns the type of the subquery as an array. 249 func typeArrayFlatten(e opt.ScalarExpr) *types.T { 250 input := e.Child(0).(RelExpr) 251 colID := e.(*ArrayFlattenExpr).RequestedCol 252 return types.MakeArray(input.Memo().Metadata().ColumnMeta(colID).Type) 253 } 254 255 // typeIfErr returns the type of the IfErrExpr. The type is boolean if 256 // there is no OrElse, and the type of Cond/OrElse otherwise. 257 func typeIfErr(e opt.ScalarExpr) *types.T { 258 if e.(*IfErrExpr).OrElse.ChildCount() == 0 { 259 return types.Bool 260 } 261 return e.(*IfErrExpr).Cond.DataType() 262 } 263 264 // typeAsFirstArg returns the type of the expression's 0th argument. 265 func typeAsFirstArg(e opt.ScalarExpr) *types.T { 266 return e.Child(0).(opt.ScalarExpr).DataType() 267 } 268 269 // typeAsTypedExpr returns the resolved type of the private field, with the 270 // assumption that it is a tree.TypedExpr. 271 func typeAsTypedExpr(e opt.ScalarExpr) *types.T { 272 return e.Private().(tree.TypedExpr).ResolvedType() 273 } 274 275 // typeAsUnary returns the type of a unary expression by hooking into the sql 276 // semantics code that searches for unary operator overloads. 277 func typeAsUnary(e opt.ScalarExpr) *types.T { 278 return InferUnaryType(e.Op(), e.Child(0).(opt.ScalarExpr).DataType()) 279 } 280 281 // typeAsBinary returns the type of a binary expression by hooking into the sql 282 // semantics code that searches for binary operator overloads. 283 func typeAsBinary(e opt.ScalarExpr) *types.T { 284 leftType := e.Child(0).(opt.ScalarExpr).DataType() 285 rightType := e.Child(1).(opt.ScalarExpr).DataType() 286 return InferBinaryType(e.Op(), leftType, rightType) 287 } 288 289 // typeAsAggregate returns the type of an aggregate expression by hooking into 290 // the sql semantics code that searches for aggregate operator overloads. 291 func typeAsAggregate(e opt.ScalarExpr) *types.T { 292 // Only handle cases where the return type is not dependent on argument 293 // types (i.e. pass nil to the ReturnTyper). Aggregates with return types 294 // that depend on argument types are handled separately. 295 _, overload := FindAggregateOverload(e) 296 t := overload.ReturnType(nil) 297 if t == tree.UnknownReturnType { 298 panic(errors.AssertionFailedf("unknown aggregate return type. e:\n%s", e)) 299 } 300 return t 301 } 302 303 // typeAsWindow returns the type of a window function expression similar to 304 // typeAsAggregate. 305 func typeAsWindow(e opt.ScalarExpr) *types.T { 306 _, overload := FindWindowOverload(e) 307 t := overload.ReturnType(nil) 308 if t == tree.UnknownReturnType { 309 panic(errors.AssertionFailedf("unknown window return type. e:\n%s", e)) 310 } 311 312 return t 313 } 314 315 // typeCoalesce returns the type of a coalesce expression, which is equal to 316 // the type of its first non-null child. 317 func typeCoalesce(e opt.ScalarExpr) *types.T { 318 for _, arg := range e.(*CoalesceExpr).Args { 319 childType := arg.DataType() 320 if childType.Family() != types.UnknownFamily { 321 return childType 322 } 323 } 324 return types.Unknown 325 } 326 327 // typeCase returns the type of a CASE expression, which is 328 // of the form: 329 // CASE [ <cond> ] 330 // WHEN <condval1> THEN <expr1> 331 // [ WHEN <condval2> THEN <expr2> ] ... 332 // [ ELSE <expr> ] 333 // END 334 // The type is equal to the type of the WHEN <condval> THEN <expr> clauses, or 335 // the type of the ELSE <expr> value if all the previous types are unknown. 336 func typeCase(e opt.ScalarExpr) *types.T { 337 caseExpr := e.(*CaseExpr) 338 return InferWhensType(caseExpr.Whens, caseExpr.OrElse) 339 } 340 341 // typeWhen returns the type of a WHEN <condval> THEN <expr> clause inside a 342 // CASE statement. 343 func typeWhen(e opt.ScalarExpr) *types.T { 344 return e.(*WhenExpr).Value.DataType() 345 } 346 347 // typeCast returns the type of a CAST operator. 348 func typeCast(e opt.ScalarExpr) *types.T { 349 return e.(*CastExpr).Typ 350 } 351 352 // typeSubquery returns the type of a subquery, which is equal to the type of 353 // its first (and only) column. 354 func typeSubquery(e opt.ScalarExpr) *types.T { 355 input := e.Child(0).(RelExpr) 356 colID := input.Relational().OutputCols.SingleColumn() 357 return input.Memo().Metadata().ColumnMeta(colID).Type 358 } 359 360 func typeColumnAccess(e opt.ScalarExpr) *types.T { 361 colAccess := e.(*ColumnAccessExpr) 362 typ := colAccess.Input.DataType() 363 return typ.TupleContents()[colAccess.Idx] 364 } 365 366 // FindBinaryOverload finds the correct type signature overload for the 367 // specified binary operator, given the types of its inputs. If an overload is 368 // found, FindBinaryOverload returns true, plus a pointer to the overload. 369 // If an overload is not found, FindBinaryOverload returns false. 370 func FindBinaryOverload(op opt.Operator, leftType, rightType *types.T) (_ *tree.BinOp, ok bool) { 371 bin := opt.BinaryOpReverseMap[op] 372 373 // Find the binary op that matches the type of the expression's left and 374 // right children. No more than one match should ever be found. The 375 // TestTypingBinaryAssumptions test ensures this will be the case even if 376 // new operators or overloads are added. 377 for _, binOverloads := range tree.BinOps[bin] { 378 o := binOverloads.(*tree.BinOp) 379 380 if leftType.Family() == types.UnknownFamily { 381 if rightType.Equivalent(o.RightType) { 382 return o, true 383 } 384 } else if rightType.Family() == types.UnknownFamily { 385 if leftType.Equivalent(o.LeftType) { 386 return o, true 387 } 388 } else { 389 if leftType.Equivalent(o.LeftType) && rightType.Equivalent(o.RightType) { 390 return o, true 391 } 392 } 393 } 394 return nil, false 395 } 396 397 // FindUnaryOverload finds the correct type signature overload for the 398 // specified unary operator, given the type of its input. If an overload is 399 // found, FindUnaryOverload returns true, plus a pointer to the overload. 400 // If an overload is not found, FindUnaryOverload returns false. 401 func FindUnaryOverload(op opt.Operator, typ *types.T) (_ *tree.UnaryOp, ok bool) { 402 unary := opt.UnaryOpReverseMap[op] 403 404 for _, unaryOverloads := range tree.UnaryOps[unary] { 405 o := unaryOverloads.(*tree.UnaryOp) 406 if o.Typ.Equivalent(typ) { 407 return o, true 408 } 409 } 410 return nil, false 411 } 412 413 // FindComparisonOverload finds the correct type signature overload for the 414 // specified comparison operator, given the types of its inputs. If an overload 415 // is found, FindComparisonOverload returns a pointer to the overload and 416 // ok=true. It also returns "flipped" and "not" flags. The "flipped" flag 417 // indicates whether the original left and right operands should be flipped 418 // with the returned overload. The "not" flag indicates whether the result of 419 // the comparison operation should be negated. If an overload is not found, 420 // FindComparisonOverload returns ok=false. 421 func FindComparisonOverload( 422 op opt.Operator, leftType, rightType *types.T, 423 ) (_ *tree.CmpOp, flipped, not, ok bool) { 424 op, flipped, not = NormalizeComparison(op) 425 comp := opt.ComparisonOpReverseMap[op] 426 427 if flipped { 428 leftType, rightType = rightType, leftType 429 } 430 431 // Find the comparison op that matches the type of the expression's left and 432 // right children. No more than one match should ever be found. The 433 // TestTypingComparisonAssumptions test ensures this will be the case even if 434 // new operators or overloads are added. 435 for _, cmpOverloads := range tree.CmpOps[comp] { 436 o := cmpOverloads.(*tree.CmpOp) 437 438 if leftType.Family() == types.UnknownFamily { 439 if rightType.Equivalent(o.RightType) { 440 return o, flipped, not, true 441 } 442 } else if rightType.Family() == types.UnknownFamily { 443 if leftType.Equivalent(o.LeftType) { 444 return o, flipped, not, true 445 } 446 } else { 447 if leftType.Equivalent(o.LeftType) && rightType.Equivalent(o.RightType) { 448 return o, flipped, not, true 449 } 450 } 451 } 452 return nil, false, false, false 453 } 454 455 // NormalizeComparison maps a given comparison operator into an equivalent 456 // operator that exists in the tree.CmpOps map, returning this new operator, 457 // along with "flipped" and "not" flags. The "flipped" flag indicates whether 458 // the left and right operands should be flipped with the new operator. The 459 // "not" flag indicates whether the result of the comparison operation should 460 // be negated. 461 func NormalizeComparison(op opt.Operator) (newOp opt.Operator, flipped, not bool) { 462 switch op { 463 case opt.NeOp: 464 // Ne(left, right) is implemented as !Eq(left, right). 465 return opt.EqOp, false, true 466 case opt.GtOp: 467 // Gt(left, right) is implemented as Lt(right, left) 468 return opt.LtOp, true, false 469 case opt.GeOp: 470 // Ge(left, right) is implemented as Le(right, left) 471 return opt.LeOp, true, false 472 case opt.NotInOp: 473 // NotIn(left, right) is implemented as !In(left, right) 474 return opt.InOp, false, true 475 case opt.NotLikeOp: 476 // NotLike(left, right) is implemented as !Like(left, right) 477 return opt.LikeOp, false, true 478 case opt.NotILikeOp: 479 // NotILike(left, right) is implemented as !ILike(left, right) 480 return opt.ILikeOp, false, true 481 case opt.NotSimilarToOp: 482 // NotSimilarTo(left, right) is implemented as !SimilarTo(left, right) 483 return opt.SimilarToOp, false, true 484 case opt.NotRegMatchOp: 485 // NotRegMatch(left, right) is implemented as !RegMatch(left, right) 486 return opt.RegMatchOp, false, true 487 case opt.NotRegIMatchOp: 488 // NotRegIMatch(left, right) is implemented as !RegIMatch(left, right) 489 return opt.RegIMatchOp, false, true 490 case opt.IsNotOp: 491 // IsNot(left, right) is implemented as !Is(left, right) 492 return opt.IsOp, false, true 493 } 494 return op, false, false 495 }