github.com/bir3/gocompiler@v0.9.2202/src/cmd/compile/internal/ir/mknode.go (about)

     1  // Copyright 2022 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build ignore
     6  
     7  // Note: this program must be run in this directory.
     8  //   go run mknode.go
     9  
    10  package main
    11  
    12  import (
    13  	"bytes"
    14  	"fmt"
    15  	"github.com/bir3/gocompiler/src/go/ast"
    16  	"github.com/bir3/gocompiler/src/go/format"
    17  	"github.com/bir3/gocompiler/src/go/parser"
    18  	"github.com/bir3/gocompiler/src/go/token"
    19  	"io/fs"
    20  	"log"
    21  	"os"
    22  	"sort"
    23  	"strings"
    24  )
    25  
    26  var fset = token.NewFileSet()
    27  
    28  var buf bytes.Buffer
    29  
    30  // concreteNodes contains all concrete types in the package that implement Node
    31  // (except for the mini* types).
    32  var concreteNodes []*ast.TypeSpec
    33  
    34  // interfaceNodes contains all interface types in the package that implement Node.
    35  var interfaceNodes []*ast.TypeSpec
    36  
    37  // mini contains the embeddable mini types (miniNode, miniExpr, and miniStmt).
    38  var mini = map[string]*ast.TypeSpec{}
    39  
    40  // implementsNode reports whether the type t is one which represents a Node
    41  // in the AST.
    42  func implementsNode(t ast.Expr) bool {
    43  	id, ok := t.(*ast.Ident)
    44  	if !ok {
    45  		return false	// only named types
    46  	}
    47  	for _, ts := range interfaceNodes {
    48  		if ts.Name.Name == id.Name {
    49  			return true
    50  		}
    51  	}
    52  	for _, ts := range concreteNodes {
    53  		if ts.Name.Name == id.Name {
    54  			return true
    55  		}
    56  	}
    57  	return false
    58  }
    59  
    60  func isMini(t ast.Expr) bool {
    61  	id, ok := t.(*ast.Ident)
    62  	return ok && mini[id.Name] != nil
    63  }
    64  
    65  func isNamedType(t ast.Expr, name string) bool {
    66  	if id, ok := t.(*ast.Ident); ok {
    67  		if id.Name == name {
    68  			return true
    69  		}
    70  	}
    71  	return false
    72  }
    73  
    74  func main() {
    75  	fmt.Fprintln(&buf, "// Code generated by mknode.go. DO NOT EDIT.")
    76  	fmt.Fprintln(&buf)
    77  	fmt.Fprintln(&buf, "package ir")
    78  	fmt.Fprintln(&buf)
    79  	fmt.Fprintln(&buf, `import "fmt"`)
    80  
    81  	filter := func(file fs.FileInfo) bool {
    82  		return !strings.HasPrefix(file.Name(), "mknode")
    83  	}
    84  	pkgs, err := parser.ParseDir(fset, ".", filter, 0)
    85  	if err != nil {
    86  		panic(err)
    87  	}
    88  	pkg := pkgs["ir"]
    89  
    90  	// Find all the mini types. These let us determine which
    91  	// concrete types implement Node, so we need to find them first.
    92  	for _, f := range pkg.Files {
    93  		for _, d := range f.Decls {
    94  			g, ok := d.(*ast.GenDecl)
    95  			if !ok {
    96  				continue
    97  			}
    98  			for _, s := range g.Specs {
    99  				t, ok := s.(*ast.TypeSpec)
   100  				if !ok {
   101  					continue
   102  				}
   103  				if strings.HasPrefix(t.Name.Name, "mini") {
   104  					mini[t.Name.Name] = t
   105  					// Double-check that it is or embeds miniNode.
   106  					if t.Name.Name != "miniNode" {
   107  						s := t.Type.(*ast.StructType)
   108  						if !isNamedType(s.Fields.List[0].Type, "miniNode") {
   109  							panic(fmt.Sprintf("can't find miniNode in %s", t.Name.Name))
   110  						}
   111  					}
   112  				}
   113  			}
   114  		}
   115  	}
   116  
   117  	// Find all the declarations of concrete types that implement Node.
   118  	for _, f := range pkg.Files {
   119  		for _, d := range f.Decls {
   120  			g, ok := d.(*ast.GenDecl)
   121  			if !ok {
   122  				continue
   123  			}
   124  			for _, s := range g.Specs {
   125  				t, ok := s.(*ast.TypeSpec)
   126  				if !ok {
   127  					continue
   128  				}
   129  				if strings.HasPrefix(t.Name.Name, "mini") {
   130  					// We don't treat the mini types as
   131  					// concrete implementations of Node
   132  					// (even though they are) because
   133  					// we only use them by embedding them.
   134  					continue
   135  				}
   136  				if isConcreteNode(t) {
   137  					concreteNodes = append(concreteNodes, t)
   138  				}
   139  				if isInterfaceNode(t) {
   140  					interfaceNodes = append(interfaceNodes, t)
   141  				}
   142  			}
   143  		}
   144  	}
   145  	// Sort for deterministic output.
   146  	sort.Slice(concreteNodes, func(i, j int) bool {
   147  		return concreteNodes[i].Name.Name < concreteNodes[j].Name.Name
   148  	})
   149  	// Generate code for each concrete type.
   150  	for _, t := range concreteNodes {
   151  		processType(t)
   152  	}
   153  	// Add some helpers.
   154  	generateHelpers()
   155  
   156  	// Format and write output.
   157  	out, err := format.Source(buf.Bytes())
   158  	if err != nil {
   159  		// write out mangled source so we can see the bug.
   160  		out = buf.Bytes()
   161  	}
   162  	err = os.WriteFile("node_gen.go", out, 0666)
   163  	if err != nil {
   164  		log.Fatal(err)
   165  	}
   166  }
   167  
   168  // isConcreteNode reports whether the type t is a concrete type
   169  // implementing Node.
   170  func isConcreteNode(t *ast.TypeSpec) bool {
   171  	s, ok := t.Type.(*ast.StructType)
   172  	if !ok {
   173  		return false
   174  	}
   175  	for _, f := range s.Fields.List {
   176  		if isMini(f.Type) {
   177  			return true
   178  		}
   179  	}
   180  	return false
   181  }
   182  
   183  // isInterfaceNode reports whether the type t is an interface type
   184  // implementing Node (including Node itself).
   185  func isInterfaceNode(t *ast.TypeSpec) bool {
   186  	s, ok := t.Type.(*ast.InterfaceType)
   187  	if !ok {
   188  		return false
   189  	}
   190  	if t.Name.Name == "Node" {
   191  		return true
   192  	}
   193  	if t.Name.Name == "OrigNode" || t.Name.Name == "InitNode" {
   194  		// These we exempt from consideration (fields of
   195  		// this type don't need to be walked or copied).
   196  		return false
   197  	}
   198  
   199  	// Look for embedded Node type.
   200  	// Note that this doesn't handle multi-level embedding, but
   201  	// we have none of that at the moment.
   202  	for _, f := range s.Methods.List {
   203  		if len(f.Names) != 0 {
   204  			continue
   205  		}
   206  		if isNamedType(f.Type, "Node") {
   207  			return true
   208  		}
   209  	}
   210  	return false
   211  }
   212  
   213  func processType(t *ast.TypeSpec) {
   214  	name := t.Name.Name
   215  	fmt.Fprintf(&buf, "\n")
   216  	fmt.Fprintf(&buf, "func (n *%s) Format(s fmt.State, verb rune) { fmtNode(n, s, verb) }\n", name)
   217  
   218  	switch name {
   219  	case "Name", "Func":
   220  		// Too specialized to automate.
   221  		return
   222  	}
   223  
   224  	s := t.Type.(*ast.StructType)
   225  	fields := s.Fields.List
   226  
   227  	// Expand any embedded fields.
   228  	for i := 0; i < len(fields); i++ {
   229  		f := fields[i]
   230  		if len(f.Names) != 0 {
   231  			continue	// not embedded
   232  		}
   233  		if isMini(f.Type) {
   234  			// Insert the fields of the embedded type into the main type.
   235  			// (It would be easier just to append, but inserting in place
   236  			// matches the old mknode behavior.)
   237  			ss := mini[f.Type.(*ast.Ident).Name].Type.(*ast.StructType)
   238  			var f2 []*ast.Field
   239  			f2 = append(f2, fields[:i]...)
   240  			f2 = append(f2, ss.Fields.List...)
   241  			f2 = append(f2, fields[i+1:]...)
   242  			fields = f2
   243  			i--
   244  			continue
   245  		} else if isNamedType(f.Type, "origNode") {
   246  			// Ignore this field
   247  			copy(fields[i:], fields[i+1:])
   248  			fields = fields[:len(fields)-1]
   249  			i--
   250  			continue
   251  		} else {
   252  			panic("unknown embedded field " + fmt.Sprintf("%v", f.Type))
   253  		}
   254  	}
   255  	// Process fields.
   256  	var copyBody strings.Builder
   257  	var doChildrenBody strings.Builder
   258  	var editChildrenBody strings.Builder
   259  	var editChildrenWithHiddenBody strings.Builder
   260  	for _, f := range fields {
   261  		names := f.Names
   262  		ft := f.Type
   263  		hidden := false
   264  		if f.Tag != nil {
   265  			tag := f.Tag.Value[1 : len(f.Tag.Value)-1]
   266  			if strings.HasPrefix(tag, "mknode:") {
   267  				if tag[7:] == "\"-\"" {
   268  					if !isNamedType(ft, "Node") {
   269  						continue
   270  					}
   271  					hidden = true
   272  				} else {
   273  					panic(fmt.Sprintf("unexpected tag value: %s", tag))
   274  				}
   275  			}
   276  		}
   277  		if isNamedType(ft, "Nodes") {
   278  			// Nodes == []Node
   279  			ft = &ast.ArrayType{Elt: &ast.Ident{Name: "Node"}}
   280  		}
   281  		isSlice := false
   282  		if a, ok := ft.(*ast.ArrayType); ok && a.Len == nil {
   283  			isSlice = true
   284  			ft = a.Elt
   285  		}
   286  		isPtr := false
   287  		if p, ok := ft.(*ast.StarExpr); ok {
   288  			isPtr = true
   289  			ft = p.X
   290  		}
   291  		if !implementsNode(ft) {
   292  			continue
   293  		}
   294  		for _, name := range names {
   295  			ptr := ""
   296  			if isPtr {
   297  				ptr = "*"
   298  			}
   299  			if isSlice {
   300  				fmt.Fprintf(&editChildrenWithHiddenBody,
   301  					"edit%ss(n.%s, edit)\n", ft, name)
   302  			} else {
   303  				fmt.Fprintf(&editChildrenWithHiddenBody,
   304  					"if n.%s != nil {\nn.%s = edit(n.%s).(%s%s)\n}\n", name, name, name, ptr, ft)
   305  			}
   306  			if hidden {
   307  				continue
   308  			}
   309  			if isSlice {
   310  				fmt.Fprintf(&copyBody, "c.%s = copy%ss(c.%s)\n", name, ft, name)
   311  				fmt.Fprintf(&doChildrenBody,
   312  					"if do%ss(n.%s, do) {\nreturn true\n}\n", ft, name)
   313  				fmt.Fprintf(&editChildrenBody,
   314  					"edit%ss(n.%s, edit)\n", ft, name)
   315  			} else {
   316  				fmt.Fprintf(&doChildrenBody,
   317  					"if n.%s != nil && do(n.%s) {\nreturn true\n}\n", name, name)
   318  				fmt.Fprintf(&editChildrenBody,
   319  					"if n.%s != nil {\nn.%s = edit(n.%s).(%s%s)\n}\n", name, name, name, ptr, ft)
   320  			}
   321  		}
   322  	}
   323  	fmt.Fprintf(&buf, "func (n *%s) copy() Node {\nc := *n\n", name)
   324  	buf.WriteString(copyBody.String())
   325  	fmt.Fprintf(&buf, "return &c\n}\n")
   326  	fmt.Fprintf(&buf, "func (n *%s) doChildren(do func(Node) bool) bool {\n", name)
   327  	buf.WriteString(doChildrenBody.String())
   328  	fmt.Fprintf(&buf, "return false\n}\n")
   329  	fmt.Fprintf(&buf, "func (n *%s) editChildren(edit func(Node) Node) {\n", name)
   330  	buf.WriteString(editChildrenBody.String())
   331  	fmt.Fprintf(&buf, "}\n")
   332  	fmt.Fprintf(&buf, "func (n *%s) editChildrenWithHidden(edit func(Node) Node) {\n", name)
   333  	buf.WriteString(editChildrenWithHiddenBody.String())
   334  	fmt.Fprintf(&buf, "}\n")
   335  }
   336  
   337  func generateHelpers() {
   338  	for _, typ := range []string{"CaseClause", "CommClause", "Name", "Node"} {
   339  		ptr := "*"
   340  		if typ == "Node" {
   341  			ptr = ""	// interfaces don't need *
   342  		}
   343  		fmt.Fprintf(&buf, "\n")
   344  		fmt.Fprintf(&buf, "func copy%ss(list []%s%s) []%s%s {\n", typ, ptr, typ, ptr, typ)
   345  		fmt.Fprintf(&buf, "if list == nil { return nil }\n")
   346  		fmt.Fprintf(&buf, "c := make([]%s%s, len(list))\n", ptr, typ)
   347  		fmt.Fprintf(&buf, "copy(c, list)\n")
   348  		fmt.Fprintf(&buf, "return c\n")
   349  		fmt.Fprintf(&buf, "}\n")
   350  		fmt.Fprintf(&buf, "func do%ss(list []%s%s, do func(Node) bool) bool {\n", typ, ptr, typ)
   351  		fmt.Fprintf(&buf, "for _, x := range list {\n")
   352  		fmt.Fprintf(&buf, "if x != nil && do(x) {\n")
   353  		fmt.Fprintf(&buf, "return true\n")
   354  		fmt.Fprintf(&buf, "}\n")
   355  		fmt.Fprintf(&buf, "}\n")
   356  		fmt.Fprintf(&buf, "return false\n")
   357  		fmt.Fprintf(&buf, "}\n")
   358  		fmt.Fprintf(&buf, "func edit%ss(list []%s%s, edit func(Node) Node) {\n", typ, ptr, typ)
   359  		fmt.Fprintf(&buf, "for i, x := range list {\n")
   360  		fmt.Fprintf(&buf, "if x != nil {\n")
   361  		fmt.Fprintf(&buf, "list[i] = edit(x).(%s%s)\n", ptr, typ)
   362  		fmt.Fprintf(&buf, "}\n")
   363  		fmt.Fprintf(&buf, "}\n")
   364  		fmt.Fprintf(&buf, "}\n")
   365  	}
   366  }