github.com/google/capslock@v0.2.3-0.20240517042941-dac19fc347c0/analyzer/rewrite.go (about) 1 // Copyright 2023 Google LLC 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file or at 5 // https://developers.google.com/open-source/licenses/bsd 6 7 package analyzer 8 9 import ( 10 "fmt" 11 "go/ast" 12 "go/constant" 13 "go/token" 14 "go/types" 15 "unsafe" 16 17 "golang.org/x/tools/go/ast/astutil" 18 "golang.org/x/tools/go/packages" 19 ) 20 21 // operandMode has the same layout as types.operandMode. 22 type operandMode byte 23 24 const ( 25 noValueMode operandMode = 1 26 constantMode operandMode = 4 27 valueMode operandMode = 7 28 ) 29 30 // constructTypeAndValue constructs a types.TypeAndValue. These are used in 31 // the types.Info.Types map to store the known types of expressions, and the 32 // values of constant expressions. 33 func constructTypeAndValue(mode operandMode, t types.Type, v constant.Value) types.TypeAndValue { 34 // The mode field of types.TypeAndValue is not exported, so we make our own 35 // copy of the type definition, and use unsafe conversion to get the type we 36 // want. 37 tv := struct { 38 mode operandMode 39 Types types.Type 40 Value constant.Value 41 }{mode, t, v} 42 return *(*types.TypeAndValue)(unsafe.Pointer(&tv)) 43 } 44 45 // typeAndValueForResults constructs a TypeAndValue corresponding to the return 46 // values of a function. 47 func typeAndValueForResults(results *types.Tuple) types.TypeAndValue { 48 if results == nil { 49 // Case 1: the function has no return values. 50 return constructTypeAndValue(noValueMode, results, nil) 51 } 52 if results.Len() == 1 { 53 // Case 2: the function has a single return value. 54 return constructTypeAndValue(valueMode, results.At(0).Type(), nil) 55 } 56 // Case 3: the function returns a tuple of more than one value. 57 return constructTypeAndValue(valueMode, results, nil) 58 } 59 60 // zeroLiteral creates and returns a zero literal of type int, and adds its 61 // type information to typeInfo.Types. 62 func zeroLiteral(typeInfo *types.Info) ast.Expr { 63 expr := &ast.BasicLit{Kind: token.INT, Value: "0"} 64 typeInfo.Types[expr] = constructTypeAndValue(constantMode, types.Typ[types.Int], constant.MakeInt64(0)) 65 return expr 66 } 67 68 // selectionForMethod finds the Selection object for the given method. 69 func selectionForMethod(typ types.Type, name string) *types.Selection { 70 var ms *types.MethodSet = types.NewMethodSet(typ) 71 // The package is not needed for exported methods, so we can pass nil for the 72 // package parameter of Lookup. 73 sel := ms.Lookup(nil, name) 74 return sel 75 } 76 77 // rewriteCallsToSort iterates through the packages in pkgs, including all 78 // transitively-imported packages, and finds calls to sort.Sort, sort.Stable, 79 // and sort.IsSorted, which each have a sort.Interface parameter. We replace 80 // each of these calls with a set of calls to each of the interface methods 81 // individually (Len, Less, and Swap.) e.g., this code: 82 // 83 // sort.Sort(xs) 84 // 85 // would be replaced with: 86 // 87 // xs.Len() 88 // xs.Less(0,0) 89 // xs.Swap(0,0) 90 // 91 // This improves the precision of the callgraph the analysis produces. The 92 // analysis produces a set of possible dynamic types for the sort.Interface 93 // value, and adds a callgraph edge to the methods for each of those. 94 // 95 // Without this change to the callgraph, we would get paths to the 96 // sort.Interface methods for every possible dynamic type for all the values 97 // passed to the same sort function anywhere in the program, which can result 98 // in a large number of false positives. 99 func rewriteCallsToSort(pkgs []*packages.Package) { 100 forEachPackageIncludingDependencies(pkgs, func(p *packages.Package) { 101 for _, file := range p.Syntax { 102 for _, node := range file.Decls { 103 var pre astutil.ApplyFunc 104 pre = func(c *astutil.Cursor) bool { 105 // If the current node, c.Node(), is a call to sort.Sort (or 106 // sort.Stable or sort.IsSorted), replace it with calls to 107 // obj.Less, obj.Swap, and obj.Len, where obj is the argument 108 // that was passed to sort. 109 if _, ok := c.Node().(ast.Stmt); !ok { 110 // c.Node() is not a statement. 111 return true 112 } 113 canRewrite := false 114 switch c.Parent().(type) { 115 case *ast.BlockStmt, *ast.CaseClause, *ast.LabeledStmt: 116 canRewrite = true 117 case *ast.CommClause: 118 canRewrite = c.Index() >= 0 119 } 120 if !canRewrite { 121 // The statement is in a position in the syntax tree where it 122 // can't be replaced with a block or with multiple statements, so 123 // we give up. 124 return true 125 } 126 127 obj := isCallToSort(p.TypesInfo, c.Node()) 128 if obj == nil { 129 // This was not a call to a sort function. 130 // 131 // We always return true from this function, because the return 132 // value indicates to astutil.Apply whether to keep searching. 133 return true 134 } 135 // Less and Swap each take two integer arguments. The values aren't 136 // important for our callgraph analysis -- we do not look at values 137 // to determine which way an if statement branches, for example -- 138 // so we just use two zeroes. 139 args1 := []ast.Expr{zeroLiteral(p.TypesInfo), zeroLiteral(p.TypesInfo)} 140 args2 := []ast.Expr{zeroLiteral(p.TypesInfo), zeroLiteral(p.TypesInfo)} 141 // Create a block with three statements which call Less, Swap, 142 // and Len. Replace the current node with this block. 143 s1 := statementCallingMethod(p.TypesInfo, obj, "Less", args1) 144 s2 := statementCallingMethod(p.TypesInfo, obj, "Swap", args2) 145 s3 := statementCallingMethod(p.TypesInfo, obj, "Len", nil) 146 if s1 == nil || s2 == nil || s3 == nil { 147 // We did not succeed in creating these statements. 148 return true 149 } 150 c.Replace(&ast.BlockStmt{List: []ast.Stmt{s1, s2, s3}}) 151 return true 152 } 153 astutil.Apply(node, pre, nil) 154 } 155 } 156 }) 157 } 158 159 // rewriteCallsToOnceDoEtc is similar to rewriteCallsToSort. It finds calls 160 // to some standard-library functions and methods which have a function 161 // parameter, and changes those calls to call the function argument directly 162 // instead. 163 // 164 // e.g. this code: 165 // 166 // var myonce *sync.Once = ... 167 // myonce.Do(fn) 168 // 169 // would be replaced with: 170 // 171 // var myonce *sync.Once = ... 172 // fn() 173 func rewriteCallsToOnceDoEtc(pkgs []*packages.Package) { 174 forEachPackageIncludingDependencies(pkgs, func(p *packages.Package) { 175 for _, file := range p.Syntax { 176 for _, node := range file.Decls { 177 var pre astutil.ApplyFunc 178 pre = func(c *astutil.Cursor) bool { 179 obj := isCallToOnceDoEtc(p.TypesInfo, c.Node()) 180 if obj == nil { 181 // This was not a call to a relevant function or method. 182 return true 183 } 184 fnType, ok := p.TypesInfo.TypeOf(obj).(*types.Signature) 185 if !ok { 186 // The argument does not appear to be a function. 187 return true 188 } 189 // Create some arguments to pass to the function. The parameters 190 // must all be integers. 191 params := fnType.Params() 192 args := make([]ast.Expr, params.Len()) 193 for i := range args { 194 args[i] = zeroLiteral(p.TypesInfo) 195 } 196 c.Replace( 197 statementCallingFunctionObject(p.TypesInfo, obj, args)) 198 return true 199 } 200 astutil.Apply(node, pre, nil) 201 } 202 } 203 }) 204 } 205 206 // isCallToSort checks if node is a statement calling sort.Sort, sort.Stable, 207 // or sort.IsSorted. If so, it returns the argument to that function. 208 // Otherwise, it returns nil. 209 func isCallToSort(typeInfo *types.Info, node ast.Node) ast.Expr { 210 expr, ok := node.(*ast.ExprStmt) 211 if !ok { 212 // Not a statement node. 213 return nil 214 } 215 call, ok := expr.X.(*ast.CallExpr) 216 if !ok { 217 // Not a function call. 218 return nil 219 } 220 callee, ok := call.Fun.(*ast.SelectorExpr) 221 if !ok { 222 // The function to be called is not a selection, so it can't be a call to 223 // the sort package. (Unless the user has dot-imported "sort", but we 224 // don't need to worry much about false negatives in unusual cases here.) 225 return nil 226 } 227 pkgIdent, ok := callee.X.(*ast.Ident) 228 if !ok { 229 // The left-hand-side of the selection is not a plain identifier. 230 return nil 231 } 232 pkgName, ok := typeInfo.Uses[pkgIdent].(*types.PkgName) 233 if !ok { 234 // The identifier does not refer to a package. 235 return nil 236 } 237 if pkgName.Imported().Path() != "sort" { 238 // The package isn't "sort". (We use Imported().Path() because the import 239 // name could be misleading, e.g.: 240 // import ( 241 // sort "os" 242 // ) 243 return nil 244 } 245 if name := callee.Sel.Name; name != "Sort" && name != "Stable" && name != "IsSorted" { 246 // This isn't one of the functions we're looking for. 247 return nil 248 } 249 if len(call.Args) != 1 { 250 // The function call doesn't have one argument. 251 return nil 252 } 253 return call.Args[0] 254 } 255 256 // isCallToOnceDoEtc checks if node is a statement calling a function or method 257 // like (*sync.Once).Do. If so, it returns the function-typed argument to that 258 // function. Otherwise, it returns nil. 259 func isCallToOnceDoEtc(typeInfo *types.Info, node ast.Node) ast.Expr { 260 expr, ok := node.(*ast.ExprStmt) 261 if !ok { 262 // Not a statement node. 263 return nil 264 } 265 call, ok := expr.X.(*ast.CallExpr) 266 if !ok { 267 // Not a call expression. 268 return nil 269 } 270 for _, m := range functionsToRewrite { 271 if e := m.match(typeInfo, call); e != nil { 272 return e 273 } 274 } 275 return nil 276 } 277 278 // statementCallingMethod constructs a statement that calls a method. The 279 // receiver is recv, the method name is methodName, and the arguments passed 280 // to the call are in args. 281 // 282 // New AST structures that are created by statementCallingMethod are added 283 // to the Types, Selections and Uses fields of typeInfo as needed. The 284 // expressions in methodName and args should already be in typeInfo. 285 // 286 // If the statement cannot be created, returns nil. 287 func statementCallingMethod(typeInfo *types.Info, recv ast.Expr, methodName string, args []ast.Expr) *ast.ExprStmt { 288 // Construct an ast node for the method name, and add it to typeInfo.Uses. 289 methodIdent := ast.NewIdent(methodName) 290 var selection *types.Selection = selectionForMethod(typeInfo.TypeOf(recv), methodName) 291 if selection == nil { 292 // We did not find the desired method for this type. recv might be an 293 // untyped nil. 294 return nil 295 } 296 typeInfo.Uses[methodIdent] = selection.Obj() 297 // Construct an ast node for the selection (e.g. "v.M"), and add it to 298 // typeInfo.Selections and typeInfo.Types. 299 selectorExpr := &ast.SelectorExpr{X: recv, Sel: methodIdent} 300 typeInfo.Selections[selectorExpr] = selection 301 typeInfo.Types[selectorExpr] = constructTypeAndValue(valueMode, selection.Type(), nil) 302 // Construct an ast node for the call (e.g. "v.M(arg1, arg2)") and add it 303 // to typeInfo.Types. 304 callExpr := &ast.CallExpr{Fun: selectorExpr, Args: append([]ast.Expr(nil), args...)} 305 typeInfo.Types[callExpr] = typeAndValueForResults(selection.Type().(*types.Signature).Results()) 306 // Return an ast node for a statement which is just the call. No type 307 // information is needed for statements. 308 return &ast.ExprStmt{X: callExpr} 309 } 310 311 // statementCallingFunctionObject constructs a statement that calls a function. 312 // 313 // New AST structures that are created by statementCallingFunctionObject are 314 // added to the Types fields of typeInfo as needed. The expressions in fn and 315 // args should already be in typeInfo. 316 func statementCallingFunctionObject(typeInfo *types.Info, fn ast.Expr, args []ast.Expr) *ast.ExprStmt { 317 // Construct an ast node for the call and add it to typeInfo.Types. 318 callExpr := &ast.CallExpr{Fun: fn, Args: append([]ast.Expr(nil), args...)} 319 fnType := typeInfo.TypeOf(fn) 320 fnTypeSignature, _ := fnType.(*types.Signature) 321 if fnTypeSignature == nil { 322 panic(fmt.Sprintf("cannot get type signature of function %v", fn)) 323 } 324 typeInfo.Types[callExpr] = typeAndValueForResults(fnTypeSignature.Results()) 325 // Return an ast node for a statement which is just the call. No type 326 // information is needed for statements. 327 return &ast.ExprStmt{X: callExpr} 328 }