github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/opt/optgen/lang/compiler.go (about) 1 // Copyright 2018 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package lang 12 13 import ( 14 "bytes" 15 "fmt" 16 17 "github.com/cockroachdb/errors" 18 ) 19 20 // CompiledExpr is the result of Optgen scanning, parsing, and semantic 21 // analysis. It contains the set of definitions and rules that were compiled 22 // from the Optgen input files. 23 type CompiledExpr struct { 24 Defines DefineSetExpr 25 Rules RuleSetExpr 26 DefineTags []string 27 defineIndex map[string]*DefineExpr 28 matchIndex map[string]RuleSetExpr 29 } 30 31 // LookupDefine returns the DefineExpr with the given name. 32 func (c *CompiledExpr) LookupDefine(name string) *DefineExpr { 33 return c.defineIndex[name] 34 } 35 36 // LookupMatchingDefines returns the set of define expressions which either 37 // exactly match the given name, or else have a tag that matches the given 38 // name. If no matches can be found, then LookupMatchingDefines returns nil. 39 func (c *CompiledExpr) LookupMatchingDefines(name string) DefineSetExpr { 40 var defines DefineSetExpr 41 define := c.LookupDefine(name) 42 if define != nil { 43 defines = append(defines, define) 44 } else { 45 // Name might be a tag name, so find all defines with that tag. 46 for _, define := range c.Defines { 47 if define.Tags.Contains(name) { 48 defines = append(defines, define) 49 } 50 } 51 } 52 return defines 53 } 54 55 // LookupMatchingRules returns the set of rules that match the given opname at 56 // the top-level, or nil if none do. For example, "InnerJoin" would match this 57 // rule: 58 // [CommuteJoin] 59 // (InnerJoin $r:* $s:*) => (InnerJoin $s $r) 60 func (c *CompiledExpr) LookupMatchingRules(name string) RuleSetExpr { 61 return c.matchIndex[name] 62 } 63 64 func (c *CompiledExpr) String() string { 65 var buf bytes.Buffer 66 buf.WriteString("(Compiled\n") 67 68 writeIndent(&buf, 1) 69 buf.WriteString("(Defines\n") 70 for _, define := range c.Defines { 71 writeIndent(&buf, 2) 72 define.Format(&buf, 2) 73 buf.WriteString("\n") 74 } 75 writeIndent(&buf, 1) 76 buf.WriteString(")\n") 77 78 writeIndent(&buf, 1) 79 buf.WriteString("(Rules\n") 80 for _, rule := range c.Rules { 81 writeIndent(&buf, 2) 82 rule.Format(&buf, 2) 83 buf.WriteString("\n") 84 } 85 writeIndent(&buf, 1) 86 buf.WriteString(")\n") 87 88 buf.WriteString(")\n") 89 90 return buf.String() 91 } 92 93 // Compiler compiles Optgen language input files and builds a CompiledExpr 94 // result from them. Compilation consists of scanning/parsing the files, which 95 // produces an AST, and is followed by semantic analysis and limited rewrites 96 // on the AST which the compiler performs. 97 type Compiler struct { 98 parser *Parser 99 compiled *CompiledExpr 100 errors []error 101 } 102 103 // NewCompiler constructs a new instance of the Optgen compiler, with the 104 // specified list of file paths as its input files. The Compile method 105 // must be called in order to compile the input files. 106 func NewCompiler(files ...string) *Compiler { 107 compiled := &CompiledExpr{ 108 defineIndex: make(map[string]*DefineExpr), 109 matchIndex: make(map[string]RuleSetExpr), 110 } 111 return &Compiler{parser: NewParser(files...), compiled: compiled} 112 } 113 114 // SetFileResolver overrides the default method of opening input files. The 115 // default resolver will use os.Open to open input files from disk. Callers 116 // can use this method to open input files in some other way. 117 func (c *Compiler) SetFileResolver(resolver FileResolver) { 118 // Forward call to the parser, which actually calls the resolver. 119 c.parser.SetFileResolver(resolver) 120 } 121 122 // Compile parses and compiles the input files and returns the resulting 123 // CompiledExpr. If there are errors, then Compile returns nil, and the errors 124 // are returned by the Errors function. 125 func (c *Compiler) Compile() *CompiledExpr { 126 root := c.parser.Parse() 127 if root == nil { 128 c.errors = c.parser.Errors() 129 return nil 130 } 131 132 if !c.compileDefines(root.Defines) { 133 return nil 134 } 135 136 if !c.compileRules(root.Rules) { 137 return nil 138 } 139 return c.compiled 140 } 141 142 // Errors returns the collection of errors that occurred during compilation. If 143 // no errors occurred, then Errors returns nil. 144 func (c *Compiler) Errors() []error { 145 return c.errors 146 } 147 148 func (c *Compiler) compileDefines(defines DefineSetExpr) bool { 149 c.compiled.Defines = defines 150 151 unique := make(map[TagExpr]bool) 152 153 for _, define := range defines { 154 // Record the define in the index for fast lookup. 155 name := string(define.Name) 156 _, ok := c.compiled.defineIndex[name] 157 if ok { 158 c.addErr(define.Source(), fmt.Errorf("duplicate '%s' define statement", name)) 159 } 160 161 c.compiled.defineIndex[name] = define 162 163 // Determine unique set of tags. 164 for _, tag := range define.Tags { 165 if !unique[tag] { 166 c.compiled.DefineTags = append(c.compiled.DefineTags, string(tag)) 167 unique[tag] = true 168 } 169 } 170 } 171 172 return true 173 } 174 175 func (c *Compiler) compileRules(rules RuleSetExpr) bool { 176 unique := make(map[StringExpr]bool) 177 178 for _, rule := range rules { 179 // Ensure that rule names are unique. 180 _, ok := unique[rule.Name] 181 if ok { 182 c.addErr(rule.Source(), fmt.Errorf("duplicate rule name '%s'", rule.Name)) 183 } 184 unique[rule.Name] = true 185 186 var ruleCompiler ruleCompiler 187 ruleCompiler.compile(c, rule) 188 } 189 190 // Index compiled rules by the op that they match at the top-level of the 191 // rule. 192 for _, rule := range c.compiled.Rules { 193 name := rule.Match.SingleName() 194 existing := c.compiled.matchIndex[name] 195 c.compiled.matchIndex[name] = append(existing, rule) 196 } 197 198 return len(c.errors) == 0 199 } 200 201 func (c *Compiler) addErr(src *SourceLoc, err error) { 202 if src != nil { 203 err = fmt.Errorf("%s: %s", src, err.Error()) 204 } 205 c.errors = append(c.errors, err) 206 } 207 208 // ruleCompiler compiles a single rule. It is a separate struct in order to 209 // keep state for the ruleContentCompiler. 210 type ruleCompiler struct { 211 compiler *Compiler 212 compiled *CompiledExpr 213 rule *RuleExpr 214 215 // bindings tracks variable bindings in order to ensure uniqueness and to 216 // infer types. 217 bindings map[StringExpr]DataType 218 219 // opName keeps the root match name in order to compile the OpName built-in 220 // function. 221 opName *NameExpr 222 } 223 224 func (c *ruleCompiler) compile(compiler *Compiler, rule *RuleExpr) { 225 c.compiler = compiler 226 c.compiled = compiler.compiled 227 c.rule = rule 228 229 if _, ok := rule.Match.Name.(*FuncExpr); ok { 230 // Function name is itself a function that dynamically determines name. 231 c.compiler.addErr(rule.Match.Source(), errors.New("cannot match dynamic name")) 232 return 233 } 234 235 // Expand root rules that match multiple operators into a separate match 236 // expression for each matching operator. 237 for _, name := range rule.Match.NameChoice() { 238 defines := c.compiled.LookupMatchingDefines(string(name)) 239 if len(defines) == 0 { 240 // No defines with that tag found, which is not allowed. 241 defines = nil 242 c.compiler.addErr(rule.Match.Source(), fmt.Errorf("unrecognized match name '%s'", name)) 243 } 244 245 for _, define := range defines { 246 // Create a rule for every matching define. 247 c.expandRule(NameExpr(define.Name)) 248 } 249 } 250 } 251 252 // expandRule rewrites the current rule to match the given opname rather than 253 // a list of names or a tag name. This transformation makes it easier for the 254 // code generator, since all rules will never match more than one op at the 255 // top-level. 256 func (c *ruleCompiler) expandRule(opName NameExpr) { 257 // Remember current error count in order to detect whether ruleContentCompiler 258 // adds additional errors. 259 errCntBefore := len(c.compiler.errors) 260 261 // Remember the root opname in case it's needed to compile the OpName 262 // built-in function. 263 c.opName = &opName 264 c.bindings = make(map[StringExpr]DataType) 265 266 // Construct new match expression that matches a single name. 267 match := &FuncExpr{Src: c.rule.Match.Src, Name: &NamesExpr{opName}} 268 match.Args = append(match.Args, c.rule.Match.Args...) 269 270 compiler := ruleContentCompiler{compiler: c, src: c.rule.Src, matchPattern: true} 271 match = compiler.compile(match).(*FuncExpr) 272 273 compiler = ruleContentCompiler{compiler: c, src: c.rule.Src, matchPattern: false} 274 replace := compiler.compile(c.rule.Replace) 275 276 newRule := &RuleExpr{ 277 Src: c.rule.Src, 278 Name: c.rule.Name, 279 Comments: c.rule.Comments, 280 Tags: c.rule.Tags, 281 Match: match, 282 Replace: replace, 283 } 284 285 // Infer data types for expressions within the match and replace patterns. 286 // Do this only if the rule triggered no errors. 287 if errCntBefore == len(c.compiler.errors) { 288 c.inferTypes(newRule.Match, AnyDataType) 289 c.inferTypes(newRule.Replace, AnyDataType) 290 } 291 292 c.compiled.Rules = append(c.compiled.Rules, newRule) 293 } 294 295 // inferTypes walks the tree and annotates it with inferred data types. It 296 // reports any typing errors it encounters. Each expression is annotated with 297 // either its "bottom-up" type which it infers from its inputs, or else its 298 // "top-down" type (the suggested argument), which is passed down from its 299 // ancestor(s). Each operator has its own rules of which to use. 300 func (c *ruleCompiler) inferTypes(e Expr, suggested DataType) { 301 switch t := e.(type) { 302 case *FuncExpr: 303 var defType *DefineSetDataType 304 if t.HasDynamicName() { 305 // Special-case the OpName built-in function. 306 nameFunc, ok := t.Name.(*CustomFuncExpr) 307 if !ok || nameFunc.Name != "OpName" { 308 panic(fmt.Sprintf("%s not allowed as dynamic function name", t.Name)) 309 } 310 311 // Inherit type of the opname target. 312 label := nameFunc.Args[0].(*RefExpr).Label 313 t.Typ = c.bindings[label] 314 if t.Typ == nil { 315 panic(fmt.Sprintf("$%s does not have its type set", label)) 316 } 317 318 defType, ok = t.Typ.(*DefineSetDataType) 319 if !ok { 320 err := errors.New("cannot infer type of construction expression") 321 c.compiler.addErr(nameFunc.Args[0].Source(), err) 322 break 323 } 324 325 // If the OpName refers to a single operator, rewrite it as a simple 326 // static name. 327 if len(defType.Defines) == 1 { 328 name := NameExpr(defType.Defines[0].Name) 329 t.Name = &name 330 } 331 } else { 332 // Construct list of defines that can be matched. 333 names := t.NameChoice() 334 defines := make(DefineSetExpr, 0, len(names)) 335 for _, name := range names { 336 defines = append(defines, c.compiled.LookupMatchingDefines(string(name))...) 337 } 338 339 // Set the data type of the function. 340 defType = &DefineSetDataType{Defines: defines} 341 t.Typ = defType 342 } 343 344 // First define in list is considered the "prototype" that all others 345 // match. The matching is checked in ruleContentCompiler.compileFunc. 346 prototype := defType.Defines[0] 347 348 if len(t.Args) > len(prototype.Fields) { 349 err := fmt.Errorf("%s has too many args", e.Op()) 350 c.compiler.addErr(t.Source(), err) 351 break 352 } 353 354 // Recurse on name and arguments. 355 c.inferTypes(t.Name, AnyDataType) 356 for i, arg := range t.Args { 357 suggested := &ExternalDataType{Name: string(prototype.Fields[i].Type)} 358 c.inferTypes(arg, suggested) 359 } 360 361 case *CustomFuncExpr: 362 // Return type of custom function isn't known, but might be inferred from 363 // context in which it's used. 364 t.Typ = suggested 365 366 // Recurse on arguments, passing AnyDataType as suggested type, because 367 // no information is known about their types. 368 for _, arg := range t.Args { 369 c.inferTypes(arg, AnyDataType) 370 } 371 372 case *BindExpr: 373 // Set type of binding to type of its target. 374 c.inferTypes(t.Target, suggested) 375 t.Typ = t.Target.InferredType() 376 377 // Update type in bindings map. 378 c.bindings[t.Label] = t.Typ 379 380 case *RefExpr: 381 // Set type of ref to type of its binding or the suggested type. 382 typ := c.bindings[t.Label] 383 if typ == nil { 384 panic(fmt.Sprintf("$%s does not have its type set", t.Label)) 385 } 386 t.Typ = mostRestrictiveDataType(typ, suggested) 387 388 case *AndExpr: 389 // Assign most restrictive type to And expression. 390 c.inferTypes(t.Left, suggested) 391 c.inferTypes(t.Right, suggested) 392 if DoTypesContradict(t.Left.InferredType(), t.Right.InferredType()) { 393 err := fmt.Errorf("match patterns contradict one another; both cannot match") 394 c.compiler.addErr(t.Source(), err) 395 } 396 t.Typ = mostRestrictiveDataType(t.Left.InferredType(), t.Right.InferredType()) 397 398 case *NotExpr: 399 // Fall back on suggested type, since only type that doesn't match is known. 400 c.inferTypes(t.Input, suggested) 401 t.Typ = suggested 402 403 case *ListExpr: 404 // Assign most restrictive type to list expression. 405 t.Typ = mostRestrictiveDataType(ListDataType, suggested) 406 for _, item := range t.Items { 407 c.inferTypes(item, AnyDataType) 408 } 409 410 case *AnyExpr: 411 t.Typ = suggested 412 413 case *StringExpr, *NumberExpr, *ListAnyExpr, *NameExpr, *NamesExpr: 414 // Type already known; nothing to infer. 415 416 default: 417 panic(fmt.Sprintf("unhandled expression: %s", t)) 418 } 419 } 420 421 // ruleContentCompiler is the workhorse of rule compilation. It is recursively 422 // constructed on the stack in order to keep scoping context for match and 423 // construct expressions. Semantics can change depending on the context. 424 type ruleContentCompiler struct { 425 compiler *ruleCompiler 426 427 // src is the source location of the nearest match or construct expression, 428 // and is used when the source location isn't otherwise available. 429 src *SourceLoc 430 431 // matchPattern is true when compiling in the scope of a match pattern, and 432 // false when compiling in the scope of a replace pattern. 433 matchPattern bool 434 435 // customFunc is true when compiling in the scope of a custom match or 436 // replace function, and false when compiling in the scope of an op matcher 437 // or op constructor. 438 customFunc bool 439 } 440 441 func (c *ruleContentCompiler) compile(e Expr) Expr { 442 // Recurse into match or construct operator separately, since they will need 443 // to ceate new context before visiting arguments. 444 switch t := e.(type) { 445 case *FuncExpr: 446 return c.compileFunc(t) 447 448 case *BindExpr: 449 return c.compileBind(t) 450 451 case *RefExpr: 452 if c.matchPattern && !c.customFunc { 453 c.addDisallowedErr(t, "cannot use variable references") 454 } else { 455 // Check that referenced variable exists. 456 _, ok := c.compiler.bindings[t.Label] 457 if !ok { 458 c.addErr(t, fmt.Errorf("unrecognized variable name '%s'", t.Label)) 459 } 460 } 461 462 case *ListExpr: 463 if c.matchPattern && c.customFunc { 464 c.addDisallowedErr(t, "cannot use lists") 465 } else { 466 c.compileList(t) 467 } 468 469 case *AndExpr, *NotExpr: 470 if !c.matchPattern || c.customFunc { 471 c.addDisallowedErr(t, "cannot use boolean expressions") 472 } 473 474 case *NameExpr: 475 if c.matchPattern && !c.customFunc { 476 c.addErr(t, fmt.Errorf("cannot match literal name '%s'", *t)) 477 } else { 478 define := c.compiler.compiled.LookupDefine(string(*t)) 479 if define == nil { 480 c.addErr(t, fmt.Errorf("%s is not an operator name", *t)) 481 } 482 } 483 484 case *AnyExpr: 485 if !c.matchPattern || c.customFunc { 486 c.addDisallowedErr(t, "cannot use wildcard matcher") 487 } 488 } 489 490 // Pre-order traversal. 491 return e.Visit(c.compile) 492 } 493 494 func (c *ruleContentCompiler) compileBind(bind *BindExpr) Expr { 495 // Ensure that binding labels are unique. 496 _, ok := c.compiler.bindings[bind.Label] 497 if ok { 498 c.addErr(bind, fmt.Errorf("duplicate bind label '%s'", bind.Label)) 499 } 500 501 // Initialize binding before visiting, since it might be recursively 502 // referenced, as in: 503 // 504 // $input:* & (Func $input) 505 // 506 c.compiler.bindings[bind.Label] = AnyDataType 507 newBind := bind.Visit(c.compile).(*BindExpr) 508 509 return newBind 510 } 511 512 func (c *ruleContentCompiler) compileList(list *ListExpr) { 513 foundNotAny := false 514 for _, item := range list.Items { 515 if item.Op() == ListAnyOp { 516 if !c.matchPattern { 517 c.addErr(list, errors.New("list constructor cannot use '...'")) 518 } 519 } else { 520 if c.matchPattern && foundNotAny { 521 c.addErr(item, errors.New("list matcher cannot contain multiple expressions")) 522 break 523 } 524 foundNotAny = true 525 } 526 } 527 } 528 529 func (c *ruleContentCompiler) compileFunc(fn *FuncExpr) Expr { 530 // Create nested context and recurse into children. 531 nested := ruleContentCompiler{ 532 compiler: c.compiler, 533 src: fn.Source(), 534 matchPattern: c.matchPattern, 535 } 536 537 funcName := fn.Name 538 539 if nameExpr, ok := funcName.(*FuncExpr); ok { 540 // Function name is itself a function that dynamically determines name. 541 if c.matchPattern { 542 c.addErr(fn, errors.New("cannot match dynamic name")) 543 } 544 545 funcName = c.compileFunc(nameExpr) 546 } else { 547 // Ensure that all function names are defined and check whether this is a 548 // custom match function invocation. 549 names, ok := c.checkNames(fn) 550 if !ok { 551 return nil 552 } 553 554 // Normalize single name into NameExpr rather than NamesExpr. 555 if len(names) == 1 { 556 funcName = &names[0] 557 } else { 558 funcName = &names 559 } 560 561 var prototype *DefineExpr 562 for _, name := range names { 563 defines := c.compiler.compiled.LookupMatchingDefines(string(name)) 564 if defines != nil { 565 // Ensure that each operator has at least as many operands as the 566 // given function has arguments. The types of those arguments must 567 // be the same across all the operators. 568 for _, define := range defines { 569 if len(define.Fields) < len(fn.Args) { 570 c.addErr(fn, fmt.Errorf("%s has only %d fields", define.Name, len(define.Fields))) 571 continue 572 } 573 574 if prototype == nil { 575 // Save the first define in order to compare it against all 576 // others. 577 prototype = define 578 continue 579 } 580 581 for i := range fn.Args { 582 if define.Fields[i].Type != prototype.Fields[i].Type { 583 c.addErr(fn, fmt.Errorf("%s and %s fields do not have same types", 584 define.Name, prototype.Name)) 585 } 586 } 587 } 588 } else { 589 // This must be an invocation of a custom function, because there is 590 // no matching define. 591 if len(names) != 1 { 592 c.addErr(fn, errors.New("custom function cannot have multiple names")) 593 return fn 594 } 595 596 // Handle built-in functions. 597 if name == "OpName" { 598 opName, ok := c.compileOpName(fn) 599 if ok { 600 return opName 601 } 602 603 // Fall through and create OpName as a CustomFuncExpr. It may 604 // be rewritten during type inference if it can be proved it 605 // always constructs a single operator. 606 } 607 608 nested.customFunc = true 609 } 610 } 611 } 612 613 if c.matchPattern && c.customFunc && !nested.customFunc { 614 c.addErr(fn, errors.New("custom function name cannot be an operator name")) 615 return fn 616 } 617 618 args := fn.Args.Visit(nested.compile).(*SliceExpr) 619 620 if nested.customFunc { 621 // Create a CustomFuncExpr to make it easier to distinguish between 622 // op matchers and and custom function invocations. 623 return &CustomFuncExpr{Name: *funcName.(*NameExpr), Args: *args, Src: fn.Source()} 624 } 625 return &FuncExpr{Name: funcName, Args: *args, Src: fn.Source()} 626 } 627 628 // checkNames ensures that all function names are valid operator names or tag 629 // names, and that they are legal in the current context. checkNames returns 630 // the list of names as a NameExpr, as well as a boolean indicating whether they 631 // passed all validity checks. 632 func (c *ruleContentCompiler) checkNames(fn *FuncExpr) (names NamesExpr, ok bool) { 633 switch t := fn.Name.(type) { 634 case *NamesExpr: 635 names = *t 636 case *NameExpr: 637 names = NamesExpr{*t} 638 default: 639 // Name dynamically derived by function. 640 return NamesExpr{}, false 641 } 642 643 // Don't allow replace pattern to have multiple names or a tag name. 644 if !c.matchPattern { 645 if len(names) != 1 { 646 c.addErr(fn, errors.New("constructor cannot have multiple names")) 647 return NamesExpr{}, false 648 } 649 650 defines := c.compiler.compiled.LookupMatchingDefines(string(names[0])) 651 if len(defines) == 0 { 652 // Must be custom function name. 653 return names, true 654 } 655 656 define := c.compiler.compiled.LookupDefine(string(names[0])) 657 if define == nil { 658 c.addErr(fn, fmt.Errorf("construct name cannot be a tag")) 659 return NamesExpr{}, false 660 } 661 } 662 663 return names, true 664 } 665 666 func (c *ruleContentCompiler) compileOpName(fn *FuncExpr) (_ Expr, ok bool) { 667 if len(fn.Args) > 1 { 668 c.addErr(fn, fmt.Errorf("too many arguments to OpName function")) 669 return fn, false 670 } 671 672 if len(fn.Args) == 0 { 673 // No args to OpName function refers to top-level match operator. 674 return c.compiler.opName, true 675 } 676 677 // Otherwise expect a single variable reference argument. 678 _, ok = fn.Args[0].(*RefExpr) 679 if !ok { 680 c.addErr(fn, fmt.Errorf("invalid OpName argument: argument must be a variable reference")) 681 return fn, false 682 } 683 684 return fn, false 685 } 686 687 // addDisallowedErr creates an error prefixed by one of the following strings, 688 // depending on the context: 689 // match pattern 690 // replace pattern 691 // custom match function 692 // custom replace function 693 func (c *ruleContentCompiler) addDisallowedErr(loc Expr, disallowed string) { 694 if c.matchPattern { 695 if c.customFunc { 696 c.addErr(loc, fmt.Errorf("custom match function %s", disallowed)) 697 } else { 698 c.addErr(loc, fmt.Errorf("match pattern %s", disallowed)) 699 } 700 } else { 701 if c.customFunc { 702 c.addErr(loc, fmt.Errorf("custom replace function %s", disallowed)) 703 } else { 704 c.addErr(loc, fmt.Errorf("replace pattern %s", disallowed)) 705 } 706 } 707 } 708 709 func (c *ruleContentCompiler) addErr(loc Expr, err error) { 710 src := loc.Source() 711 if src == nil { 712 src = c.src 713 } 714 c.compiler.compiler.addErr(src, err) 715 } 716 717 // mostRestrictiveDataType returns the more restrictive of the two data types, 718 // or the left data type if they are equally restrictive. 719 func mostRestrictiveDataType(left, right DataType) DataType { 720 if IsTypeMoreRestrictive(right, left) { 721 return right 722 } 723 return left 724 }