github.com/brownsys/tracing-framework-go@v0.0.0-20161210174012-0542a62412fe/cmd/rewrite/rewrite.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/ast"
     7  	"go/format"
     8  	"go/parser"
     9  	"go/token"
    10  	"go/types"
    11  	"reflect"
    12  	"text/template"
    13  
    14  	"golang.org/x/tools/go/ast/astutil"
    15  )
    16  
    17  func rewriteGos(fset *token.FileSet, info types.Info, qual types.Qualifier, f *ast.File) (changed bool, err error) {
    18  	rname := runtimeName(f)
    19  	err = mapStmts(f, func(s ast.Stmt) ([]ast.Stmt, error) {
    20  		if g, ok := s.(*ast.GoStmt); ok {
    21  			stmts, err := rewriteGoStmt(fset, info, qual, rname, g)
    22  			if stmts != nil {
    23  				changed = true
    24  			}
    25  			return stmts, err
    26  		}
    27  		return nil, nil
    28  	})
    29  	if changed {
    30  		// astutil.AddNamedImport(fset, f, rname, "runtime")
    31  		astutil.AddImport(fset, f, "github.com/brownsys/tracing-framework-go/local")
    32  		// astutil.AddImport(fset, f, "runtime")
    33  	}
    34  	return changed, err
    35  }
    36  
    37  // runtimeName searches through f's imports to find whether
    38  // the "runtime" package has been imported, and if not, whether
    39  // another package whose name is also "runtime" has been
    40  // imported (which would conflict if we were to add "runtime"
    41  // as an import). It returns the name that should be used to
    42  // identify the "runtime" package.
    43  func runtimeName(f *ast.File) string {
    44  	for _, imp := range f.Imports {
    45  		if imp.Path.Value == `"runtime"` {
    46  			if imp.Name != nil {
    47  				return imp.Name.Name
    48  			}
    49  			return "runtime"
    50  		}
    51  	}
    52  	return "__runtime"
    53  }
    54  
    55  // nameForPackage searches through f's imports to find
    56  // whether the package identified by the given path has
    57  // been imported, and if not, whether another package
    58  // whose name is the same has been imported (which would
    59  // conflict if we were to add the given path as an
    60  // import). It returns the name that should be used to
    61  // identify the given package.
    62  //
    63  // TODO: does this actually implement the spec?
    64  // func nameForPackage(f *ast.File, path, name string) string {
    65  // 	path = '"' + path + '"'
    66  // 	for _, imp := range f.Imports {
    67  // 		if imp.Path.Value == path {
    68  // 			if imp.Name != nil {
    69  // 				return imp.Name.Name
    70  // 			}
    71  // 			return name
    72  // 		}
    73  // 	}
    74  // 	return "__" + name
    75  // }
    76  
    77  // mapStmts walks v, searching for values of type []ast.Stmt
    78  // or [x]ast.Stmt. After recurring into such values, it loops
    79  // over the slice or array, and for each element, calls f.
    80  // If f returns nil, the value is left as is. If f returns a
    81  // non-nil slice (of any length, including 0), the contents
    82  // of this slice replace the original value in the slice or array.
    83  //
    84  // If f ever returns a non-nil error, it is immediately returned.
    85  func mapStmts(v ast.Node, f func(s ast.Stmt) ([]ast.Stmt, error)) error {
    86  	var blocks []*ast.BlockStmt
    87  	ast.Inspect(v, func(n ast.Node) bool {
    88  		if b, ok := n.(*ast.BlockStmt); ok {
    89  			blocks = append(blocks, b)
    90  		}
    91  		return true
    92  	})
    93  
    94  	// make sure to process blocks backwards
    95  	// so that children are processed before parents
    96  	for i := len(blocks) - 1; i >= 0; i-- {
    97  		b := blocks[i]
    98  		var newStmts []ast.Stmt
    99  		for _, s := range b.List {
   100  			new, err := f(s)
   101  			if err != nil {
   102  				return err
   103  			}
   104  			if new == nil {
   105  				newStmts = append(newStmts, s)
   106  			} else {
   107  				newStmts = append(newStmts, new...)
   108  			}
   109  		}
   110  		b.List = newStmts
   111  	}
   112  
   113  	return nil
   114  }
   115  
   116  func rewriteGoStmt(fset *token.FileSet, info types.Info, qual types.Qualifier, rname string, g *ast.GoStmt) ([]ast.Stmt, error) {
   117  	ftyp := info.TypeOf(g.Call.Fun)
   118  
   119  	if ftyp == nil {
   120  		return nil, fmt.Errorf("%v: could not determine type of function",
   121  			fset.Position(g.Call.Fun.Pos()))
   122  	}
   123  	sig := ftyp.(*types.Signature)
   124  
   125  	var arg struct {
   126  		Runtime                       string
   127  		Func                          string
   128  		Typ                           string
   129  		DefArgs, InnerArgs, OuterArgs []string
   130  	}
   131  
   132  	arg.Runtime = rname
   133  	arg.Func = nodeString(fset, g.Call.Fun)
   134  	arg.Typ = types.TypeString(ftyp, qual)
   135  
   136  	params := sig.Params()
   137  	for i := 0; i < params.Len(); i++ {
   138  		name := fmt.Sprintf("arg%v", i)
   139  		if sig.Variadic() && i == params.Len()-1 {
   140  			typ := types.TypeString(params.At(i).Type().(*types.Slice).Elem(), qual)
   141  			arg.DefArgs = append(arg.DefArgs, name+" ..."+typ)
   142  			arg.InnerArgs = append(arg.InnerArgs, name+"...")
   143  		} else {
   144  			typ := types.TypeString(params.At(i).Type(), qual)
   145  			arg.DefArgs = append(arg.DefArgs, name+" "+typ)
   146  			arg.InnerArgs = append(arg.InnerArgs, name)
   147  		}
   148  	}
   149  
   150  	for i, a := range g.Call.Args {
   151  		if g.Call.Ellipsis.IsValid() && i == len(g.Call.Args)-1 {
   152  			// g.Call.Ellipsis.IsValid() is true if g is variadic
   153  			arg.OuterArgs = append(arg.OuterArgs, nodeString(fset, a)+"...")
   154  		} else {
   155  			arg.OuterArgs = append(arg.OuterArgs, nodeString(fset, a))
   156  		}
   157  	}
   158  
   159  	var buf bytes.Buffer
   160  	err := goTmpl.Execute(&buf, arg)
   161  	if err != nil {
   162  		panic(fmt.Errorf("internal error: %v", err))
   163  	}
   164  	return parseStmts(string(buf.Bytes())), nil
   165  }
   166  
   167  var goTmpl = template.Must(template.New("").Parse(`
   168  go func(__f1 func(), __f2 {{.Typ}} {{range .DefArgs}},{{.}}{{end}}){
   169  	__f1()
   170  	__f2({{range .InnerArgs}}{{.}},{{end}})
   171  }(local.GetSpawnCallback(), {{.Func}}{{range .OuterArgs}},{{.}}{{end}})
   172  `))
   173  
   174  func parseStmts(src string) []ast.Stmt {
   175  	src = `package main
   176  	func a() {` + src + `}`
   177  	fset := token.NewFileSet()
   178  	a, err := parser.ParseFile(fset, "", src, parser.ParseComments|parser.DeclarationErrors)
   179  	if err != nil {
   180  		panic(fmt.Errorf("internal error: %v", err))
   181  	}
   182  	stmts := a.Decls[0].(*ast.FuncDecl).Body.List
   183  	zeroPos(&stmts)
   184  	return stmts
   185  }
   186  
   187  // walk v and zero all values of type token.Pos
   188  func zeroPos(v interface{}) {
   189  	rv := reflect.ValueOf(v)
   190  	if rv.Kind() != reflect.Ptr {
   191  		panic("internal error")
   192  	}
   193  	zeroPosHelper(rv)
   194  }
   195  
   196  var posTyp = reflect.TypeOf(token.Pos(0))
   197  
   198  func zeroPosHelper(rv reflect.Value) {
   199  	if rv.Type() == posTyp {
   200  		rv.SetInt(0)
   201  		return
   202  	}
   203  	switch rv.Kind() {
   204  	case reflect.Ptr:
   205  		if !rv.IsNil() {
   206  			zeroPosHelper(rv.Elem())
   207  		}
   208  	case reflect.Slice, reflect.Array:
   209  		for i := 0; i < rv.Len(); i++ {
   210  			zeroPosHelper(rv.Index(i))
   211  		}
   212  	case reflect.Map:
   213  		keys := rv.MapKeys()
   214  		for _, k := range keys {
   215  			zeroPosHelper(rv.MapIndex(k))
   216  		}
   217  	case reflect.Struct:
   218  		for i := 0; i < rv.NumField(); i++ {
   219  			zeroPosHelper(rv.Field(i))
   220  		}
   221  	}
   222  }
   223  
   224  func nodeString(fset *token.FileSet, node interface{}) string {
   225  	var buf bytes.Buffer
   226  	err := format.Node(&buf, fset, node)
   227  	if err != nil {
   228  		panic(fmt.Errorf("unexpected internal error: %v", err))
   229  	}
   230  	return string(buf.Bytes())
   231  }
   232  
   233  func qualifierForFile(pkg *types.Package, f *ast.File) types.Qualifier {
   234  	pathToPackage := make(map[string]*types.Package)
   235  	for _, pkg := range pkg.Imports() {
   236  		pathToPackage[pkg.Path()] = pkg
   237  	}
   238  
   239  	m := make(map[*types.Package]string)
   240  	for _, imp := range f.Imports {
   241  		if imp.Path.Value == `"unsafe"` {
   242  			continue
   243  		}
   244  		// slice out quotation marks
   245  		l := len(imp.Path.Value)
   246  		pkg, ok := pathToPackage[imp.Path.Value[1:l-1]]
   247  		if !ok {
   248  			panic(fmt.Errorf("package %v (imported in %v) not in (*loader.Program).AllPackages", imp.Path.Value, f.Name.Name))
   249  		}
   250  		name := ""
   251  		if imp.Name == nil {
   252  			name = pkg.Name()
   253  		} else {
   254  			name = imp.Name.Name
   255  		}
   256  		m[pkg] = name
   257  	}
   258  	return func(p *types.Package) string { return m[p] }
   259  }