github.com/blend/go-sdk@v1.20220411.3/sourceutil/go_ast_rewrite.go (about) 1 /* 2 3 Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file. 5 6 */ 7 8 package sourceutil 9 10 import ( 11 "context" 12 "go/ast" 13 ) 14 15 // GoAstRewrite returns a go ast visitor with a given set of options. 16 func GoAstRewrite(opts ...GoAstRewriteOption) GoAstVisitor { 17 var rewriteOpts GoAstRewriteOptions 18 for _, opt := range opts { 19 opt(&rewriteOpts) 20 } 21 return func(ctx context.Context, n ast.Node) bool { 22 return rewriteOpts.Apply(ctx, n) 23 } 24 } 25 26 // GoIsPackageCall returns a filter that determines if a function is a given sel.Fn. 27 // 28 // It will only evaluate for function calls that use a package selector 29 // that is, function calls that have a selector. 30 func GoIsPackageCall(pkg, fn string) GoAstRewriteOption { 31 return func(opts *GoAstRewriteOptions) { 32 opts.Filter = func(ctx context.Context, n ast.Node) (visit, recurse bool) { 33 if nt, ok := n.(*ast.CallExpr); ok { 34 if ft, ok := nt.Fun.(*ast.SelectorExpr); ok { 35 if exprIsName(ft.X, pkg) && exprIsName(ft.Sel, fn) { 36 return true, false // visit, do not recurse 37 } 38 } 39 return false, false // do not visit, do not recurse 40 } 41 return false, true // do not visit, do recurse 42 } 43 } 44 } 45 46 // GoIsCall returns a filter that determines if a function is a given name. 47 // 48 // It will only evaluate for function calls that appear local to the 49 // current package, that is, function calls that do not have a selector. 50 func GoIsCall(fn string) GoAstRewriteOption { 51 return func(opts *GoAstRewriteOptions) { 52 opts.Filter = func(_ context.Context, n ast.Node) (visit, recurse bool) { 53 if nt, ok := n.(*ast.CallExpr); ok { 54 if ft, ok := nt.Fun.(*ast.Ident); ok { 55 if exprIsName(ft, fn) { 56 return true, false // visit, do not recurse 57 } 58 } 59 return false, false // do not visit, do not recurse 60 } 61 return false, true // do not visit, do recurse 62 } 63 } 64 } 65 66 // GoRewritePackageCall changes a given function as filtered by a filter 67 // to a given call noted by sel.Fn. 68 func GoRewritePackageCall(sel, fn string) GoAstRewriteOption { 69 return func(opts *GoAstRewriteOptions) { 70 opts.NodeVisitor = func(_ context.Context, n ast.Node) { 71 if nt, ok := n.(*ast.CallExpr); ok { 72 if ft, ok := nt.Fun.(*ast.SelectorExpr); ok { 73 exprSetName(ft.X, sel) 74 exprSetName(ft.Sel, fn) 75 } 76 } 77 } 78 } 79 } 80 81 // GoRewriteCall changes a given function as filtered by a filter 82 // to a given call noted by Fn. 83 func GoRewriteCall(fn string) GoAstRewriteOption { 84 return func(opts *GoAstRewriteOptions) { 85 opts.NodeVisitor = func(_ context.Context, n ast.Node) { 86 if nt, ok := n.(*ast.CallExpr); ok { 87 if ft, ok := nt.Fun.(*ast.Ident); ok { 88 exprSetName(ft, fn) 89 } 90 } 91 } 92 } 93 } 94 95 // GoAstVisitor mutates an ast node. 96 type GoAstVisitor func(context.Context, ast.Node) bool 97 98 // GoAstRewriteOption the ast rewrite options. 99 type GoAstRewriteOption func(*GoAstRewriteOptions) 100 101 // GoAstFilter is a delegate type that filters ast nodes for visiting. 102 type GoAstFilter func(context.Context, ast.Node) (visit, recurse bool) 103 104 // GoAstNodeVisitor mutates a given node. 105 type GoAstNodeVisitor func(context.Context, ast.Node) 106 107 // GoAstRewriteOptions breaks the mutator out into field specific mutators. 108 type GoAstRewriteOptions struct { 109 Filter GoAstFilter 110 NodeVisitor GoAstNodeVisitor 111 } 112 113 // Apply applies the options to the ast node. 114 func (opts GoAstRewriteOptions) Apply(ctx context.Context, node ast.Node) bool { 115 if opts.Filter != nil { 116 visit, recurse := opts.Filter(ctx, node) 117 if visit { 118 if opts.NodeVisitor != nil { 119 opts.NodeVisitor(ctx, node) 120 } 121 } 122 return recurse 123 } 124 if opts.NodeVisitor != nil { 125 opts.NodeVisitor(ctx, node) 126 } 127 return true // if no filter, always recurse 128 } 129 130 func exprIsName(expr ast.Expr, name string) bool { 131 id, ok := expr.(*ast.Ident) 132 return ok && id.Name == name 133 } 134 135 func exprSetName(expr ast.Expr, name string) { 136 id, ok := expr.(*ast.Ident) 137 if ok { 138 id.Name = name 139 } 140 }