github.com/jgarto/itcv@v0.0.0-20180826224514-4eea09c1aa0d/cmd/stateGen/gen.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/ast"
     7  	"go/parser"
     8  	"go/printer"
     9  	"go/token"
    10  	"html/template"
    11  	"io"
    12  	"io/ioutil"
    13  	"os"
    14  	"path/filepath"
    15  	"strings"
    16  	"unicode"
    17  	"unicode/utf8"
    18  
    19  	"golang.org/x/tools/imports"
    20  
    21  	"myitcv.io/gogenerate"
    22  )
    23  
    24  const (
    25  	rootVar      = "root"
    26  	leafTypeName = "Leaf"
    27  	nodePrefix   = "_Node_"
    28  )
    29  
    30  type gen struct {
    31  	fset *token.FileSet
    32  
    33  	dir string
    34  
    35  	buf *bytes.Buffer
    36  
    37  	rootType *node
    38  
    39  	roots   []*ast.ValueSpec
    40  	nodes   map[string]node
    41  	pkg     *ast.Package
    42  	pkgName string
    43  
    44  	imports map[*ast.ImportSpec]bool
    45  
    46  	file *ast.File
    47  
    48  	// a map from the _Node_XYZ name
    49  	seenNodes    map[string]bool
    50  	nodesToVisit []node
    51  
    52  	seenLeaves    map[string]bool
    53  	leavesToVisit []leafField
    54  
    55  	stderr io.Writer
    56  	failed bool
    57  }
    58  
    59  func dogen(stderr io.Writer, dir, license string) bool {
    60  	fset := token.NewFileSet()
    61  
    62  	notGenByUs := func(fi os.FileInfo) bool {
    63  		return !gogenerate.FileGeneratedBy(fi.Name(), stateGenCmd)
    64  	}
    65  
    66  	pkgs, err := parser.ParseDir(fset, dir, notGenByUs, 0)
    67  	if err != nil {
    68  		panic(fmt.Errorf("unable to parse directory %v: %v", dir, err))
    69  	}
    70  
    71  	failed := false
    72  
    73  	for pn, pkg := range pkgs {
    74  		g := &gen{
    75  			fset:    fset,
    76  			dir:     dir,
    77  			pkg:     pkg,
    78  			pkgName: pn,
    79  
    80  			buf: bytes.NewBuffer(nil),
    81  
    82  			imports: make(map[*ast.ImportSpec]bool),
    83  
    84  			nodes:      make(map[string]node),
    85  			seenNodes:  make(map[string]bool),
    86  			seenLeaves: make(map[string]bool),
    87  
    88  			stderr: stderr,
    89  		}
    90  
    91  		g.parse()
    92  		if !g.ok() {
    93  			failed = true
    94  			continue
    95  		}
    96  
    97  		if g.rootType == nil {
    98  			continue
    99  		}
   100  
   101  		g.pf("// Code generated by %v; DO NOT EDIT.\n", stateGenCmd)
   102  		g.pln()
   103  
   104  		g.pf("package %v\n", pn)
   105  
   106  		g.pf(`
   107  		import "path"
   108  		`)
   109  
   110  		for i := range g.imports {
   111  			if i.Name != nil {
   112  				g.pf("import %v %v\n", i.Name.Name, i.Path.Value)
   113  			} else {
   114  				g.pf("import %v\n", i.Path.Value)
   115  			}
   116  		}
   117  
   118  		g.gen()
   119  
   120  		fn := gogenerate.NameFile(pn, stateGenCmd)
   121  		fp := filepath.Join(dir, fn)
   122  
   123  		toWrite := g.buf.Bytes()
   124  
   125  		res, err := imports.Process(fn, toWrite, nil)
   126  		if err == nil {
   127  			toWrite = res
   128  		}
   129  
   130  		if err := ioutil.WriteFile(fp, toWrite, 0644); err != nil {
   131  			panic(fmt.Errorf("unable to write to %v: %v", fp, err))
   132  		}
   133  	}
   134  
   135  	return !failed
   136  }
   137  
   138  type node struct {
   139  	Name     string
   140  	children []field
   141  	leaves   []leafField
   142  }
   143  
   144  type field struct {
   145  	Name string
   146  	Type string
   147  }
   148  
   149  type leafField struct {
   150  	Name     string
   151  	Type     string
   152  	LeafType string
   153  }
   154  
   155  func (g *gen) parse() {
   156  	for _, f := range g.pkg.Files {
   157  		g.file = f
   158  
   159  		for _, d := range f.Decls {
   160  			gd, ok := d.(*ast.GenDecl)
   161  			if !ok {
   162  				continue
   163  			}
   164  
   165  			switch gd.Tok {
   166  			case token.TYPE:
   167  				for _, s := range gd.Specs {
   168  					g.parseNode(s.(*ast.TypeSpec))
   169  				}
   170  			case token.VAR:
   171  				for _, s := range gd.Specs {
   172  					s := s.(*ast.ValueSpec)
   173  
   174  					if len(s.Names) != 1 {
   175  						continue
   176  					}
   177  
   178  					if s.Names[0].Name != rootVar {
   179  						continue
   180  					}
   181  
   182  					_, ok := s.Type.(*ast.Ident)
   183  					if !ok {
   184  						continue
   185  					}
   186  
   187  					g.roots = append(g.roots, s)
   188  				}
   189  			}
   190  		}
   191  	}
   192  }
   193  
   194  func (g *gen) parseNode(s *ast.TypeSpec) {
   195  	st, ok := s.Type.(*ast.StructType)
   196  	if !ok {
   197  		return
   198  	}
   199  
   200  	tn := s.Name.Name
   201  
   202  	if !strings.HasPrefix(tn, nodePrefix) {
   203  		return
   204  	}
   205  
   206  	n := strings.TrimPrefix(tn, nodePrefix)
   207  
   208  	var children []field
   209  	var leaves []leafField
   210  
   211  	for _, f := range st.Fields.List {
   212  		var id *ast.Ident
   213  
   214  		switch typ := f.Type.(type) {
   215  		case *ast.Ident:
   216  			id = typ
   217  		case *ast.StarExpr:
   218  			if v, ok := typ.X.(*ast.Ident); ok {
   219  				id = v
   220  			}
   221  		}
   222  
   223  		if id != nil && strings.HasPrefix(id.Name, nodePrefix) {
   224  			for _, n := range f.Names {
   225  				children = append(children, field{
   226  					Name: n.Name,
   227  					Type: strings.TrimPrefix(id.Name, nodePrefix),
   228  				})
   229  			}
   230  		} else {
   231  			typ, leafTyp := g.addImports(f.Type)
   232  			for _, n := range f.Names {
   233  				leaves = append(leaves, leafField{
   234  					Name:     n.Name,
   235  					Type:     typ,
   236  					LeafType: leafTyp,
   237  				})
   238  			}
   239  		}
   240  	}
   241  
   242  	g.nodes[s.Name.Name] = node{
   243  		Name:     n,
   244  		children: children,
   245  		leaves:   leaves,
   246  	}
   247  }
   248  
   249  func (g *gen) addImports(exp ast.Expr) (string, string) {
   250  	finder := &importFinder{
   251  		imports: g.file.Imports,
   252  		matches: g.imports,
   253  	}
   254  
   255  	ast.Walk(finder, exp)
   256  
   257  	es := g.expString(exp)
   258  
   259  	s := strings.Replace(es, ".", "", -1)
   260  
   261  	if strings.HasPrefix(s, "*") {
   262  		s = strings.TrimPrefix(s, "*")
   263  		s = s + "P"
   264  	}
   265  
   266  	r, l := utf8.DecodeRune([]byte(s))
   267  
   268  	return es, string(unicode.ToUpper(r)) + string(s[l:]) + leafTypeName
   269  }
   270  
   271  func (g *gen) ok() bool {
   272  	if len(g.roots) == 0 {
   273  		return true
   274  	}
   275  
   276  	if v := len(g.roots); v > 1 {
   277  		g.errorf("expected 1 root, found %v\n", v)
   278  		for _, v := range g.roots {
   279  			g.errorf("  %v\n", g.fset.Position(v.Pos()))
   280  		}
   281  
   282  		return false
   283  	}
   284  
   285  	r := g.roots[0]
   286  	rtn := r.Type.(*ast.Ident)
   287  
   288  	rt, ok := g.nodes[rtn.Name]
   289  	if !ok {
   290  		g.errorf("need root type to be node type; instead it was %v", rtn.Name)
   291  	}
   292  
   293  	g.rootType = &rt
   294  
   295  	return !g.failed
   296  }
   297  
   298  func (g *gen) gen() {
   299  	var h node
   300  	var l leafField
   301  
   302  	g.nodesToVisit = append(g.nodesToVisit, *g.rootType)
   303  
   304  	for len(g.nodesToVisit) != 0 {
   305  		h, g.nodesToVisit = g.nodesToVisit[0], g.nodesToVisit[1:]
   306  		g.genNode(h)
   307  	}
   308  
   309  	for len(g.leavesToVisit) != 0 {
   310  		l, g.leavesToVisit = g.leavesToVisit[0], g.leavesToVisit[1:]
   311  		g.genLeaf(l)
   312  	}
   313  
   314  	g.pt(`
   315  	func NewRoot() *{{.Name}} {
   316  		r := &rootNode{
   317  			store: make(map[string]interface{}),
   318  			cbs:   make(map[string]map[*Sub]struct{}),
   319  			subs:  make(map[*Sub]struct{}),
   320  		}
   321  
   322  		return new{{.Name}}(r, "")
   323  	}
   324  	`, g.rootType)
   325  
   326  	g.pf(`
   327  	type Node interface {
   328  		Subscribe(cb func()) *Sub
   329  	}
   330  
   331  	type Sub struct {
   332  		*rootNode
   333  		prefix string
   334  		cb     func()
   335  	}
   336  
   337  	func (s *Sub) Clear() {
   338  		s.rootNode.unsubscribe(s)
   339  	}
   340  
   341  	var NoSuchSubErr = errors.New("No such sub")
   342  
   343  	type rootNode struct {
   344  		store map[string]interface{}
   345  		cbs   map[string]map[*Sub]struct{}
   346  		subs  map[*Sub]struct{}
   347  	}
   348  
   349  	func (r *rootNode) subscribe(prefix string, cb func()) *Sub {
   350  
   351  		res := &Sub{
   352  			cb:     cb,
   353  			prefix: prefix,
   354  			rootNode: r,
   355  		}
   356  
   357  		l, ok := r.cbs[prefix]
   358  		if !ok {
   359  			l = make(map[*Sub]struct{})
   360  			r.cbs[prefix] = l
   361  		}
   362  
   363  		l[res] = struct{}{}
   364  		r.subs[res] = struct{}{}
   365  
   366  		return res
   367  	}
   368  
   369  	func (r *rootNode) unsubscribe(s *Sub) {
   370  		if _, ok := r.subs[s]; !ok {
   371  			panic(NoSuchSubErr)
   372  		}
   373  
   374  		l, ok := r.cbs[s.prefix]
   375  		if !ok {
   376  			panic("Real problems...")
   377  		}
   378  
   379  		delete(l, s)
   380  		delete(r.subs, s)
   381  	}
   382  
   383  	func (r *rootNode) get(k string) (interface{}, bool) {
   384  		v, ok := r.store[k]
   385  		return v, ok
   386  	}
   387  
   388  	func (r rootNode) set(k string, v interface{}) {
   389  		if curr, ok := r.store[k]; ok && v == curr {
   390  			return
   391  		}
   392  
   393  		r.store[k] = v
   394  
   395  		parts := strings.Split(k, "/")
   396  
   397  		var subs []*Sub
   398  
   399  		var kk string
   400  
   401  		for _, p := range parts {
   402  			kk = path.Join(kk, p)
   403  
   404  			if ll, ok := r.cbs[kk]; ok {
   405  				for k := range ll {
   406  					subs = append(subs, k)
   407  				}
   408  			}
   409  
   410  		}
   411  
   412  		for _, s := range subs {
   413  			s.cb()
   414  		}
   415  	}
   416  	`)
   417  }
   418  
   419  func (g *gen) genLeaf(n leafField) {
   420  	g.pt(`
   421  	type {{.LeafType}} struct {
   422  		*rootNode
   423  		prefix string
   424  	}
   425  
   426  	var _ Node = new({{.LeafType}})
   427  
   428  	func new{{.LeafType}}(r *rootNode, prefix string) *{{.LeafType}} {
   429  		prefix = path.Join(prefix, "{{.LeafType}}")
   430  
   431  		return &{{.LeafType}}{
   432  			rootNode:   r,
   433  			prefix: prefix,
   434  		}
   435  	}
   436  
   437  	func (m *{{.LeafType}}) Get() {{.Type}} {
   438  		var res {{.Type}}
   439  		if v, ok := m.rootNode.get(m.prefix); ok {
   440  			return v.({{.Type}})
   441  		}
   442  		return res
   443  	}
   444  
   445  	func (m *{{.LeafType}}) Set(v {{.Type}}) {
   446  		m.rootNode.set(m.prefix, v)
   447  	}
   448  
   449  	func (m *{{.LeafType}}) Subscribe(cb func()) *Sub {
   450  		return m.rootNode.subscribe(m.prefix, cb)
   451  	}
   452  	`, n)
   453  
   454  }
   455  
   456  func (g *gen) genNode(n node) {
   457  	g.pt(`
   458  	var _ Node = new({{.Name}})
   459  
   460  	type {{.Name}} struct {
   461  		*rootNode
   462  		prefix string
   463  
   464  	`, n)
   465  
   466  	for _, c := range n.children {
   467  		if !g.seenNodes[c.Type] {
   468  			g.nodesToVisit = append(g.nodesToVisit, g.nodes[nodePrefix+c.Type])
   469  			g.seenNodes[c.Type] = true
   470  		}
   471  
   472  		g.pt(`
   473  		_{{.Name}} *{{.Type}}
   474  		`, c)
   475  	}
   476  
   477  	for _, l := range n.leaves {
   478  		if !g.seenLeaves[l.LeafType] {
   479  			g.leavesToVisit = append(g.leavesToVisit, l)
   480  			g.seenLeaves[l.LeafType] = true
   481  		}
   482  
   483  		g.pt(`
   484  		_{{.Name}} *{{.LeafType}}
   485  		`, l)
   486  	}
   487  
   488  	g.pt(`
   489  	}
   490  
   491  	func new{{.Name}}(r *rootNode, prefix string) *{{.Name}} {
   492  		prefix = path.Join(prefix, "{{.Name}}")
   493  
   494  		res := &{{.Name}}{
   495  			rootNode:   r,
   496  			prefix: prefix,
   497  		}
   498  	`, n)
   499  
   500  	for _, c := range n.children {
   501  		g.pt(`
   502  		res._{{.Name}} = new{{.Type}}(r, prefix)
   503  		`, c)
   504  	}
   505  	for _, l := range n.leaves {
   506  		g.pt(`
   507  		res._{{.Name}} = new{{.LeafType}}(r, prefix)
   508  		`, l)
   509  	}
   510  
   511  	g.pt(`
   512  		return res
   513  	}
   514  
   515  	func (n *{{.Name}}) Subscribe(cb func()) *Sub {
   516  		return n.rootNode.subscribe(n.prefix, cb)
   517  	}
   518  	`, n)
   519  
   520  	for _, c := range n.children {
   521  		tmpl := struct {
   522  			Node  node
   523  			Child field
   524  		}{
   525  			Node:  n,
   526  			Child: c,
   527  		}
   528  		g.pt(`
   529  		func (n *{{.Node.Name}}) {{.Child.Name}}() *{{.Child.Type}} {
   530  			return n._{{.Child.Name}}
   531  		}
   532  		`, tmpl)
   533  	}
   534  
   535  	for _, l := range n.leaves {
   536  		tmpl := struct {
   537  			Node node
   538  			Leaf leafField
   539  		}{
   540  			Node: n,
   541  			Leaf: l,
   542  		}
   543  		g.pt(`
   544  		func (n *{{.Node.Name}}) {{.Leaf.Name}}() *{{.Leaf.LeafType}} {
   545  			return n._{{.Leaf.Name}}
   546  		}
   547  		`, tmpl)
   548  	}
   549  }
   550  
   551  func (g *gen) errorf(format string, args ...interface{}) {
   552  	g.failed = true
   553  	fmt.Fprintf(g.stderr, format, args...)
   554  }
   555  
   556  func (g *gen) expString(e interface{}) string {
   557  	b := bytes.NewBuffer(nil)
   558  	err := printer.Fprint(b, g.fset, e)
   559  	if err != nil {
   560  		panic(err)
   561  	}
   562  
   563  	return b.String()
   564  }
   565  
   566  func (g *gen) pf(format string, args ...interface{}) {
   567  	fmt.Fprintf(g.buf, format, args...)
   568  }
   569  
   570  func (g *gen) pln(args ...interface{}) {
   571  	fmt.Fprintln(g.buf, args...)
   572  }
   573  
   574  func (g *gen) pt(tmpl string, val interface{}) {
   575  	// on the basis most templates are for convenience define inline
   576  	// as raw string literals which start the ` on one line but then start
   577  	// the template on the next (for readability) we strip the first leading
   578  	// \n if one exists
   579  	tmpl = strings.TrimPrefix(tmpl, "\n")
   580  
   581  	t := template.New("tmp")
   582  
   583  	_, err := t.Parse(tmpl)
   584  	if err != nil {
   585  		fatalf("unable to parse template: %v", err)
   586  	}
   587  
   588  	err = t.Execute(g.buf, val)
   589  	if err != nil {
   590  		fatalf("cannot execute template: %v", err)
   591  	}
   592  }
   593  
   594  type importFinder struct {
   595  	imports []*ast.ImportSpec
   596  	matches map[*ast.ImportSpec]bool
   597  }
   598  
   599  func (i *importFinder) Visit(node ast.Node) ast.Visitor {
   600  	switch node := node.(type) {
   601  	case *ast.SelectorExpr:
   602  		if x, ok := node.X.(*ast.Ident); ok {
   603  			for _, imp := range i.imports {
   604  				if imp.Name != nil {
   605  					if x.Name == imp.Name.Name {
   606  						i.matches[imp] = true
   607  					}
   608  				} else {
   609  					cleanPath := strings.Trim(imp.Path.Value, "\"")
   610  					parts := strings.Split(cleanPath, "/")
   611  					if x.Name == parts[len(parts)-1] {
   612  						i.matches[imp] = true
   613  					}
   614  				}
   615  			}
   616  
   617  		}
   618  	}
   619  
   620  	return i
   621  }