github.com/golangci/go-tools@v0.0.0-20190318060251-af6baa5dc196/callgraph/rta/rta.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // This package provides Rapid Type Analysis (RTA) for Go, a fast
     6  // algorithm for call graph construction and discovery of reachable code
     7  // (and hence dead code) and runtime types.  The algorithm was first
     8  // described in:
     9  //
    10  // David F. Bacon and Peter F. Sweeney. 1996.
    11  // Fast static analysis of C++ virtual function calls. (OOPSLA '96)
    12  // http://doi.acm.org/10.1145/236337.236371
    13  //
    14  // The algorithm uses dynamic programming to tabulate the cross-product
    15  // of the set of known "address taken" functions with the set of known
    16  // dynamic calls of the same type.  As each new address-taken function
    17  // is discovered, call graph edges are added from each known callsite,
    18  // and as each new call site is discovered, call graph edges are added
    19  // from it to each known address-taken function.
    20  //
    21  // A similar approach is used for dynamic calls via interfaces: it
    22  // tabulates the cross-product of the set of known "runtime types",
    23  // i.e. types that may appear in an interface value, or be derived from
    24  // one via reflection, with the set of known "invoke"-mode dynamic
    25  // calls.  As each new "runtime type" is discovered, call edges are
    26  // added from the known call sites, and as each new call site is
    27  // discovered, call graph edges are added to each compatible
    28  // method.
    29  //
    30  // In addition, we must consider all exported methods of any runtime type
    31  // as reachable, since they may be called via reflection.
    32  //
    33  // Each time a newly added call edge causes a new function to become
    34  // reachable, the code of that function is analyzed for more call sites,
    35  // address-taken functions, and runtime types.  The process continues
    36  // until a fixed point is achieved.
    37  //
    38  // The resulting call graph is less precise than one produced by pointer
    39  // analysis, but the algorithm is much faster.  For example, running the
    40  // cmd/callgraph tool on its own source takes ~2.1s for RTA and ~5.4s
    41  // for points-to analysis.
    42  //
    43  package rta // import "github.com/golangci/go-tools/callgraph/rta"
    44  
    45  // TODO(adonovan): test it by connecting it to the interpreter and
    46  // replacing all "unreachable" functions by a special intrinsic, and
    47  // ensure that that intrinsic is never called.
    48  
    49  import (
    50  	"fmt"
    51  	"go/types"
    52  
    53  	"github.com/golangci/go-tools/callgraph"
    54  	"github.com/golangci/go-tools/ssa"
    55  	"golang.org/x/tools/go/types/typeutil"
    56  )
    57  
    58  // A Result holds the results of Rapid Type Analysis, which includes the
    59  // set of reachable functions/methods, runtime types, and the call graph.
    60  //
    61  type Result struct {
    62  	// CallGraph is the discovered callgraph.
    63  	// It does not include edges for calls made via reflection.
    64  	CallGraph *callgraph.Graph
    65  
    66  	// Reachable contains the set of reachable functions and methods.
    67  	// This includes exported methods of runtime types, since
    68  	// they may be accessed via reflection.
    69  	// The value indicates whether the function is address-taken.
    70  	//
    71  	// (We wrap the bool in a struct to avoid inadvertent use of
    72  	// "if Reachable[f] {" to test for set membership.)
    73  	Reachable map[*ssa.Function]struct{ AddrTaken bool }
    74  
    75  	// RuntimeTypes contains the set of types that are needed at
    76  	// runtime, for interfaces or reflection.
    77  	//
    78  	// The value indicates whether the type is inaccessible to reflection.
    79  	// Consider:
    80  	// 	type A struct{B}
    81  	// 	fmt.Println(new(A))
    82  	// Types *A, A and B are accessible to reflection, but the unnamed
    83  	// type struct{B} is not.
    84  	RuntimeTypes typeutil.Map
    85  }
    86  
    87  // Working state of the RTA algorithm.
    88  type rta struct {
    89  	result *Result
    90  
    91  	prog *ssa.Program
    92  
    93  	worklist []*ssa.Function // list of functions to visit
    94  
    95  	// addrTakenFuncsBySig contains all address-taken *Functions, grouped by signature.
    96  	// Keys are *types.Signature, values are map[*ssa.Function]bool sets.
    97  	addrTakenFuncsBySig typeutil.Map
    98  
    99  	// dynCallSites contains all dynamic "call"-mode call sites, grouped by signature.
   100  	// Keys are *types.Signature, values are unordered []ssa.CallInstruction.
   101  	dynCallSites typeutil.Map
   102  
   103  	// invokeSites contains all "invoke"-mode call sites, grouped by interface.
   104  	// Keys are *types.Interface (never *types.Named),
   105  	// Values are unordered []ssa.CallInstruction sets.
   106  	invokeSites typeutil.Map
   107  
   108  	// The following two maps together define the subset of the
   109  	// m:n "implements" relation needed by the algorithm.
   110  
   111  	// concreteTypes maps each concrete type to the set of interfaces that it implements.
   112  	// Keys are types.Type, values are unordered []*types.Interface.
   113  	// Only concrete types used as MakeInterface operands are included.
   114  	concreteTypes typeutil.Map
   115  
   116  	// interfaceTypes maps each interface type to
   117  	// the set of concrete types that implement it.
   118  	// Keys are *types.Interface, values are unordered []types.Type.
   119  	// Only interfaces used in "invoke"-mode CallInstructions are included.
   120  	interfaceTypes typeutil.Map
   121  }
   122  
   123  // addReachable marks a function as potentially callable at run-time,
   124  // and ensures that it gets processed.
   125  func (r *rta) addReachable(f *ssa.Function, addrTaken bool) {
   126  	reachable := r.result.Reachable
   127  	n := len(reachable)
   128  	v := reachable[f]
   129  	if addrTaken {
   130  		v.AddrTaken = true
   131  	}
   132  	reachable[f] = v
   133  	if len(reachable) > n {
   134  		// First time seeing f.  Add it to the worklist.
   135  		r.worklist = append(r.worklist, f)
   136  	}
   137  }
   138  
   139  // addEdge adds the specified call graph edge, and marks it reachable.
   140  // addrTaken indicates whether to mark the callee as "address-taken".
   141  func (r *rta) addEdge(site ssa.CallInstruction, callee *ssa.Function, addrTaken bool) {
   142  	r.addReachable(callee, addrTaken)
   143  
   144  	if g := r.result.CallGraph; g != nil {
   145  		if site.Parent() == nil {
   146  			panic(site)
   147  		}
   148  		from := g.CreateNode(site.Parent())
   149  		to := g.CreateNode(callee)
   150  		callgraph.AddEdge(from, site, to)
   151  	}
   152  }
   153  
   154  // ---------- addrTakenFuncs × dynCallSites ----------
   155  
   156  // visitAddrTakenFunc is called each time we encounter an address-taken function f.
   157  func (r *rta) visitAddrTakenFunc(f *ssa.Function) {
   158  	// Create two-level map (Signature -> Function -> bool).
   159  	S := f.Signature
   160  	funcs, _ := r.addrTakenFuncsBySig.At(S).(map[*ssa.Function]bool)
   161  	if funcs == nil {
   162  		funcs = make(map[*ssa.Function]bool)
   163  		r.addrTakenFuncsBySig.Set(S, funcs)
   164  	}
   165  	if !funcs[f] {
   166  		// First time seeing f.
   167  		funcs[f] = true
   168  
   169  		// If we've seen any dyncalls of this type, mark it reachable,
   170  		// and add call graph edges.
   171  		sites, _ := r.dynCallSites.At(S).([]ssa.CallInstruction)
   172  		for _, site := range sites {
   173  			r.addEdge(site, f, true)
   174  		}
   175  	}
   176  }
   177  
   178  // visitDynCall is called each time we encounter a dynamic "call"-mode call.
   179  func (r *rta) visitDynCall(site ssa.CallInstruction) {
   180  	S := site.Common().Signature()
   181  
   182  	// Record the call site.
   183  	sites, _ := r.dynCallSites.At(S).([]ssa.CallInstruction)
   184  	r.dynCallSites.Set(S, append(sites, site))
   185  
   186  	// For each function of signature S that we know is address-taken,
   187  	// mark it reachable.  We'll add the callgraph edges later.
   188  	funcs, _ := r.addrTakenFuncsBySig.At(S).(map[*ssa.Function]bool)
   189  	for g := range funcs {
   190  		r.addEdge(site, g, true)
   191  	}
   192  }
   193  
   194  // ---------- concrete types × invoke sites ----------
   195  
   196  // addInvokeEdge is called for each new pair (site, C) in the matrix.
   197  func (r *rta) addInvokeEdge(site ssa.CallInstruction, C types.Type) {
   198  	// Ascertain the concrete method of C to be called.
   199  	imethod := site.Common().Method
   200  	cmethod := r.prog.MethodValue(r.prog.MethodSets.MethodSet(C).Lookup(imethod.Pkg(), imethod.Name()))
   201  	r.addEdge(site, cmethod, true)
   202  }
   203  
   204  // visitInvoke is called each time the algorithm encounters an "invoke"-mode call.
   205  func (r *rta) visitInvoke(site ssa.CallInstruction) {
   206  	I := site.Common().Value.Type().Underlying().(*types.Interface)
   207  
   208  	// Record the invoke site.
   209  	sites, _ := r.invokeSites.At(I).([]ssa.CallInstruction)
   210  	r.invokeSites.Set(I, append(sites, site))
   211  
   212  	// Add callgraph edge for each existing
   213  	// address-taken concrete type implementing I.
   214  	for _, C := range r.implementations(I) {
   215  		r.addInvokeEdge(site, C)
   216  	}
   217  }
   218  
   219  // ---------- main algorithm ----------
   220  
   221  // visitFunc processes function f.
   222  func (r *rta) visitFunc(f *ssa.Function) {
   223  	var space [32]*ssa.Value // preallocate space for common case
   224  
   225  	for _, b := range f.Blocks {
   226  		for _, instr := range b.Instrs {
   227  			rands := instr.Operands(space[:0])
   228  
   229  			switch instr := instr.(type) {
   230  			case ssa.CallInstruction:
   231  				call := instr.Common()
   232  				if call.IsInvoke() {
   233  					r.visitInvoke(instr)
   234  				} else if g := call.StaticCallee(); g != nil {
   235  					r.addEdge(instr, g, false)
   236  				} else if _, ok := call.Value.(*ssa.Builtin); !ok {
   237  					r.visitDynCall(instr)
   238  				}
   239  
   240  				// Ignore the call-position operand when
   241  				// looking for address-taken Functions.
   242  				// Hack: assume this is rands[0].
   243  				rands = rands[1:]
   244  
   245  			case *ssa.MakeInterface:
   246  				r.addRuntimeType(instr.X.Type(), false)
   247  			}
   248  
   249  			// Process all address-taken functions.
   250  			for _, op := range rands {
   251  				if g, ok := (*op).(*ssa.Function); ok {
   252  					r.visitAddrTakenFunc(g)
   253  				}
   254  			}
   255  		}
   256  	}
   257  }
   258  
   259  // Analyze performs Rapid Type Analysis, starting at the specified root
   260  // functions.  It returns nil if no roots were specified.
   261  //
   262  // If buildCallGraph is true, Result.CallGraph will contain a call
   263  // graph; otherwise, only the other fields (reachable functions) are
   264  // populated.
   265  //
   266  func Analyze(roots []*ssa.Function, buildCallGraph bool) *Result {
   267  	if len(roots) == 0 {
   268  		return nil
   269  	}
   270  
   271  	r := &rta{
   272  		result: &Result{Reachable: make(map[*ssa.Function]struct{ AddrTaken bool })},
   273  		prog:   roots[0].Prog,
   274  	}
   275  
   276  	if buildCallGraph {
   277  		// TODO(adonovan): change callgraph API to eliminate the
   278  		// notion of a distinguished root node.  Some callgraphs
   279  		// have many roots, or none.
   280  		r.result.CallGraph = callgraph.New(roots[0])
   281  	}
   282  
   283  	hasher := typeutil.MakeHasher()
   284  	r.result.RuntimeTypes.SetHasher(hasher)
   285  	r.addrTakenFuncsBySig.SetHasher(hasher)
   286  	r.dynCallSites.SetHasher(hasher)
   287  	r.invokeSites.SetHasher(hasher)
   288  	r.concreteTypes.SetHasher(hasher)
   289  	r.interfaceTypes.SetHasher(hasher)
   290  
   291  	// Visit functions, processing their instructions, and adding
   292  	// new functions to the worklist, until a fixed point is
   293  	// reached.
   294  	var shadow []*ssa.Function // for efficiency, we double-buffer the worklist
   295  	r.worklist = append(r.worklist, roots...)
   296  	for len(r.worklist) > 0 {
   297  		shadow, r.worklist = r.worklist, shadow[:0]
   298  		for _, f := range shadow {
   299  			r.visitFunc(f)
   300  		}
   301  	}
   302  	return r.result
   303  }
   304  
   305  // interfaces(C) returns all currently known interfaces implemented by C.
   306  func (r *rta) interfaces(C types.Type) []*types.Interface {
   307  	// Ascertain set of interfaces C implements
   308  	// and update 'implements' relation.
   309  	var ifaces []*types.Interface
   310  	r.interfaceTypes.Iterate(func(I types.Type, concs interface{}) {
   311  		if I := I.(*types.Interface); types.Implements(C, I) {
   312  			concs, _ := concs.([]types.Type)
   313  			r.interfaceTypes.Set(I, append(concs, C))
   314  			ifaces = append(ifaces, I)
   315  		}
   316  	})
   317  	r.concreteTypes.Set(C, ifaces)
   318  	return ifaces
   319  }
   320  
   321  // implementations(I) returns all currently known concrete types that implement I.
   322  func (r *rta) implementations(I *types.Interface) []types.Type {
   323  	var concs []types.Type
   324  	if v := r.interfaceTypes.At(I); v != nil {
   325  		concs = v.([]types.Type)
   326  	} else {
   327  		// First time seeing this interface.
   328  		// Update the 'implements' relation.
   329  		r.concreteTypes.Iterate(func(C types.Type, ifaces interface{}) {
   330  			if types.Implements(C, I) {
   331  				ifaces, _ := ifaces.([]*types.Interface)
   332  				r.concreteTypes.Set(C, append(ifaces, I))
   333  				concs = append(concs, C)
   334  			}
   335  		})
   336  		r.interfaceTypes.Set(I, concs)
   337  	}
   338  	return concs
   339  }
   340  
   341  // addRuntimeType is called for each concrete type that can be the
   342  // dynamic type of some interface or reflect.Value.
   343  // Adapted from needMethods in go/ssa/builder.go
   344  //
   345  func (r *rta) addRuntimeType(T types.Type, skip bool) {
   346  	if prev, ok := r.result.RuntimeTypes.At(T).(bool); ok {
   347  		if skip && !prev {
   348  			r.result.RuntimeTypes.Set(T, skip)
   349  		}
   350  		return
   351  	}
   352  	r.result.RuntimeTypes.Set(T, skip)
   353  
   354  	mset := r.prog.MethodSets.MethodSet(T)
   355  
   356  	if _, ok := T.Underlying().(*types.Interface); !ok {
   357  		// T is a new concrete type.
   358  		for i, n := 0, mset.Len(); i < n; i++ {
   359  			sel := mset.At(i)
   360  			m := sel.Obj()
   361  
   362  			if m.Exported() {
   363  				// Exported methods are always potentially callable via reflection.
   364  				r.addReachable(r.prog.MethodValue(sel), true)
   365  			}
   366  		}
   367  
   368  		// Add callgraph edge for each existing dynamic
   369  		// "invoke"-mode call via that interface.
   370  		for _, I := range r.interfaces(T) {
   371  			sites, _ := r.invokeSites.At(I).([]ssa.CallInstruction)
   372  			for _, site := range sites {
   373  				r.addInvokeEdge(site, T)
   374  			}
   375  		}
   376  	}
   377  
   378  	// Precondition: T is not a method signature (*Signature with Recv()!=nil).
   379  	// Recursive case: skip => don't call makeMethods(T).
   380  	// Each package maintains its own set of types it has visited.
   381  
   382  	var n *types.Named
   383  	switch T := T.(type) {
   384  	case *types.Named:
   385  		n = T
   386  	case *types.Pointer:
   387  		n, _ = T.Elem().(*types.Named)
   388  	}
   389  	if n != nil {
   390  		owner := n.Obj().Pkg()
   391  		if owner == nil {
   392  			return // built-in error type
   393  		}
   394  	}
   395  
   396  	// Recursion over signatures of each exported method.
   397  	for i := 0; i < mset.Len(); i++ {
   398  		if mset.At(i).Obj().Exported() {
   399  			sig := mset.At(i).Type().(*types.Signature)
   400  			r.addRuntimeType(sig.Params(), true)  // skip the Tuple itself
   401  			r.addRuntimeType(sig.Results(), true) // skip the Tuple itself
   402  		}
   403  	}
   404  
   405  	switch t := T.(type) {
   406  	case *types.Basic:
   407  		// nop
   408  
   409  	case *types.Interface:
   410  		// nop---handled by recursion over method set.
   411  
   412  	case *types.Pointer:
   413  		r.addRuntimeType(t.Elem(), false)
   414  
   415  	case *types.Slice:
   416  		r.addRuntimeType(t.Elem(), false)
   417  
   418  	case *types.Chan:
   419  		r.addRuntimeType(t.Elem(), false)
   420  
   421  	case *types.Map:
   422  		r.addRuntimeType(t.Key(), false)
   423  		r.addRuntimeType(t.Elem(), false)
   424  
   425  	case *types.Signature:
   426  		if t.Recv() != nil {
   427  			panic(fmt.Sprintf("Signature %s has Recv %s", t, t.Recv()))
   428  		}
   429  		r.addRuntimeType(t.Params(), true)  // skip the Tuple itself
   430  		r.addRuntimeType(t.Results(), true) // skip the Tuple itself
   431  
   432  	case *types.Named:
   433  		// A pointer-to-named type can be derived from a named
   434  		// type via reflection.  It may have methods too.
   435  		r.addRuntimeType(types.NewPointer(T), false)
   436  
   437  		// Consider 'type T struct{S}' where S has methods.
   438  		// Reflection provides no way to get from T to struct{S},
   439  		// only to S, so the method set of struct{S} is unwanted,
   440  		// so set 'skip' flag during recursion.
   441  		r.addRuntimeType(t.Underlying(), true)
   442  
   443  	case *types.Array:
   444  		r.addRuntimeType(t.Elem(), false)
   445  
   446  	case *types.Struct:
   447  		for i, n := 0, t.NumFields(); i < n; i++ {
   448  			r.addRuntimeType(t.Field(i).Type(), false)
   449  		}
   450  
   451  	case *types.Tuple:
   452  		for i, n := 0, t.Len(); i < n; i++ {
   453  			r.addRuntimeType(t.At(i).Type(), false)
   454  		}
   455  
   456  	default:
   457  		panic(T)
   458  	}
   459  }