github.com/cockroachdb/tools@v0.0.0-20230222021103-a6d27438930d/cmd/digraph/digraph.go (about)

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  /*
     6  The digraph command performs queries over unlabelled directed graphs
     7  represented in text form.  It is intended to integrate nicely with
     8  typical UNIX command pipelines.
     9  
    10  Usage:
    11  
    12  	your-application | digraph [command]
    13  
    14  The support commands are:
    15  
    16  	nodes
    17  		the set of all nodes
    18  	degree
    19  		the in-degree and out-degree of each node
    20  	transpose
    21  		the reverse of the input edges
    22  	preds <node> ...
    23  		the set of immediate predecessors of the specified nodes
    24  	succs <node> ...
    25  		the set of immediate successors of the specified nodes
    26  	forward <node> ...
    27  		the set of nodes transitively reachable from the specified nodes
    28  	reverse <node> ...
    29  		the set of nodes that transitively reach the specified nodes
    30  	somepath <node> <node>
    31  		the list of nodes on some arbitrary path from the first node to the second
    32  	allpaths <node> <node>
    33  		the set of nodes on all paths from the first node to the second
    34  	sccs
    35  		all strongly connected components (one per line)
    36  	scc <node>
    37  		the set of nodes strongly connected to the specified one
    38  	focus <node>
    39  		the subgraph containing all directed paths that pass through the specified node
    40  
    41  Input format:
    42  
    43  Each line contains zero or more words. Words are separated by unquoted
    44  whitespace; words may contain Go-style double-quoted portions, allowing spaces
    45  and other characters to be expressed.
    46  
    47  Each word declares a node, and if there are more than one, an edge from the
    48  first to each subsequent one. The graph is provided on the standard input.
    49  
    50  For instance, the following (acyclic) graph specifies a partial order among the
    51  subtasks of getting dressed:
    52  
    53  	$ cat clothes.txt
    54  	socks shoes
    55  	"boxer shorts" pants
    56  	pants belt shoes
    57  	shirt tie sweater
    58  	sweater jacket
    59  	hat
    60  
    61  The line "shirt tie sweater" indicates the two edges shirt -> tie and
    62  shirt -> sweater, not shirt -> tie -> sweater.
    63  
    64  Example usage:
    65  
    66  Using digraph with existing Go tools:
    67  
    68  	$ go mod graph | digraph nodes # Operate on the Go module graph.
    69  	$ go list -m all | digraph nodes # Operate on the Go package graph.
    70  
    71  Show the transitive closure of imports of the digraph tool itself:
    72  
    73  	$ go list -f '{{.ImportPath}} {{join .Imports " "}}' ... | digraph forward golang.org/x/tools/cmd/digraph
    74  
    75  Show which clothes (see above) must be donned before a jacket:
    76  
    77  	$ digraph reverse jacket
    78  */
    79  package main // import "golang.org/x/tools/cmd/digraph"
    80  
    81  // TODO(adonovan):
    82  // - support input files other than stdin
    83  // - support alternative formats (AT&T GraphViz, CSV, etc),
    84  //   a comment syntax, etc.
    85  // - allow queries to nest, like Blaze query language.
    86  
    87  import (
    88  	"bufio"
    89  	"bytes"
    90  	"errors"
    91  	"flag"
    92  	"fmt"
    93  	"io"
    94  	"os"
    95  	"sort"
    96  	"strconv"
    97  	"strings"
    98  	"unicode"
    99  	"unicode/utf8"
   100  )
   101  
   102  func usage() {
   103  	fmt.Fprintf(os.Stderr, `Usage: your-application | digraph [command]
   104  
   105  The support commands are:
   106  	nodes
   107  		the set of all nodes
   108  	degree
   109  		the in-degree and out-degree of each node
   110  	transpose
   111  		the reverse of the input edges
   112  	preds <node> ...
   113  		the set of immediate predecessors of the specified nodes
   114  	succs <node> ...
   115  		the set of immediate successors of the specified nodes
   116  	forward <node> ...
   117  		the set of nodes transitively reachable from the specified nodes
   118  	reverse <node> ...
   119  		the set of nodes that transitively reach the specified nodes
   120  	somepath <node> <node>
   121  		the list of nodes on some arbitrary path from the first node to the second
   122  	allpaths <node> <node>
   123  		the set of nodes on all paths from the first node to the second
   124  	sccs
   125  		all non-trivial strongly connected components, one per line
   126  		(single-node components are only printed for nodes with self-loops)
   127  	scc <node>
   128  		the set of nodes nodes strongly connected to the specified one
   129  	focus <node>
   130  		the subgraph containing all directed paths that pass through the specified node
   131  `)
   132  	os.Exit(2)
   133  }
   134  
   135  func main() {
   136  	flag.Usage = usage
   137  	flag.Parse()
   138  
   139  	args := flag.Args()
   140  	if len(args) == 0 {
   141  		usage()
   142  	}
   143  
   144  	if err := digraph(args[0], args[1:]); err != nil {
   145  		fmt.Fprintf(os.Stderr, "digraph: %s\n", err)
   146  		os.Exit(1)
   147  	}
   148  }
   149  
   150  type nodelist []string
   151  
   152  func (l nodelist) println(sep string) {
   153  	for i, node := range l {
   154  		if i > 0 {
   155  			fmt.Fprint(stdout, sep)
   156  		}
   157  		fmt.Fprint(stdout, node)
   158  	}
   159  	fmt.Fprintln(stdout)
   160  }
   161  
   162  type nodeset map[string]bool
   163  
   164  func (s nodeset) sort() nodelist {
   165  	nodes := make(nodelist, len(s))
   166  	var i int
   167  	for node := range s {
   168  		nodes[i] = node
   169  		i++
   170  	}
   171  	sort.Strings(nodes)
   172  	return nodes
   173  }
   174  
   175  func (s nodeset) addAll(x nodeset) {
   176  	for node := range x {
   177  		s[node] = true
   178  	}
   179  }
   180  
   181  // A graph maps nodes to the non-nil set of their immediate successors.
   182  type graph map[string]nodeset
   183  
   184  func (g graph) addNode(node string) nodeset {
   185  	edges := g[node]
   186  	if edges == nil {
   187  		edges = make(nodeset)
   188  		g[node] = edges
   189  	}
   190  	return edges
   191  }
   192  
   193  func (g graph) addEdges(from string, to ...string) {
   194  	edges := g.addNode(from)
   195  	for _, to := range to {
   196  		g.addNode(to)
   197  		edges[to] = true
   198  	}
   199  }
   200  
   201  func (g graph) reachableFrom(roots nodeset) nodeset {
   202  	seen := make(nodeset)
   203  	var visit func(node string)
   204  	visit = func(node string) {
   205  		if !seen[node] {
   206  			seen[node] = true
   207  			for e := range g[node] {
   208  				visit(e)
   209  			}
   210  		}
   211  	}
   212  	for root := range roots {
   213  		visit(root)
   214  	}
   215  	return seen
   216  }
   217  
   218  func (g graph) transpose() graph {
   219  	rev := make(graph)
   220  	for node, edges := range g {
   221  		rev.addNode(node)
   222  		for succ := range edges {
   223  			rev.addEdges(succ, node)
   224  		}
   225  	}
   226  	return rev
   227  }
   228  
   229  func (g graph) sccs() []nodeset {
   230  	// Kosaraju's algorithm---Tarjan is overkill here.
   231  
   232  	// Forward pass.
   233  	S := make(nodelist, 0, len(g)) // postorder stack
   234  	seen := make(nodeset)
   235  	var visit func(node string)
   236  	visit = func(node string) {
   237  		if !seen[node] {
   238  			seen[node] = true
   239  			for e := range g[node] {
   240  				visit(e)
   241  			}
   242  			S = append(S, node)
   243  		}
   244  	}
   245  	for node := range g {
   246  		visit(node)
   247  	}
   248  
   249  	// Reverse pass.
   250  	rev := g.transpose()
   251  	var scc nodeset
   252  	seen = make(nodeset)
   253  	var rvisit func(node string)
   254  	rvisit = func(node string) {
   255  		if !seen[node] {
   256  			seen[node] = true
   257  			scc[node] = true
   258  			for e := range rev[node] {
   259  				rvisit(e)
   260  			}
   261  		}
   262  	}
   263  	var sccs []nodeset
   264  	for len(S) > 0 {
   265  		top := S[len(S)-1]
   266  		S = S[:len(S)-1] // pop
   267  		if !seen[top] {
   268  			scc = make(nodeset)
   269  			rvisit(top)
   270  			if len(scc) == 1 && !g[top][top] {
   271  				continue
   272  			}
   273  			sccs = append(sccs, scc)
   274  		}
   275  	}
   276  	return sccs
   277  }
   278  
   279  func (g graph) allpaths(from, to string) error {
   280  	// Mark all nodes to "to".
   281  	seen := make(nodeset) // value of seen[x] indicates whether x is on some path to "to"
   282  	var visit func(node string) bool
   283  	visit = func(node string) bool {
   284  		reachesTo, ok := seen[node]
   285  		if !ok {
   286  			reachesTo = node == to
   287  			seen[node] = reachesTo
   288  			for e := range g[node] {
   289  				if visit(e) {
   290  					reachesTo = true
   291  				}
   292  			}
   293  			if reachesTo && node != to {
   294  				seen[node] = true
   295  			}
   296  		}
   297  		return reachesTo
   298  	}
   299  	visit(from)
   300  
   301  	// For each marked node, collect its marked successors.
   302  	var edges []string
   303  	for n := range seen {
   304  		for succ := range g[n] {
   305  			if seen[succ] {
   306  				edges = append(edges, n+" "+succ)
   307  			}
   308  		}
   309  	}
   310  
   311  	// Sort (so that this method is deterministic) and print edges.
   312  	sort.Strings(edges)
   313  	for _, e := range edges {
   314  		fmt.Fprintln(stdout, e)
   315  	}
   316  
   317  	return nil
   318  }
   319  
   320  func (g graph) somepath(from, to string) error {
   321  	type edge struct{ from, to string }
   322  	seen := make(nodeset)
   323  	var dfs func(path []edge, from string) bool
   324  	dfs = func(path []edge, from string) bool {
   325  		if !seen[from] {
   326  			seen[from] = true
   327  			if from == to {
   328  				// fmt.Println(path, len(path), cap(path))
   329  				// Print and unwind.
   330  				for _, e := range path {
   331  					fmt.Fprintln(stdout, e.from+" "+e.to)
   332  				}
   333  				return true
   334  			}
   335  			for e := range g[from] {
   336  				if dfs(append(path, edge{from: from, to: e}), e) {
   337  					return true
   338  				}
   339  			}
   340  		}
   341  		return false
   342  	}
   343  	maxEdgesInGraph := len(g) * (len(g) - 1)
   344  	if !dfs(make([]edge, 0, maxEdgesInGraph), from) {
   345  		return fmt.Errorf("no path from %q to %q", from, to)
   346  	}
   347  	return nil
   348  }
   349  
   350  func parse(rd io.Reader) (graph, error) {
   351  	g := make(graph)
   352  
   353  	var linenum int
   354  	// We avoid bufio.Scanner as it imposes a (configurable) limit
   355  	// on line length, whereas Reader.ReadString does not.
   356  	in := bufio.NewReader(rd)
   357  	for {
   358  		linenum++
   359  		line, err := in.ReadString('\n')
   360  		eof := false
   361  		if err == io.EOF {
   362  			eof = true
   363  		} else if err != nil {
   364  			return nil, err
   365  		}
   366  		// Split into words, honoring double-quotes per Go spec.
   367  		words, err := split(line)
   368  		if err != nil {
   369  			return nil, fmt.Errorf("at line %d: %v", linenum, err)
   370  		}
   371  		if len(words) > 0 {
   372  			g.addEdges(words[0], words[1:]...)
   373  		}
   374  		if eof {
   375  			break
   376  		}
   377  	}
   378  	return g, nil
   379  }
   380  
   381  // Overridable for redirection.
   382  var stdin io.Reader = os.Stdin
   383  var stdout io.Writer = os.Stdout
   384  
   385  func digraph(cmd string, args []string) error {
   386  	// Parse the input graph.
   387  	g, err := parse(stdin)
   388  	if err != nil {
   389  		return err
   390  	}
   391  
   392  	// Parse the command line.
   393  	switch cmd {
   394  	case "nodes":
   395  		if len(args) != 0 {
   396  			return fmt.Errorf("usage: digraph nodes")
   397  		}
   398  		nodes := make(nodeset)
   399  		for node := range g {
   400  			nodes[node] = true
   401  		}
   402  		nodes.sort().println("\n")
   403  
   404  	case "degree":
   405  		if len(args) != 0 {
   406  			return fmt.Errorf("usage: digraph degree")
   407  		}
   408  		nodes := make(nodeset)
   409  		for node := range g {
   410  			nodes[node] = true
   411  		}
   412  		rev := g.transpose()
   413  		for _, node := range nodes.sort() {
   414  			fmt.Fprintf(stdout, "%d\t%d\t%s\n", len(rev[node]), len(g[node]), node)
   415  		}
   416  
   417  	case "transpose":
   418  		if len(args) != 0 {
   419  			return fmt.Errorf("usage: digraph transpose")
   420  		}
   421  		var revEdges []string
   422  		for node, succs := range g.transpose() {
   423  			for succ := range succs {
   424  				revEdges = append(revEdges, fmt.Sprintf("%s %s", node, succ))
   425  			}
   426  		}
   427  		sort.Strings(revEdges) // make output deterministic
   428  		for _, e := range revEdges {
   429  			fmt.Fprintln(stdout, e)
   430  		}
   431  
   432  	case "succs", "preds":
   433  		if len(args) == 0 {
   434  			return fmt.Errorf("usage: digraph %s <node> ... ", cmd)
   435  		}
   436  		g := g
   437  		if cmd == "preds" {
   438  			g = g.transpose()
   439  		}
   440  		result := make(nodeset)
   441  		for _, root := range args {
   442  			edges := g[root]
   443  			if edges == nil {
   444  				return fmt.Errorf("no such node %q", root)
   445  			}
   446  			result.addAll(edges)
   447  		}
   448  		result.sort().println("\n")
   449  
   450  	case "forward", "reverse":
   451  		if len(args) == 0 {
   452  			return fmt.Errorf("usage: digraph %s <node> ... ", cmd)
   453  		}
   454  		roots := make(nodeset)
   455  		for _, root := range args {
   456  			if g[root] == nil {
   457  				return fmt.Errorf("no such node %q", root)
   458  			}
   459  			roots[root] = true
   460  		}
   461  		g := g
   462  		if cmd == "reverse" {
   463  			g = g.transpose()
   464  		}
   465  		g.reachableFrom(roots).sort().println("\n")
   466  
   467  	case "somepath":
   468  		if len(args) != 2 {
   469  			return fmt.Errorf("usage: digraph somepath <from> <to>")
   470  		}
   471  		from, to := args[0], args[1]
   472  		if g[from] == nil {
   473  			return fmt.Errorf("no such 'from' node %q", from)
   474  		}
   475  		if g[to] == nil {
   476  			return fmt.Errorf("no such 'to' node %q", to)
   477  		}
   478  		if err := g.somepath(from, to); err != nil {
   479  			return err
   480  		}
   481  
   482  	case "allpaths":
   483  		if len(args) != 2 {
   484  			return fmt.Errorf("usage: digraph allpaths <from> <to>")
   485  		}
   486  		from, to := args[0], args[1]
   487  		if g[from] == nil {
   488  			return fmt.Errorf("no such 'from' node %q", from)
   489  		}
   490  		if g[to] == nil {
   491  			return fmt.Errorf("no such 'to' node %q", to)
   492  		}
   493  		if err := g.allpaths(from, to); err != nil {
   494  			return err
   495  		}
   496  
   497  	case "sccs":
   498  		if len(args) != 0 {
   499  			return fmt.Errorf("usage: digraph sccs")
   500  		}
   501  		buf := new(bytes.Buffer)
   502  		oldStdout := stdout
   503  		stdout = buf
   504  		for _, scc := range g.sccs() {
   505  			scc.sort().println(" ")
   506  		}
   507  		lines := strings.SplitAfter(buf.String(), "\n")
   508  		sort.Strings(lines)
   509  		stdout = oldStdout
   510  		io.WriteString(stdout, strings.Join(lines, ""))
   511  
   512  	case "scc":
   513  		if len(args) != 1 {
   514  			return fmt.Errorf("usage: digraph scc <node>")
   515  		}
   516  		node := args[0]
   517  		if g[node] == nil {
   518  			return fmt.Errorf("no such node %q", node)
   519  		}
   520  		for _, scc := range g.sccs() {
   521  			if scc[node] {
   522  				scc.sort().println("\n")
   523  				break
   524  			}
   525  		}
   526  
   527  	case "focus":
   528  		if len(args) != 1 {
   529  			return fmt.Errorf("usage: digraph focus <node>")
   530  		}
   531  		node := args[0]
   532  		if g[node] == nil {
   533  			return fmt.Errorf("no such node %q", node)
   534  		}
   535  
   536  		edges := make(map[string]struct{})
   537  		for from := range g.reachableFrom(nodeset{node: true}) {
   538  			for to := range g[from] {
   539  				edges[fmt.Sprintf("%s %s", from, to)] = struct{}{}
   540  			}
   541  		}
   542  
   543  		gtrans := g.transpose()
   544  		for from := range gtrans.reachableFrom(nodeset{node: true}) {
   545  			for to := range gtrans[from] {
   546  				edges[fmt.Sprintf("%s %s", to, from)] = struct{}{}
   547  			}
   548  		}
   549  
   550  		edgesSorted := make([]string, 0, len(edges))
   551  		for e := range edges {
   552  			edgesSorted = append(edgesSorted, e)
   553  		}
   554  		sort.Strings(edgesSorted)
   555  		fmt.Fprintln(stdout, strings.Join(edgesSorted, "\n"))
   556  
   557  	default:
   558  		return fmt.Errorf("no such command %q", cmd)
   559  	}
   560  
   561  	return nil
   562  }
   563  
   564  // -- Utilities --------------------------------------------------------
   565  
   566  // split splits a line into words, which are generally separated by
   567  // spaces, but Go-style double-quoted string literals are also supported.
   568  // (This approximates the behaviour of the Bourne shell.)
   569  //
   570  //	`one "two three"` -> ["one" "two three"]
   571  //	`a"\n"b` -> ["a\nb"]
   572  func split(line string) ([]string, error) {
   573  	var (
   574  		words   []string
   575  		inWord  bool
   576  		current bytes.Buffer
   577  	)
   578  
   579  	for len(line) > 0 {
   580  		r, size := utf8.DecodeRuneInString(line)
   581  		if unicode.IsSpace(r) {
   582  			if inWord {
   583  				words = append(words, current.String())
   584  				current.Reset()
   585  				inWord = false
   586  			}
   587  		} else if r == '"' {
   588  			var ok bool
   589  			size, ok = quotedLength(line)
   590  			if !ok {
   591  				return nil, errors.New("invalid quotation")
   592  			}
   593  			s, err := strconv.Unquote(line[:size])
   594  			if err != nil {
   595  				return nil, err
   596  			}
   597  			current.WriteString(s)
   598  			inWord = true
   599  		} else {
   600  			current.WriteRune(r)
   601  			inWord = true
   602  		}
   603  		line = line[size:]
   604  	}
   605  	if inWord {
   606  		words = append(words, current.String())
   607  	}
   608  	return words, nil
   609  }
   610  
   611  // quotedLength returns the length in bytes of the prefix of input that
   612  // contain a possibly-valid double-quoted Go string literal.
   613  //
   614  // On success, n is at least two (""); input[:n] may be passed to
   615  // strconv.Unquote to interpret its value, and input[n:] contains the
   616  // rest of the input.
   617  //
   618  // On failure, quotedLength returns false, and the entire input can be
   619  // passed to strconv.Unquote if an informative error message is desired.
   620  //
   621  // quotedLength does not and need not detect all errors, such as
   622  // invalid hex or octal escape sequences, since it assumes
   623  // strconv.Unquote will be applied to the prefix.  It guarantees only
   624  // that if there is a prefix of input containing a valid string literal,
   625  // its length is returned.
   626  //
   627  // TODO(adonovan): move this into a strconv-like utility package.
   628  func quotedLength(input string) (n int, ok bool) {
   629  	var offset int
   630  
   631  	// next returns the rune at offset, or -1 on EOF.
   632  	// offset advances to just after that rune.
   633  	next := func() rune {
   634  		if offset < len(input) {
   635  			r, size := utf8.DecodeRuneInString(input[offset:])
   636  			offset += size
   637  			return r
   638  		}
   639  		return -1
   640  	}
   641  
   642  	if next() != '"' {
   643  		return // error: not a quotation
   644  	}
   645  
   646  	for {
   647  		r := next()
   648  		if r == '\n' || r < 0 {
   649  			return // error: string literal not terminated
   650  		}
   651  		if r == '"' {
   652  			return offset, true // success
   653  		}
   654  		if r == '\\' {
   655  			var skip int
   656  			switch next() {
   657  			case 'a', 'b', 'f', 'n', 'r', 't', 'v', '\\', '"':
   658  				skip = 0
   659  			case '0', '1', '2', '3', '4', '5', '6', '7':
   660  				skip = 2
   661  			case 'x':
   662  				skip = 2
   663  			case 'u':
   664  				skip = 4
   665  			case 'U':
   666  				skip = 8
   667  			default:
   668  				return // error: invalid escape
   669  			}
   670  
   671  			for i := 0; i < skip; i++ {
   672  				next()
   673  			}
   674  		}
   675  	}
   676  }