github.com/april1989/origin-go-tools@v0.0.32/internal/lsp/helper/helper.go (about)

     1  // Invoke with //go:generate helper/helper -t Server -d protocol/tsserver.go -u lsp -o server_gen.go
     2  // invoke in internal/lsp
     3  package main
     4  
     5  import (
     6  	"bytes"
     7  	"flag"
     8  	"fmt"
     9  	"go/ast"
    10  	"go/format"
    11  	"go/parser"
    12  	"go/token"
    13  	"log"
    14  	"os"
    15  	"sort"
    16  	"strings"
    17  	"text/template"
    18  )
    19  
    20  var (
    21  	typ = flag.String("t", "Server", "generate code for this type")
    22  	def = flag.String("d", "", "the file the type is defined in") // this relies on punning
    23  	use = flag.String("u", "", "look for uses in this package")
    24  	out = flag.String("o", "", "where to write the generated file")
    25  )
    26  
    27  func main() {
    28  	log.SetFlags(log.Lshortfile)
    29  	flag.Parse()
    30  	if *typ == "" || *def == "" || *use == "" || *out == "" {
    31  		flag.PrintDefaults()
    32  		return
    33  	}
    34  	// read the type definition and see what methods we're looking for
    35  	doTypes()
    36  
    37  	// parse the package and see which methods are defined
    38  	doUses()
    39  
    40  	output()
    41  }
    42  
    43  // replace "\\\n" with nothing before using
    44  var tmpl = `
    45  package lsp
    46  
    47  // code generated by helper. DO NOT EDIT.
    48  
    49  import (
    50  	"context"
    51  
    52  	"github.com/april1989/origin-go-tools/internal/lsp/protocol"
    53  )
    54  
    55  {{range $key, $v := .Stuff}}
    56  func (s *{{$.Type}}) {{$v.Name}}({{.Param}}) {{.Result}} {
    57  	{{if ne .Found ""}} return s.{{.Internal}}({{.Invoke}})\
    58  	{{else}}return {{if lt 1 (len .Results)}}nil, {{end}}notImplemented("{{.Name}}"){{end}}
    59  }
    60  {{end}}
    61  `
    62  
    63  func output() {
    64  	// put in empty param names as needed
    65  	for _, t := range types {
    66  		if t.paramnames == nil {
    67  			t.paramnames = make([]string, len(t.paramtypes))
    68  		}
    69  		for i, p := range t.paramtypes {
    70  			cm := ""
    71  			if i > 0 {
    72  				cm = ", "
    73  			}
    74  			t.Param += fmt.Sprintf("%s%s %s", cm, t.paramnames[i], p)
    75  			t.Invoke += fmt.Sprintf("%s%s", cm, t.paramnames[i])
    76  		}
    77  		if len(t.Results) > 1 {
    78  			t.Result = "("
    79  		}
    80  		for i, r := range t.Results {
    81  			cm := ""
    82  			if i > 0 {
    83  				cm = ", "
    84  			}
    85  			t.Result += fmt.Sprintf("%s%s", cm, r)
    86  		}
    87  		if len(t.Results) > 1 {
    88  			t.Result += ")"
    89  		}
    90  	}
    91  
    92  	fd, err := os.Create(*out)
    93  	if err != nil {
    94  		log.Fatal(err)
    95  	}
    96  	t, err := template.New("foo").Parse(tmpl)
    97  	if err != nil {
    98  		log.Fatal(err)
    99  	}
   100  	type par struct {
   101  		Type  string
   102  		Stuff []*Function
   103  	}
   104  	p := par{*typ, types}
   105  	if false { // debugging the template
   106  		t.Execute(os.Stderr, &p)
   107  	}
   108  	buf := bytes.NewBuffer(nil)
   109  	err = t.Execute(buf, &p)
   110  	if err != nil {
   111  		log.Fatal(err)
   112  	}
   113  	ans, err := format.Source(bytes.Replace(buf.Bytes(), []byte("\\\n"), []byte{}, -1))
   114  	if err != nil {
   115  		log.Fatal(err)
   116  	}
   117  	fd.Write(ans)
   118  }
   119  
   120  func doUses() {
   121  	fset := token.NewFileSet()
   122  	pkgs, err := parser.ParseDir(fset, *use, nil, 0)
   123  	if err != nil {
   124  		log.Fatalf("%q:%v", *use, err)
   125  	}
   126  	pkg := pkgs["lsp"] // CHECK
   127  	files := pkg.Files
   128  	for fname, f := range files {
   129  		for _, d := range f.Decls {
   130  			fd, ok := d.(*ast.FuncDecl)
   131  			if !ok {
   132  				continue
   133  			}
   134  			nm := fd.Name.String()
   135  			if ast.IsExported(nm) {
   136  				// we're looking for things like didChange
   137  				continue
   138  			}
   139  			if fx, ok := byname[nm]; ok {
   140  				if fx.Found != "" {
   141  					log.Fatalf("found %s in %s and %s", fx.Internal, fx.Found, fname)
   142  				}
   143  				fx.Found = fname
   144  				// and the Paramnames
   145  				ft := fd.Type
   146  				for _, f := range ft.Params.List {
   147  					nm := ""
   148  					if len(f.Names) > 0 {
   149  						nm = f.Names[0].String()
   150  					}
   151  					fx.paramnames = append(fx.paramnames, nm)
   152  				}
   153  			}
   154  		}
   155  	}
   156  	if false {
   157  		for i, f := range types {
   158  			log.Printf("%d %s %s", i, f.Internal, f.Found)
   159  		}
   160  	}
   161  }
   162  
   163  type Function struct {
   164  	Name       string
   165  	Internal   string // first letter lower case
   166  	paramtypes []string
   167  	paramnames []string
   168  	Results    []string
   169  	Param      string
   170  	Result     string // do it in code, easier than in a template
   171  	Invoke     string
   172  	Found      string // file it was found in
   173  }
   174  
   175  var types []*Function
   176  var byname = map[string]*Function{} // internal names
   177  
   178  func doTypes() {
   179  	fset := token.NewFileSet()
   180  	f, err := parser.ParseFile(fset, *def, nil, 0)
   181  	if err != nil {
   182  		log.Fatal(err)
   183  	}
   184  	fd, err := os.Create("/tmp/ast")
   185  	if err != nil {
   186  		log.Fatal(err)
   187  	}
   188  	ast.Fprint(fd, fset, f, ast.NotNilFilter)
   189  	ast.Inspect(f, inter)
   190  	sort.Slice(types, func(i, j int) bool { return types[i].Name < types[j].Name })
   191  	if false {
   192  		for i, f := range types {
   193  			log.Printf("%d %s(%v) %v", i, f.Name, f.paramtypes, f.Results)
   194  		}
   195  	}
   196  }
   197  
   198  func inter(n ast.Node) bool {
   199  	x, ok := n.(*ast.TypeSpec)
   200  	if !ok || x.Name.Name != *typ {
   201  		return true
   202  	}
   203  	m := x.Type.(*ast.InterfaceType).Methods.List
   204  	for _, fld := range m {
   205  		fn := fld.Type.(*ast.FuncType)
   206  		p := fn.Params.List
   207  		r := fn.Results.List
   208  		fx := &Function{
   209  			Name: fld.Names[0].String(),
   210  		}
   211  		fx.Internal = strings.ToLower(fx.Name[:1]) + fx.Name[1:]
   212  		for _, f := range p {
   213  			fx.paramtypes = append(fx.paramtypes, whatis(f.Type))
   214  		}
   215  		for _, f := range r {
   216  			fx.Results = append(fx.Results, whatis(f.Type))
   217  		}
   218  		types = append(types, fx)
   219  		byname[fx.Internal] = fx
   220  	}
   221  	return false
   222  }
   223  
   224  func whatis(x ast.Expr) string {
   225  	switch n := x.(type) {
   226  	case *ast.SelectorExpr:
   227  		return whatis(n.X) + "." + n.Sel.String()
   228  	case *ast.StarExpr:
   229  		return "*" + whatis(n.X)
   230  	case *ast.Ident:
   231  		if ast.IsExported(n.Name) {
   232  			// these are from package protocol
   233  			return "protocol." + n.Name
   234  		}
   235  		return n.Name
   236  	case *ast.ArrayType:
   237  		return "[]" + whatis(n.Elt)
   238  	case *ast.InterfaceType:
   239  		return "interface{}"
   240  	default:
   241  		log.Fatalf("Fatal %T", x)
   242  		return fmt.Sprintf("%T", x)
   243  	}
   244  }