github.com/jhump/golang-x-tools@v0.0.0-20220218190644-4958d6d39439/cmd/guru/whicherrs.go (about)

     1  // Copyright 2014 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  	"go/token"
    11  	"go/types"
    12  	"sort"
    13  
    14  	"github.com/jhump/golang-x-tools/cmd/guru/serial"
    15  	"github.com/jhump/golang-x-tools/go/ast/astutil"
    16  	"github.com/jhump/golang-x-tools/go/loader"
    17  	"github.com/jhump/golang-x-tools/go/pointer"
    18  	"github.com/jhump/golang-x-tools/go/ssa"
    19  	"github.com/jhump/golang-x-tools/go/ssa/ssautil"
    20  )
    21  
    22  var builtinErrorType = types.Universe.Lookup("error").Type()
    23  
    24  // whicherrs takes an position to an error and tries to find all types, constants
    25  // and global value which a given error can point to and which can be checked from the
    26  // scope where the error lives.
    27  // In short, it returns a list of things that can be checked against in order to handle
    28  // an error properly.
    29  //
    30  // TODO(dmorsing): figure out if fields in errors like *os.PathError.Err
    31  // can be queried recursively somehow.
    32  func whicherrs(q *Query) error {
    33  	lconf := loader.Config{Build: q.Build}
    34  
    35  	if err := setPTAScope(&lconf, q.Scope); err != nil {
    36  		return err
    37  	}
    38  
    39  	// Load/parse/type-check the program.
    40  	lprog, err := loadWithSoftErrors(&lconf)
    41  	if err != nil {
    42  		return err
    43  	}
    44  
    45  	qpos, err := parseQueryPos(lprog, q.Pos, true) // needs exact pos
    46  	if err != nil {
    47  		return err
    48  	}
    49  
    50  	prog := ssautil.CreateProgram(lprog, ssa.GlobalDebug)
    51  
    52  	ptaConfig, err := setupPTA(prog, lprog, q.PTALog, q.Reflection)
    53  	if err != nil {
    54  		return err
    55  	}
    56  
    57  	path, action := findInterestingNode(qpos.info, qpos.path)
    58  	if action != actionExpr {
    59  		return fmt.Errorf("whicherrs wants an expression; got %s",
    60  			astutil.NodeDescription(qpos.path[0]))
    61  	}
    62  	var expr ast.Expr
    63  	var obj types.Object
    64  	switch n := path[0].(type) {
    65  	case *ast.ValueSpec:
    66  		// ambiguous ValueSpec containing multiple names
    67  		return fmt.Errorf("multiple value specification")
    68  	case *ast.Ident:
    69  		obj = qpos.info.ObjectOf(n)
    70  		expr = n
    71  	case ast.Expr:
    72  		expr = n
    73  	default:
    74  		return fmt.Errorf("unexpected AST for expr: %T", n)
    75  	}
    76  
    77  	typ := qpos.info.TypeOf(expr)
    78  	if !types.Identical(typ, builtinErrorType) {
    79  		return fmt.Errorf("selection is not an expression of type 'error'")
    80  	}
    81  	// Determine the ssa.Value for the expression.
    82  	var value ssa.Value
    83  	if obj != nil {
    84  		// def/ref of func/var object
    85  		value, _, err = ssaValueForIdent(prog, qpos.info, obj, path)
    86  	} else {
    87  		value, _, err = ssaValueForExpr(prog, qpos.info, path)
    88  	}
    89  	if err != nil {
    90  		return err // e.g. trivially dead code
    91  	}
    92  
    93  	// Defer SSA construction till after errors are reported.
    94  	prog.Build()
    95  
    96  	globals := findVisibleErrs(prog, qpos)
    97  	constants := findVisibleConsts(prog, qpos)
    98  
    99  	res := &whicherrsResult{
   100  		qpos:   qpos,
   101  		errpos: expr.Pos(),
   102  	}
   103  
   104  	// TODO(adonovan): the following code is heavily duplicated
   105  	// w.r.t.  "pointsto".  Refactor?
   106  
   107  	// Find the instruction which initialized the
   108  	// global error. If more than one instruction has stored to the global
   109  	// remove the global from the set of values that we want to query.
   110  	allFuncs := ssautil.AllFunctions(prog)
   111  	for fn := range allFuncs {
   112  		for _, b := range fn.Blocks {
   113  			for _, instr := range b.Instrs {
   114  				store, ok := instr.(*ssa.Store)
   115  				if !ok {
   116  					continue
   117  				}
   118  				gval, ok := store.Addr.(*ssa.Global)
   119  				if !ok {
   120  					continue
   121  				}
   122  				gbl, ok := globals[gval]
   123  				if !ok {
   124  					continue
   125  				}
   126  				// we already found a store to this global
   127  				// The normal error define is just one store in the init
   128  				// so we just remove this global from the set we want to query
   129  				if gbl != nil {
   130  					delete(globals, gval)
   131  				}
   132  				globals[gval] = store.Val
   133  			}
   134  		}
   135  	}
   136  
   137  	ptaConfig.AddQuery(value)
   138  	for _, v := range globals {
   139  		ptaConfig.AddQuery(v)
   140  	}
   141  
   142  	ptares := ptrAnalysis(ptaConfig)
   143  	valueptr := ptares.Queries[value]
   144  	if valueptr == (pointer.Pointer{}) {
   145  		return fmt.Errorf("pointer analysis did not find expression (dead code?)")
   146  	}
   147  	for g, v := range globals {
   148  		ptr, ok := ptares.Queries[v]
   149  		if !ok {
   150  			continue
   151  		}
   152  		if !ptr.MayAlias(valueptr) {
   153  			continue
   154  		}
   155  		res.globals = append(res.globals, g)
   156  	}
   157  	pts := valueptr.PointsTo()
   158  	dedup := make(map[*ssa.NamedConst]bool)
   159  	for _, label := range pts.Labels() {
   160  		// These values are either MakeInterfaces or reflect
   161  		// generated interfaces. For the purposes of this
   162  		// analysis, we don't care about reflect generated ones
   163  		makeiface, ok := label.Value().(*ssa.MakeInterface)
   164  		if !ok {
   165  			continue
   166  		}
   167  		constval, ok := makeiface.X.(*ssa.Const)
   168  		if !ok {
   169  			continue
   170  		}
   171  		c := constants[*constval]
   172  		if c != nil && !dedup[c] {
   173  			dedup[c] = true
   174  			res.consts = append(res.consts, c)
   175  		}
   176  	}
   177  	concs := pts.DynamicTypes()
   178  	concs.Iterate(func(conc types.Type, _ interface{}) {
   179  		// go/types is a bit annoying here.
   180  		// We want to find all the types that we can
   181  		// typeswitch or assert to. This means finding out
   182  		// if the type pointed to can be seen by us.
   183  		//
   184  		// For the purposes of this analysis, we care only about
   185  		// TypeNames of Named or pointer-to-Named types.
   186  		// We ignore other types (e.g. structs) that implement error.
   187  		var name *types.TypeName
   188  		switch t := conc.(type) {
   189  		case *types.Pointer:
   190  			named, ok := t.Elem().(*types.Named)
   191  			if !ok {
   192  				return
   193  			}
   194  			name = named.Obj()
   195  		case *types.Named:
   196  			name = t.Obj()
   197  		default:
   198  			return
   199  		}
   200  		if !isAccessibleFrom(name, qpos.info.Pkg) {
   201  			return
   202  		}
   203  		res.types = append(res.types, &errorType{conc, name})
   204  	})
   205  	sort.Sort(membersByPosAndString(res.globals))
   206  	sort.Sort(membersByPosAndString(res.consts))
   207  	sort.Sort(sorterrorType(res.types))
   208  
   209  	q.Output(lprog.Fset, res)
   210  	return nil
   211  }
   212  
   213  // findVisibleErrs returns a mapping from each package-level variable of type "error" to nil.
   214  func findVisibleErrs(prog *ssa.Program, qpos *queryPos) map[*ssa.Global]ssa.Value {
   215  	globals := make(map[*ssa.Global]ssa.Value)
   216  	for _, pkg := range prog.AllPackages() {
   217  		for _, mem := range pkg.Members {
   218  			gbl, ok := mem.(*ssa.Global)
   219  			if !ok {
   220  				continue
   221  			}
   222  			gbltype := gbl.Type()
   223  			// globals are always pointers
   224  			if !types.Identical(deref(gbltype), builtinErrorType) {
   225  				continue
   226  			}
   227  			if !isAccessibleFrom(gbl.Object(), qpos.info.Pkg) {
   228  				continue
   229  			}
   230  			globals[gbl] = nil
   231  		}
   232  	}
   233  	return globals
   234  }
   235  
   236  // findVisibleConsts returns a mapping from each package-level constant assignable to type "error", to nil.
   237  func findVisibleConsts(prog *ssa.Program, qpos *queryPos) map[ssa.Const]*ssa.NamedConst {
   238  	constants := make(map[ssa.Const]*ssa.NamedConst)
   239  	for _, pkg := range prog.AllPackages() {
   240  		for _, mem := range pkg.Members {
   241  			obj, ok := mem.(*ssa.NamedConst)
   242  			if !ok {
   243  				continue
   244  			}
   245  			consttype := obj.Type()
   246  			if !types.AssignableTo(consttype, builtinErrorType) {
   247  				continue
   248  			}
   249  			if !isAccessibleFrom(obj.Object(), qpos.info.Pkg) {
   250  				continue
   251  			}
   252  			constants[*obj.Value] = obj
   253  		}
   254  	}
   255  
   256  	return constants
   257  }
   258  
   259  type membersByPosAndString []ssa.Member
   260  
   261  func (a membersByPosAndString) Len() int { return len(a) }
   262  func (a membersByPosAndString) Less(i, j int) bool {
   263  	cmp := a[i].Pos() - a[j].Pos()
   264  	return cmp < 0 || cmp == 0 && a[i].String() < a[j].String()
   265  }
   266  func (a membersByPosAndString) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
   267  
   268  type sorterrorType []*errorType
   269  
   270  func (a sorterrorType) Len() int { return len(a) }
   271  func (a sorterrorType) Less(i, j int) bool {
   272  	cmp := a[i].obj.Pos() - a[j].obj.Pos()
   273  	return cmp < 0 || cmp == 0 && a[i].typ.String() < a[j].typ.String()
   274  }
   275  func (a sorterrorType) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
   276  
   277  type errorType struct {
   278  	typ types.Type      // concrete type N or *N that implements error
   279  	obj *types.TypeName // the named type N
   280  }
   281  
   282  type whicherrsResult struct {
   283  	qpos    *queryPos
   284  	errpos  token.Pos
   285  	globals []ssa.Member
   286  	consts  []ssa.Member
   287  	types   []*errorType
   288  }
   289  
   290  func (r *whicherrsResult) PrintPlain(printf printfFunc) {
   291  	if len(r.globals) > 0 {
   292  		printf(r.qpos, "this error may point to these globals:")
   293  		for _, g := range r.globals {
   294  			printf(g.Pos(), "\t%s", g.RelString(r.qpos.info.Pkg))
   295  		}
   296  	}
   297  	if len(r.consts) > 0 {
   298  		printf(r.qpos, "this error may contain these constants:")
   299  		for _, c := range r.consts {
   300  			printf(c.Pos(), "\t%s", c.RelString(r.qpos.info.Pkg))
   301  		}
   302  	}
   303  	if len(r.types) > 0 {
   304  		printf(r.qpos, "this error may contain these dynamic types:")
   305  		for _, t := range r.types {
   306  			printf(t.obj.Pos(), "\t%s", r.qpos.typeString(t.typ))
   307  		}
   308  	}
   309  }
   310  
   311  func (r *whicherrsResult) JSON(fset *token.FileSet) []byte {
   312  	we := &serial.WhichErrs{}
   313  	we.ErrPos = fset.Position(r.errpos).String()
   314  	for _, g := range r.globals {
   315  		we.Globals = append(we.Globals, fset.Position(g.Pos()).String())
   316  	}
   317  	for _, c := range r.consts {
   318  		we.Constants = append(we.Constants, fset.Position(c.Pos()).String())
   319  	}
   320  	for _, t := range r.types {
   321  		var et serial.WhichErrsType
   322  		et.Type = r.qpos.typeString(t.typ)
   323  		et.Position = fset.Position(t.obj.Pos()).String()
   324  		we.Types = append(we.Types, et)
   325  	}
   326  	return toJSON(we)
   327  }