github.com/google/capslock@v0.2.3-0.20240517042941-dac19fc347c0/analyzer/util.go (about)

     1  // Copyright 2023 Google LLC
     2  //
     3  // Use of this source code is governed by a BSD-style
     4  // license that can be found in the LICENSE file or at
     5  // https://developers.google.com/open-source/licenses/bsd
     6  
     7  package analyzer
     8  
     9  import (
    10  	"go/ast"
    11  	"go/token"
    12  	"go/types"
    13  	"os"
    14  	"path"
    15  	"strings"
    16  
    17  	cpb "github.com/google/capslock/proto"
    18  	"golang.org/x/tools/go/callgraph"
    19  	"golang.org/x/tools/go/callgraph/cha"
    20  	"golang.org/x/tools/go/callgraph/vta"
    21  	"golang.org/x/tools/go/packages"
    22  	"golang.org/x/tools/go/ssa"
    23  	"golang.org/x/tools/go/ssa/ssautil"
    24  )
    25  
    26  type bfsState struct {
    27  	// edge is the callgraph edge leading to the next node in a path to an
    28  	// interesting function.
    29  	edge *callgraph.Edge
    30  }
    31  
    32  // next returns the next node in the path to an interesting function.
    33  func (b bfsState) next() *callgraph.Node {
    34  	if b.edge == nil {
    35  		return nil
    36  	}
    37  	return b.edge.Callee
    38  }
    39  
    40  type nodeset map[*callgraph.Node]struct{}
    41  type nodesetPerCapability map[cpb.Capability]nodeset
    42  
    43  func (nc nodesetPerCapability) add(cap cpb.Capability, node *callgraph.Node) {
    44  	m := nc[cap]
    45  	if m == nil {
    46  		m = make(nodeset)
    47  		nc[cap] = m
    48  	}
    49  	m[node] = struct{}{}
    50  }
    51  
    52  // byFunction is a slice of *callgraph.Node that can be sorted using sort.Sort.
    53  // The ordering is first by package name, then function name.
    54  type byFunction []*callgraph.Node
    55  
    56  func (s byFunction) Len() int { return len(s) }
    57  func (s byFunction) Less(i, j int) bool {
    58  	return nodeCompare(s[i], s[j]) < 0
    59  }
    60  func (s byFunction) Swap(i, j int) {
    61  	s[i], s[j] = s[j], s[i]
    62  }
    63  
    64  // byCaller is a slice of *callgraph.Edge that can be sorted using
    65  // sort.Sort.  It sorts by calling function, then callsite position.
    66  type byCaller []*callgraph.Edge
    67  
    68  func (s byCaller) Len() int { return len(s) }
    69  func (s byCaller) Less(i, j int) bool {
    70  	if c := nodeCompare(s[i].Caller, s[j].Caller); c != 0 {
    71  		return c < 0
    72  	}
    73  	return positionLess(callsitePosition(s[i]), callsitePosition(s[j]))
    74  }
    75  func (s byCaller) Swap(i, j int) {
    76  	s[i], s[j] = s[j], s[i]
    77  }
    78  
    79  func nodeCompare(a, b *callgraph.Node) int {
    80  	return funcCompare(a.Func, b.Func)
    81  }
    82  
    83  // funcCompare orders by package path, then by whether the function is a
    84  // method, then by name.  Returns {-1, 0, +1} in the manner of strings.Compare.
    85  func funcCompare(a, b *ssa.Function) int {
    86  	// Put nils last.
    87  	if a == nil && b == nil {
    88  		return 0
    89  	} else if b == nil {
    90  		return -1
    91  	} else if a == nil {
    92  		return +1
    93  	}
    94  	if c := strings.Compare(packagePath(a), packagePath(b)); c != 0 {
    95  		return c
    96  	}
    97  	hasReceiver := func(f *ssa.Function) bool {
    98  		sig := f.Signature
    99  		return sig != nil && sig.Recv() != nil
   100  	}
   101  	if ar, br := hasReceiver(a), hasReceiver(b); !ar && br {
   102  		return -1
   103  	} else if ar && !br {
   104  		return +1
   105  	}
   106  	return strings.Compare(a.String(), b.String())
   107  }
   108  
   109  // positionLess implements an ordering on token.Position.
   110  // It orders first by filename, then by position in the file.
   111  // Invalid positions are sorted last.
   112  func positionLess(p1, p2 token.Position) bool {
   113  	if p2.Line == 0 {
   114  		// A token.Position with Line == 0 is invalid.
   115  		return p1.Line != 0
   116  	}
   117  	if p1.Line == 0 {
   118  		return false
   119  	}
   120  	if p1.Filename != p2.Filename {
   121  		// Note that two positions from the same function can have different
   122  		// filenames because the ssa.Function for "init" can include
   123  		// initialization code for package-level variables in multiple files.
   124  		return p1.Filename < p2.Filename
   125  	}
   126  	return p1.Offset < p2.Offset
   127  }
   128  
   129  // packagePath returns the name of the package the function belongs to, or
   130  // "" if it has no package.
   131  func packagePath(f *ssa.Function) string {
   132  	// If f is an instantiation of a generic function, use its origin.
   133  	if f.Origin() != nil {
   134  		f = f.Origin()
   135  	}
   136  	if ssaPackage := f.Package(); ssaPackage != nil {
   137  		if typesPackage := ssaPackage.Pkg; typesPackage != nil {
   138  			return typesPackage.Path()
   139  		}
   140  	}
   141  	// Check f.Object() for a package.  This covers the case of synthetic wrapper
   142  	// functions for promoted methods of embedded fields.
   143  	if obj := types.Object(f.Object()); obj != nil {
   144  		if typesPackage := obj.Pkg(); typesPackage != nil {
   145  			return typesPackage.Path()
   146  		}
   147  	}
   148  	return ""
   149  }
   150  
   151  // callsitePosition returns a token.Position for the edge's callsite.
   152  // If edge is nil, or the source is unavailable, the returned token.Position
   153  // will have token.IsValid() == false.
   154  func callsitePosition(edge *callgraph.Edge) token.Position {
   155  	if edge == nil {
   156  		return token.Position{}
   157  	} else if f := edge.Caller.Func; f == nil {
   158  		return token.Position{}
   159  	} else if prog := f.Prog; prog == nil {
   160  		return token.Position{}
   161  	} else if fset := prog.Fset; fset == nil {
   162  		return token.Position{}
   163  	} else {
   164  		return fset.Position(edge.Pos())
   165  	}
   166  }
   167  
   168  func isStdLib(p string) bool {
   169  	if strings.Contains(p, ".") {
   170  		return false
   171  	}
   172  	return true
   173  }
   174  
   175  func buildGraph(pkgs []*packages.Package, populateSyntax bool) (*callgraph.Graph, *ssa.Program, map[*ssa.Function]bool) {
   176  	rewriteCallsToSort(pkgs)
   177  	rewriteCallsToOnceDoEtc(pkgs)
   178  	ssaBuilderMode := ssa.InstantiateGenerics
   179  	if populateSyntax {
   180  		// Debug mode makes ssa.Function.Syntax() point to the ast Node for the
   181  		// function.  This will allow us to link nodes in the callgraph with
   182  		// functions in the syntax tree which convert unsafe.Pointer objects or
   183  		// use the reflect package in notable ways.
   184  		ssaBuilderMode |= ssa.GlobalDebug
   185  	}
   186  	ssaProg, _ := ssautil.AllPackages(pkgs, ssaBuilderMode)
   187  	ssaProg.Build()
   188  	graph := cha.CallGraph(ssaProg)
   189  	allFunctions := ssautil.AllFunctions(ssaProg)
   190  	graph = vta.CallGraph(allFunctions, graph)
   191  	return graph, ssaProg, allFunctions
   192  }
   193  
   194  // functionsToRewrite lists the functions and methods like (*sync.Once).Do that
   195  // rewriteCallsToOnceDoEtc will rewrite to calls to their arguments.
   196  var functionsToRewrite = []matcher{
   197  	&methodMatcher{
   198  		pkg:                         "sync",
   199  		typeName:                    "Once",
   200  		methodName:                  "Do",
   201  		functionTypedParameterIndex: 0,
   202  	},
   203  	&packageFunctionMatcher{
   204  		pkg:                         "sort",
   205  		functionName:                "Slice",
   206  		functionTypedParameterIndex: 1,
   207  	},
   208  	&packageFunctionMatcher{
   209  		pkg:                         "sort",
   210  		functionName:                "SliceStable",
   211  		functionTypedParameterIndex: 1,
   212  	},
   213  }
   214  
   215  type matcher interface {
   216  	// match checks if a CallExpr is a call to a particular function or method
   217  	// that this object is looking for.  If it matches, it returns a particular
   218  	// argument in the call that has a function type.  Otherwise it returns nil.
   219  	match(*types.Info, *ast.CallExpr) ast.Expr
   220  }
   221  
   222  // packageFunctionMatcher objects match a package-scope function.
   223  type packageFunctionMatcher struct {
   224  	pkg                         string
   225  	functionName                string
   226  	functionTypedParameterIndex int
   227  }
   228  
   229  // methodMatcher objects match a method of some type.
   230  type methodMatcher struct {
   231  	pkg                         string
   232  	typeName                    string
   233  	methodName                  string
   234  	functionTypedParameterIndex int
   235  }
   236  
   237  func (m *packageFunctionMatcher) match(typeInfo *types.Info, call *ast.CallExpr) ast.Expr {
   238  	callee, ok := call.Fun.(*ast.SelectorExpr)
   239  	if !ok {
   240  		// The function to be called is not a selection, so it can't be a call to
   241  		// the relevant package.  (Unless the user has dot-imported the package,
   242  		// but we don't need to worry much about false negatives in unusual cases
   243  		// here.)
   244  		return nil
   245  	}
   246  	pkgIdent, ok := callee.X.(*ast.Ident)
   247  	if !ok {
   248  		// The left-hand side of the selection is not a plain identifier.
   249  		return nil
   250  	}
   251  	pkgName, ok := typeInfo.Uses[pkgIdent].(*types.PkgName)
   252  	if !ok {
   253  		// The identifier does not refer to a package.
   254  		return nil
   255  	}
   256  	if pkgName.Imported().Path() != m.pkg {
   257  		// Not the right package.
   258  		return nil
   259  	}
   260  	if name := callee.Sel.Name; name != m.functionName {
   261  		// This isn't the function we're looking for.
   262  		return nil
   263  	}
   264  	if len(call.Args) <= m.functionTypedParameterIndex {
   265  		// The function call doesn't have enough arguments.
   266  		return nil
   267  	}
   268  	return call.Args[m.functionTypedParameterIndex]
   269  }
   270  
   271  // mayHaveSideEffects determines whether an expression might write to a
   272  // variable or call a function.  It can have false positives.  It does not
   273  // consider panicking to be a side effect, so e.g. index expressions do not
   274  // have side effects unless one of its components do.
   275  //
   276  // This is used to determine whether we can delete the expression from the
   277  // syntax tree in isCallToOnceDoEtc.
   278  func mayHaveSideEffects(e ast.Expr) bool {
   279  	switch e := e.(type) {
   280  	case *ast.Ident, *ast.BasicLit:
   281  		return false
   282  	case nil:
   283  		return false // we can reach a nil via *ast.SliceExpr
   284  	case *ast.FuncLit:
   285  		return false // a definition doesn't do anything on its own
   286  	case *ast.CallExpr:
   287  		return true
   288  	case *ast.CompositeLit:
   289  		for _, elt := range e.Elts {
   290  			if mayHaveSideEffects(elt) {
   291  				return true
   292  			}
   293  		}
   294  		return false
   295  	case *ast.ParenExpr:
   296  		return mayHaveSideEffects(e.X)
   297  	case *ast.SelectorExpr:
   298  		return mayHaveSideEffects(e.X)
   299  	case *ast.IndexExpr:
   300  		return mayHaveSideEffects(e.X) || mayHaveSideEffects(e.Index)
   301  	case *ast.IndexListExpr:
   302  		for _, idx := range e.Indices {
   303  			if mayHaveSideEffects(idx) {
   304  				return true
   305  			}
   306  		}
   307  		return mayHaveSideEffects(e.X)
   308  	case *ast.SliceExpr:
   309  		return mayHaveSideEffects(e.X) ||
   310  			mayHaveSideEffects(e.Low) ||
   311  			mayHaveSideEffects(e.High) ||
   312  			mayHaveSideEffects(e.Max)
   313  	case *ast.TypeAssertExpr:
   314  		return mayHaveSideEffects(e.X)
   315  	case *ast.StarExpr:
   316  		return mayHaveSideEffects(e.X)
   317  	case *ast.UnaryExpr:
   318  		return mayHaveSideEffects(e.X)
   319  	case *ast.BinaryExpr:
   320  		return mayHaveSideEffects(e.X) || mayHaveSideEffects(e.Y)
   321  	case *ast.KeyValueExpr:
   322  		return mayHaveSideEffects(e.Key) || mayHaveSideEffects(e.Value)
   323  	}
   324  	return true
   325  }
   326  
   327  func (m *methodMatcher) match(typeInfo *types.Info, call *ast.CallExpr) ast.Expr {
   328  	sel, ok := call.Fun.(*ast.SelectorExpr)
   329  	if !ok {
   330  		return nil
   331  	}
   332  	if mayHaveSideEffects(sel.X) {
   333  		// The expression may be something like foo().Do(bar), which we can't
   334  		// rewrite to a call to bar because then the analysis would not see the
   335  		// call to foo.
   336  		return nil
   337  	}
   338  	calleeType := typeInfo.TypeOf(sel.X)
   339  	if calleeType == nil {
   340  		return nil
   341  	}
   342  	if ptr, ok := calleeType.(*types.Pointer); ok {
   343  		calleeType = ptr.Elem()
   344  	}
   345  	named, ok := calleeType.(*types.Named)
   346  	if !ok {
   347  		return nil
   348  	}
   349  	if named.Obj().Pkg() != nil {
   350  		if pkg := named.Obj().Pkg().Path(); pkg != m.pkg {
   351  			// Not the right package.
   352  			return nil
   353  		}
   354  	}
   355  	if named.Obj().Name() != m.typeName {
   356  		// Not the right type.
   357  		return nil
   358  	}
   359  	if name := sel.Sel.Name; name != m.methodName {
   360  		// Not the right method.
   361  		return nil
   362  	}
   363  	if len(call.Args) <= m.functionTypedParameterIndex {
   364  		// The method call doesn't have enough arguments.
   365  		return nil
   366  	}
   367  	return call.Args[m.functionTypedParameterIndex]
   368  }
   369  
   370  // visitor is passed to ast.Visit, to find AST nodes where
   371  // unsafe.Pointer values are converted to pointers.
   372  // It satisfies the ast.Visitor interface.
   373  type visitor struct {
   374  	// The sets we are populating.
   375  	unsafeFunctionNodes map[ast.Node]struct{}
   376  	// Set to true if an unsafe.Pointer conversion is found that is not inside
   377  	// a function, method, or function literal definition.
   378  	seenUnsafePointerUseInInitialization *bool
   379  	// The Package for the ast Node being visited.  This is used to get type
   380  	// information.
   381  	pkg *packages.Package
   382  	// The node for the current function being visited.  When function definitions
   383  	// are nested, this is the innermost function.
   384  	currentFunction ast.Node // *ast.FuncDecl or *ast.FuncLit
   385  }
   386  
   387  // containsReflectValue returns true if t is reflect.Value, or is a struct
   388  // or array containing reflect.Value.
   389  func containsReflectValue(t types.Type) bool {
   390  	seen := map[types.Type]struct{}{}
   391  	var rec func(t types.Type) bool
   392  	rec = func(t types.Type) bool {
   393  		if t == nil {
   394  			return false
   395  		}
   396  		if t.String() == "reflect.Value" {
   397  			return true
   398  		}
   399  		// avoid an infinite loop if the type is recursive somehow.
   400  		if _, ok := seen[t]; ok {
   401  			return false
   402  		}
   403  		seen[t] = struct{}{}
   404  		// If the underlying type is different, use that.
   405  		if u := t.Underlying(); !types.Identical(t, u) {
   406  			return rec(u)
   407  		}
   408  		// Check fields of structs.
   409  		if s, ok := t.(*types.Struct); ok {
   410  			for i := 0; i < s.NumFields(); i++ {
   411  				if rec(s.Field(i).Type()) {
   412  					return true
   413  				}
   414  			}
   415  		}
   416  		// Check elements of arrays.
   417  		if a, ok := t.(*types.Array); ok {
   418  			return rec(a.Elem())
   419  		}
   420  		return false
   421  	}
   422  	return rec(t)
   423  }
   424  
   425  func (v visitor) Visit(node ast.Node) ast.Visitor {
   426  	if node == nil {
   427  		return v // the return value is ignored if node == nil.
   428  	}
   429  	switch node := node.(type) {
   430  	case *ast.FuncDecl, *ast.FuncLit:
   431  		// The subtree at this node is a function definition or function literal.
   432  		// The visitor returned here is used to visit this node's children, so we
   433  		// return a visitor with the current function set to this node.
   434  		v.currentFunction = node
   435  		return v
   436  	case *ast.CallExpr:
   437  		// A type conversion is represented as a CallExpr node with a Fun that is a
   438  		// type, and Args containing the expression to be converted.
   439  		//
   440  		// If this node has a single argument which is an unsafe.Pointer (or
   441  		// is equivalent to an unsafe.Pointer) and the callee is a type which is not
   442  		// uintptr, we add the current function to v.unsafeFunctionNodes.
   443  		funType := v.pkg.TypesInfo.Types[node.Fun]
   444  		if !funType.IsType() {
   445  			// The callee is not a type; it's probably a function or method.
   446  			break
   447  		}
   448  		if b, ok := funType.Type.Underlying().(*types.Basic); ok && b.Kind() == types.Uintptr {
   449  			// The conversion is to a uintptr, not a pointer.  On its own, this is
   450  			// safe.
   451  			break
   452  		}
   453  		var args []ast.Expr = node.Args
   454  		if len(args) != 1 {
   455  			// There wasn't the right number of arguments.
   456  			break
   457  		}
   458  		argType := v.pkg.TypesInfo.Types[args[0]].Type
   459  		if argType == nil {
   460  			// The argument has no type information.
   461  			break
   462  		}
   463  		if b, ok := argType.Underlying().(*types.Basic); !ok || b.Kind() != types.UnsafePointer {
   464  			// The argument's type is not equivalent to unsafe.Pointer.
   465  			break
   466  		}
   467  		if v.currentFunction == nil {
   468  			*v.seenUnsafePointerUseInInitialization = true
   469  		} else {
   470  			v.unsafeFunctionNodes[v.currentFunction] = struct{}{}
   471  		}
   472  	}
   473  	return v
   474  }
   475  
   476  // forEachPackageIncludingDependencies calls fn exactly once for each package
   477  // that is in pkgs or in the transitive dependencies of pkgs.
   478  func forEachPackageIncludingDependencies(pkgs []*packages.Package, fn func(*packages.Package)) {
   479  	visitedPackages := make(map[*packages.Package]struct{})
   480  	var visit func(p *packages.Package)
   481  	visit = func(p *packages.Package) {
   482  		if _, ok := visitedPackages[p]; ok {
   483  			return
   484  		}
   485  		visitedPackages[p] = struct{}{}
   486  		for _, p2 := range p.Imports {
   487  			visit(p2)
   488  		}
   489  		fn(p)
   490  	}
   491  	for _, p := range pkgs {
   492  		visit(p)
   493  	}
   494  }
   495  
   496  func programName() string {
   497  	if a := os.Args; len(a) >= 1 {
   498  		return path.Base(a[0])
   499  	}
   500  	return "capslock"
   501  }