github.com/google/capslock@v0.2.3-0.20240517042941-dac19fc347c0/analyzer/rewrite.go (about)

     1  // Copyright 2023 Google LLC
     2  //
     3  // Use of this source code is governed by a BSD-style
     4  // license that can be found in the LICENSE file or at
     5  // https://developers.google.com/open-source/licenses/bsd
     6  
     7  package analyzer
     8  
     9  import (
    10  	"fmt"
    11  	"go/ast"
    12  	"go/constant"
    13  	"go/token"
    14  	"go/types"
    15  	"unsafe"
    16  
    17  	"golang.org/x/tools/go/ast/astutil"
    18  	"golang.org/x/tools/go/packages"
    19  )
    20  
    21  // operandMode has the same layout as types.operandMode.
    22  type operandMode byte
    23  
    24  const (
    25  	noValueMode  operandMode = 1
    26  	constantMode operandMode = 4
    27  	valueMode    operandMode = 7
    28  )
    29  
    30  // constructTypeAndValue constructs a types.TypeAndValue.  These are used in
    31  // the types.Info.Types map to store the known types of expressions, and the
    32  // values of constant expressions.
    33  func constructTypeAndValue(mode operandMode, t types.Type, v constant.Value) types.TypeAndValue {
    34  	// The mode field of types.TypeAndValue is not exported, so we make our own
    35  	// copy of the type definition, and use unsafe conversion to get the type we
    36  	// want.
    37  	tv := struct {
    38  		mode  operandMode
    39  		Types types.Type
    40  		Value constant.Value
    41  	}{mode, t, v}
    42  	return *(*types.TypeAndValue)(unsafe.Pointer(&tv))
    43  }
    44  
    45  // typeAndValueForResults constructs a TypeAndValue corresponding to the return
    46  // values of a function.
    47  func typeAndValueForResults(results *types.Tuple) types.TypeAndValue {
    48  	if results == nil {
    49  		// Case 1: the function has no return values.
    50  		return constructTypeAndValue(noValueMode, results, nil)
    51  	}
    52  	if results.Len() == 1 {
    53  		// Case 2: the function has a single return value.
    54  		return constructTypeAndValue(valueMode, results.At(0).Type(), nil)
    55  	}
    56  	// Case 3: the function returns a tuple of more than one value.
    57  	return constructTypeAndValue(valueMode, results, nil)
    58  }
    59  
    60  // zeroLiteral creates and returns a zero literal of type int, and adds its
    61  // type information to typeInfo.Types.
    62  func zeroLiteral(typeInfo *types.Info) ast.Expr {
    63  	expr := &ast.BasicLit{Kind: token.INT, Value: "0"}
    64  	typeInfo.Types[expr] = constructTypeAndValue(constantMode, types.Typ[types.Int], constant.MakeInt64(0))
    65  	return expr
    66  }
    67  
    68  // selectionForMethod finds the Selection object for the given method.
    69  func selectionForMethod(typ types.Type, name string) *types.Selection {
    70  	var ms *types.MethodSet = types.NewMethodSet(typ)
    71  	// The package is not needed for exported methods, so we can pass nil for the
    72  	// package parameter of Lookup.
    73  	sel := ms.Lookup(nil, name)
    74  	return sel
    75  }
    76  
    77  // rewriteCallsToSort iterates through the packages in pkgs, including all
    78  // transitively-imported packages, and finds calls to sort.Sort, sort.Stable,
    79  // and sort.IsSorted, which each have a sort.Interface parameter.  We replace
    80  // each of these calls with a set of calls to each of the interface methods
    81  // individually (Len, Less, and Swap.)  e.g., this code:
    82  //
    83  //	sort.Sort(xs)
    84  //
    85  // would be replaced with:
    86  //
    87  //	xs.Len()
    88  //	xs.Less(0,0)
    89  //	xs.Swap(0,0)
    90  //
    91  // This improves the precision of the callgraph the analysis produces.  The
    92  // analysis produces a set of possible dynamic types for the sort.Interface
    93  // value, and adds a callgraph edge to the methods for each of those.
    94  //
    95  // Without this change to the callgraph, we would get paths to the
    96  // sort.Interface methods for every possible dynamic type for all the values
    97  // passed to the same sort function anywhere in the program, which can result
    98  // in a large number of false positives.
    99  func rewriteCallsToSort(pkgs []*packages.Package) {
   100  	forEachPackageIncludingDependencies(pkgs, func(p *packages.Package) {
   101  		for _, file := range p.Syntax {
   102  			for _, node := range file.Decls {
   103  				var pre astutil.ApplyFunc
   104  				pre = func(c *astutil.Cursor) bool {
   105  					// If the current node, c.Node(), is a call to sort.Sort (or
   106  					// sort.Stable or sort.IsSorted), replace it with calls to
   107  					// obj.Less, obj.Swap, and obj.Len, where obj is the argument
   108  					// that was passed to sort.
   109  					if _, ok := c.Node().(ast.Stmt); !ok {
   110  						// c.Node() is not a statement.
   111  						return true
   112  					}
   113  					canRewrite := false
   114  					switch c.Parent().(type) {
   115  					case *ast.BlockStmt, *ast.CaseClause, *ast.LabeledStmt:
   116  						canRewrite = true
   117  					case *ast.CommClause:
   118  						canRewrite = c.Index() >= 0
   119  					}
   120  					if !canRewrite {
   121  						// The statement is in a position in the syntax tree where it
   122  						// can't be replaced with a block or with multiple statements, so
   123  						// we give up.
   124  						return true
   125  					}
   126  
   127  					obj := isCallToSort(p.TypesInfo, c.Node())
   128  					if obj == nil {
   129  						// This was not a call to a sort function.
   130  						//
   131  						// We always return true from this function, because the return
   132  						// value indicates to astutil.Apply whether to keep searching.
   133  						return true
   134  					}
   135  					// Less and Swap each take two integer arguments.  The values aren't
   136  					// important for our callgraph analysis -- we do not look at values
   137  					// to determine which way an if statement branches, for example --
   138  					// so we just use two zeroes.
   139  					args1 := []ast.Expr{zeroLiteral(p.TypesInfo), zeroLiteral(p.TypesInfo)}
   140  					args2 := []ast.Expr{zeroLiteral(p.TypesInfo), zeroLiteral(p.TypesInfo)}
   141  					// Create a block with three statements which call Less, Swap,
   142  					// and Len.  Replace the current node with this block.
   143  					s1 := statementCallingMethod(p.TypesInfo, obj, "Less", args1)
   144  					s2 := statementCallingMethod(p.TypesInfo, obj, "Swap", args2)
   145  					s3 := statementCallingMethod(p.TypesInfo, obj, "Len", nil)
   146  					if s1 == nil || s2 == nil || s3 == nil {
   147  						// We did not succeed in creating these statements.
   148  						return true
   149  					}
   150  					c.Replace(&ast.BlockStmt{List: []ast.Stmt{s1, s2, s3}})
   151  					return true
   152  				}
   153  				astutil.Apply(node, pre, nil)
   154  			}
   155  		}
   156  	})
   157  }
   158  
   159  // rewriteCallsToOnceDoEtc is similar to rewriteCallsToSort.  It finds calls
   160  // to some standard-library functions and methods which have a function
   161  // parameter, and changes those calls to call the function argument directly
   162  // instead.
   163  //
   164  // e.g. this code:
   165  //
   166  //	var myonce *sync.Once = ...
   167  //	myonce.Do(fn)
   168  //
   169  // would be replaced with:
   170  //
   171  //	var myonce *sync.Once = ...
   172  //	fn()
   173  func rewriteCallsToOnceDoEtc(pkgs []*packages.Package) {
   174  	forEachPackageIncludingDependencies(pkgs, func(p *packages.Package) {
   175  		for _, file := range p.Syntax {
   176  			for _, node := range file.Decls {
   177  				var pre astutil.ApplyFunc
   178  				pre = func(c *astutil.Cursor) bool {
   179  					obj := isCallToOnceDoEtc(p.TypesInfo, c.Node())
   180  					if obj == nil {
   181  						// This was not a call to a relevant function or method.
   182  						return true
   183  					}
   184  					fnType, ok := p.TypesInfo.TypeOf(obj).(*types.Signature)
   185  					if !ok {
   186  						// The argument does not appear to be a function.
   187  						return true
   188  					}
   189  					// Create some arguments to pass to the function.  The parameters
   190  					// must all be integers.
   191  					params := fnType.Params()
   192  					args := make([]ast.Expr, params.Len())
   193  					for i := range args {
   194  						args[i] = zeroLiteral(p.TypesInfo)
   195  					}
   196  					c.Replace(
   197  						statementCallingFunctionObject(p.TypesInfo, obj, args))
   198  					return true
   199  				}
   200  				astutil.Apply(node, pre, nil)
   201  			}
   202  		}
   203  	})
   204  }
   205  
   206  // isCallToSort checks if node is a statement calling sort.Sort, sort.Stable,
   207  // or sort.IsSorted.  If so, it returns the argument to that function.
   208  // Otherwise, it returns nil.
   209  func isCallToSort(typeInfo *types.Info, node ast.Node) ast.Expr {
   210  	expr, ok := node.(*ast.ExprStmt)
   211  	if !ok {
   212  		// Not a statement node.
   213  		return nil
   214  	}
   215  	call, ok := expr.X.(*ast.CallExpr)
   216  	if !ok {
   217  		// Not a function call.
   218  		return nil
   219  	}
   220  	callee, ok := call.Fun.(*ast.SelectorExpr)
   221  	if !ok {
   222  		// The function to be called is not a selection, so it can't be a call to
   223  		// the sort package.  (Unless the user has dot-imported "sort", but we
   224  		// don't need to worry much about false negatives in unusual cases here.)
   225  		return nil
   226  	}
   227  	pkgIdent, ok := callee.X.(*ast.Ident)
   228  	if !ok {
   229  		// The left-hand-side of the selection is not a plain identifier.
   230  		return nil
   231  	}
   232  	pkgName, ok := typeInfo.Uses[pkgIdent].(*types.PkgName)
   233  	if !ok {
   234  		// The identifier does not refer to a package.
   235  		return nil
   236  	}
   237  	if pkgName.Imported().Path() != "sort" {
   238  		// The package isn't "sort".  (We use Imported().Path() because the import
   239  		// name could be misleading, e.g.:
   240  		// import (
   241  		//   sort "os"
   242  		// )
   243  		return nil
   244  	}
   245  	if name := callee.Sel.Name; name != "Sort" && name != "Stable" && name != "IsSorted" {
   246  		// This isn't one of the functions we're looking for.
   247  		return nil
   248  	}
   249  	if len(call.Args) != 1 {
   250  		// The function call doesn't have one argument.
   251  		return nil
   252  	}
   253  	return call.Args[0]
   254  }
   255  
   256  // isCallToOnceDoEtc checks if node is a statement calling a function or method
   257  // like (*sync.Once).Do.  If so, it returns the function-typed argument to that
   258  // function.  Otherwise, it returns nil.
   259  func isCallToOnceDoEtc(typeInfo *types.Info, node ast.Node) ast.Expr {
   260  	expr, ok := node.(*ast.ExprStmt)
   261  	if !ok {
   262  		// Not a statement node.
   263  		return nil
   264  	}
   265  	call, ok := expr.X.(*ast.CallExpr)
   266  	if !ok {
   267  		// Not a call expression.
   268  		return nil
   269  	}
   270  	for _, m := range functionsToRewrite {
   271  		if e := m.match(typeInfo, call); e != nil {
   272  			return e
   273  		}
   274  	}
   275  	return nil
   276  }
   277  
   278  // statementCallingMethod constructs a statement that calls a method.  The
   279  // receiver is recv, the method name is methodName, and the arguments passed
   280  // to the call are in args.
   281  //
   282  // New AST structures that are created by statementCallingMethod are added
   283  // to the Types, Selections and Uses fields of typeInfo as needed.  The
   284  // expressions in methodName and args should already be in typeInfo.
   285  //
   286  // If the statement cannot be created, returns nil.
   287  func statementCallingMethod(typeInfo *types.Info, recv ast.Expr, methodName string, args []ast.Expr) *ast.ExprStmt {
   288  	// Construct an ast node for the method name, and add it to typeInfo.Uses.
   289  	methodIdent := ast.NewIdent(methodName)
   290  	var selection *types.Selection = selectionForMethod(typeInfo.TypeOf(recv), methodName)
   291  	if selection == nil {
   292  		// We did not find the desired method for this type.  recv might be an
   293  		// untyped nil.
   294  		return nil
   295  	}
   296  	typeInfo.Uses[methodIdent] = selection.Obj()
   297  	// Construct an ast node for the selection (e.g. "v.M"), and add it to
   298  	// typeInfo.Selections and typeInfo.Types.
   299  	selectorExpr := &ast.SelectorExpr{X: recv, Sel: methodIdent}
   300  	typeInfo.Selections[selectorExpr] = selection
   301  	typeInfo.Types[selectorExpr] = constructTypeAndValue(valueMode, selection.Type(), nil)
   302  	// Construct an ast node for the call (e.g. "v.M(arg1, arg2)") and add it
   303  	// to typeInfo.Types.
   304  	callExpr := &ast.CallExpr{Fun: selectorExpr, Args: append([]ast.Expr(nil), args...)}
   305  	typeInfo.Types[callExpr] = typeAndValueForResults(selection.Type().(*types.Signature).Results())
   306  	// Return an ast node for a statement which is just the call.  No type
   307  	// information is needed for statements.
   308  	return &ast.ExprStmt{X: callExpr}
   309  }
   310  
   311  // statementCallingFunctionObject constructs a statement that calls a function.
   312  //
   313  // New AST structures that are created by statementCallingFunctionObject are
   314  // added to the Types fields of typeInfo as needed.  The expressions in fn and
   315  // args should already be in typeInfo.
   316  func statementCallingFunctionObject(typeInfo *types.Info, fn ast.Expr, args []ast.Expr) *ast.ExprStmt {
   317  	// Construct an ast node for the call and add it to typeInfo.Types.
   318  	callExpr := &ast.CallExpr{Fun: fn, Args: append([]ast.Expr(nil), args...)}
   319  	fnType := typeInfo.TypeOf(fn)
   320  	fnTypeSignature, _ := fnType.(*types.Signature)
   321  	if fnTypeSignature == nil {
   322  		panic(fmt.Sprintf("cannot get type signature of function %v", fn))
   323  	}
   324  	typeInfo.Types[callExpr] = typeAndValueForResults(fnTypeSignature.Results())
   325  	// Return an ast node for a statement which is just the call.  No type
   326  	// information is needed for statements.
   327  	return &ast.ExprStmt{X: callExpr}
   328  }