github.com/serversong/goreporter@v0.0.0-20200325104552-3cfaf44fd178/linters/interfacer/check.go (about)

     1  // Copyright (c) 2015, Daniel Martí <mvdan@mvdan.cc>
     2  // See LICENSE for licensing information
     3  
     4  package interfacer
     5  
     6  import (
     7  	"fmt"
     8  	"go/ast"
     9  	"go/token"
    10  	"go/types"
    11  	"os"
    12  	"strings"
    13  
    14  	"golang.org/x/tools/go/loader"
    15  	"golang.org/x/tools/go/ssa"
    16  	"golang.org/x/tools/go/ssa/ssautil"
    17  
    18  	"github.com/kisielk/gotool"
    19  	"mvdan.cc/lint"
    20  )
    21  
    22  func toDiscard(usage *varUsage) bool {
    23  	if usage.discard {
    24  		return true
    25  	}
    26  	for to := range usage.assigned {
    27  		if toDiscard(to) {
    28  			return true
    29  		}
    30  	}
    31  	return false
    32  }
    33  
    34  func allCalls(usage *varUsage, all, ftypes map[string]string) {
    35  	for fname := range usage.calls {
    36  		all[fname] = ftypes[fname]
    37  	}
    38  	for to := range usage.assigned {
    39  		allCalls(to, all, ftypes)
    40  	}
    41  }
    42  
    43  func (c *Checker) interfaceMatching(param *types.Var, usage *varUsage) (string, string) {
    44  	if toDiscard(usage) {
    45  		return "", ""
    46  	}
    47  	ftypes := typeFuncMap(param.Type())
    48  	called := make(map[string]string, len(usage.calls))
    49  	allCalls(usage, called, ftypes)
    50  	s := funcMapString(called)
    51  	return c.ifaces[s], s
    52  }
    53  
    54  type varUsage struct {
    55  	calls   map[string]struct{}
    56  	discard bool
    57  
    58  	assigned map[*varUsage]struct{}
    59  }
    60  
    61  type funcDecl struct {
    62  	astDecl *ast.FuncDecl
    63  	ssaFn   *ssa.Function
    64  }
    65  
    66  // CheckArgs checks the packages specified by their import paths in
    67  // args.
    68  func CheckArgs(args []string) ([]string, error) {
    69  	paths := gotool.ImportPaths(args)
    70  	conf := loader.Config{}
    71  	conf.AllowErrors = true
    72  	rest, err := conf.FromArgs(paths, false)
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  	if len(rest) > 0 {
    77  		return nil, fmt.Errorf("unwanted extra args: %v", rest)
    78  	}
    79  	lprog, err := conf.Load()
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  
    84  	prog := ssautil.CreateProgram(lprog, 0)
    85  	prog.Build()
    86  	c := new(Checker)
    87  	c.Program(lprog)
    88  	c.ProgramSSA(prog)
    89  	issues, err := c.Check()
    90  	if err != nil {
    91  		return nil, err
    92  	}
    93  	wd, err := os.Getwd()
    94  	if err != nil {
    95  		return nil, err
    96  	}
    97  
    98  	lines := make([]string, len(issues))
    99  	for i, issue := range issues {
   100  		fpos := prog.Fset.Position(issue.Pos()).String()
   101  		if strings.HasPrefix(fpos, wd) {
   102  			fpos = fpos[len(wd)+1:]
   103  		}
   104  		lines[i] = fmt.Sprintf("%s: %s", fpos, issue.Message())
   105  	}
   106  	return lines, nil
   107  }
   108  
   109  type Checker struct {
   110  	lprog *loader.Program
   111  	prog  *ssa.Program
   112  
   113  	pkgTypes
   114  	*loader.PackageInfo
   115  
   116  	fset  *token.FileSet
   117  	funcs []*funcDecl
   118  
   119  	ssaByPos map[token.Pos]*ssa.Function
   120  
   121  	discardFuncs map[*types.Signature]struct{}
   122  
   123  	vars map[*types.Var]*varUsage
   124  }
   125  
   126  var (
   127  	_ lint.Checker = (*Checker)(nil)
   128  	_ lint.WithSSA = (*Checker)(nil)
   129  )
   130  
   131  func (c *Checker) Program(lprog *loader.Program) {
   132  	c.lprog = lprog
   133  }
   134  
   135  func (c *Checker) ProgramSSA(prog *ssa.Program) {
   136  	c.prog = prog
   137  }
   138  
   139  func (c *Checker) Check() ([]lint.Issue, error) {
   140  	var total []lint.Issue
   141  	c.ssaByPos = make(map[token.Pos]*ssa.Function)
   142  	wantPkg := make(map[*types.Package]bool)
   143  	for _, pinfo := range c.lprog.InitialPackages() {
   144  		wantPkg[pinfo.Pkg] = true
   145  	}
   146  	for fn := range ssautil.AllFunctions(c.prog) {
   147  		if fn.Pkg == nil { // builtin?
   148  			continue
   149  		}
   150  		if len(fn.Blocks) == 0 { // stub
   151  			continue
   152  		}
   153  		if !wantPkg[fn.Pkg.Pkg] { // not part of given pkgs
   154  			continue
   155  		}
   156  		c.ssaByPos[fn.Pos()] = fn
   157  	}
   158  	for _, pinfo := range c.lprog.InitialPackages() {
   159  		pkg := pinfo.Pkg
   160  		c.getTypes(pkg)
   161  		c.PackageInfo = c.lprog.AllPackages[pkg]
   162  		total = append(total, c.checkPkg()...)
   163  	}
   164  	return total, nil
   165  }
   166  
   167  func (c *Checker) checkPkg() []lint.Issue {
   168  	c.discardFuncs = make(map[*types.Signature]struct{})
   169  	c.vars = make(map[*types.Var]*varUsage)
   170  	c.funcs = c.funcs[:0]
   171  	findFuncs := func(node ast.Node) bool {
   172  		decl, ok := node.(*ast.FuncDecl)
   173  		if !ok {
   174  			return true
   175  		}
   176  		ssaFn := c.ssaByPos[decl.Name.Pos()]
   177  		if ssaFn == nil {
   178  			return true
   179  		}
   180  		fd := &funcDecl{
   181  			astDecl: decl,
   182  			ssaFn:   ssaFn,
   183  		}
   184  		if c.funcSigns[signString(fd.ssaFn.Signature)] {
   185  			// implements interface
   186  			return true
   187  		}
   188  		c.funcs = append(c.funcs, fd)
   189  		ast.Walk(c, decl.Body)
   190  		return true
   191  	}
   192  	for _, f := range c.Files {
   193  		ast.Inspect(f, findFuncs)
   194  	}
   195  	return c.packageIssues()
   196  }
   197  
   198  func paramVarAndType(sign *types.Signature, i int) (*types.Var, types.Type) {
   199  	params := sign.Params()
   200  	extra := sign.Variadic() && i >= params.Len()-1
   201  	if !extra {
   202  		if i >= params.Len() {
   203  			// builtins with multiple signatures
   204  			return nil, nil
   205  		}
   206  		vr := params.At(i)
   207  		return vr, vr.Type()
   208  	}
   209  	last := params.At(params.Len() - 1)
   210  	switch x := last.Type().(type) {
   211  	case *types.Slice:
   212  		return nil, x.Elem()
   213  	default:
   214  		return nil, x
   215  	}
   216  }
   217  
   218  func (c *Checker) varUsage(e ast.Expr) *varUsage {
   219  	id, ok := e.(*ast.Ident)
   220  	if !ok {
   221  		return nil
   222  	}
   223  	param, ok := c.ObjectOf(id).(*types.Var)
   224  	if !ok {
   225  		// not a variable
   226  		return nil
   227  	}
   228  	if usage, e := c.vars[param]; e {
   229  		return usage
   230  	}
   231  	if !interesting(param.Type()) {
   232  		return nil
   233  	}
   234  	usage := &varUsage{
   235  		calls:    make(map[string]struct{}),
   236  		assigned: make(map[*varUsage]struct{}),
   237  	}
   238  	c.vars[param] = usage
   239  	return usage
   240  }
   241  
   242  func (c *Checker) addUsed(e ast.Expr, as types.Type) {
   243  	if as == nil {
   244  		return
   245  	}
   246  	if usage := c.varUsage(e); usage != nil {
   247  		// using variable
   248  		iface, ok := as.Underlying().(*types.Interface)
   249  		if !ok {
   250  			usage.discard = true
   251  			return
   252  		}
   253  		for i := 0; i < iface.NumMethods(); i++ {
   254  			m := iface.Method(i)
   255  			usage.calls[m.Name()] = struct{}{}
   256  		}
   257  	} else if t, ok := c.TypeOf(e).(*types.Signature); ok {
   258  		// using func
   259  		c.discardFuncs[t] = struct{}{}
   260  	}
   261  }
   262  
   263  func (c *Checker) addAssign(to, from ast.Expr) {
   264  	pto := c.varUsage(to)
   265  	pfrom := c.varUsage(from)
   266  	if pto == nil || pfrom == nil {
   267  		// either isn't interesting
   268  		return
   269  	}
   270  	pfrom.assigned[pto] = struct{}{}
   271  }
   272  
   273  func (c *Checker) discard(e ast.Expr) {
   274  	if usage := c.varUsage(e); usage != nil {
   275  		usage.discard = true
   276  	}
   277  }
   278  
   279  func (c *Checker) comparedWith(e, with ast.Expr) {
   280  	if _, ok := with.(*ast.BasicLit); ok {
   281  		c.discard(e)
   282  	}
   283  }
   284  
   285  func (c *Checker) Visit(node ast.Node) ast.Visitor {
   286  	switch x := node.(type) {
   287  	case *ast.SelectorExpr:
   288  		if _, ok := c.TypeOf(x.Sel).(*types.Signature); !ok {
   289  			c.discard(x.X)
   290  		}
   291  	case *ast.StarExpr:
   292  		c.discard(x.X)
   293  	case *ast.UnaryExpr:
   294  		c.discard(x.X)
   295  	case *ast.IndexExpr:
   296  		c.discard(x.X)
   297  	case *ast.IncDecStmt:
   298  		c.discard(x.X)
   299  	case *ast.BinaryExpr:
   300  		switch x.Op {
   301  		case token.EQL, token.NEQ:
   302  			c.comparedWith(x.X, x.Y)
   303  			c.comparedWith(x.Y, x.X)
   304  		default:
   305  			c.discard(x.X)
   306  			c.discard(x.Y)
   307  		}
   308  	case *ast.ValueSpec:
   309  		for _, val := range x.Values {
   310  			c.addUsed(val, c.TypeOf(x.Type))
   311  		}
   312  	case *ast.AssignStmt:
   313  		for i, val := range x.Rhs {
   314  			left := x.Lhs[i]
   315  			if x.Tok == token.ASSIGN {
   316  				c.addUsed(val, c.TypeOf(left))
   317  			}
   318  			c.addAssign(left, val)
   319  		}
   320  	case *ast.CompositeLit:
   321  		for i, e := range x.Elts {
   322  			switch y := e.(type) {
   323  			case *ast.KeyValueExpr:
   324  				c.addUsed(y.Key, c.TypeOf(y.Value))
   325  				c.addUsed(y.Value, c.TypeOf(y.Key))
   326  			case *ast.Ident:
   327  				c.addUsed(y, compositeIdentType(c.TypeOf(x), i))
   328  			}
   329  		}
   330  	case *ast.CallExpr:
   331  		switch y := c.TypeOf(x.Fun).Underlying().(type) {
   332  		case *types.Signature:
   333  			c.onMethodCall(x, y)
   334  		default:
   335  			// type conversion
   336  			if len(x.Args) == 1 {
   337  				c.addUsed(x.Args[0], y)
   338  			}
   339  		}
   340  	}
   341  	return c
   342  }
   343  
   344  func compositeIdentType(t types.Type, i int) types.Type {
   345  	switch x := t.(type) {
   346  	case *types.Named:
   347  		return compositeIdentType(x.Underlying(), i)
   348  	case *types.Struct:
   349  		return x.Field(i).Type()
   350  	case *types.Array:
   351  		return x.Elem()
   352  	case *types.Slice:
   353  		return x.Elem()
   354  	}
   355  	return nil
   356  }
   357  
   358  func (c *Checker) onMethodCall(ce *ast.CallExpr, sign *types.Signature) {
   359  	for i, e := range ce.Args {
   360  		paramObj, t := paramVarAndType(sign, i)
   361  		// Don't if this is a parameter being re-used as itself
   362  		// in a recursive call
   363  		if id, ok := e.(*ast.Ident); ok {
   364  			if paramObj == c.ObjectOf(id) {
   365  				continue
   366  			}
   367  		}
   368  		c.addUsed(e, t)
   369  	}
   370  	sel, ok := ce.Fun.(*ast.SelectorExpr)
   371  	if !ok {
   372  		return
   373  	}
   374  	// receiver func call on the left side
   375  	if usage := c.varUsage(sel.X); usage != nil {
   376  		usage.calls[sel.Sel.Name] = struct{}{}
   377  	}
   378  }
   379  
   380  func (fd *funcDecl) paramGroups() [][]*types.Var {
   381  	astList := fd.astDecl.Type.Params.List
   382  	groups := make([][]*types.Var, len(astList))
   383  	signIndex := 0
   384  	for i, field := range astList {
   385  		group := make([]*types.Var, len(field.Names))
   386  		for j := range field.Names {
   387  			group[j] = fd.ssaFn.Signature.Params().At(signIndex)
   388  			signIndex++
   389  		}
   390  		groups[i] = group
   391  	}
   392  	return groups
   393  }
   394  
   395  func (c *Checker) packageIssues() []lint.Issue {
   396  	var issues []lint.Issue
   397  	for _, fd := range c.funcs {
   398  		if _, e := c.discardFuncs[fd.ssaFn.Signature]; e {
   399  			continue
   400  		}
   401  		for _, group := range fd.paramGroups() {
   402  			issues = append(issues, c.groupIssues(fd, group)...)
   403  		}
   404  	}
   405  	return issues
   406  }
   407  
   408  type Issue struct {
   409  	pos token.Pos
   410  	msg string
   411  }
   412  
   413  func (i Issue) Pos() token.Pos  { return i.pos }
   414  func (i Issue) Message() string { return i.msg }
   415  
   416  func (c *Checker) groupIssues(fd *funcDecl, group []*types.Var) []lint.Issue {
   417  	var issues []lint.Issue
   418  	for _, param := range group {
   419  		usage := c.vars[param]
   420  		if usage == nil {
   421  			return nil
   422  		}
   423  		newType := c.paramNewType(fd.astDecl.Name.Name, param, usage)
   424  		if newType == "" {
   425  			return nil
   426  		}
   427  		issues = append(issues, Issue{
   428  			pos: param.Pos(),
   429  			msg: fmt.Sprintf("%s can be %s", param.Name(), newType),
   430  		})
   431  	}
   432  	return issues
   433  }
   434  
   435  func willAddAllocation(t types.Type) bool {
   436  	switch t.Underlying().(type) {
   437  	case *types.Pointer, *types.Interface:
   438  		return false
   439  	}
   440  	return true
   441  }
   442  
   443  func (c *Checker) paramNewType(funcName string, param *types.Var, usage *varUsage) string {
   444  	t := param.Type()
   445  	if !ast.IsExported(funcName) && willAddAllocation(t) {
   446  		return ""
   447  	}
   448  	if named := typeNamed(t); named != nil {
   449  		tname := named.Obj().Name()
   450  		vname := param.Name()
   451  		if mentionsName(funcName, tname) || mentionsName(funcName, vname) {
   452  			return ""
   453  		}
   454  	}
   455  	ifname, iftype := c.interfaceMatching(param, usage)
   456  	if ifname == "" {
   457  		return ""
   458  	}
   459  	if types.IsInterface(t.Underlying()) {
   460  		if have := funcMapString(typeFuncMap(t)); have == iftype {
   461  			return ""
   462  		}
   463  	}
   464  	return ifname
   465  }