github.com/gagliardetto/golang-go@v0.0.0-20201020153340-53909ea70814/cmd/compile/internal/ssa/gen/rulegen.go (about)

     1  // Copyright 2015 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // +build gen
     6  
     7  // This program generates Go code that applies rewrite rules to a Value.
     8  // The generated code implements a function of type func (v *Value) bool
     9  // which reports whether if did something.
    10  // Ideas stolen from Swift: http://www.hpl.hp.com/techreports/Compaq-DEC/WRL-2000-2.html
    11  
    12  package main
    13  
    14  import (
    15  	"bufio"
    16  	"bytes"
    17  	"flag"
    18  	"fmt"
    19  	"go/ast"
    20  	"go/format"
    21  	"go/parser"
    22  	"go/printer"
    23  	"go/token"
    24  	"io"
    25  	"log"
    26  	"os"
    27  	"path"
    28  	"regexp"
    29  	"sort"
    30  	"strconv"
    31  	"strings"
    32  
    33  	"golang.org/x/tools/go/ast/astutil"
    34  )
    35  
    36  // rule syntax:
    37  //  sexpr [&& extra conditions] -> [@block] sexpr
    38  //
    39  // sexpr are s-expressions (lisp-like parenthesized groupings)
    40  // sexpr ::= [variable:](opcode sexpr*)
    41  //         | variable
    42  //         | <type>
    43  //         | [auxint]
    44  //         | {aux}
    45  //
    46  // aux      ::= variable | {code}
    47  // type     ::= variable | {code}
    48  // variable ::= some token
    49  // opcode   ::= one of the opcodes from the *Ops.go files
    50  
    51  // extra conditions is just a chunk of Go that evaluates to a boolean. It may use
    52  // variables declared in the matching sexpr. The variable "v" is predefined to be
    53  // the value matched by the entire rule.
    54  
    55  // If multiple rules match, the first one in file order is selected.
    56  
    57  var genLog = flag.Bool("log", false, "generate code that logs; for debugging only")
    58  
    59  type Rule struct {
    60  	rule string
    61  	loc  string // file name & line number
    62  }
    63  
    64  func (r Rule) String() string {
    65  	return fmt.Sprintf("rule %q at %s", r.rule, r.loc)
    66  }
    67  
    68  func normalizeSpaces(s string) string {
    69  	return strings.Join(strings.Fields(strings.TrimSpace(s)), " ")
    70  }
    71  
    72  // parse returns the matching part of the rule, additional conditions, and the result.
    73  func (r Rule) parse() (match, cond, result string) {
    74  	s := strings.Split(r.rule, "->")
    75  	if len(s) != 2 {
    76  		log.Fatalf("no arrow in %s", r)
    77  	}
    78  	match = normalizeSpaces(s[0])
    79  	result = normalizeSpaces(s[1])
    80  	cond = ""
    81  	if i := strings.Index(match, "&&"); i >= 0 {
    82  		cond = normalizeSpaces(match[i+2:])
    83  		match = normalizeSpaces(match[:i])
    84  	}
    85  	return match, cond, result
    86  }
    87  
    88  func genRules(arch arch)          { genRulesSuffix(arch, "") }
    89  func genSplitLoadRules(arch arch) { genRulesSuffix(arch, "splitload") }
    90  
    91  func genRulesSuffix(arch arch, suff string) {
    92  	// Open input file.
    93  	text, err := os.Open(arch.name + suff + ".rules")
    94  	if err != nil {
    95  		if suff == "" {
    96  			// All architectures must have a plain rules file.
    97  			log.Fatalf("can't read rule file: %v", err)
    98  		}
    99  		// Some architectures have bonus rules files that others don't share. That's fine.
   100  		return
   101  	}
   102  
   103  	// oprules contains a list of rules for each block and opcode
   104  	blockrules := map[string][]Rule{}
   105  	oprules := map[string][]Rule{}
   106  
   107  	// read rule file
   108  	scanner := bufio.NewScanner(text)
   109  	rule := ""
   110  	var lineno int
   111  	var ruleLineno int // line number of "->"
   112  	for scanner.Scan() {
   113  		lineno++
   114  		line := scanner.Text()
   115  		if i := strings.Index(line, "//"); i >= 0 {
   116  			// Remove comments. Note that this isn't string safe, so
   117  			// it will truncate lines with // inside strings. Oh well.
   118  			line = line[:i]
   119  		}
   120  		rule += " " + line
   121  		rule = strings.TrimSpace(rule)
   122  		if rule == "" {
   123  			continue
   124  		}
   125  		if !strings.Contains(rule, "->") {
   126  			continue
   127  		}
   128  		if ruleLineno == 0 {
   129  			ruleLineno = lineno
   130  		}
   131  		if strings.HasSuffix(rule, "->") {
   132  			continue
   133  		}
   134  		if unbalanced(rule) {
   135  			continue
   136  		}
   137  
   138  		loc := fmt.Sprintf("%s%s.rules:%d", arch.name, suff, ruleLineno)
   139  		for _, rule2 := range expandOr(rule) {
   140  			for _, rule3 := range commute(rule2, arch) {
   141  				r := Rule{rule: rule3, loc: loc}
   142  				if rawop := strings.Split(rule3, " ")[0][1:]; isBlock(rawop, arch) {
   143  					blockrules[rawop] = append(blockrules[rawop], r)
   144  					continue
   145  				}
   146  				// Do fancier value op matching.
   147  				match, _, _ := r.parse()
   148  				op, oparch, _, _, _, _ := parseValue(match, arch, loc)
   149  				opname := fmt.Sprintf("Op%s%s", oparch, op.name)
   150  				oprules[opname] = append(oprules[opname], r)
   151  			}
   152  		}
   153  		rule = ""
   154  		ruleLineno = 0
   155  	}
   156  	if err := scanner.Err(); err != nil {
   157  		log.Fatalf("scanner failed: %v\n", err)
   158  	}
   159  	if unbalanced(rule) {
   160  		log.Fatalf("%s.rules:%d: unbalanced rule: %v\n", arch.name, lineno, rule)
   161  	}
   162  
   163  	// Order all the ops.
   164  	var ops []string
   165  	for op := range oprules {
   166  		ops = append(ops, op)
   167  	}
   168  	sort.Strings(ops)
   169  
   170  	genFile := &File{arch: arch, suffix: suff}
   171  	const chunkSize = 10
   172  	// Main rewrite routine is a switch on v.Op.
   173  	fn := &Func{kind: "Value"}
   174  
   175  	sw := &Switch{expr: exprf("v.Op")}
   176  	for _, op := range ops {
   177  		var ors []string
   178  		for chunk := 0; chunk < len(oprules[op]); chunk += chunkSize {
   179  			ors = append(ors, fmt.Sprintf("rewriteValue%s%s_%s_%d(v)", arch.name, suff, op, chunk))
   180  		}
   181  		swc := &Case{expr: exprf(op)}
   182  		swc.add(stmtf("return %s", strings.Join(ors, " || ")))
   183  		sw.add(swc)
   184  	}
   185  	fn.add(sw)
   186  	fn.add(stmtf("return false"))
   187  	genFile.add(fn)
   188  
   189  	// Generate a routine per op. Note that we don't make one giant routine
   190  	// because it is too big for some compilers.
   191  	for _, op := range ops {
   192  		rules := oprules[op]
   193  		// rr is kept between chunks, so that a following chunk checks
   194  		// that the previous one ended with a rule that wasn't
   195  		// unconditional.
   196  		var rr *RuleRewrite
   197  		for chunk := 0; chunk < len(rules); chunk += chunkSize {
   198  			endchunk := chunk + chunkSize
   199  			if endchunk > len(rules) {
   200  				endchunk = len(rules)
   201  			}
   202  			fn := &Func{
   203  				kind:   "Value",
   204  				suffix: fmt.Sprintf("_%s_%d", op, chunk),
   205  			}
   206  			fn.add(declf("b", "v.Block"))
   207  			fn.add(declf("config", "b.Func.Config"))
   208  			fn.add(declf("fe", "b.Func.fe"))
   209  			fn.add(declf("typ", "&b.Func.Config.Types"))
   210  			for _, rule := range rules[chunk:endchunk] {
   211  				if rr != nil && !rr.canFail {
   212  					log.Fatalf("unconditional rule %s is followed by other rules", rr.match)
   213  				}
   214  				rr = &RuleRewrite{loc: rule.loc}
   215  				rr.match, rr.cond, rr.result = rule.parse()
   216  				pos, _ := genMatch(rr, arch, rr.match)
   217  				if pos == "" {
   218  					pos = "v.Pos"
   219  				}
   220  				if rr.cond != "" {
   221  					rr.add(breakf("!(%s)", rr.cond))
   222  				}
   223  				genResult(rr, arch, rr.result, pos)
   224  				if *genLog {
   225  					rr.add(stmtf("logRule(%q)", rule.loc))
   226  				}
   227  				fn.add(rr)
   228  			}
   229  			if rr.canFail {
   230  				fn.add(stmtf("return false"))
   231  			}
   232  			genFile.add(fn)
   233  		}
   234  	}
   235  
   236  	// Generate block rewrite function. There are only a few block types
   237  	// so we can make this one function with a switch.
   238  	fn = &Func{kind: "Block"}
   239  	fn.add(declf("config", "b.Func.Config"))
   240  	fn.add(declf("typ", "&b.Func.Config.Types"))
   241  
   242  	sw = &Switch{expr: exprf("b.Kind")}
   243  	ops = ops[:0]
   244  	for op := range blockrules {
   245  		ops = append(ops, op)
   246  	}
   247  	sort.Strings(ops)
   248  	for _, op := range ops {
   249  		name, data := getBlockInfo(op, arch)
   250  		swc := &Case{expr: exprf("%s", name)}
   251  		for _, rule := range blockrules[op] {
   252  			swc.add(genBlockRewrite(rule, arch, data))
   253  		}
   254  		sw.add(swc)
   255  	}
   256  	fn.add(sw)
   257  	fn.add(stmtf("return false"))
   258  	genFile.add(fn)
   259  
   260  	// Remove unused imports and variables.
   261  	buf := new(bytes.Buffer)
   262  	fprint(buf, genFile)
   263  	fset := token.NewFileSet()
   264  	file, err := parser.ParseFile(fset, "", buf, parser.ParseComments)
   265  	if err != nil {
   266  		log.Fatal(err)
   267  	}
   268  	tfile := fset.File(file.Pos())
   269  
   270  	// First, use unusedInspector to find the unused declarations by their
   271  	// start position.
   272  	u := unusedInspector{unused: make(map[token.Pos]bool)}
   273  	u.node(file)
   274  
   275  	// Then, delete said nodes via astutil.Apply.
   276  	pre := func(c *astutil.Cursor) bool {
   277  		node := c.Node()
   278  		if node == nil {
   279  			return true
   280  		}
   281  		if u.unused[node.Pos()] {
   282  			c.Delete()
   283  			// Unused imports and declarations use exactly
   284  			// one line. Prevent leaving an empty line.
   285  			tfile.MergeLine(tfile.Position(node.Pos()).Line)
   286  			return false
   287  		}
   288  		return true
   289  	}
   290  	post := func(c *astutil.Cursor) bool {
   291  		switch node := c.Node().(type) {
   292  		case *ast.GenDecl:
   293  			if len(node.Specs) == 0 {
   294  				// Don't leave a broken or empty GenDecl behind,
   295  				// such as "import ()".
   296  				c.Delete()
   297  			}
   298  		}
   299  		return true
   300  	}
   301  	file = astutil.Apply(file, pre, post).(*ast.File)
   302  
   303  	// Write the well-formatted source to file
   304  	f, err := os.Create("../rewrite" + arch.name + suff + ".go")
   305  	if err != nil {
   306  		log.Fatalf("can't write output: %v", err)
   307  	}
   308  	defer f.Close()
   309  	// gofmt result; use a buffered writer, as otherwise go/format spends
   310  	// far too much time in syscalls.
   311  	bw := bufio.NewWriter(f)
   312  	if err := format.Node(bw, fset, file); err != nil {
   313  		log.Fatalf("can't format output: %v", err)
   314  	}
   315  	if err := bw.Flush(); err != nil {
   316  		log.Fatalf("can't write output: %v", err)
   317  	}
   318  	if err := f.Close(); err != nil {
   319  		log.Fatalf("can't write output: %v", err)
   320  	}
   321  }
   322  
   323  // unusedInspector can be used to detect unused variables and imports in an
   324  // ast.Node via its node method. The result is available in the "unused" map.
   325  //
   326  // note that unusedInspector is lazy and best-effort; it only supports the node
   327  // types and patterns used by the rulegen program.
   328  type unusedInspector struct {
   329  	// scope is the current scope, which can never be nil when a declaration
   330  	// is encountered. That is, the unusedInspector.node entrypoint should
   331  	// generally be an entire file or block.
   332  	scope *scope
   333  
   334  	// unused is the resulting set of unused declared names, indexed by the
   335  	// starting position of the node that declared the name.
   336  	unused map[token.Pos]bool
   337  
   338  	// defining is the object currently being defined; this is useful so
   339  	// that if "foo := bar" is unused and removed, we can then detect if
   340  	// "bar" becomes unused as well.
   341  	defining *object
   342  }
   343  
   344  // scoped opens a new scope when called, and returns a function which closes
   345  // that same scope. When a scope is closed, unused variables are recorded.
   346  func (u *unusedInspector) scoped() func() {
   347  	outer := u.scope
   348  	u.scope = &scope{outer: outer, objects: map[string]*object{}}
   349  	return func() {
   350  		for anyUnused := true; anyUnused; {
   351  			anyUnused = false
   352  			for _, obj := range u.scope.objects {
   353  				if obj.numUses > 0 {
   354  					continue
   355  				}
   356  				u.unused[obj.pos] = true
   357  				for _, used := range obj.used {
   358  					if used.numUses--; used.numUses == 0 {
   359  						anyUnused = true
   360  					}
   361  				}
   362  				// We've decremented numUses for each of the
   363  				// objects in used. Zero this slice too, to keep
   364  				// everything consistent.
   365  				obj.used = nil
   366  			}
   367  		}
   368  		u.scope = outer
   369  	}
   370  }
   371  
   372  func (u *unusedInspector) exprs(list []ast.Expr) {
   373  	for _, x := range list {
   374  		u.node(x)
   375  	}
   376  }
   377  
   378  func (u *unusedInspector) stmts(list []ast.Stmt) {
   379  	for _, x := range list {
   380  		u.node(x)
   381  	}
   382  }
   383  
   384  func (u *unusedInspector) decls(list []ast.Decl) {
   385  	for _, x := range list {
   386  		u.node(x)
   387  	}
   388  }
   389  
   390  func (u *unusedInspector) node(node ast.Node) {
   391  	switch node := node.(type) {
   392  	case *ast.File:
   393  		defer u.scoped()()
   394  		u.decls(node.Decls)
   395  	case *ast.GenDecl:
   396  		for _, spec := range node.Specs {
   397  			u.node(spec)
   398  		}
   399  	case *ast.ImportSpec:
   400  		impPath, _ := strconv.Unquote(node.Path.Value)
   401  		name := path.Base(impPath)
   402  		u.scope.objects[name] = &object{
   403  			name: name,
   404  			pos:  node.Pos(),
   405  		}
   406  	case *ast.FuncDecl:
   407  		u.node(node.Type)
   408  		if node.Body != nil {
   409  			u.node(node.Body)
   410  		}
   411  	case *ast.FuncType:
   412  		if node.Params != nil {
   413  			u.node(node.Params)
   414  		}
   415  		if node.Results != nil {
   416  			u.node(node.Results)
   417  		}
   418  	case *ast.FieldList:
   419  		for _, field := range node.List {
   420  			u.node(field)
   421  		}
   422  	case *ast.Field:
   423  		u.node(node.Type)
   424  
   425  	// statements
   426  
   427  	case *ast.BlockStmt:
   428  		defer u.scoped()()
   429  		u.stmts(node.List)
   430  	case *ast.IfStmt:
   431  		if node.Init != nil {
   432  			u.node(node.Init)
   433  		}
   434  		u.node(node.Cond)
   435  		u.node(node.Body)
   436  		if node.Else != nil {
   437  			u.node(node.Else)
   438  		}
   439  	case *ast.ForStmt:
   440  		if node.Init != nil {
   441  			u.node(node.Init)
   442  		}
   443  		if node.Cond != nil {
   444  			u.node(node.Cond)
   445  		}
   446  		if node.Post != nil {
   447  			u.node(node.Post)
   448  		}
   449  		u.node(node.Body)
   450  	case *ast.SwitchStmt:
   451  		if node.Init != nil {
   452  			u.node(node.Init)
   453  		}
   454  		if node.Tag != nil {
   455  			u.node(node.Tag)
   456  		}
   457  		u.node(node.Body)
   458  	case *ast.CaseClause:
   459  		u.exprs(node.List)
   460  		defer u.scoped()()
   461  		u.stmts(node.Body)
   462  	case *ast.BranchStmt:
   463  	case *ast.ExprStmt:
   464  		u.node(node.X)
   465  	case *ast.AssignStmt:
   466  		if node.Tok != token.DEFINE {
   467  			u.exprs(node.Rhs)
   468  			u.exprs(node.Lhs)
   469  			break
   470  		}
   471  		if len(node.Lhs) != 1 {
   472  			panic("no support for := with multiple names")
   473  		}
   474  
   475  		name := node.Lhs[0].(*ast.Ident)
   476  		obj := &object{
   477  			name: name.Name,
   478  			pos:  name.NamePos,
   479  		}
   480  
   481  		old := u.defining
   482  		u.defining = obj
   483  		u.exprs(node.Rhs)
   484  		u.defining = old
   485  
   486  		u.scope.objects[name.Name] = obj
   487  	case *ast.ReturnStmt:
   488  		u.exprs(node.Results)
   489  
   490  	// expressions
   491  
   492  	case *ast.CallExpr:
   493  		u.node(node.Fun)
   494  		u.exprs(node.Args)
   495  	case *ast.SelectorExpr:
   496  		u.node(node.X)
   497  	case *ast.UnaryExpr:
   498  		u.node(node.X)
   499  	case *ast.BinaryExpr:
   500  		u.node(node.X)
   501  		u.node(node.Y)
   502  	case *ast.StarExpr:
   503  		u.node(node.X)
   504  	case *ast.ParenExpr:
   505  		u.node(node.X)
   506  	case *ast.IndexExpr:
   507  		u.node(node.X)
   508  		u.node(node.Index)
   509  	case *ast.TypeAssertExpr:
   510  		u.node(node.X)
   511  		u.node(node.Type)
   512  	case *ast.Ident:
   513  		if obj := u.scope.Lookup(node.Name); obj != nil {
   514  			obj.numUses++
   515  			if u.defining != nil {
   516  				u.defining.used = append(u.defining.used, obj)
   517  			}
   518  		}
   519  	case *ast.BasicLit:
   520  	default:
   521  		panic(fmt.Sprintf("unhandled node: %T", node))
   522  	}
   523  }
   524  
   525  // scope keeps track of a certain scope and its declared names, as well as the
   526  // outer (parent) scope.
   527  type scope struct {
   528  	outer   *scope             // can be nil, if this is the top-level scope
   529  	objects map[string]*object // indexed by each declared name
   530  }
   531  
   532  func (s *scope) Lookup(name string) *object {
   533  	if obj := s.objects[name]; obj != nil {
   534  		return obj
   535  	}
   536  	if s.outer == nil {
   537  		return nil
   538  	}
   539  	return s.outer.Lookup(name)
   540  }
   541  
   542  // object keeps track of a declared name, such as a variable or import.
   543  type object struct {
   544  	name string
   545  	pos  token.Pos // start position of the node declaring the object
   546  
   547  	numUses int       // number of times this object is used
   548  	used    []*object // objects that its declaration makes use of
   549  }
   550  
   551  func fprint(w io.Writer, n Node) {
   552  	switch n := n.(type) {
   553  	case *File:
   554  		fmt.Fprintf(w, "// Code generated from gen/%s%s.rules; DO NOT EDIT.\n", n.arch.name, n.suffix)
   555  		fmt.Fprintf(w, "// generated with: cd gen; go run *.go\n")
   556  		fmt.Fprintf(w, "\npackage ssa\n")
   557  		for _, path := range append([]string{
   558  			"fmt",
   559  			"math",
   560  			"github.com/gagliardetto/golang-go/cmd/internal/obj",
   561  			"github.com/gagliardetto/golang-go/cmd/internal/objabi",
   562  			"github.com/gagliardetto/golang-go/cmd/compile/internal/types",
   563  		}, n.arch.imports...) {
   564  			fmt.Fprintf(w, "import %q\n", path)
   565  		}
   566  		for _, f := range n.list {
   567  			f := f.(*Func)
   568  			fmt.Fprintf(w, "func rewrite%s%s%s%s(", f.kind, n.arch.name, n.suffix, f.suffix)
   569  			fmt.Fprintf(w, "%c *%s) bool {\n", strings.ToLower(f.kind)[0], f.kind)
   570  			for _, n := range f.list {
   571  				fprint(w, n)
   572  			}
   573  			fmt.Fprintf(w, "}\n")
   574  		}
   575  	case *Switch:
   576  		fmt.Fprintf(w, "switch ")
   577  		fprint(w, n.expr)
   578  		fmt.Fprintf(w, " {\n")
   579  		for _, n := range n.list {
   580  			fprint(w, n)
   581  		}
   582  		fmt.Fprintf(w, "}\n")
   583  	case *Case:
   584  		fmt.Fprintf(w, "case ")
   585  		fprint(w, n.expr)
   586  		fmt.Fprintf(w, ":\n")
   587  		for _, n := range n.list {
   588  			fprint(w, n)
   589  		}
   590  	case *RuleRewrite:
   591  		fmt.Fprintf(w, "// match: %s\n", n.match)
   592  		if n.cond != "" {
   593  			fmt.Fprintf(w, "// cond: %s\n", n.cond)
   594  		}
   595  		fmt.Fprintf(w, "// result: %s\n", n.result)
   596  		fmt.Fprintf(w, "for %s {\n", n.check)
   597  		for _, n := range n.list {
   598  			fprint(w, n)
   599  		}
   600  		fmt.Fprintf(w, "return true\n}\n")
   601  	case *Declare:
   602  		fmt.Fprintf(w, "%s := ", n.name)
   603  		fprint(w, n.value)
   604  		fmt.Fprintln(w)
   605  	case *CondBreak:
   606  		fmt.Fprintf(w, "if ")
   607  		fprint(w, n.expr)
   608  		fmt.Fprintf(w, " {\nbreak\n}\n")
   609  	case ast.Node:
   610  		printConfig.Fprint(w, emptyFset, n)
   611  		if _, ok := n.(ast.Stmt); ok {
   612  			fmt.Fprintln(w)
   613  		}
   614  	default:
   615  		log.Fatalf("cannot print %T", n)
   616  	}
   617  }
   618  
   619  var printConfig = printer.Config{
   620  	Mode: printer.RawFormat, // we use go/format later, so skip work here
   621  }
   622  
   623  var emptyFset = token.NewFileSet()
   624  
   625  // Node can be a Statement or an ast.Expr.
   626  type Node interface{}
   627  
   628  // Statement can be one of our high-level statement struct types, or an
   629  // ast.Stmt under some limited circumstances.
   630  type Statement interface{}
   631  
   632  // bodyBase is shared by all of our statement pseudo-node types which can
   633  // contain other statements.
   634  type bodyBase struct {
   635  	list    []Statement
   636  	canFail bool
   637  }
   638  
   639  func (w *bodyBase) add(node Statement) {
   640  	var last Statement
   641  	if len(w.list) > 0 {
   642  		last = w.list[len(w.list)-1]
   643  	}
   644  	if node, ok := node.(*CondBreak); ok {
   645  		w.canFail = true
   646  		if last, ok := last.(*CondBreak); ok {
   647  			// Add to the previous "if <cond> { break }" via a
   648  			// logical OR, which will save verbosity.
   649  			last.expr = &ast.BinaryExpr{
   650  				Op: token.LOR,
   651  				X:  last.expr,
   652  				Y:  node.expr,
   653  			}
   654  			return
   655  		}
   656  	}
   657  
   658  	w.list = append(w.list, node)
   659  }
   660  
   661  // declared reports if the body contains a Declare with the given name.
   662  func (w *bodyBase) declared(name string) bool {
   663  	for _, s := range w.list {
   664  		if decl, ok := s.(*Declare); ok && decl.name == name {
   665  			return true
   666  		}
   667  	}
   668  	return false
   669  }
   670  
   671  // These types define some high-level statement struct types, which can be used
   672  // as a Statement. This allows us to keep some node structs simpler, and have
   673  // higher-level nodes such as an entire rule rewrite.
   674  //
   675  // Note that ast.Expr is always used as-is; we don't declare our own expression
   676  // nodes.
   677  type (
   678  	File struct {
   679  		bodyBase // []*Func
   680  		arch     arch
   681  		suffix   string
   682  	}
   683  	Func struct {
   684  		bodyBase
   685  		kind   string // "Value" or "Block"
   686  		suffix string
   687  	}
   688  	Switch struct {
   689  		bodyBase // []*Case
   690  		expr     ast.Expr
   691  	}
   692  	Case struct {
   693  		bodyBase
   694  		expr ast.Expr
   695  	}
   696  	RuleRewrite struct {
   697  		bodyBase
   698  		match, cond, result string // top comments
   699  		check               string // top-level boolean expression
   700  
   701  		alloc int    // for unique var names
   702  		loc   string // file name & line number of the original rule
   703  	}
   704  	Declare struct {
   705  		name  string
   706  		value ast.Expr
   707  	}
   708  	CondBreak struct {
   709  		expr ast.Expr
   710  	}
   711  )
   712  
   713  // exprf parses a Go expression generated from fmt.Sprintf, panicking if an
   714  // error occurs.
   715  func exprf(format string, a ...interface{}) ast.Expr {
   716  	src := fmt.Sprintf(format, a...)
   717  	expr, err := parser.ParseExpr(src)
   718  	if err != nil {
   719  		log.Fatalf("expr parse error on %q: %v", src, err)
   720  	}
   721  	return expr
   722  }
   723  
   724  // stmtf parses a Go statement generated from fmt.Sprintf. This function is only
   725  // meant for simple statements that don't have a custom Statement node declared
   726  // in this package, such as ast.ReturnStmt or ast.ExprStmt.
   727  func stmtf(format string, a ...interface{}) Statement {
   728  	src := fmt.Sprintf(format, a...)
   729  	fsrc := "package p\nfunc _() {\n" + src + "\n}\n"
   730  	file, err := parser.ParseFile(token.NewFileSet(), "", fsrc, 0)
   731  	if err != nil {
   732  		log.Fatalf("stmt parse error on %q: %v", src, err)
   733  	}
   734  	return file.Decls[0].(*ast.FuncDecl).Body.List[0]
   735  }
   736  
   737  // declf constructs a simple "name := value" declaration, using exprf for its
   738  // value.
   739  func declf(name, format string, a ...interface{}) *Declare {
   740  	return &Declare{name, exprf(format, a...)}
   741  }
   742  
   743  // breakf constructs a simple "if cond { break }" statement, using exprf for its
   744  // condition.
   745  func breakf(format string, a ...interface{}) *CondBreak {
   746  	return &CondBreak{exprf(format, a...)}
   747  }
   748  
   749  func genBlockRewrite(rule Rule, arch arch, data blockData) *RuleRewrite {
   750  	rr := &RuleRewrite{loc: rule.loc}
   751  	rr.match, rr.cond, rr.result = rule.parse()
   752  	_, _, auxint, aux, s := extract(rr.match) // remove parens, then split
   753  
   754  	// check match of control values
   755  	if len(s) < data.controls {
   756  		log.Fatalf("incorrect number of arguments in %s, got %v wanted at least %v", rule, len(s), data.controls)
   757  	}
   758  	controls := s[:data.controls]
   759  	pos := make([]string, data.controls)
   760  	for i, arg := range controls {
   761  		if strings.Contains(arg, "(") {
   762  			// TODO: allow custom names?
   763  			cname := fmt.Sprintf("b.Controls[%v]", i)
   764  			vname := fmt.Sprintf("v_%v", i)
   765  			rr.add(declf(vname, cname))
   766  			p, op := genMatch0(rr, arch, arg, vname)
   767  			if op != "" {
   768  				check := fmt.Sprintf("%s.Op == %s", cname, op)
   769  				if rr.check == "" {
   770  					rr.check = check
   771  				} else {
   772  					rr.check = rr.check + " && " + check
   773  				}
   774  			}
   775  			if p == "" {
   776  				p = vname + ".Pos"
   777  			}
   778  			pos[i] = p
   779  		} else {
   780  			rr.add(declf(arg, "b.Controls[%v]", i))
   781  			pos[i] = arg + ".Pos"
   782  		}
   783  	}
   784  	for _, e := range []struct {
   785  		name, field string
   786  	}{
   787  		{auxint, "AuxInt"},
   788  		{aux, "Aux"},
   789  	} {
   790  		if e.name == "" {
   791  			continue
   792  		}
   793  		if !token.IsIdentifier(e.name) || rr.declared(e.name) {
   794  			// code or variable
   795  			rr.add(breakf("b.%s != %s", e.field, e.name))
   796  		} else {
   797  			rr.add(declf(e.name, "b.%s", e.field))
   798  		}
   799  	}
   800  	if rr.cond != "" {
   801  		rr.add(breakf("!(%s)", rr.cond))
   802  	}
   803  
   804  	// Rule matches. Generate result.
   805  	outop, _, auxint, aux, t := extract(rr.result) // remove parens, then split
   806  	_, outdata := getBlockInfo(outop, arch)
   807  	if len(t) < outdata.controls {
   808  		log.Fatalf("incorrect number of output arguments in %s, got %v wanted at least %v", rule, len(s), outdata.controls)
   809  	}
   810  
   811  	// Check if newsuccs is the same set as succs.
   812  	succs := s[data.controls:]
   813  	newsuccs := t[outdata.controls:]
   814  	m := map[string]bool{}
   815  	for _, succ := range succs {
   816  		if m[succ] {
   817  			log.Fatalf("can't have a repeat successor name %s in %s", succ, rule)
   818  		}
   819  		m[succ] = true
   820  	}
   821  	for _, succ := range newsuccs {
   822  		if !m[succ] {
   823  			log.Fatalf("unknown successor %s in %s", succ, rule)
   824  		}
   825  		delete(m, succ)
   826  	}
   827  	if len(m) != 0 {
   828  		log.Fatalf("unmatched successors %v in %s", m, rule)
   829  	}
   830  
   831  	blockName, _ := getBlockInfo(outop, arch)
   832  	rr.add(stmtf("b.Reset(%s)", blockName))
   833  	for i, control := range t[:outdata.controls] {
   834  		// Select a source position for any new control values.
   835  		// TODO: does it always make sense to use the source position
   836  		// of the original control values or should we be using the
   837  		// block's source position in some cases?
   838  		newpos := "b.Pos" // default to block's source position
   839  		if i < len(pos) && pos[i] != "" {
   840  			// Use the previous control value's source position.
   841  			newpos = pos[i]
   842  		}
   843  
   844  		// Generate a new control value (or copy an existing value).
   845  		v := genResult0(rr, arch, control, false, false, newpos)
   846  		rr.add(stmtf("b.AddControl(%s)", v))
   847  	}
   848  	if auxint != "" {
   849  		rr.add(stmtf("b.AuxInt = %s", auxint))
   850  	}
   851  	if aux != "" {
   852  		rr.add(stmtf("b.Aux = %s", aux))
   853  	}
   854  
   855  	succChanged := false
   856  	for i := 0; i < len(succs); i++ {
   857  		if succs[i] != newsuccs[i] {
   858  			succChanged = true
   859  		}
   860  	}
   861  	if succChanged {
   862  		if len(succs) != 2 {
   863  			log.Fatalf("changed successors, len!=2 in %s", rule)
   864  		}
   865  		if succs[0] != newsuccs[1] || succs[1] != newsuccs[0] {
   866  			log.Fatalf("can only handle swapped successors in %s", rule)
   867  		}
   868  		rr.add(stmtf("b.swapSuccessors()"))
   869  	}
   870  
   871  	if *genLog {
   872  		rr.add(stmtf("logRule(%q)", rule.loc))
   873  	}
   874  	return rr
   875  }
   876  
   877  // genMatch returns the variable whose source position should be used for the
   878  // result (or "" if no opinion), and a boolean that reports whether the match can fail.
   879  func genMatch(rr *RuleRewrite, arch arch, match string) (pos, checkOp string) {
   880  	return genMatch0(rr, arch, match, "v")
   881  }
   882  
   883  func genMatch0(rr *RuleRewrite, arch arch, match, v string) (pos, checkOp string) {
   884  	if match[0] != '(' || match[len(match)-1] != ')' {
   885  		log.Fatalf("non-compound expr in genMatch0: %q", match)
   886  	}
   887  	op, oparch, typ, auxint, aux, args := parseValue(match, arch, rr.loc)
   888  
   889  	checkOp = fmt.Sprintf("Op%s%s", oparch, op.name)
   890  
   891  	if op.faultOnNilArg0 || op.faultOnNilArg1 {
   892  		// Prefer the position of an instruction which could fault.
   893  		pos = v + ".Pos"
   894  	}
   895  
   896  	for _, e := range []struct {
   897  		name, field string
   898  	}{
   899  		{typ, "Type"},
   900  		{auxint, "AuxInt"},
   901  		{aux, "Aux"},
   902  	} {
   903  		if e.name == "" {
   904  			continue
   905  		}
   906  		if !token.IsIdentifier(e.name) || rr.declared(e.name) {
   907  			// code or variable
   908  			rr.add(breakf("%s.%s != %s", v, e.field, e.name))
   909  		} else {
   910  			rr.add(declf(e.name, "%s.%s", v, e.field))
   911  		}
   912  	}
   913  
   914  	// Access last argument first to minimize bounds checks.
   915  	if n := len(args); n > 1 {
   916  		a := args[n-1]
   917  		if a != "_" && !rr.declared(a) && token.IsIdentifier(a) {
   918  			rr.add(declf(a, "%s.Args[%d]", v, n-1))
   919  
   920  			// delete the last argument so it is not reprocessed
   921  			args = args[:n-1]
   922  		} else {
   923  			rr.add(stmtf("_ = %s.Args[%d]", v, n-1))
   924  		}
   925  	}
   926  	for i, arg := range args {
   927  		if arg == "_" {
   928  			continue
   929  		}
   930  		if !strings.Contains(arg, "(") {
   931  			// leaf variable
   932  			if rr.declared(arg) {
   933  				// variable already has a definition. Check whether
   934  				// the old definition and the new definition match.
   935  				// For example, (add x x).  Equality is just pointer equality
   936  				// on Values (so cse is important to do before lowering).
   937  				rr.add(breakf("%s != %s.Args[%d]", arg, v, i))
   938  			} else {
   939  				rr.add(declf(arg, "%s.Args[%d]", v, i))
   940  			}
   941  			continue
   942  		}
   943  		// compound sexpr
   944  		argname := fmt.Sprintf("%s_%d", v, i)
   945  		colon := strings.Index(arg, ":")
   946  		openparen := strings.Index(arg, "(")
   947  		if colon >= 0 && openparen >= 0 && colon < openparen {
   948  			// rule-specified name
   949  			argname = arg[:colon]
   950  			arg = arg[colon+1:]
   951  		}
   952  		if argname == "b" {
   953  			log.Fatalf("don't name args 'b', it is ambiguous with blocks")
   954  		}
   955  
   956  		rr.add(declf(argname, "%s.Args[%d]", v, i))
   957  		bexpr := exprf("%s.Op != addLater", argname)
   958  		rr.add(&CondBreak{expr: bexpr})
   959  		argPos, argCheckOp := genMatch0(rr, arch, arg, argname)
   960  		bexpr.(*ast.BinaryExpr).Y.(*ast.Ident).Name = argCheckOp
   961  
   962  		if argPos != "" {
   963  			// Keep the argument in preference to the parent, as the
   964  			// argument is normally earlier in program flow.
   965  			// Keep the argument in preference to an earlier argument,
   966  			// as that prefers the memory argument which is also earlier
   967  			// in the program flow.
   968  			pos = argPos
   969  		}
   970  	}
   971  
   972  	if op.argLength == -1 {
   973  		rr.add(breakf("len(%s.Args) != %d", v, len(args)))
   974  	}
   975  	return pos, checkOp
   976  }
   977  
   978  func genResult(rr *RuleRewrite, arch arch, result, pos string) {
   979  	move := result[0] == '@'
   980  	if move {
   981  		// parse @block directive
   982  		s := strings.SplitN(result[1:], " ", 2)
   983  		rr.add(stmtf("b = %s", s[0]))
   984  		result = s[1]
   985  	}
   986  	genResult0(rr, arch, result, true, move, pos)
   987  }
   988  
   989  func genResult0(rr *RuleRewrite, arch arch, result string, top, move bool, pos string) string {
   990  	// TODO: when generating a constant result, use f.constVal to avoid
   991  	// introducing copies just to clean them up again.
   992  	if result[0] != '(' {
   993  		// variable
   994  		if top {
   995  			// It in not safe in general to move a variable between blocks
   996  			// (and particularly not a phi node).
   997  			// Introduce a copy.
   998  			rr.add(stmtf("v.reset(OpCopy)"))
   999  			rr.add(stmtf("v.Type = %s.Type", result))
  1000  			rr.add(stmtf("v.AddArg(%s)", result))
  1001  		}
  1002  		return result
  1003  	}
  1004  
  1005  	op, oparch, typ, auxint, aux, args := parseValue(result, arch, rr.loc)
  1006  
  1007  	// Find the type of the variable.
  1008  	typeOverride := typ != ""
  1009  	if typ == "" && op.typ != "" {
  1010  		typ = typeName(op.typ)
  1011  	}
  1012  
  1013  	v := "v"
  1014  	if top && !move {
  1015  		rr.add(stmtf("v.reset(Op%s%s)", oparch, op.name))
  1016  		if typeOverride {
  1017  			rr.add(stmtf("v.Type = %s", typ))
  1018  		}
  1019  	} else {
  1020  		if typ == "" {
  1021  			log.Fatalf("sub-expression %s (op=Op%s%s) at %s must have a type", result, oparch, op.name, rr.loc)
  1022  		}
  1023  		v = fmt.Sprintf("v%d", rr.alloc)
  1024  		rr.alloc++
  1025  		rr.add(declf(v, "b.NewValue0(%s, Op%s%s, %s)", pos, oparch, op.name, typ))
  1026  		if move && top {
  1027  			// Rewrite original into a copy
  1028  			rr.add(stmtf("v.reset(OpCopy)"))
  1029  			rr.add(stmtf("v.AddArg(%s)", v))
  1030  		}
  1031  	}
  1032  
  1033  	if auxint != "" {
  1034  		rr.add(stmtf("%s.AuxInt = %s", v, auxint))
  1035  	}
  1036  	if aux != "" {
  1037  		rr.add(stmtf("%s.Aux = %s", v, aux))
  1038  	}
  1039  	for _, arg := range args {
  1040  		x := genResult0(rr, arch, arg, false, move, pos)
  1041  		rr.add(stmtf("%s.AddArg(%s)", v, x))
  1042  	}
  1043  
  1044  	return v
  1045  }
  1046  
  1047  func split(s string) []string {
  1048  	var r []string
  1049  
  1050  outer:
  1051  	for s != "" {
  1052  		d := 0               // depth of ({[<
  1053  		var open, close byte // opening and closing markers ({[< or )}]>
  1054  		nonsp := false       // found a non-space char so far
  1055  		for i := 0; i < len(s); i++ {
  1056  			switch {
  1057  			case d == 0 && s[i] == '(':
  1058  				open, close = '(', ')'
  1059  				d++
  1060  			case d == 0 && s[i] == '<':
  1061  				open, close = '<', '>'
  1062  				d++
  1063  			case d == 0 && s[i] == '[':
  1064  				open, close = '[', ']'
  1065  				d++
  1066  			case d == 0 && s[i] == '{':
  1067  				open, close = '{', '}'
  1068  				d++
  1069  			case d == 0 && (s[i] == ' ' || s[i] == '\t'):
  1070  				if nonsp {
  1071  					r = append(r, strings.TrimSpace(s[:i]))
  1072  					s = s[i:]
  1073  					continue outer
  1074  				}
  1075  			case d > 0 && s[i] == open:
  1076  				d++
  1077  			case d > 0 && s[i] == close:
  1078  				d--
  1079  			default:
  1080  				nonsp = true
  1081  			}
  1082  		}
  1083  		if d != 0 {
  1084  			log.Fatalf("imbalanced expression: %q", s)
  1085  		}
  1086  		if nonsp {
  1087  			r = append(r, strings.TrimSpace(s))
  1088  		}
  1089  		break
  1090  	}
  1091  	return r
  1092  }
  1093  
  1094  // isBlock reports whether this op is a block opcode.
  1095  func isBlock(name string, arch arch) bool {
  1096  	for _, b := range genericBlocks {
  1097  		if b.name == name {
  1098  			return true
  1099  		}
  1100  	}
  1101  	for _, b := range arch.blocks {
  1102  		if b.name == name {
  1103  			return true
  1104  		}
  1105  	}
  1106  	return false
  1107  }
  1108  
  1109  func extract(val string) (op, typ, auxint, aux string, args []string) {
  1110  	val = val[1 : len(val)-1] // remove ()
  1111  
  1112  	// Split val up into regions.
  1113  	// Split by spaces/tabs, except those contained in (), {}, [], or <>.
  1114  	s := split(val)
  1115  
  1116  	// Extract restrictions and args.
  1117  	op = s[0]
  1118  	for _, a := range s[1:] {
  1119  		switch a[0] {
  1120  		case '<':
  1121  			typ = a[1 : len(a)-1] // remove <>
  1122  		case '[':
  1123  			auxint = a[1 : len(a)-1] // remove []
  1124  		case '{':
  1125  			aux = a[1 : len(a)-1] // remove {}
  1126  		default:
  1127  			args = append(args, a)
  1128  		}
  1129  	}
  1130  	return
  1131  }
  1132  
  1133  // parseValue parses a parenthesized value from a rule.
  1134  // The value can be from the match or the result side.
  1135  // It returns the op and unparsed strings for typ, auxint, and aux restrictions and for all args.
  1136  // oparch is the architecture that op is located in, or "" for generic.
  1137  func parseValue(val string, arch arch, loc string) (op opData, oparch, typ, auxint, aux string, args []string) {
  1138  	// Resolve the op.
  1139  	var s string
  1140  	s, typ, auxint, aux, args = extract(val)
  1141  
  1142  	// match reports whether x is a good op to select.
  1143  	// If strict is true, rule generation might succeed.
  1144  	// If strict is false, rule generation has failed,
  1145  	// but we're trying to generate a useful error.
  1146  	// Doing strict=true then strict=false allows
  1147  	// precise op matching while retaining good error messages.
  1148  	match := func(x opData, strict bool, archname string) bool {
  1149  		if x.name != s {
  1150  			return false
  1151  		}
  1152  		if x.argLength != -1 && int(x.argLength) != len(args) {
  1153  			if strict {
  1154  				return false
  1155  			}
  1156  			log.Printf("%s: op %s (%s) should have %d args, has %d", loc, s, archname, x.argLength, len(args))
  1157  		}
  1158  		return true
  1159  	}
  1160  
  1161  	for _, x := range genericOps {
  1162  		if match(x, true, "generic") {
  1163  			op = x
  1164  			break
  1165  		}
  1166  	}
  1167  	for _, x := range arch.ops {
  1168  		if arch.name != "generic" && match(x, true, arch.name) {
  1169  			if op.name != "" {
  1170  				log.Fatalf("%s: matches for op %s found in both generic and %s", loc, op.name, arch.name)
  1171  			}
  1172  			op = x
  1173  			oparch = arch.name
  1174  			break
  1175  		}
  1176  	}
  1177  
  1178  	if op.name == "" {
  1179  		// Failed to find the op.
  1180  		// Run through everything again with strict=false
  1181  		// to generate useful diagnosic messages before failing.
  1182  		for _, x := range genericOps {
  1183  			match(x, false, "generic")
  1184  		}
  1185  		for _, x := range arch.ops {
  1186  			match(x, false, arch.name)
  1187  		}
  1188  		log.Fatalf("%s: unknown op %s", loc, s)
  1189  	}
  1190  
  1191  	// Sanity check aux, auxint.
  1192  	if auxint != "" {
  1193  		switch op.aux {
  1194  		case "Bool", "Int8", "Int16", "Int32", "Int64", "Int128", "Float32", "Float64", "SymOff", "SymValAndOff", "TypSize":
  1195  		default:
  1196  			log.Fatalf("%s: op %s %s can't have auxint", loc, op.name, op.aux)
  1197  		}
  1198  	}
  1199  	if aux != "" {
  1200  		switch op.aux {
  1201  		case "String", "Sym", "SymOff", "SymValAndOff", "Typ", "TypSize", "CCop", "ArchSpecific":
  1202  		default:
  1203  			log.Fatalf("%s: op %s %s can't have aux", loc, op.name, op.aux)
  1204  		}
  1205  	}
  1206  	return
  1207  }
  1208  
  1209  func getBlockInfo(op string, arch arch) (name string, data blockData) {
  1210  	for _, b := range genericBlocks {
  1211  		if b.name == op {
  1212  			return "Block" + op, b
  1213  		}
  1214  	}
  1215  	for _, b := range arch.blocks {
  1216  		if b.name == op {
  1217  			return "Block" + arch.name + op, b
  1218  		}
  1219  	}
  1220  	log.Fatalf("could not find block data for %s", op)
  1221  	panic("unreachable")
  1222  }
  1223  
  1224  // typeName returns the string to use to generate a type.
  1225  func typeName(typ string) string {
  1226  	if typ[0] == '(' {
  1227  		ts := strings.Split(typ[1:len(typ)-1], ",")
  1228  		if len(ts) != 2 {
  1229  			log.Fatalf("Tuple expect 2 arguments")
  1230  		}
  1231  		return "types.NewTuple(" + typeName(ts[0]) + ", " + typeName(ts[1]) + ")"
  1232  	}
  1233  	switch typ {
  1234  	case "Flags", "Mem", "Void", "Int128":
  1235  		return "types.Type" + typ
  1236  	default:
  1237  		return "typ." + typ
  1238  	}
  1239  }
  1240  
  1241  // unbalanced reports whether there aren't the same number of ( and ) in the string.
  1242  func unbalanced(s string) bool {
  1243  	balance := 0
  1244  	for _, c := range s {
  1245  		if c == '(' {
  1246  			balance++
  1247  		} else if c == ')' {
  1248  			balance--
  1249  		}
  1250  	}
  1251  	return balance != 0
  1252  }
  1253  
  1254  // findAllOpcode is a function to find the opcode portion of s-expressions.
  1255  var findAllOpcode = regexp.MustCompile(`[(](\w+[|])+\w+[)]`).FindAllStringIndex
  1256  
  1257  // excludeFromExpansion reports whether the substring s[idx[0]:idx[1]] in a rule
  1258  // should be disregarded as a candidate for | expansion.
  1259  // It uses simple syntactic checks to see whether the substring
  1260  // is inside an AuxInt expression or inside the && conditions.
  1261  func excludeFromExpansion(s string, idx []int) bool {
  1262  	left := s[:idx[0]]
  1263  	if strings.LastIndexByte(left, '[') > strings.LastIndexByte(left, ']') {
  1264  		// Inside an AuxInt expression.
  1265  		return true
  1266  	}
  1267  	right := s[idx[1]:]
  1268  	if strings.Contains(left, "&&") && strings.Contains(right, "->") {
  1269  		// Inside && conditions.
  1270  		return true
  1271  	}
  1272  	return false
  1273  }
  1274  
  1275  // expandOr converts a rule into multiple rules by expanding | ops.
  1276  func expandOr(r string) []string {
  1277  	// Find every occurrence of |-separated things.
  1278  	// They look like MOV(B|W|L|Q|SS|SD)load or MOV(Q|L)loadidx(1|8).
  1279  	// Generate rules selecting one case from each |-form.
  1280  
  1281  	// Count width of |-forms.  They must match.
  1282  	n := 1
  1283  	for _, idx := range findAllOpcode(r, -1) {
  1284  		if excludeFromExpansion(r, idx) {
  1285  			continue
  1286  		}
  1287  		s := r[idx[0]:idx[1]]
  1288  		c := strings.Count(s, "|") + 1
  1289  		if c == 1 {
  1290  			continue
  1291  		}
  1292  		if n > 1 && n != c {
  1293  			log.Fatalf("'|' count doesn't match in %s: both %d and %d\n", r, n, c)
  1294  		}
  1295  		n = c
  1296  	}
  1297  	if n == 1 {
  1298  		// No |-form in this rule.
  1299  		return []string{r}
  1300  	}
  1301  	// Build each new rule.
  1302  	res := make([]string, n)
  1303  	for i := 0; i < n; i++ {
  1304  		buf := new(strings.Builder)
  1305  		x := 0
  1306  		for _, idx := range findAllOpcode(r, -1) {
  1307  			if excludeFromExpansion(r, idx) {
  1308  				continue
  1309  			}
  1310  			buf.WriteString(r[x:idx[0]])              // write bytes we've skipped over so far
  1311  			s := r[idx[0]+1 : idx[1]-1]               // remove leading "(" and trailing ")"
  1312  			buf.WriteString(strings.Split(s, "|")[i]) // write the op component for this rule
  1313  			x = idx[1]                                // note that we've written more bytes
  1314  		}
  1315  		buf.WriteString(r[x:])
  1316  		res[i] = buf.String()
  1317  	}
  1318  	return res
  1319  }
  1320  
  1321  // commute returns all equivalent rules to r after applying all possible
  1322  // argument swaps to the commutable ops in r.
  1323  // Potentially exponential, be careful.
  1324  func commute(r string, arch arch) []string {
  1325  	match, cond, result := Rule{rule: r}.parse()
  1326  	a := commute1(match, varCount(match), arch)
  1327  	for i, m := range a {
  1328  		if cond != "" {
  1329  			m += " && " + cond
  1330  		}
  1331  		m += " -> " + result
  1332  		a[i] = m
  1333  	}
  1334  	if len(a) == 1 && normalizeWhitespace(r) != normalizeWhitespace(a[0]) {
  1335  		fmt.Println(normalizeWhitespace(r))
  1336  		fmt.Println(normalizeWhitespace(a[0]))
  1337  		log.Fatalf("commute() is not the identity for noncommuting rule")
  1338  	}
  1339  	if false && len(a) > 1 {
  1340  		fmt.Println(r)
  1341  		for _, x := range a {
  1342  			fmt.Println("  " + x)
  1343  		}
  1344  	}
  1345  	return a
  1346  }
  1347  
  1348  func commute1(m string, cnt map[string]int, arch arch) []string {
  1349  	if m[0] == '<' || m[0] == '[' || m[0] == '{' || token.IsIdentifier(m) {
  1350  		return []string{m}
  1351  	}
  1352  	// Split up input.
  1353  	var prefix string
  1354  	if i := strings.Index(m, ":"); i >= 0 && token.IsIdentifier(m[:i]) {
  1355  		prefix = m[:i+1]
  1356  		m = m[i+1:]
  1357  	}
  1358  	if m[0] != '(' || m[len(m)-1] != ')' {
  1359  		log.Fatalf("non-compound expr in commute1: %q", m)
  1360  	}
  1361  	s := split(m[1 : len(m)-1])
  1362  	op := s[0]
  1363  
  1364  	// Figure out if the op is commutative or not.
  1365  	commutative := false
  1366  	for _, x := range genericOps {
  1367  		if op == x.name {
  1368  			if x.commutative {
  1369  				commutative = true
  1370  			}
  1371  			break
  1372  		}
  1373  	}
  1374  	if arch.name != "generic" {
  1375  		for _, x := range arch.ops {
  1376  			if op == x.name {
  1377  				if x.commutative {
  1378  					commutative = true
  1379  				}
  1380  				break
  1381  			}
  1382  		}
  1383  	}
  1384  	var idx0, idx1 int
  1385  	if commutative {
  1386  		// Find indexes of two args we can swap.
  1387  		for i, arg := range s {
  1388  			if i == 0 || arg[0] == '<' || arg[0] == '[' || arg[0] == '{' {
  1389  				continue
  1390  			}
  1391  			if idx0 == 0 {
  1392  				idx0 = i
  1393  				continue
  1394  			}
  1395  			if idx1 == 0 {
  1396  				idx1 = i
  1397  				break
  1398  			}
  1399  		}
  1400  		if idx1 == 0 {
  1401  			log.Fatalf("couldn't find first two args of commutative op %q", s[0])
  1402  		}
  1403  		if cnt[s[idx0]] == 1 && cnt[s[idx1]] == 1 || s[idx0] == s[idx1] && cnt[s[idx0]] == 2 {
  1404  			// When we have (Add x y) with no other uses of x and y in the matching rule,
  1405  			// then we can skip the commutative match (Add y x).
  1406  			commutative = false
  1407  		}
  1408  	}
  1409  
  1410  	// Recursively commute arguments.
  1411  	a := make([][]string, len(s))
  1412  	for i, arg := range s {
  1413  		a[i] = commute1(arg, cnt, arch)
  1414  	}
  1415  
  1416  	// Choose all possibilities from all args.
  1417  	r := crossProduct(a)
  1418  
  1419  	// If commutative, do that again with its two args reversed.
  1420  	if commutative {
  1421  		a[idx0], a[idx1] = a[idx1], a[idx0]
  1422  		r = append(r, crossProduct(a)...)
  1423  	}
  1424  
  1425  	// Construct result.
  1426  	for i, x := range r {
  1427  		r[i] = prefix + "(" + x + ")"
  1428  	}
  1429  	return r
  1430  }
  1431  
  1432  // varCount returns a map which counts the number of occurrences of
  1433  // Value variables in m.
  1434  func varCount(m string) map[string]int {
  1435  	cnt := map[string]int{}
  1436  	varCount1(m, cnt)
  1437  	return cnt
  1438  }
  1439  
  1440  func varCount1(m string, cnt map[string]int) {
  1441  	if m[0] == '<' || m[0] == '[' || m[0] == '{' {
  1442  		return
  1443  	}
  1444  	if token.IsIdentifier(m) {
  1445  		cnt[m]++
  1446  		return
  1447  	}
  1448  	// Split up input.
  1449  	if i := strings.Index(m, ":"); i >= 0 && token.IsIdentifier(m[:i]) {
  1450  		cnt[m[:i]]++
  1451  		m = m[i+1:]
  1452  	}
  1453  	if m[0] != '(' || m[len(m)-1] != ')' {
  1454  		log.Fatalf("non-compound expr in commute1: %q", m)
  1455  	}
  1456  	s := split(m[1 : len(m)-1])
  1457  	for _, arg := range s[1:] {
  1458  		varCount1(arg, cnt)
  1459  	}
  1460  }
  1461  
  1462  // crossProduct returns all possible values
  1463  // x[0][i] + " " + x[1][j] + " " + ... + " " + x[len(x)-1][k]
  1464  // for all valid values of i, j, ..., k.
  1465  func crossProduct(x [][]string) []string {
  1466  	if len(x) == 1 {
  1467  		return x[0]
  1468  	}
  1469  	var r []string
  1470  	for _, tail := range crossProduct(x[1:]) {
  1471  		for _, first := range x[0] {
  1472  			r = append(r, first+" "+tail)
  1473  		}
  1474  	}
  1475  	return r
  1476  }
  1477  
  1478  // normalizeWhitespace replaces 2+ whitespace sequences with a single space.
  1479  func normalizeWhitespace(x string) string {
  1480  	x = strings.Join(strings.Fields(x), " ")
  1481  	x = strings.Replace(x, "( ", "(", -1)
  1482  	x = strings.Replace(x, " )", ")", -1)
  1483  	x = strings.Replace(x, ")->", ") ->", -1)
  1484  	return x
  1485  }