github.com/cockroachdb/tools@v0.0.0-20230222021103-a6d27438930d/go/callgraph/vta/propagation.go (about)

     1  // Copyright 2021 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package vta
     6  
     7  import (
     8  	"go/types"
     9  
    10  	"golang.org/x/tools/go/callgraph/vta/internal/trie"
    11  	"golang.org/x/tools/go/ssa"
    12  
    13  	"golang.org/x/tools/go/types/typeutil"
    14  )
    15  
    16  // scc computes strongly connected components (SCCs) of `g` using the
    17  // classical Tarjan's algorithm for SCCs. The result is a pair <m, id>
    18  // where m is a map from nodes to unique id of their SCC in the range
    19  // [0, id). The SCCs are sorted in reverse topological order: for SCCs
    20  // with ids X and Y s.t. X < Y, Y comes before X in the topological order.
    21  func scc(g vtaGraph) (map[node]int, int) {
    22  	// standard data structures used by Tarjan's algorithm.
    23  	type state struct {
    24  		index   int
    25  		lowLink int
    26  		onStack bool
    27  	}
    28  	states := make(map[node]*state, len(g))
    29  	var stack []node
    30  
    31  	nodeToSccID := make(map[node]int, len(g))
    32  	sccID := 0
    33  
    34  	var doSCC func(node)
    35  	doSCC = func(n node) {
    36  		index := len(states)
    37  		ns := &state{index: index, lowLink: index, onStack: true}
    38  		states[n] = ns
    39  		stack = append(stack, n)
    40  
    41  		for s := range g[n] {
    42  			if ss, visited := states[s]; !visited {
    43  				// Analyze successor s that has not been visited yet.
    44  				doSCC(s)
    45  				ss = states[s]
    46  				ns.lowLink = min(ns.lowLink, ss.lowLink)
    47  			} else if ss.onStack {
    48  				// The successor is on the stack, meaning it has to be
    49  				// in the current SCC.
    50  				ns.lowLink = min(ns.lowLink, ss.index)
    51  			}
    52  		}
    53  
    54  		// if n is a root node, pop the stack and generate a new SCC.
    55  		if ns.lowLink == index {
    56  			var w node
    57  			for w != n {
    58  				w = stack[len(stack)-1]
    59  				stack = stack[:len(stack)-1]
    60  				states[w].onStack = false
    61  				nodeToSccID[w] = sccID
    62  			}
    63  			sccID++
    64  		}
    65  	}
    66  
    67  	for n := range g {
    68  		if _, visited := states[n]; !visited {
    69  			doSCC(n)
    70  		}
    71  	}
    72  
    73  	return nodeToSccID, sccID
    74  }
    75  
    76  func min(x, y int) int {
    77  	if x < y {
    78  		return x
    79  	}
    80  	return y
    81  }
    82  
    83  // propType represents type information being propagated
    84  // over the vta graph. f != nil only for function nodes
    85  // and nodes reachable from function nodes. There, we also
    86  // remember the actual *ssa.Function in order to more
    87  // precisely model higher-order flow.
    88  type propType struct {
    89  	typ types.Type
    90  	f   *ssa.Function
    91  }
    92  
    93  // propTypeMap is an auxiliary structure that serves
    94  // the role of a map from nodes to a set of propTypes.
    95  type propTypeMap struct {
    96  	nodeToScc  map[node]int
    97  	sccToTypes map[int]*trie.MutMap
    98  }
    99  
   100  // propTypes returns a list of propTypes associated with
   101  // node `n`. If `n` is not in the map `ptm`, nil is returned.
   102  func (ptm propTypeMap) propTypes(n node) []propType {
   103  	id, ok := ptm.nodeToScc[n]
   104  	if !ok {
   105  		return nil
   106  	}
   107  	var pts []propType
   108  	for _, elem := range trie.Elems(ptm.sccToTypes[id].M) {
   109  		pts = append(pts, elem.(propType))
   110  	}
   111  	return pts
   112  }
   113  
   114  // propagate reduces the `graph` based on its SCCs and
   115  // then propagates type information through the reduced
   116  // graph. The result is a map from nodes to a set of types
   117  // and functions, stemming from higher-order data flow,
   118  // reaching the node. `canon` is used for type uniqueness.
   119  func propagate(graph vtaGraph, canon *typeutil.Map) propTypeMap {
   120  	nodeToScc, sccID := scc(graph)
   121  
   122  	// We also need the reverse map, from ids to SCCs.
   123  	sccs := make(map[int][]node, sccID)
   124  	for n, id := range nodeToScc {
   125  		sccs[id] = append(sccs[id], n)
   126  	}
   127  
   128  	// propTypeIds are used to create unique ids for
   129  	// propType, to be used for trie-based type sets.
   130  	propTypeIds := make(map[propType]uint64)
   131  	// Id creation is based on == equality, which works
   132  	// as types are canonicalized (see getPropType).
   133  	propTypeId := func(p propType) uint64 {
   134  		if id, ok := propTypeIds[p]; ok {
   135  			return id
   136  		}
   137  		id := uint64(len(propTypeIds))
   138  		propTypeIds[p] = id
   139  		return id
   140  	}
   141  	builder := trie.NewBuilder()
   142  	// Initialize sccToTypes to avoid repeated check
   143  	// for initialization later.
   144  	sccToTypes := make(map[int]*trie.MutMap, sccID)
   145  	for i := 0; i <= sccID; i++ {
   146  		sccToTypes[i] = nodeTypes(sccs[i], builder, propTypeId, canon)
   147  	}
   148  
   149  	for i := len(sccs) - 1; i >= 0; i-- {
   150  		nextSccs := make(map[int]struct{})
   151  		for _, node := range sccs[i] {
   152  			for succ := range graph[node] {
   153  				nextSccs[nodeToScc[succ]] = struct{}{}
   154  			}
   155  		}
   156  		// Propagate types to all successor SCCs.
   157  		for nextScc := range nextSccs {
   158  			sccToTypes[nextScc].Merge(sccToTypes[i].M)
   159  		}
   160  	}
   161  	return propTypeMap{nodeToScc: nodeToScc, sccToTypes: sccToTypes}
   162  }
   163  
   164  // nodeTypes returns a set of propTypes for `nodes`. These are the
   165  // propTypes stemming from the type of each node in `nodes` plus.
   166  func nodeTypes(nodes []node, builder *trie.Builder, propTypeId func(p propType) uint64, canon *typeutil.Map) *trie.MutMap {
   167  	typeSet := builder.MutEmpty()
   168  	for _, n := range nodes {
   169  		if hasInitialTypes(n) {
   170  			pt := getPropType(n, canon)
   171  			typeSet.Update(propTypeId(pt), pt)
   172  		}
   173  	}
   174  	return &typeSet
   175  }
   176  
   177  // hasInitialTypes check if a node can have initial types.
   178  // Returns true iff `n` is not a panic, recover, nestedPtr*
   179  // node, nor a node whose type is an interface.
   180  func hasInitialTypes(n node) bool {
   181  	switch n.(type) {
   182  	case panicArg, recoverReturn, nestedPtrFunction, nestedPtrInterface:
   183  		return false
   184  	default:
   185  		return !types.IsInterface(n.Type())
   186  	}
   187  }
   188  
   189  // getPropType creates a propType for `node` based on its type.
   190  // propType.typ is always node.Type(). If node is function, then
   191  // propType.val is the underlying function; nil otherwise.
   192  func getPropType(node node, canon *typeutil.Map) propType {
   193  	t := canonicalize(node.Type(), canon)
   194  	if fn, ok := node.(function); ok {
   195  		return propType{f: fn.f, typ: t}
   196  	}
   197  	return propType{f: nil, typ: t}
   198  }