github.com/kevinklinger/open_terraform@v1.3.6/noninternal/dag/dag.go (about)

     1  package dag
     2  
     3  import (
     4  	"fmt"
     5  	"sort"
     6  	"strings"
     7  
     8  	"github.com/kevinklinger/open_terraform/noninternal/tfdiags"
     9  
    10  	"github.com/hashicorp/go-multierror"
    11  )
    12  
    13  // AcyclicGraph is a specialization of Graph that cannot have cycles.
    14  type AcyclicGraph struct {
    15  	Graph
    16  }
    17  
    18  // WalkFunc is the callback used for walking the graph.
    19  type WalkFunc func(Vertex) tfdiags.Diagnostics
    20  
    21  // DepthWalkFunc is a walk function that also receives the current depth of the
    22  // walk as an argument
    23  type DepthWalkFunc func(Vertex, int) error
    24  
    25  func (g *AcyclicGraph) DirectedGraph() Grapher {
    26  	return g
    27  }
    28  
    29  // Returns a Set that includes every Vertex yielded by walking down from the
    30  // provided starting Vertex v.
    31  func (g *AcyclicGraph) Ancestors(v Vertex) (Set, error) {
    32  	s := make(Set)
    33  	memoFunc := func(v Vertex, d int) error {
    34  		s.Add(v)
    35  		return nil
    36  	}
    37  
    38  	if err := g.DepthFirstWalk(g.downEdgesNoCopy(v), memoFunc); err != nil {
    39  		return nil, err
    40  	}
    41  
    42  	return s, nil
    43  }
    44  
    45  // Returns a Set that includes every Vertex yielded by walking up from the
    46  // provided starting Vertex v.
    47  func (g *AcyclicGraph) Descendents(v Vertex) (Set, error) {
    48  	s := make(Set)
    49  	memoFunc := func(v Vertex, d int) error {
    50  		s.Add(v)
    51  		return nil
    52  	}
    53  
    54  	if err := g.ReverseDepthFirstWalk(g.upEdgesNoCopy(v), memoFunc); err != nil {
    55  		return nil, err
    56  	}
    57  
    58  	return s, nil
    59  }
    60  
    61  // Root returns the root of the DAG, or an error.
    62  //
    63  // Complexity: O(V)
    64  func (g *AcyclicGraph) Root() (Vertex, error) {
    65  	roots := make([]Vertex, 0, 1)
    66  	for _, v := range g.Vertices() {
    67  		if g.upEdgesNoCopy(v).Len() == 0 {
    68  			roots = append(roots, v)
    69  		}
    70  	}
    71  
    72  	if len(roots) > 1 {
    73  		// TODO(mitchellh): make this error message a lot better
    74  		return nil, fmt.Errorf("multiple roots: %#v", roots)
    75  	}
    76  
    77  	if len(roots) == 0 {
    78  		return nil, fmt.Errorf("no roots found")
    79  	}
    80  
    81  	return roots[0], nil
    82  }
    83  
    84  // TransitiveReduction performs the transitive reduction of graph g in place.
    85  // The transitive reduction of a graph is a graph with as few edges as
    86  // possible with the same reachability as the original graph. This means
    87  // that if there are three nodes A => B => C, and A connects to both
    88  // B and C, and B connects to C, then the transitive reduction is the
    89  // same graph with only a single edge between A and B, and a single edge
    90  // between B and C.
    91  //
    92  // The graph must be free of cycles for this operation to behave properly.
    93  //
    94  // Complexity: O(V(V+E)), or asymptotically O(VE)
    95  func (g *AcyclicGraph) TransitiveReduction() {
    96  	// For each vertex u in graph g, do a DFS starting from each vertex
    97  	// v such that the edge (u,v) exists (v is a direct descendant of u).
    98  	//
    99  	// For each v-prime reachable from v, remove the edge (u, v-prime).
   100  	for _, u := range g.Vertices() {
   101  		uTargets := g.downEdgesNoCopy(u)
   102  
   103  		g.DepthFirstWalk(g.downEdgesNoCopy(u), func(v Vertex, d int) error {
   104  			shared := uTargets.Intersection(g.downEdgesNoCopy(v))
   105  			for _, vPrime := range shared {
   106  				g.RemoveEdge(BasicEdge(u, vPrime))
   107  			}
   108  
   109  			return nil
   110  		})
   111  	}
   112  }
   113  
   114  // Validate validates the DAG. A DAG is valid if it has a single root
   115  // with no cycles.
   116  func (g *AcyclicGraph) Validate() error {
   117  	if _, err := g.Root(); err != nil {
   118  		return err
   119  	}
   120  
   121  	// Look for cycles of more than 1 component
   122  	var err error
   123  	cycles := g.Cycles()
   124  	if len(cycles) > 0 {
   125  		for _, cycle := range cycles {
   126  			cycleStr := make([]string, len(cycle))
   127  			for j, vertex := range cycle {
   128  				cycleStr[j] = VertexName(vertex)
   129  			}
   130  
   131  			err = multierror.Append(err, fmt.Errorf(
   132  				"Cycle: %s", strings.Join(cycleStr, ", ")))
   133  		}
   134  	}
   135  
   136  	// Look for cycles to self
   137  	for _, e := range g.Edges() {
   138  		if e.Source() == e.Target() {
   139  			err = multierror.Append(err, fmt.Errorf(
   140  				"Self reference: %s", VertexName(e.Source())))
   141  		}
   142  	}
   143  
   144  	return err
   145  }
   146  
   147  // Cycles reports any cycles between graph nodes.
   148  // Self-referencing nodes are not reported, and must be detected separately.
   149  func (g *AcyclicGraph) Cycles() [][]Vertex {
   150  	var cycles [][]Vertex
   151  	for _, cycle := range StronglyConnected(&g.Graph) {
   152  		if len(cycle) > 1 {
   153  			cycles = append(cycles, cycle)
   154  		}
   155  	}
   156  	return cycles
   157  }
   158  
   159  // Walk walks the graph, calling your callback as each node is visited.
   160  // This will walk nodes in parallel if it can. The resulting diagnostics
   161  // contains problems from all graphs visited, in no particular order.
   162  func (g *AcyclicGraph) Walk(cb WalkFunc) tfdiags.Diagnostics {
   163  	w := &Walker{Callback: cb, Reverse: true}
   164  	w.Update(g)
   165  	return w.Wait()
   166  }
   167  
   168  // simple convenience helper for converting a dag.Set to a []Vertex
   169  func AsVertexList(s Set) []Vertex {
   170  	vertexList := make([]Vertex, 0, len(s))
   171  	for _, raw := range s {
   172  		vertexList = append(vertexList, raw.(Vertex))
   173  	}
   174  	return vertexList
   175  }
   176  
   177  type vertexAtDepth struct {
   178  	Vertex Vertex
   179  	Depth  int
   180  }
   181  
   182  // TopologicalOrder returns a topological sort of the given graph. The nodes
   183  // are not sorted, and any valid order may be returned. This function will
   184  // panic if it encounters a cycle.
   185  func (g *AcyclicGraph) TopologicalOrder() []Vertex {
   186  	return g.topoOrder(upOrder)
   187  }
   188  
   189  // ReverseTopologicalOrder returns a topological sort of the given graph,
   190  // following each edge in reverse. The nodes are not sorted, and any valid
   191  // order may be returned. This function will panic if it encounters a cycle.
   192  func (g *AcyclicGraph) ReverseTopologicalOrder() []Vertex {
   193  	return g.topoOrder(downOrder)
   194  }
   195  
   196  func (g *AcyclicGraph) topoOrder(order walkType) []Vertex {
   197  	// Use a dfs-based sorting algorithm, similar to that used in
   198  	// TransitiveReduction.
   199  	sorted := make([]Vertex, 0, len(g.vertices))
   200  
   201  	// tmp track the current working node to check for cycles
   202  	tmp := map[Vertex]bool{}
   203  
   204  	// perm tracks completed nodes to end the recursion
   205  	perm := map[Vertex]bool{}
   206  
   207  	var visit func(v Vertex)
   208  
   209  	visit = func(v Vertex) {
   210  		if perm[v] {
   211  			return
   212  		}
   213  
   214  		if tmp[v] {
   215  			panic("cycle found in dag")
   216  		}
   217  
   218  		tmp[v] = true
   219  		var next Set
   220  		switch {
   221  		case order&downOrder != 0:
   222  			next = g.downEdgesNoCopy(v)
   223  		case order&upOrder != 0:
   224  			next = g.upEdgesNoCopy(v)
   225  		default:
   226  			panic(fmt.Sprintln("invalid order", order))
   227  		}
   228  
   229  		for _, u := range next {
   230  			visit(u)
   231  		}
   232  
   233  		tmp[v] = false
   234  		perm[v] = true
   235  		sorted = append(sorted, v)
   236  	}
   237  
   238  	for _, v := range g.Vertices() {
   239  		visit(v)
   240  	}
   241  
   242  	return sorted
   243  }
   244  
   245  type walkType uint64
   246  
   247  const (
   248  	depthFirst walkType = 1 << iota
   249  	breadthFirst
   250  	downOrder
   251  	upOrder
   252  )
   253  
   254  // DepthFirstWalk does a depth-first walk of the graph starting from
   255  // the vertices in start.
   256  func (g *AcyclicGraph) DepthFirstWalk(start Set, f DepthWalkFunc) error {
   257  	return g.walk(depthFirst|downOrder, false, start, f)
   258  }
   259  
   260  // ReverseDepthFirstWalk does a depth-first walk _up_ the graph starting from
   261  // the vertices in start.
   262  func (g *AcyclicGraph) ReverseDepthFirstWalk(start Set, f DepthWalkFunc) error {
   263  	return g.walk(depthFirst|upOrder, false, start, f)
   264  }
   265  
   266  // BreadthFirstWalk does a breadth-first walk of the graph starting from
   267  // the vertices in start.
   268  func (g *AcyclicGraph) BreadthFirstWalk(start Set, f DepthWalkFunc) error {
   269  	return g.walk(breadthFirst|downOrder, false, start, f)
   270  }
   271  
   272  // ReverseBreadthFirstWalk does a breadth-first walk _up_ the graph starting from
   273  // the vertices in start.
   274  func (g *AcyclicGraph) ReverseBreadthFirstWalk(start Set, f DepthWalkFunc) error {
   275  	return g.walk(breadthFirst|upOrder, false, start, f)
   276  }
   277  
   278  // Setting test to true will walk sets of vertices in sorted order for
   279  // deterministic testing.
   280  func (g *AcyclicGraph) walk(order walkType, test bool, start Set, f DepthWalkFunc) error {
   281  	seen := make(map[Vertex]struct{})
   282  	frontier := make([]vertexAtDepth, 0, len(start))
   283  	for _, v := range start {
   284  		frontier = append(frontier, vertexAtDepth{
   285  			Vertex: v,
   286  			Depth:  0,
   287  		})
   288  	}
   289  
   290  	if test {
   291  		testSortFrontier(frontier)
   292  	}
   293  
   294  	for len(frontier) > 0 {
   295  		// Pop the current vertex
   296  		var current vertexAtDepth
   297  
   298  		switch {
   299  		case order&depthFirst != 0:
   300  			// depth first, the frontier is used like a stack
   301  			n := len(frontier)
   302  			current = frontier[n-1]
   303  			frontier = frontier[:n-1]
   304  		case order&breadthFirst != 0:
   305  			// breadth first, the frontier is used like a queue
   306  			current = frontier[0]
   307  			frontier = frontier[1:]
   308  		default:
   309  			panic(fmt.Sprint("invalid visit order", order))
   310  		}
   311  
   312  		// Check if we've seen this already and return...
   313  		if _, ok := seen[current.Vertex]; ok {
   314  			continue
   315  		}
   316  		seen[current.Vertex] = struct{}{}
   317  
   318  		// Visit the current node
   319  		if err := f(current.Vertex, current.Depth); err != nil {
   320  			return err
   321  		}
   322  
   323  		var edges Set
   324  		switch {
   325  		case order&downOrder != 0:
   326  			edges = g.downEdgesNoCopy(current.Vertex)
   327  		case order&upOrder != 0:
   328  			edges = g.upEdgesNoCopy(current.Vertex)
   329  		default:
   330  			panic(fmt.Sprint("invalid walk order", order))
   331  		}
   332  
   333  		if test {
   334  			frontier = testAppendNextSorted(frontier, edges, current.Depth+1)
   335  		} else {
   336  			frontier = appendNext(frontier, edges, current.Depth+1)
   337  		}
   338  	}
   339  	return nil
   340  }
   341  
   342  func appendNext(frontier []vertexAtDepth, next Set, depth int) []vertexAtDepth {
   343  	for _, v := range next {
   344  		frontier = append(frontier, vertexAtDepth{
   345  			Vertex: v,
   346  			Depth:  depth,
   347  		})
   348  	}
   349  	return frontier
   350  }
   351  
   352  func testAppendNextSorted(frontier []vertexAtDepth, edges Set, depth int) []vertexAtDepth {
   353  	var newEdges []vertexAtDepth
   354  	for _, v := range edges {
   355  		newEdges = append(newEdges, vertexAtDepth{
   356  			Vertex: v,
   357  			Depth:  depth,
   358  		})
   359  	}
   360  	testSortFrontier(newEdges)
   361  	return append(frontier, newEdges...)
   362  }
   363  func testSortFrontier(f []vertexAtDepth) {
   364  	sort.Slice(f, func(i, j int) bool {
   365  		return VertexName(f[i].Vertex) < VertexName(f[j].Vertex)
   366  	})
   367  }