github.com/mvdan/interfacer@v0.0.0-20180901003855-c20040233aed/check/check.go (about)

     1  // Copyright (c) 2015, Daniel Martí <mvdan@mvdan.cc>
     2  // See LICENSE for licensing information
     3  
     4  package check // import "mvdan.cc/interfacer/check"
     5  
     6  import (
     7  	"fmt"
     8  	"go/ast"
     9  	"go/token"
    10  	"go/types"
    11  	"os"
    12  	"strings"
    13  
    14  	"golang.org/x/tools/go/loader"
    15  	"golang.org/x/tools/go/ssa"
    16  	"golang.org/x/tools/go/ssa/ssautil"
    17  
    18  	"github.com/kisielk/gotool"
    19  	"mvdan.cc/lint"
    20  )
    21  
    22  func toDiscard(usage *varUsage) bool {
    23  	if usage.discard {
    24  		return true
    25  	}
    26  	for to := range usage.assigned {
    27  		if toDiscard(to) {
    28  			return true
    29  		}
    30  	}
    31  	return false
    32  }
    33  
    34  func allCalls(usage *varUsage, all, ftypes map[string]string) {
    35  	for fname := range usage.calls {
    36  		all[fname] = ftypes[fname]
    37  	}
    38  	for to := range usage.assigned {
    39  		allCalls(to, all, ftypes)
    40  	}
    41  }
    42  
    43  func (c *Checker) interfaceMatching(param *types.Var, usage *varUsage) (string, string) {
    44  	if toDiscard(usage) {
    45  		return "", ""
    46  	}
    47  	ftypes := typeFuncMap(param.Type())
    48  	called := make(map[string]string, len(usage.calls))
    49  	allCalls(usage, called, ftypes)
    50  	s := funcMapString(called)
    51  	return c.ifaces[s], s
    52  }
    53  
    54  type varUsage struct {
    55  	calls   map[string]struct{}
    56  	discard bool
    57  
    58  	assigned map[*varUsage]struct{}
    59  }
    60  
    61  type funcDecl struct {
    62  	astDecl *ast.FuncDecl
    63  	ssaFn   *ssa.Function
    64  }
    65  
    66  // CheckArgs checks the packages specified by their import paths in
    67  // args.
    68  func CheckArgs(args []string) ([]string, error) {
    69  	paths := gotool.ImportPaths(args)
    70  	conf := loader.Config{}
    71  	conf.AllowErrors = true
    72  	rest, err := conf.FromArgs(paths, false)
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  	if len(rest) > 0 {
    77  		return nil, fmt.Errorf("unwanted extra args: %v", rest)
    78  	}
    79  	lprog, err := conf.Load()
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  	prog := ssautil.CreateProgram(lprog, 0)
    84  	prog.Build()
    85  	c := new(Checker)
    86  	c.Program(lprog)
    87  	c.ProgramSSA(prog)
    88  	issues, err := c.Check()
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  	wd, err := os.Getwd()
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  	lines := make([]string, len(issues))
    97  	for i, issue := range issues {
    98  		fpos := prog.Fset.Position(issue.Pos()).String()
    99  		if strings.HasPrefix(fpos, wd) {
   100  			fpos = fpos[len(wd)+1:]
   101  		}
   102  		lines[i] = fmt.Sprintf("%s: %s", fpos, issue.Message())
   103  	}
   104  	return lines, nil
   105  }
   106  
   107  type Checker struct {
   108  	lprog *loader.Program
   109  	prog  *ssa.Program
   110  
   111  	pkgTypes
   112  	*loader.PackageInfo
   113  
   114  	funcs []*funcDecl
   115  
   116  	ssaByPos map[token.Pos]*ssa.Function
   117  
   118  	discardFuncs map[*types.Signature]struct{}
   119  
   120  	vars map[*types.Var]*varUsage
   121  }
   122  
   123  var (
   124  	_ lint.Checker = (*Checker)(nil)
   125  	_ lint.WithSSA = (*Checker)(nil)
   126  )
   127  
   128  func (c *Checker) Program(lprog *loader.Program) {
   129  	c.lprog = lprog
   130  }
   131  
   132  func (c *Checker) ProgramSSA(prog *ssa.Program) {
   133  	c.prog = prog
   134  }
   135  
   136  func (c *Checker) Check() ([]lint.Issue, error) {
   137  	var total []lint.Issue
   138  	c.ssaByPos = make(map[token.Pos]*ssa.Function)
   139  	wantPkg := make(map[*types.Package]bool)
   140  	for _, pinfo := range c.lprog.InitialPackages() {
   141  		wantPkg[pinfo.Pkg] = true
   142  	}
   143  	for fn := range ssautil.AllFunctions(c.prog) {
   144  		if fn.Pkg == nil { // builtin?
   145  			continue
   146  		}
   147  		if len(fn.Blocks) == 0 { // stub
   148  			continue
   149  		}
   150  		if !wantPkg[fn.Pkg.Pkg] { // not part of given pkgs
   151  			continue
   152  		}
   153  		c.ssaByPos[fn.Pos()] = fn
   154  	}
   155  	for _, pinfo := range c.lprog.InitialPackages() {
   156  		pkg := pinfo.Pkg
   157  		c.getTypes(pkg)
   158  		c.PackageInfo = c.lprog.AllPackages[pkg]
   159  		total = append(total, c.checkPkg()...)
   160  	}
   161  	return total, nil
   162  }
   163  
   164  func (c *Checker) checkPkg() []lint.Issue {
   165  	c.discardFuncs = make(map[*types.Signature]struct{})
   166  	c.vars = make(map[*types.Var]*varUsage)
   167  	c.funcs = c.funcs[:0]
   168  	findFuncs := func(node ast.Node) bool {
   169  		decl, ok := node.(*ast.FuncDecl)
   170  		if !ok {
   171  			return true
   172  		}
   173  		ssaFn := c.ssaByPos[decl.Name.Pos()]
   174  		if ssaFn == nil {
   175  			return true
   176  		}
   177  		fd := &funcDecl{
   178  			astDecl: decl,
   179  			ssaFn:   ssaFn,
   180  		}
   181  		if c.funcSigns[signString(fd.ssaFn.Signature)] {
   182  			// implements interface
   183  			return true
   184  		}
   185  		c.funcs = append(c.funcs, fd)
   186  		ast.Walk(c, decl.Body)
   187  		return true
   188  	}
   189  	for _, f := range c.Files {
   190  		ast.Inspect(f, findFuncs)
   191  	}
   192  	return c.packageIssues()
   193  }
   194  
   195  func paramVarAndType(sign *types.Signature, i int) (*types.Var, types.Type) {
   196  	params := sign.Params()
   197  	extra := sign.Variadic() && i >= params.Len()-1
   198  	if !extra {
   199  		if i >= params.Len() {
   200  			// builtins with multiple signatures
   201  			return nil, nil
   202  		}
   203  		vr := params.At(i)
   204  		return vr, vr.Type()
   205  	}
   206  	last := params.At(params.Len() - 1)
   207  	switch x := last.Type().(type) {
   208  	case *types.Slice:
   209  		return nil, x.Elem()
   210  	default:
   211  		return nil, x
   212  	}
   213  }
   214  
   215  func (c *Checker) varUsage(e ast.Expr) *varUsage {
   216  	id, ok := e.(*ast.Ident)
   217  	if !ok {
   218  		return nil
   219  	}
   220  	param, ok := c.ObjectOf(id).(*types.Var)
   221  	if !ok {
   222  		// not a variable
   223  		return nil
   224  	}
   225  	if usage, e := c.vars[param]; e {
   226  		return usage
   227  	}
   228  	if !interesting(param.Type()) {
   229  		return nil
   230  	}
   231  	usage := &varUsage{
   232  		calls:    make(map[string]struct{}),
   233  		assigned: make(map[*varUsage]struct{}),
   234  	}
   235  	c.vars[param] = usage
   236  	return usage
   237  }
   238  
   239  func (c *Checker) addUsed(e ast.Expr, as types.Type) {
   240  	if as == nil {
   241  		return
   242  	}
   243  	if usage := c.varUsage(e); usage != nil {
   244  		// using variable
   245  		iface, ok := as.Underlying().(*types.Interface)
   246  		if !ok {
   247  			usage.discard = true
   248  			return
   249  		}
   250  		for i := 0; i < iface.NumMethods(); i++ {
   251  			m := iface.Method(i)
   252  			usage.calls[m.Name()] = struct{}{}
   253  		}
   254  	} else if t, ok := c.TypeOf(e).(*types.Signature); ok {
   255  		// using func
   256  		c.discardFuncs[t] = struct{}{}
   257  	}
   258  }
   259  
   260  func (c *Checker) addAssign(to, from ast.Expr) {
   261  	pto := c.varUsage(to)
   262  	pfrom := c.varUsage(from)
   263  	if pto == nil || pfrom == nil {
   264  		// either isn't interesting
   265  		return
   266  	}
   267  	pfrom.assigned[pto] = struct{}{}
   268  }
   269  
   270  func (c *Checker) discard(e ast.Expr) {
   271  	if usage := c.varUsage(e); usage != nil {
   272  		usage.discard = true
   273  	}
   274  }
   275  
   276  func (c *Checker) comparedWith(e, with ast.Expr) {
   277  	if _, ok := with.(*ast.BasicLit); ok {
   278  		c.discard(e)
   279  	}
   280  }
   281  
   282  func (c *Checker) Visit(node ast.Node) ast.Visitor {
   283  	switch x := node.(type) {
   284  	case *ast.SelectorExpr:
   285  		if _, ok := c.TypeOf(x.Sel).(*types.Signature); !ok {
   286  			c.discard(x.X)
   287  		}
   288  	case *ast.StarExpr:
   289  		c.discard(x.X)
   290  	case *ast.UnaryExpr:
   291  		c.discard(x.X)
   292  	case *ast.IndexExpr:
   293  		c.discard(x.X)
   294  	case *ast.IncDecStmt:
   295  		c.discard(x.X)
   296  	case *ast.BinaryExpr:
   297  		switch x.Op {
   298  		case token.EQL, token.NEQ:
   299  			c.comparedWith(x.X, x.Y)
   300  			c.comparedWith(x.Y, x.X)
   301  		default:
   302  			c.discard(x.X)
   303  			c.discard(x.Y)
   304  		}
   305  	case *ast.ValueSpec:
   306  		for _, val := range x.Values {
   307  			c.addUsed(val, c.TypeOf(x.Type))
   308  		}
   309  	case *ast.AssignStmt:
   310  		for i, val := range x.Rhs {
   311  			left := x.Lhs[i]
   312  			if x.Tok == token.ASSIGN {
   313  				c.addUsed(val, c.TypeOf(left))
   314  			}
   315  			c.addAssign(left, val)
   316  		}
   317  	case *ast.CompositeLit:
   318  		for i, e := range x.Elts {
   319  			switch y := e.(type) {
   320  			case *ast.KeyValueExpr:
   321  				c.addUsed(y.Key, c.TypeOf(y.Value))
   322  				c.addUsed(y.Value, c.TypeOf(y.Key))
   323  			case *ast.Ident:
   324  				c.addUsed(y, compositeIdentType(c.TypeOf(x), i))
   325  			}
   326  		}
   327  	case *ast.CallExpr:
   328  		switch y := c.TypeOf(x.Fun).Underlying().(type) {
   329  		case *types.Signature:
   330  			c.onMethodCall(x, y)
   331  		default:
   332  			// type conversion
   333  			if len(x.Args) == 1 {
   334  				c.addUsed(x.Args[0], y)
   335  			}
   336  		}
   337  	}
   338  	return c
   339  }
   340  
   341  func compositeIdentType(t types.Type, i int) types.Type {
   342  	switch x := t.(type) {
   343  	case *types.Named:
   344  		return compositeIdentType(x.Underlying(), i)
   345  	case *types.Struct:
   346  		return x.Field(i).Type()
   347  	case *types.Array:
   348  		return x.Elem()
   349  	case *types.Slice:
   350  		return x.Elem()
   351  	}
   352  	return nil
   353  }
   354  
   355  func (c *Checker) onMethodCall(ce *ast.CallExpr, sign *types.Signature) {
   356  	for i, e := range ce.Args {
   357  		paramObj, t := paramVarAndType(sign, i)
   358  		// Don't if this is a parameter being re-used as itself
   359  		// in a recursive call
   360  		if id, ok := e.(*ast.Ident); ok {
   361  			if paramObj == c.ObjectOf(id) {
   362  				continue
   363  			}
   364  		}
   365  		c.addUsed(e, t)
   366  	}
   367  	sel, ok := ce.Fun.(*ast.SelectorExpr)
   368  	if !ok {
   369  		return
   370  	}
   371  	// receiver func call on the left side
   372  	if usage := c.varUsage(sel.X); usage != nil {
   373  		usage.calls[sel.Sel.Name] = struct{}{}
   374  	}
   375  }
   376  
   377  func (fd *funcDecl) paramGroups() [][]*types.Var {
   378  	astList := fd.astDecl.Type.Params.List
   379  	groups := make([][]*types.Var, len(astList))
   380  	signIndex := 0
   381  	for i, field := range astList {
   382  		group := make([]*types.Var, len(field.Names))
   383  		for j := range field.Names {
   384  			group[j] = fd.ssaFn.Signature.Params().At(signIndex)
   385  			signIndex++
   386  		}
   387  		groups[i] = group
   388  	}
   389  	return groups
   390  }
   391  
   392  func (c *Checker) packageIssues() []lint.Issue {
   393  	var issues []lint.Issue
   394  	for _, fd := range c.funcs {
   395  		if _, e := c.discardFuncs[fd.ssaFn.Signature]; e {
   396  			continue
   397  		}
   398  		for _, group := range fd.paramGroups() {
   399  			issues = append(issues, c.groupIssues(fd, group)...)
   400  		}
   401  	}
   402  	return issues
   403  }
   404  
   405  type Issue struct {
   406  	pos token.Pos
   407  	msg string
   408  }
   409  
   410  func (i Issue) Pos() token.Pos  { return i.pos }
   411  func (i Issue) Message() string { return i.msg }
   412  
   413  func (c *Checker) groupIssues(fd *funcDecl, group []*types.Var) []lint.Issue {
   414  	var issues []lint.Issue
   415  	for _, param := range group {
   416  		usage := c.vars[param]
   417  		if usage == nil {
   418  			return nil
   419  		}
   420  		newType := c.paramNewType(fd.astDecl.Name.Name, param, usage)
   421  		if newType == "" {
   422  			return nil
   423  		}
   424  		issues = append(issues, Issue{
   425  			pos: param.Pos(),
   426  			msg: fmt.Sprintf("%s can be %s", param.Name(), newType),
   427  		})
   428  	}
   429  	return issues
   430  }
   431  
   432  func willAddAllocation(t types.Type) bool {
   433  	switch t.Underlying().(type) {
   434  	case *types.Pointer, *types.Interface:
   435  		return false
   436  	}
   437  	return true
   438  }
   439  
   440  func (c *Checker) paramNewType(funcName string, param *types.Var, usage *varUsage) string {
   441  	t := param.Type()
   442  	if !ast.IsExported(funcName) && willAddAllocation(t) {
   443  		return ""
   444  	}
   445  	if named := typeNamed(t); named != nil {
   446  		tname := named.Obj().Name()
   447  		vname := param.Name()
   448  		if mentionsName(funcName, tname) || mentionsName(funcName, vname) {
   449  			return ""
   450  		}
   451  	}
   452  	ifname, iftype := c.interfaceMatching(param, usage)
   453  	if ifname == "" {
   454  		return ""
   455  	}
   456  	if types.IsInterface(t.Underlying()) {
   457  		if have := funcMapString(typeFuncMap(t)); have == iftype {
   458  			return ""
   459  		}
   460  	}
   461  	return ifname
   462  }