github.com/qri-io/qri@v0.10.1-0.20220104210721-c771715036cb/transform/staticlark/call_graph.go (about)

     1  package staticlark
     2  
     3  import (
     4  	"fmt"
     5  	"sort"
     6  	"strings"
     7  )
     8  
     9  // callGraph is a graph of function nodes and what they call
    10  type callGraph struct {
    11  	nodes  []*funcNode
    12  	lookup map[string]*funcNode
    13  }
    14  
    15  // buildCallGraph iterates the function nodes provided, and adds
    16  // the list of calls that each makes, forming an acyclic graph
    17  // of the entire script. This is the basis of whole file analysis,
    18  // such as dataflow analysis
    19  func buildCallGraph(functions []*funcNode, entryPoints []string, symtable map[string]*funcNode) *callGraph {
    20  	// Add top level functions to the symbol table
    21  	for _, f := range functions {
    22  		symtable[f.name] = f
    23  	}
    24  
    25  	// Build the call graph
    26  	graph := &callGraph{
    27  		nodes:  make([]*funcNode, 0, len(functions)),
    28  		lookup: make(map[string]*funcNode),
    29  	}
    30  	for _, f := range functions {
    31  		addToCallGraph(f, graph, symtable)
    32  	}
    33  
    34  	for _, n := range graph.nodes {
    35  		n.setCallHeight()
    36  	}
    37  
    38  	// Determine reachability using the given entry points
    39  	if entryPoints != nil {
    40  		for _, entry := range entryPoints {
    41  			root := graph.lookup[entry]
    42  			if root != nil {
    43  				root.markReachable()
    44  			}
    45  		}
    46  	}
    47  
    48  	return graph
    49  }
    50  
    51  func addToCallGraph(f *funcNode, graph *callGraph, symtable map[string]*funcNode) *funcNode {
    52  	me, ok := graph.lookup[f.name]
    53  	if ok {
    54  		return me
    55  	}
    56  	me = &funcNode{
    57  		name:   f.name,
    58  		params: f.params,
    59  		body:   f.body,
    60  		calls:  make([]*funcNode, 0),
    61  	}
    62  	for _, name := range f.callNames {
    63  		child, ok := symtable[name]
    64  		if !ok {
    65  			log.Debugw("addToCallGraph func not found", "name", name)
    66  			continue
    67  		}
    68  		n := addToCallGraph(child, graph, symtable)
    69  		me.calls = append(me.calls, n)
    70  	}
    71  	graph.lookup[f.name] = me
    72  	graph.nodes = append(graph.nodes, me)
    73  	return me
    74  }
    75  
    76  func (n *funcNode) setCallHeight() {
    77  	maxChild := -1
    78  	for _, call := range n.calls {
    79  		call.setCallHeight()
    80  		if call.height > maxChild {
    81  			maxChild = call.height
    82  		}
    83  	}
    84  	n.height = maxChild + 1
    85  }
    86  
    87  func (n *funcNode) markReachable() {
    88  	n.reach = true
    89  	for _, call := range n.calls {
    90  		call.markReachable()
    91  	}
    92  }
    93  
    94  func (cg *callGraph) findUnusedFuncs() []Diagnostic {
    95  	// Recursively walk the tree to find unreachable nodes
    96  	unusedNames := map[string]struct{}{}
    97  	for _, f := range cg.nodes {
    98  		checkfuncNodeUnused(f, unusedNames)
    99  	}
   100  	// Sort the function names
   101  	results := make([]Diagnostic, 0, len(unusedNames))
   102  	for fname := range unusedNames {
   103  		results = append(results, Diagnostic{
   104  			Category: "unused",
   105  			Message:  fname,
   106  		})
   107  	}
   108  	sort.Slice(results, func(i, j int) bool {
   109  		return results[i].Message < results[j].Message
   110  	})
   111  	return results
   112  }
   113  
   114  func checkfuncNodeUnused(node *funcNode, unusedNames map[string]struct{}) {
   115  	if !node.reach {
   116  		// TODO(dustmop): Copy the position of the function definition
   117  		unusedNames[node.name] = struct{}{}
   118  	}
   119  	for _, call := range node.calls {
   120  		checkfuncNodeUnused(call, unusedNames)
   121  	}
   122  }
   123  
   124  // String creates a string representation of functions in the call graph
   125  func (cg *callGraph) String() string {
   126  	text := ""
   127  	for _, n := range cg.nodes {
   128  		text += stringifyNode(n, 0)
   129  	}
   130  	return text
   131  }
   132  
   133  func stringifyNode(n *funcNode, depth int) string {
   134  	padding := strings.Repeat(" ", depth)
   135  	seen := map[string]struct{}{}
   136  	text := fmt.Sprintf("%s%s\n", padding, n.name)
   137  	for _, call := range n.calls {
   138  		if _, ok := seen[call.name]; ok {
   139  			continue
   140  		}
   141  		seen[call.name] = struct{}{}
   142  		text += stringifyNode(call, depth+1)
   143  	}
   144  	return text
   145  }