github.com/brownsys/tracing-framework-go@v0.0.0-20161210174012-0542a62412fe/other/cmd/rewrite/rewrite.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/ast"
     7  	"go/format"
     8  	"go/parser"
     9  	"go/token"
    10  	"go/types"
    11  	"reflect"
    12  	"strings"
    13  	"text/template"
    14  
    15  	"golang.org/x/tools/go/ast/astutil"
    16  )
    17  
    18  func rewriteGos(fset *token.FileSet, info types.Info, qual types.Qualifier, f *ast.File) (changed bool, err error) {
    19  	rname := runtimeName(f)
    20  	err = mapStmts(f, func(s ast.Stmt) ([]ast.Stmt, error) {
    21  		if g, ok := s.(*ast.GoStmt); ok {
    22  			stmts, err := rewriteGoStmt(fset, info, qual, rname, g)
    23  			if stmts != nil {
    24  				changed = true
    25  			}
    26  			return stmts, err
    27  		}
    28  		return nil, nil
    29  	})
    30  	if changed {
    31  		astutil.AddNamedImport(fset, f, rname, "runtime")
    32  		// astutil.AddImport(fset, f, "runtime")
    33  	}
    34  	return changed, err
    35  }
    36  
    37  func rewriteCalls(fset *token.FileSet, info types.Info, qual types.Qualifier, f *ast.File) (changed bool, err error) {
    38  	rname := runtimeName(f)
    39  	err = mapStmts(f, func(s ast.Stmt) ([]ast.Stmt, error) {
    40  		if a, ok := s.(*ast.AssignStmt); ok {
    41  			stmts, err := rewriteCallStmt(fset, info, qual, rname, a)
    42  			if stmts != nil {
    43  				changed = true
    44  			}
    45  			return stmts, err
    46  		}
    47  		return nil, nil
    48  	})
    49  	if changed {
    50  		astutil.AddNamedImport(fset, f, rname, "runtime")
    51  		// astutil.AddImport(fset, f, "runtime")
    52  	}
    53  	return changed, err
    54  }
    55  
    56  // runtimeName searches through f's imports to find whether
    57  // the "runtime" package has been imported, and if not, whether
    58  // another package whose name is also "runtime" has been
    59  // imported (which would conflict if we were to add "runtime"
    60  // as an import). It returns the name that should be used to
    61  // identify the "runtime" package.
    62  func runtimeName(f *ast.File) string {
    63  	for _, imp := range f.Imports {
    64  		if imp.Path.Value == `"runtime"` {
    65  			if imp.Name != nil {
    66  				return imp.Name.Name
    67  			}
    68  			return "runtime"
    69  		}
    70  	}
    71  	return "__runtime"
    72  }
    73  
    74  // mapStmts walks v, searching for values of type []ast.Stmt
    75  // or [x]ast.Stmt. After recurring into such values, it loops
    76  // over the slice or array, and for each element, calls f.
    77  // If f returns nil, the value is left as is. If f returns a
    78  // non-nil slice (of any length, including 0), the contents
    79  // of this slice replace the original value in the slice or array.
    80  //
    81  // If f ever returns a non-nil error, it is immediately returned.
    82  func mapStmts(v ast.Node, f func(s ast.Stmt) ([]ast.Stmt, error)) error {
    83  	var blocks []*ast.BlockStmt
    84  	ast.Inspect(v, func(n ast.Node) bool {
    85  		if b, ok := n.(*ast.BlockStmt); ok {
    86  			blocks = append(blocks, b)
    87  		}
    88  		return true
    89  	})
    90  
    91  	// make sure to process blocks backwards
    92  	// so that children are processed before parents
    93  	for i := len(blocks) - 1; i >= 0; i-- {
    94  		b := blocks[i]
    95  		var newStmts []ast.Stmt
    96  		for _, s := range b.List {
    97  			new, err := f(s)
    98  			if err != nil {
    99  				return err
   100  			}
   101  			if new == nil {
   102  				newStmts = append(newStmts, s)
   103  			} else {
   104  				newStmts = append(newStmts, new...)
   105  			}
   106  		}
   107  		b.List = newStmts
   108  	}
   109  
   110  	return nil
   111  }
   112  
   113  func rewriteCallStmt(fset *token.FileSet, info types.Info, qual types.Qualifier, rname string, a *ast.AssignStmt) ([]ast.Stmt, error) {
   114  	// for the time being, we only handle
   115  	// statements which have a single
   116  	// function call on the RHS, like:
   117  	//  a, b = f()
   118  
   119  	if len(a.Rhs) != 1 {
   120  		for _, aa := range a.Rhs {
   121  			if _, ok := aa.(*ast.CallExpr); ok {
   122  				return nil, fmt.Errorf("%v: unsupported statement format", fset.Position(a.Pos()))
   123  			}
   124  		}
   125  		// none of the RHS expressions are function
   126  		// calls, so we can just safely ignore this
   127  		return nil, nil
   128  	}
   129  
   130  	c, ok := a.Rhs[0].(*ast.CallExpr)
   131  	if !ok {
   132  		return nil, nil
   133  	}
   134  
   135  	rettyp := info.TypeOf(c)
   136  	if rettyp == nil {
   137  		return nil, fmt.Errorf("%v: could not determine return type of function",
   138  			fset.Position(c.Pos()))
   139  	}
   140  
   141  	var vname string
   142  
   143  	// since the code has been type checked,
   144  	// we can assume that the function has
   145  	// at least one return value, and that
   146  	// len(LHS) = len(RHS)
   147  	if t, ok := rettyp.(*types.Tuple); ok {
   148  		context := false
   149  		for i := 0; i < t.Len(); i++ {
   150  			switch v := a.Lhs[i].(type) {
   151  			case *ast.Ident:
   152  				if v.Name != "_" && isContext(t.At(i).Type()) {
   153  					if context {
   154  						// more than one context.Context variable
   155  						return nil, fmt.Errorf("%v: unsupported statement format", fset.Position(a.Pos()))
   156  					}
   157  					context = true
   158  					vname = v.Name
   159  				}
   160  			default:
   161  				// TODO: handle LHS elements other than identifiers
   162  				return nil, nil
   163  				panic(fmt.Errorf("unexpected type %v", reflect.TypeOf(v)))
   164  			}
   165  		}
   166  		if !context {
   167  			return nil, nil
   168  		}
   169  	} else {
   170  		switch v := a.Lhs[0].(type) {
   171  		case *ast.Ident:
   172  			if v.Name == "_" || !isContext(rettyp) {
   173  				return nil, nil
   174  			}
   175  			vname = v.Name
   176  		default:
   177  			// TODO: handle LHS elements other than identifiers
   178  			return nil, nil
   179  			// panic(fmt.Errorf("unexpected type %v", reflect.TypeOf(v)))
   180  		}
   181  	}
   182  
   183  	arg := struct{ Runtime, Ctx string }{rname, vname}
   184  
   185  	var buf bytes.Buffer
   186  	err := callTmpl.Execute(&buf, arg)
   187  	if err != nil {
   188  		panic(fmt.Errorf("internal error: %v", err))
   189  	}
   190  	return append([]ast.Stmt{a}, parseStmts(string(buf.Bytes()))...), nil
   191  }
   192  
   193  var callTmpl = template.Must(template.New("").Parse(`{{.Runtime}}.SetLocal({{.Ctx}})`))
   194  
   195  func rewriteGoStmt(fset *token.FileSet, info types.Info, qual types.Qualifier, rname string, g *ast.GoStmt) ([]ast.Stmt, error) {
   196  	ftyp := info.TypeOf(g.Call.Fun)
   197  
   198  	if ftyp == nil {
   199  		return nil, fmt.Errorf("%v: could not determine type of function",
   200  			fset.Position(g.Call.Fun.Pos()))
   201  	}
   202  	sig := ftyp.(*types.Signature)
   203  
   204  	// According to the context documentation:
   205  	//
   206  	// Do not store Contexts inside a struct type;
   207  	// instead, pass a Context explicitly to each
   208  	// function that needs it. The Context should
   209  	// be the first parameter, typically named ctx.
   210  	//
   211  	// Thus, we only handle this case.
   212  	if sig.Params().Len() == 0 || !isContext(sig.Params().At(0).Type()) {
   213  		return nil, nil
   214  	}
   215  
   216  	var arg struct {
   217  		Runtime                       string
   218  		Func                          string
   219  		Typ                           string
   220  		DefArgs, InnerArgs, OuterArgs []string
   221  	}
   222  
   223  	arg.Runtime = rname
   224  	arg.Func = nodeString(fset, g.Call.Fun)
   225  	arg.Typ = types.TypeString(ftyp, qual)
   226  
   227  	params := sig.Params()
   228  	for i := 0; i < params.Len(); i++ {
   229  		typ := types.TypeString(params.At(i).Type(), qual)
   230  		name := fmt.Sprintf("arg%v", i)
   231  		if sig.Variadic() && i == params.Len()-1 {
   232  			arg.DefArgs = append(arg.DefArgs, name+" ..."+typ)
   233  			arg.InnerArgs = append(arg.InnerArgs, name+"...")
   234  		} else {
   235  			arg.DefArgs = append(arg.DefArgs, name+" "+typ)
   236  			arg.InnerArgs = append(arg.InnerArgs, name)
   237  		}
   238  	}
   239  
   240  	for _, a := range g.Call.Args {
   241  		arg.OuterArgs = append(arg.OuterArgs, nodeString(fset, a))
   242  	}
   243  
   244  	var buf bytes.Buffer
   245  	err := goTmpl.Execute(&buf, arg)
   246  	if err != nil {
   247  		panic(fmt.Errorf("internal error: %v", err))
   248  	}
   249  	return parseStmts(string(buf.Bytes())), nil
   250  }
   251  
   252  var goTmpl = template.Must(template.New("").Parse(`
   253  go func(__f {{.Typ}} {{range .DefArgs}},{{.}}{{end}}){
   254  	{{.Runtime}}.SetLocal(arg0)
   255  	__f({{range .InnerArgs}}{{.}},{{end}})
   256  }({{.Func}}{{range .OuterArgs}},{{.}}{{end}})
   257  `))
   258  
   259  func parseStmts(src string) []ast.Stmt {
   260  	src = `package main
   261  	func a() {` + src + `}`
   262  	fset := token.NewFileSet()
   263  	a, err := parser.ParseFile(fset, "", src, parser.ParseComments|parser.DeclarationErrors)
   264  	if err != nil {
   265  		panic(fmt.Errorf("internal error: %v", err))
   266  	}
   267  	stmts := a.Decls[0].(*ast.FuncDecl).Body.List
   268  	zeroPos(&stmts)
   269  	return stmts
   270  }
   271  
   272  // walk v and zero all values of type token.Pos
   273  func zeroPos(v interface{}) {
   274  	rv := reflect.ValueOf(v)
   275  	if rv.Kind() != reflect.Ptr {
   276  		panic("internal error")
   277  	}
   278  	zeroPosHelper(rv)
   279  }
   280  
   281  var posTyp = reflect.TypeOf(token.Pos(0))
   282  
   283  func zeroPosHelper(rv reflect.Value) {
   284  	if rv.Type() == posTyp {
   285  		rv.SetInt(0)
   286  		return
   287  	}
   288  	switch rv.Kind() {
   289  	case reflect.Ptr:
   290  		if !rv.IsNil() {
   291  			zeroPosHelper(rv.Elem())
   292  		}
   293  	case reflect.Slice, reflect.Array:
   294  		for i := 0; i < rv.Len(); i++ {
   295  			zeroPosHelper(rv.Index(i))
   296  		}
   297  	case reflect.Map:
   298  		keys := rv.MapKeys()
   299  		for _, k := range keys {
   300  			zeroPosHelper(rv.MapIndex(k))
   301  		}
   302  	case reflect.Struct:
   303  		for i := 0; i < rv.NumField(); i++ {
   304  			zeroPosHelper(rv.Field(i))
   305  		}
   306  	}
   307  }
   308  
   309  func nodeString(fset *token.FileSet, node interface{}) string {
   310  	var buf bytes.Buffer
   311  	err := format.Node(&buf, fset, node)
   312  	if err != nil {
   313  		panic(fmt.Errorf("unexpected internal error: %v", err))
   314  	}
   315  	return string(buf.Bytes())
   316  }
   317  
   318  func qualifierForFile(pkg *types.Package, f *ast.File) types.Qualifier {
   319  	pathToPackage := make(map[string]*types.Package)
   320  	for _, pkg := range pkg.Imports() {
   321  		pathToPackage[pkg.Path()] = pkg
   322  	}
   323  
   324  	m := make(map[*types.Package]string)
   325  	for _, imp := range f.Imports {
   326  		if imp.Path.Value == `"unsafe"` {
   327  			continue
   328  		}
   329  		// slice out quotation marks
   330  		l := len(imp.Path.Value)
   331  		pkg, ok := pathToPackage[imp.Path.Value[1:l-1]]
   332  		if !ok {
   333  			panic(fmt.Errorf("package %v (imported in %v) not in (*loader.Program).AllPackages", imp.Path.Value, f.Name.Name))
   334  		}
   335  		name := ""
   336  		if imp.Name == nil {
   337  			name = pkg.Name()
   338  		} else {
   339  			name = imp.Name.Name
   340  		}
   341  		m[pkg] = name
   342  	}
   343  	return func(p *types.Package) string { return m[p] }
   344  }
   345  
   346  func isContext(t types.Type) bool {
   347  	return t.String() == "golang.org/x/net/context.Context" ||
   348  		t.String() == "context.Context" || strings.HasSuffix(t.String(),
   349  		"_workspace/src/golang.org/x/net/context.Context")
   350  }