github.com/blend/go-sdk@v1.20220411.3/sourceutil/go_ast_rewrite.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package sourceutil
     9  
    10  import (
    11  	"context"
    12  	"go/ast"
    13  )
    14  
    15  // GoAstRewrite returns a go ast visitor with a given set of options.
    16  func GoAstRewrite(opts ...GoAstRewriteOption) GoAstVisitor {
    17  	var rewriteOpts GoAstRewriteOptions
    18  	for _, opt := range opts {
    19  		opt(&rewriteOpts)
    20  	}
    21  	return func(ctx context.Context, n ast.Node) bool {
    22  		return rewriteOpts.Apply(ctx, n)
    23  	}
    24  }
    25  
    26  // GoIsPackageCall returns a filter that determines if a function is a given sel.Fn.
    27  //
    28  // It will only evaluate for function calls that use a package selector
    29  // that is, function calls that have a selector.
    30  func GoIsPackageCall(pkg, fn string) GoAstRewriteOption {
    31  	return func(opts *GoAstRewriteOptions) {
    32  		opts.Filter = func(ctx context.Context, n ast.Node) (visit, recurse bool) {
    33  			if nt, ok := n.(*ast.CallExpr); ok {
    34  				if ft, ok := nt.Fun.(*ast.SelectorExpr); ok {
    35  					if exprIsName(ft.X, pkg) && exprIsName(ft.Sel, fn) {
    36  						return true, false // visit, do not recurse
    37  					}
    38  				}
    39  				return false, false // do not visit, do not recurse
    40  			}
    41  			return false, true // do not visit, do recurse
    42  		}
    43  	}
    44  }
    45  
    46  // GoIsCall returns a filter that determines if a function is a given name.
    47  //
    48  // It will only evaluate for function calls that appear local to the
    49  // current package, that is, function calls that do not have a selector.
    50  func GoIsCall(fn string) GoAstRewriteOption {
    51  	return func(opts *GoAstRewriteOptions) {
    52  		opts.Filter = func(_ context.Context, n ast.Node) (visit, recurse bool) {
    53  			if nt, ok := n.(*ast.CallExpr); ok {
    54  				if ft, ok := nt.Fun.(*ast.Ident); ok {
    55  					if exprIsName(ft, fn) {
    56  						return true, false // visit, do not recurse
    57  					}
    58  				}
    59  				return false, false // do not visit, do not recurse
    60  			}
    61  			return false, true // do not visit, do recurse
    62  		}
    63  	}
    64  }
    65  
    66  // GoRewritePackageCall changes a given function as filtered by a filter
    67  // to a given call noted by sel.Fn.
    68  func GoRewritePackageCall(sel, fn string) GoAstRewriteOption {
    69  	return func(opts *GoAstRewriteOptions) {
    70  		opts.NodeVisitor = func(_ context.Context, n ast.Node) {
    71  			if nt, ok := n.(*ast.CallExpr); ok {
    72  				if ft, ok := nt.Fun.(*ast.SelectorExpr); ok {
    73  					exprSetName(ft.X, sel)
    74  					exprSetName(ft.Sel, fn)
    75  				}
    76  			}
    77  		}
    78  	}
    79  }
    80  
    81  // GoRewriteCall changes a given function as filtered by a filter
    82  // to a given call noted by Fn.
    83  func GoRewriteCall(fn string) GoAstRewriteOption {
    84  	return func(opts *GoAstRewriteOptions) {
    85  		opts.NodeVisitor = func(_ context.Context, n ast.Node) {
    86  			if nt, ok := n.(*ast.CallExpr); ok {
    87  				if ft, ok := nt.Fun.(*ast.Ident); ok {
    88  					exprSetName(ft, fn)
    89  				}
    90  			}
    91  		}
    92  	}
    93  }
    94  
    95  // GoAstVisitor mutates an ast node.
    96  type GoAstVisitor func(context.Context, ast.Node) bool
    97  
    98  // GoAstRewriteOption the ast rewrite options.
    99  type GoAstRewriteOption func(*GoAstRewriteOptions)
   100  
   101  // GoAstFilter is a delegate type that filters ast nodes for visiting.
   102  type GoAstFilter func(context.Context, ast.Node) (visit, recurse bool)
   103  
   104  // GoAstNodeVisitor mutates a given node.
   105  type GoAstNodeVisitor func(context.Context, ast.Node)
   106  
   107  // GoAstRewriteOptions breaks the mutator out into field specific mutators.
   108  type GoAstRewriteOptions struct {
   109  	Filter      GoAstFilter
   110  	NodeVisitor GoAstNodeVisitor
   111  }
   112  
   113  // Apply applies the options to the ast node.
   114  func (opts GoAstRewriteOptions) Apply(ctx context.Context, node ast.Node) bool {
   115  	if opts.Filter != nil {
   116  		visit, recurse := opts.Filter(ctx, node)
   117  		if visit {
   118  			if opts.NodeVisitor != nil {
   119  				opts.NodeVisitor(ctx, node)
   120  			}
   121  		}
   122  		return recurse
   123  	}
   124  	if opts.NodeVisitor != nil {
   125  		opts.NodeVisitor(ctx, node)
   126  	}
   127  	return true // if no filter, always recurse
   128  }
   129  
   130  func exprIsName(expr ast.Expr, name string) bool {
   131  	id, ok := expr.(*ast.Ident)
   132  	return ok && id.Name == name
   133  }
   134  
   135  func exprSetName(expr ast.Expr, name string) {
   136  	id, ok := expr.(*ast.Ident)
   137  	if ok {
   138  		id.Name = name
   139  	}
   140  }