github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/analysis/callcheck/callcheck.go (about) 1 // Package callcheck provides a framework for validating arguments in function calls. 2 package callcheck 3 4 import ( 5 "fmt" 6 "go/ast" 7 "go/constant" 8 "go/types" 9 10 "github.com/amarpal/go-tools/analysis/report" 11 "github.com/amarpal/go-tools/go/ir" 12 "github.com/amarpal/go-tools/go/ir/irutil" 13 "github.com/amarpal/go-tools/go/types/typeutil" 14 "github.com/amarpal/go-tools/internal/passes/buildir" 15 "golang.org/x/tools/go/analysis" 16 ) 17 18 type Call struct { 19 Pass *analysis.Pass 20 Instr ir.CallInstruction 21 Args []*Argument 22 23 Parent *ir.Function 24 25 invalids []string 26 } 27 28 func (c *Call) Invalid(msg string) { 29 c.invalids = append(c.invalids, msg) 30 } 31 32 type Argument struct { 33 Value Value 34 invalids []string 35 } 36 37 type Value struct { 38 Value ir.Value 39 } 40 41 func (arg *Argument) Invalid(msg string) { 42 arg.invalids = append(arg.invalids, msg) 43 } 44 45 type Check func(call *Call) 46 47 func Analyzer(rules map[string]Check) func(pass *analysis.Pass) (interface{}, error) { 48 return func(pass *analysis.Pass) (interface{}, error) { 49 return checkCalls(pass, rules) 50 } 51 } 52 53 func checkCalls(pass *analysis.Pass, rules map[string]Check) (interface{}, error) { 54 cb := func(caller *ir.Function, site ir.CallInstruction, callee *ir.Function) { 55 obj, ok := callee.Object().(*types.Func) 56 if !ok { 57 return 58 } 59 60 r, ok := rules[typeutil.FuncName(obj)] 61 if !ok { 62 return 63 } 64 var args []*Argument 65 irargs := site.Common().Args 66 if callee.Signature.Recv() != nil { 67 irargs = irargs[1:] 68 } 69 for _, arg := range irargs { 70 if iarg, ok := arg.(*ir.MakeInterface); ok { 71 arg = iarg.X 72 } 73 args = append(args, &Argument{Value: Value{arg}}) 74 } 75 call := &Call{ 76 Pass: pass, 77 Instr: site, 78 Args: args, 79 Parent: site.Parent(), 80 } 81 r(call) 82 83 var astcall *ast.CallExpr 84 switch source := site.Source().(type) { 85 case *ast.CallExpr: 86 astcall = source 87 case *ast.DeferStmt: 88 astcall = source.Call 89 case *ast.GoStmt: 90 astcall = source.Call 91 case nil: 92 // TODO(dh): I am not sure this can actually happen. If it 93 // can't, we should remove this case, and also stop 94 // checking for astcall == nil in the code that follows. 95 default: 96 panic(fmt.Sprintf("unhandled case %T", source)) 97 } 98 99 for idx, arg := range call.Args { 100 for _, e := range arg.invalids { 101 if astcall != nil { 102 if idx < len(astcall.Args) { 103 report.Report(pass, astcall.Args[idx], e) 104 } else { 105 // this is an instance of fn1(fn2()) where fn2 106 // returns multiple values. Report the error 107 // at the next-best position that we have, the 108 // first argument. An example of a check that 109 // triggers this is checkEncodingBinaryRules. 110 report.Report(pass, astcall.Args[0], e) 111 } 112 } else { 113 report.Report(pass, site, e) 114 } 115 } 116 } 117 for _, e := range call.invalids { 118 report.Report(pass, call.Instr, e) 119 } 120 } 121 for _, fn := range pass.ResultOf[buildir.Analyzer].(*buildir.IR).SrcFuncs { 122 eachCall(fn, cb) 123 } 124 return nil, nil 125 } 126 127 func eachCall(fn *ir.Function, cb func(caller *ir.Function, site ir.CallInstruction, callee *ir.Function)) { 128 for _, b := range fn.Blocks { 129 for _, instr := range b.Instrs { 130 if site, ok := instr.(ir.CallInstruction); ok { 131 if g := site.Common().StaticCallee(); g != nil { 132 cb(fn, site, g) 133 } 134 } 135 } 136 } 137 } 138 139 func ExtractConstExpectKind(v Value, kind constant.Kind) *ir.Const { 140 k := extractConst(v.Value) 141 if k == nil || k.Value == nil || k.Value.Kind() != kind { 142 return nil 143 } 144 return k 145 } 146 147 func ExtractConst(v Value) *ir.Const { 148 return extractConst(v.Value) 149 } 150 151 func extractConst(v ir.Value) *ir.Const { 152 v = irutil.Flatten(v) 153 switch v := v.(type) { 154 case *ir.Const: 155 return v 156 case *ir.MakeInterface: 157 return extractConst(v.X) 158 default: 159 return nil 160 } 161 }