github.com/jd-ly/tools@v0.5.7/internal/lsp/helper/helper.go (about)

     1  // Invoke with //go:generate helper/helper -t Server -d protocol/tsserver.go -u lsp -o server_gen.go
     2  // invoke in internal/lsp
     3  package main
     4  
     5  import (
     6  	"bytes"
     7  	"flag"
     8  	"fmt"
     9  	"go/ast"
    10  	"go/format"
    11  	"go/parser"
    12  	"go/token"
    13  	"log"
    14  	"os"
    15  	"sort"
    16  	"strings"
    17  	"text/template"
    18  )
    19  
    20  var (
    21  	typ = flag.String("t", "Server", "generate code for this type")
    22  	def = flag.String("d", "", "the file the type is defined in") // this relies on punning
    23  	use = flag.String("u", "", "look for uses in this package")
    24  	out = flag.String("o", "", "where to write the generated file")
    25  )
    26  
    27  func main() {
    28  	log.SetFlags(log.Lshortfile)
    29  	flag.Parse()
    30  	if *typ == "" || *def == "" || *use == "" || *out == "" {
    31  		flag.PrintDefaults()
    32  		return
    33  	}
    34  	// read the type definition and see what methods we're looking for
    35  	doTypes()
    36  
    37  	// parse the package and see which methods are defined
    38  	doUses()
    39  
    40  	output()
    41  }
    42  
    43  // replace "\\\n" with nothing before using
    44  var tmpl = `
    45  package lsp
    46  
    47  // code generated by helper. DO NOT EDIT.
    48  
    49  import (
    50  	"context"
    51  
    52  	"github.com/jd-ly/tools/internal/lsp/protocol"
    53  )
    54  
    55  {{range $key, $v := .Stuff}}
    56  func (s *{{$.Type}}) {{$v.Name}}({{.Param}}) {{.Result}} {
    57  	{{if ne .Found ""}} return s.{{.Internal}}({{.Invoke}})\
    58  	{{else}}return {{if lt 1 (len .Results)}}nil, {{end}}notImplemented("{{.Name}}"){{end}}
    59  }
    60  {{end}}
    61  `
    62  
    63  func output() {
    64  	// put in empty param names as needed
    65  	for _, t := range types {
    66  		if t.paramnames == nil {
    67  			t.paramnames = make([]string, len(t.paramtypes))
    68  		}
    69  		for i, p := range t.paramtypes {
    70  			cm := ""
    71  			if i > 0 {
    72  				cm = ", "
    73  			}
    74  			t.Param += fmt.Sprintf("%s%s %s", cm, t.paramnames[i], p)
    75  			this := t.paramnames[i]
    76  			if this == "_" {
    77  				this = "nil"
    78  			}
    79  			t.Invoke += fmt.Sprintf("%s%s", cm, this)
    80  		}
    81  		if len(t.Results) > 1 {
    82  			t.Result = "("
    83  		}
    84  		for i, r := range t.Results {
    85  			cm := ""
    86  			if i > 0 {
    87  				cm = ", "
    88  			}
    89  			t.Result += fmt.Sprintf("%s%s", cm, r)
    90  		}
    91  		if len(t.Results) > 1 {
    92  			t.Result += ")"
    93  		}
    94  	}
    95  
    96  	fd, err := os.Create(*out)
    97  	if err != nil {
    98  		log.Fatal(err)
    99  	}
   100  	t, err := template.New("foo").Parse(tmpl)
   101  	if err != nil {
   102  		log.Fatal(err)
   103  	}
   104  	type par struct {
   105  		Type  string
   106  		Stuff []*Function
   107  	}
   108  	p := par{*typ, types}
   109  	if false { // debugging the template
   110  		t.Execute(os.Stderr, &p)
   111  	}
   112  	buf := bytes.NewBuffer(nil)
   113  	err = t.Execute(buf, &p)
   114  	if err != nil {
   115  		log.Fatal(err)
   116  	}
   117  	ans, err := format.Source(bytes.Replace(buf.Bytes(), []byte("\\\n"), []byte{}, -1))
   118  	if err != nil {
   119  		log.Fatal(err)
   120  	}
   121  	fd.Write(ans)
   122  }
   123  
   124  func doUses() {
   125  	fset := token.NewFileSet()
   126  	pkgs, err := parser.ParseDir(fset, *use, nil, 0)
   127  	if err != nil {
   128  		log.Fatalf("%q:%v", *use, err)
   129  	}
   130  	pkg := pkgs["lsp"] // CHECK
   131  	files := pkg.Files
   132  	for fname, f := range files {
   133  		for _, d := range f.Decls {
   134  			fd, ok := d.(*ast.FuncDecl)
   135  			if !ok {
   136  				continue
   137  			}
   138  			nm := fd.Name.String()
   139  			if ast.IsExported(nm) {
   140  				// we're looking for things like didChange
   141  				continue
   142  			}
   143  			if fx, ok := byname[nm]; ok {
   144  				if fx.Found != "" {
   145  					log.Fatalf("found %s in %s and %s", fx.Internal, fx.Found, fname)
   146  				}
   147  				fx.Found = fname
   148  				// and the Paramnames
   149  				ft := fd.Type
   150  				for _, f := range ft.Params.List {
   151  					nm := ""
   152  					if len(f.Names) > 0 {
   153  						nm = f.Names[0].String()
   154  					}
   155  					fx.paramnames = append(fx.paramnames, nm)
   156  				}
   157  			}
   158  		}
   159  	}
   160  	if false {
   161  		for i, f := range types {
   162  			log.Printf("%d %s %s", i, f.Internal, f.Found)
   163  		}
   164  	}
   165  }
   166  
   167  type Function struct {
   168  	Name       string
   169  	Internal   string // first letter lower case
   170  	paramtypes []string
   171  	paramnames []string
   172  	Results    []string
   173  	Param      string
   174  	Result     string // do it in code, easier than in a template
   175  	Invoke     string
   176  	Found      string // file it was found in
   177  }
   178  
   179  var types []*Function
   180  var byname = map[string]*Function{} // internal names
   181  
   182  func doTypes() {
   183  	fset := token.NewFileSet()
   184  	f, err := parser.ParseFile(fset, *def, nil, 0)
   185  	if err != nil {
   186  		log.Fatal(err)
   187  	}
   188  	fd, err := os.Create("/tmp/ast")
   189  	if err != nil {
   190  		log.Fatal(err)
   191  	}
   192  	ast.Fprint(fd, fset, f, ast.NotNilFilter)
   193  	ast.Inspect(f, inter)
   194  	sort.Slice(types, func(i, j int) bool { return types[i].Name < types[j].Name })
   195  	if false {
   196  		for i, f := range types {
   197  			log.Printf("%d %s(%v) %v", i, f.Name, f.paramtypes, f.Results)
   198  		}
   199  	}
   200  }
   201  
   202  func inter(n ast.Node) bool {
   203  	x, ok := n.(*ast.TypeSpec)
   204  	if !ok || x.Name.Name != *typ {
   205  		return true
   206  	}
   207  	m := x.Type.(*ast.InterfaceType).Methods.List
   208  	for _, fld := range m {
   209  		fn := fld.Type.(*ast.FuncType)
   210  		p := fn.Params.List
   211  		r := fn.Results.List
   212  		fx := &Function{
   213  			Name: fld.Names[0].String(),
   214  		}
   215  		fx.Internal = strings.ToLower(fx.Name[:1]) + fx.Name[1:]
   216  		for _, f := range p {
   217  			fx.paramtypes = append(fx.paramtypes, whatis(f.Type))
   218  		}
   219  		for _, f := range r {
   220  			fx.Results = append(fx.Results, whatis(f.Type))
   221  		}
   222  		types = append(types, fx)
   223  		byname[fx.Internal] = fx
   224  	}
   225  	return false
   226  }
   227  
   228  func whatis(x ast.Expr) string {
   229  	switch n := x.(type) {
   230  	case *ast.SelectorExpr:
   231  		return whatis(n.X) + "." + n.Sel.String()
   232  	case *ast.StarExpr:
   233  		return "*" + whatis(n.X)
   234  	case *ast.Ident:
   235  		if ast.IsExported(n.Name) {
   236  			// these are from package protocol
   237  			return "protocol." + n.Name
   238  		}
   239  		return n.Name
   240  	case *ast.ArrayType:
   241  		return "[]" + whatis(n.Elt)
   242  	case *ast.InterfaceType:
   243  		return "interface{}"
   244  	default:
   245  		log.Fatalf("Fatal %T", x)
   246  		return fmt.Sprintf("%T", x)
   247  	}
   248  }