github.com/google/capslock@v0.2.3-0.20240517042941-dac19fc347c0/analyzer/util.go (about) 1 // Copyright 2023 Google LLC 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file or at 5 // https://developers.google.com/open-source/licenses/bsd 6 7 package analyzer 8 9 import ( 10 "go/ast" 11 "go/token" 12 "go/types" 13 "os" 14 "path" 15 "strings" 16 17 cpb "github.com/google/capslock/proto" 18 "golang.org/x/tools/go/callgraph" 19 "golang.org/x/tools/go/callgraph/cha" 20 "golang.org/x/tools/go/callgraph/vta" 21 "golang.org/x/tools/go/packages" 22 "golang.org/x/tools/go/ssa" 23 "golang.org/x/tools/go/ssa/ssautil" 24 ) 25 26 type bfsState struct { 27 // edge is the callgraph edge leading to the next node in a path to an 28 // interesting function. 29 edge *callgraph.Edge 30 } 31 32 // next returns the next node in the path to an interesting function. 33 func (b bfsState) next() *callgraph.Node { 34 if b.edge == nil { 35 return nil 36 } 37 return b.edge.Callee 38 } 39 40 type nodeset map[*callgraph.Node]struct{} 41 type nodesetPerCapability map[cpb.Capability]nodeset 42 43 func (nc nodesetPerCapability) add(cap cpb.Capability, node *callgraph.Node) { 44 m := nc[cap] 45 if m == nil { 46 m = make(nodeset) 47 nc[cap] = m 48 } 49 m[node] = struct{}{} 50 } 51 52 // byFunction is a slice of *callgraph.Node that can be sorted using sort.Sort. 53 // The ordering is first by package name, then function name. 54 type byFunction []*callgraph.Node 55 56 func (s byFunction) Len() int { return len(s) } 57 func (s byFunction) Less(i, j int) bool { 58 return nodeCompare(s[i], s[j]) < 0 59 } 60 func (s byFunction) Swap(i, j int) { 61 s[i], s[j] = s[j], s[i] 62 } 63 64 // byCaller is a slice of *callgraph.Edge that can be sorted using 65 // sort.Sort. It sorts by calling function, then callsite position. 66 type byCaller []*callgraph.Edge 67 68 func (s byCaller) Len() int { return len(s) } 69 func (s byCaller) Less(i, j int) bool { 70 if c := nodeCompare(s[i].Caller, s[j].Caller); c != 0 { 71 return c < 0 72 } 73 return positionLess(callsitePosition(s[i]), callsitePosition(s[j])) 74 } 75 func (s byCaller) Swap(i, j int) { 76 s[i], s[j] = s[j], s[i] 77 } 78 79 func nodeCompare(a, b *callgraph.Node) int { 80 return funcCompare(a.Func, b.Func) 81 } 82 83 // funcCompare orders by package path, then by whether the function is a 84 // method, then by name. Returns {-1, 0, +1} in the manner of strings.Compare. 85 func funcCompare(a, b *ssa.Function) int { 86 // Put nils last. 87 if a == nil && b == nil { 88 return 0 89 } else if b == nil { 90 return -1 91 } else if a == nil { 92 return +1 93 } 94 if c := strings.Compare(packagePath(a), packagePath(b)); c != 0 { 95 return c 96 } 97 hasReceiver := func(f *ssa.Function) bool { 98 sig := f.Signature 99 return sig != nil && sig.Recv() != nil 100 } 101 if ar, br := hasReceiver(a), hasReceiver(b); !ar && br { 102 return -1 103 } else if ar && !br { 104 return +1 105 } 106 return strings.Compare(a.String(), b.String()) 107 } 108 109 // positionLess implements an ordering on token.Position. 110 // It orders first by filename, then by position in the file. 111 // Invalid positions are sorted last. 112 func positionLess(p1, p2 token.Position) bool { 113 if p2.Line == 0 { 114 // A token.Position with Line == 0 is invalid. 115 return p1.Line != 0 116 } 117 if p1.Line == 0 { 118 return false 119 } 120 if p1.Filename != p2.Filename { 121 // Note that two positions from the same function can have different 122 // filenames because the ssa.Function for "init" can include 123 // initialization code for package-level variables in multiple files. 124 return p1.Filename < p2.Filename 125 } 126 return p1.Offset < p2.Offset 127 } 128 129 // packagePath returns the name of the package the function belongs to, or 130 // "" if it has no package. 131 func packagePath(f *ssa.Function) string { 132 // If f is an instantiation of a generic function, use its origin. 133 if f.Origin() != nil { 134 f = f.Origin() 135 } 136 if ssaPackage := f.Package(); ssaPackage != nil { 137 if typesPackage := ssaPackage.Pkg; typesPackage != nil { 138 return typesPackage.Path() 139 } 140 } 141 // Check f.Object() for a package. This covers the case of synthetic wrapper 142 // functions for promoted methods of embedded fields. 143 if obj := types.Object(f.Object()); obj != nil { 144 if typesPackage := obj.Pkg(); typesPackage != nil { 145 return typesPackage.Path() 146 } 147 } 148 return "" 149 } 150 151 // callsitePosition returns a token.Position for the edge's callsite. 152 // If edge is nil, or the source is unavailable, the returned token.Position 153 // will have token.IsValid() == false. 154 func callsitePosition(edge *callgraph.Edge) token.Position { 155 if edge == nil { 156 return token.Position{} 157 } else if f := edge.Caller.Func; f == nil { 158 return token.Position{} 159 } else if prog := f.Prog; prog == nil { 160 return token.Position{} 161 } else if fset := prog.Fset; fset == nil { 162 return token.Position{} 163 } else { 164 return fset.Position(edge.Pos()) 165 } 166 } 167 168 func isStdLib(p string) bool { 169 if strings.Contains(p, ".") { 170 return false 171 } 172 return true 173 } 174 175 func buildGraph(pkgs []*packages.Package, populateSyntax bool) (*callgraph.Graph, *ssa.Program, map[*ssa.Function]bool) { 176 rewriteCallsToSort(pkgs) 177 rewriteCallsToOnceDoEtc(pkgs) 178 ssaBuilderMode := ssa.InstantiateGenerics 179 if populateSyntax { 180 // Debug mode makes ssa.Function.Syntax() point to the ast Node for the 181 // function. This will allow us to link nodes in the callgraph with 182 // functions in the syntax tree which convert unsafe.Pointer objects or 183 // use the reflect package in notable ways. 184 ssaBuilderMode |= ssa.GlobalDebug 185 } 186 ssaProg, _ := ssautil.AllPackages(pkgs, ssaBuilderMode) 187 ssaProg.Build() 188 graph := cha.CallGraph(ssaProg) 189 allFunctions := ssautil.AllFunctions(ssaProg) 190 graph = vta.CallGraph(allFunctions, graph) 191 return graph, ssaProg, allFunctions 192 } 193 194 // functionsToRewrite lists the functions and methods like (*sync.Once).Do that 195 // rewriteCallsToOnceDoEtc will rewrite to calls to their arguments. 196 var functionsToRewrite = []matcher{ 197 &methodMatcher{ 198 pkg: "sync", 199 typeName: "Once", 200 methodName: "Do", 201 functionTypedParameterIndex: 0, 202 }, 203 &packageFunctionMatcher{ 204 pkg: "sort", 205 functionName: "Slice", 206 functionTypedParameterIndex: 1, 207 }, 208 &packageFunctionMatcher{ 209 pkg: "sort", 210 functionName: "SliceStable", 211 functionTypedParameterIndex: 1, 212 }, 213 } 214 215 type matcher interface { 216 // match checks if a CallExpr is a call to a particular function or method 217 // that this object is looking for. If it matches, it returns a particular 218 // argument in the call that has a function type. Otherwise it returns nil. 219 match(*types.Info, *ast.CallExpr) ast.Expr 220 } 221 222 // packageFunctionMatcher objects match a package-scope function. 223 type packageFunctionMatcher struct { 224 pkg string 225 functionName string 226 functionTypedParameterIndex int 227 } 228 229 // methodMatcher objects match a method of some type. 230 type methodMatcher struct { 231 pkg string 232 typeName string 233 methodName string 234 functionTypedParameterIndex int 235 } 236 237 func (m *packageFunctionMatcher) match(typeInfo *types.Info, call *ast.CallExpr) ast.Expr { 238 callee, ok := call.Fun.(*ast.SelectorExpr) 239 if !ok { 240 // The function to be called is not a selection, so it can't be a call to 241 // the relevant package. (Unless the user has dot-imported the package, 242 // but we don't need to worry much about false negatives in unusual cases 243 // here.) 244 return nil 245 } 246 pkgIdent, ok := callee.X.(*ast.Ident) 247 if !ok { 248 // The left-hand side of the selection is not a plain identifier. 249 return nil 250 } 251 pkgName, ok := typeInfo.Uses[pkgIdent].(*types.PkgName) 252 if !ok { 253 // The identifier does not refer to a package. 254 return nil 255 } 256 if pkgName.Imported().Path() != m.pkg { 257 // Not the right package. 258 return nil 259 } 260 if name := callee.Sel.Name; name != m.functionName { 261 // This isn't the function we're looking for. 262 return nil 263 } 264 if len(call.Args) <= m.functionTypedParameterIndex { 265 // The function call doesn't have enough arguments. 266 return nil 267 } 268 return call.Args[m.functionTypedParameterIndex] 269 } 270 271 // mayHaveSideEffects determines whether an expression might write to a 272 // variable or call a function. It can have false positives. It does not 273 // consider panicking to be a side effect, so e.g. index expressions do not 274 // have side effects unless one of its components do. 275 // 276 // This is used to determine whether we can delete the expression from the 277 // syntax tree in isCallToOnceDoEtc. 278 func mayHaveSideEffects(e ast.Expr) bool { 279 switch e := e.(type) { 280 case *ast.Ident, *ast.BasicLit: 281 return false 282 case nil: 283 return false // we can reach a nil via *ast.SliceExpr 284 case *ast.FuncLit: 285 return false // a definition doesn't do anything on its own 286 case *ast.CallExpr: 287 return true 288 case *ast.CompositeLit: 289 for _, elt := range e.Elts { 290 if mayHaveSideEffects(elt) { 291 return true 292 } 293 } 294 return false 295 case *ast.ParenExpr: 296 return mayHaveSideEffects(e.X) 297 case *ast.SelectorExpr: 298 return mayHaveSideEffects(e.X) 299 case *ast.IndexExpr: 300 return mayHaveSideEffects(e.X) || mayHaveSideEffects(e.Index) 301 case *ast.IndexListExpr: 302 for _, idx := range e.Indices { 303 if mayHaveSideEffects(idx) { 304 return true 305 } 306 } 307 return mayHaveSideEffects(e.X) 308 case *ast.SliceExpr: 309 return mayHaveSideEffects(e.X) || 310 mayHaveSideEffects(e.Low) || 311 mayHaveSideEffects(e.High) || 312 mayHaveSideEffects(e.Max) 313 case *ast.TypeAssertExpr: 314 return mayHaveSideEffects(e.X) 315 case *ast.StarExpr: 316 return mayHaveSideEffects(e.X) 317 case *ast.UnaryExpr: 318 return mayHaveSideEffects(e.X) 319 case *ast.BinaryExpr: 320 return mayHaveSideEffects(e.X) || mayHaveSideEffects(e.Y) 321 case *ast.KeyValueExpr: 322 return mayHaveSideEffects(e.Key) || mayHaveSideEffects(e.Value) 323 } 324 return true 325 } 326 327 func (m *methodMatcher) match(typeInfo *types.Info, call *ast.CallExpr) ast.Expr { 328 sel, ok := call.Fun.(*ast.SelectorExpr) 329 if !ok { 330 return nil 331 } 332 if mayHaveSideEffects(sel.X) { 333 // The expression may be something like foo().Do(bar), which we can't 334 // rewrite to a call to bar because then the analysis would not see the 335 // call to foo. 336 return nil 337 } 338 calleeType := typeInfo.TypeOf(sel.X) 339 if calleeType == nil { 340 return nil 341 } 342 if ptr, ok := calleeType.(*types.Pointer); ok { 343 calleeType = ptr.Elem() 344 } 345 named, ok := calleeType.(*types.Named) 346 if !ok { 347 return nil 348 } 349 if named.Obj().Pkg() != nil { 350 if pkg := named.Obj().Pkg().Path(); pkg != m.pkg { 351 // Not the right package. 352 return nil 353 } 354 } 355 if named.Obj().Name() != m.typeName { 356 // Not the right type. 357 return nil 358 } 359 if name := sel.Sel.Name; name != m.methodName { 360 // Not the right method. 361 return nil 362 } 363 if len(call.Args) <= m.functionTypedParameterIndex { 364 // The method call doesn't have enough arguments. 365 return nil 366 } 367 return call.Args[m.functionTypedParameterIndex] 368 } 369 370 // visitor is passed to ast.Visit, to find AST nodes where 371 // unsafe.Pointer values are converted to pointers. 372 // It satisfies the ast.Visitor interface. 373 type visitor struct { 374 // The sets we are populating. 375 unsafeFunctionNodes map[ast.Node]struct{} 376 // Set to true if an unsafe.Pointer conversion is found that is not inside 377 // a function, method, or function literal definition. 378 seenUnsafePointerUseInInitialization *bool 379 // The Package for the ast Node being visited. This is used to get type 380 // information. 381 pkg *packages.Package 382 // The node for the current function being visited. When function definitions 383 // are nested, this is the innermost function. 384 currentFunction ast.Node // *ast.FuncDecl or *ast.FuncLit 385 } 386 387 // containsReflectValue returns true if t is reflect.Value, or is a struct 388 // or array containing reflect.Value. 389 func containsReflectValue(t types.Type) bool { 390 seen := map[types.Type]struct{}{} 391 var rec func(t types.Type) bool 392 rec = func(t types.Type) bool { 393 if t == nil { 394 return false 395 } 396 if t.String() == "reflect.Value" { 397 return true 398 } 399 // avoid an infinite loop if the type is recursive somehow. 400 if _, ok := seen[t]; ok { 401 return false 402 } 403 seen[t] = struct{}{} 404 // If the underlying type is different, use that. 405 if u := t.Underlying(); !types.Identical(t, u) { 406 return rec(u) 407 } 408 // Check fields of structs. 409 if s, ok := t.(*types.Struct); ok { 410 for i := 0; i < s.NumFields(); i++ { 411 if rec(s.Field(i).Type()) { 412 return true 413 } 414 } 415 } 416 // Check elements of arrays. 417 if a, ok := t.(*types.Array); ok { 418 return rec(a.Elem()) 419 } 420 return false 421 } 422 return rec(t) 423 } 424 425 func (v visitor) Visit(node ast.Node) ast.Visitor { 426 if node == nil { 427 return v // the return value is ignored if node == nil. 428 } 429 switch node := node.(type) { 430 case *ast.FuncDecl, *ast.FuncLit: 431 // The subtree at this node is a function definition or function literal. 432 // The visitor returned here is used to visit this node's children, so we 433 // return a visitor with the current function set to this node. 434 v.currentFunction = node 435 return v 436 case *ast.CallExpr: 437 // A type conversion is represented as a CallExpr node with a Fun that is a 438 // type, and Args containing the expression to be converted. 439 // 440 // If this node has a single argument which is an unsafe.Pointer (or 441 // is equivalent to an unsafe.Pointer) and the callee is a type which is not 442 // uintptr, we add the current function to v.unsafeFunctionNodes. 443 funType := v.pkg.TypesInfo.Types[node.Fun] 444 if !funType.IsType() { 445 // The callee is not a type; it's probably a function or method. 446 break 447 } 448 if b, ok := funType.Type.Underlying().(*types.Basic); ok && b.Kind() == types.Uintptr { 449 // The conversion is to a uintptr, not a pointer. On its own, this is 450 // safe. 451 break 452 } 453 var args []ast.Expr = node.Args 454 if len(args) != 1 { 455 // There wasn't the right number of arguments. 456 break 457 } 458 argType := v.pkg.TypesInfo.Types[args[0]].Type 459 if argType == nil { 460 // The argument has no type information. 461 break 462 } 463 if b, ok := argType.Underlying().(*types.Basic); !ok || b.Kind() != types.UnsafePointer { 464 // The argument's type is not equivalent to unsafe.Pointer. 465 break 466 } 467 if v.currentFunction == nil { 468 *v.seenUnsafePointerUseInInitialization = true 469 } else { 470 v.unsafeFunctionNodes[v.currentFunction] = struct{}{} 471 } 472 } 473 return v 474 } 475 476 // forEachPackageIncludingDependencies calls fn exactly once for each package 477 // that is in pkgs or in the transitive dependencies of pkgs. 478 func forEachPackageIncludingDependencies(pkgs []*packages.Package, fn func(*packages.Package)) { 479 visitedPackages := make(map[*packages.Package]struct{}) 480 var visit func(p *packages.Package) 481 visit = func(p *packages.Package) { 482 if _, ok := visitedPackages[p]; ok { 483 return 484 } 485 visitedPackages[p] = struct{}{} 486 for _, p2 := range p.Imports { 487 visit(p2) 488 } 489 fn(p) 490 } 491 for _, p := range pkgs { 492 visit(p) 493 } 494 } 495 496 func programName() string { 497 if a := os.Args; len(a) >= 1 { 498 return path.Base(a[0]) 499 } 500 return "capslock" 501 }