github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/analysis/callcheck/callcheck.go (about)

     1  // Package callcheck provides a framework for validating arguments in function calls.
     2  package callcheck
     3  
     4  import (
     5  	"fmt"
     6  	"go/ast"
     7  	"go/constant"
     8  	"go/types"
     9  
    10  	"github.com/amarpal/go-tools/analysis/report"
    11  	"github.com/amarpal/go-tools/go/ir"
    12  	"github.com/amarpal/go-tools/go/ir/irutil"
    13  	"github.com/amarpal/go-tools/go/types/typeutil"
    14  	"github.com/amarpal/go-tools/internal/passes/buildir"
    15  	"golang.org/x/tools/go/analysis"
    16  )
    17  
    18  type Call struct {
    19  	Pass  *analysis.Pass
    20  	Instr ir.CallInstruction
    21  	Args  []*Argument
    22  
    23  	Parent *ir.Function
    24  
    25  	invalids []string
    26  }
    27  
    28  func (c *Call) Invalid(msg string) {
    29  	c.invalids = append(c.invalids, msg)
    30  }
    31  
    32  type Argument struct {
    33  	Value    Value
    34  	invalids []string
    35  }
    36  
    37  type Value struct {
    38  	Value ir.Value
    39  }
    40  
    41  func (arg *Argument) Invalid(msg string) {
    42  	arg.invalids = append(arg.invalids, msg)
    43  }
    44  
    45  type Check func(call *Call)
    46  
    47  func Analyzer(rules map[string]Check) func(pass *analysis.Pass) (interface{}, error) {
    48  	return func(pass *analysis.Pass) (interface{}, error) {
    49  		return checkCalls(pass, rules)
    50  	}
    51  }
    52  
    53  func checkCalls(pass *analysis.Pass, rules map[string]Check) (interface{}, error) {
    54  	cb := func(caller *ir.Function, site ir.CallInstruction, callee *ir.Function) {
    55  		obj, ok := callee.Object().(*types.Func)
    56  		if !ok {
    57  			return
    58  		}
    59  
    60  		r, ok := rules[typeutil.FuncName(obj)]
    61  		if !ok {
    62  			return
    63  		}
    64  		var args []*Argument
    65  		irargs := site.Common().Args
    66  		if callee.Signature.Recv() != nil {
    67  			irargs = irargs[1:]
    68  		}
    69  		for _, arg := range irargs {
    70  			if iarg, ok := arg.(*ir.MakeInterface); ok {
    71  				arg = iarg.X
    72  			}
    73  			args = append(args, &Argument{Value: Value{arg}})
    74  		}
    75  		call := &Call{
    76  			Pass:   pass,
    77  			Instr:  site,
    78  			Args:   args,
    79  			Parent: site.Parent(),
    80  		}
    81  		r(call)
    82  
    83  		var astcall *ast.CallExpr
    84  		switch source := site.Source().(type) {
    85  		case *ast.CallExpr:
    86  			astcall = source
    87  		case *ast.DeferStmt:
    88  			astcall = source.Call
    89  		case *ast.GoStmt:
    90  			astcall = source.Call
    91  		case nil:
    92  			// TODO(dh): I am not sure this can actually happen. If it
    93  			// can't, we should remove this case, and also stop
    94  			// checking for astcall == nil in the code that follows.
    95  		default:
    96  			panic(fmt.Sprintf("unhandled case %T", source))
    97  		}
    98  
    99  		for idx, arg := range call.Args {
   100  			for _, e := range arg.invalids {
   101  				if astcall != nil {
   102  					if idx < len(astcall.Args) {
   103  						report.Report(pass, astcall.Args[idx], e)
   104  					} else {
   105  						// this is an instance of fn1(fn2()) where fn2
   106  						// returns multiple values. Report the error
   107  						// at the next-best position that we have, the
   108  						// first argument. An example of a check that
   109  						// triggers this is checkEncodingBinaryRules.
   110  						report.Report(pass, astcall.Args[0], e)
   111  					}
   112  				} else {
   113  					report.Report(pass, site, e)
   114  				}
   115  			}
   116  		}
   117  		for _, e := range call.invalids {
   118  			report.Report(pass, call.Instr, e)
   119  		}
   120  	}
   121  	for _, fn := range pass.ResultOf[buildir.Analyzer].(*buildir.IR).SrcFuncs {
   122  		eachCall(fn, cb)
   123  	}
   124  	return nil, nil
   125  }
   126  
   127  func eachCall(fn *ir.Function, cb func(caller *ir.Function, site ir.CallInstruction, callee *ir.Function)) {
   128  	for _, b := range fn.Blocks {
   129  		for _, instr := range b.Instrs {
   130  			if site, ok := instr.(ir.CallInstruction); ok {
   131  				if g := site.Common().StaticCallee(); g != nil {
   132  					cb(fn, site, g)
   133  				}
   134  			}
   135  		}
   136  	}
   137  }
   138  
   139  func ExtractConstExpectKind(v Value, kind constant.Kind) *ir.Const {
   140  	k := extractConst(v.Value)
   141  	if k == nil || k.Value == nil || k.Value.Kind() != kind {
   142  		return nil
   143  	}
   144  	return k
   145  }
   146  
   147  func ExtractConst(v Value) *ir.Const {
   148  	return extractConst(v.Value)
   149  }
   150  
   151  func extractConst(v ir.Value) *ir.Const {
   152  	v = irutil.Flatten(v)
   153  	switch v := v.(type) {
   154  	case *ir.Const:
   155  		return v
   156  	case *ir.MakeInterface:
   157  		return extractConst(v.X)
   158  	default:
   159  		return nil
   160  	}
   161  }