golang.org/x/tools@v0.21.1-0.20240520172518-788d39e776b1/cmd/digraph/digraph.go (about) 1 // Copyright 2019 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 package main // import "golang.org/x/tools/cmd/digraph" 5 6 // TODO(adonovan): 7 // - support input files other than stdin 8 // - support alternative formats (AT&T GraphViz, CSV, etc), 9 // a comment syntax, etc. 10 // - allow queries to nest, like Blaze query language. 11 12 import ( 13 "bufio" 14 "bytes" 15 _ "embed" 16 "errors" 17 "flag" 18 "fmt" 19 "io" 20 "os" 21 "sort" 22 "strconv" 23 "strings" 24 "unicode" 25 "unicode/utf8" 26 ) 27 28 func usage() { 29 // Extract the content of the /* ... */ comment in doc.go. 30 _, after, _ := strings.Cut(doc, "/*") 31 doc, _, _ := strings.Cut(after, "*/") 32 io.WriteString(flag.CommandLine.Output(), doc) 33 flag.PrintDefaults() 34 35 os.Exit(2) 36 } 37 38 //go:embed doc.go 39 var doc string 40 41 func main() { 42 flag.Usage = usage 43 flag.Parse() 44 45 args := flag.Args() 46 if len(args) == 0 { 47 usage() 48 } 49 50 if err := digraph(args[0], args[1:]); err != nil { 51 fmt.Fprintf(os.Stderr, "digraph: %s\n", err) 52 os.Exit(1) 53 } 54 } 55 56 type nodelist []string 57 58 func (l nodelist) println(sep string) { 59 for i, node := range l { 60 if i > 0 { 61 fmt.Fprint(stdout, sep) 62 } 63 fmt.Fprint(stdout, node) 64 } 65 fmt.Fprintln(stdout) 66 } 67 68 type nodeset map[string]bool 69 70 func (s nodeset) sort() nodelist { 71 nodes := make(nodelist, len(s)) 72 var i int 73 for node := range s { 74 nodes[i] = node 75 i++ 76 } 77 sort.Strings(nodes) 78 return nodes 79 } 80 81 func (s nodeset) addAll(x nodeset) { 82 for node := range x { 83 s[node] = true 84 } 85 } 86 87 // A graph maps nodes to the non-nil set of their immediate successors. 88 type graph map[string]nodeset 89 90 func (g graph) addNode(node string) nodeset { 91 edges := g[node] 92 if edges == nil { 93 edges = make(nodeset) 94 g[node] = edges 95 } 96 return edges 97 } 98 99 func (g graph) addEdges(from string, to ...string) { 100 edges := g.addNode(from) 101 for _, to := range to { 102 g.addNode(to) 103 edges[to] = true 104 } 105 } 106 107 func (g graph) nodelist() nodelist { 108 nodes := make(nodeset) 109 for node := range g { 110 nodes[node] = true 111 } 112 return nodes.sort() 113 } 114 115 func (g graph) reachableFrom(roots nodeset) nodeset { 116 seen := make(nodeset) 117 var visit func(node string) 118 visit = func(node string) { 119 if !seen[node] { 120 seen[node] = true 121 for e := range g[node] { 122 visit(e) 123 } 124 } 125 } 126 for root := range roots { 127 visit(root) 128 } 129 return seen 130 } 131 132 func (g graph) transpose() graph { 133 rev := make(graph) 134 for node, edges := range g { 135 rev.addNode(node) 136 for succ := range edges { 137 rev.addEdges(succ, node) 138 } 139 } 140 return rev 141 } 142 143 func (g graph) sccs() []nodeset { 144 // Kosaraju's algorithm---Tarjan is overkill here. 145 146 // Forward pass. 147 S := make(nodelist, 0, len(g)) // postorder stack 148 seen := make(nodeset) 149 var visit func(node string) 150 visit = func(node string) { 151 if !seen[node] { 152 seen[node] = true 153 for e := range g[node] { 154 visit(e) 155 } 156 S = append(S, node) 157 } 158 } 159 for node := range g { 160 visit(node) 161 } 162 163 // Reverse pass. 164 rev := g.transpose() 165 var scc nodeset 166 seen = make(nodeset) 167 var rvisit func(node string) 168 rvisit = func(node string) { 169 if !seen[node] { 170 seen[node] = true 171 scc[node] = true 172 for e := range rev[node] { 173 rvisit(e) 174 } 175 } 176 } 177 var sccs []nodeset 178 for len(S) > 0 { 179 top := S[len(S)-1] 180 S = S[:len(S)-1] // pop 181 if !seen[top] { 182 scc = make(nodeset) 183 rvisit(top) 184 if len(scc) == 1 && !g[top][top] { 185 continue 186 } 187 sccs = append(sccs, scc) 188 } 189 } 190 return sccs 191 } 192 193 func (g graph) allpaths(from, to string) error { 194 // Mark all nodes to "to". 195 seen := make(nodeset) // value of seen[x] indicates whether x is on some path to "to" 196 var visit func(node string) bool 197 visit = func(node string) bool { 198 reachesTo, ok := seen[node] 199 if !ok { 200 reachesTo = node == to 201 seen[node] = reachesTo 202 for e := range g[node] { 203 if visit(e) { 204 reachesTo = true 205 } 206 } 207 if reachesTo && node != to { 208 seen[node] = true 209 } 210 } 211 return reachesTo 212 } 213 visit(from) 214 215 // For each marked node, collect its marked successors. 216 var edges []string 217 for n := range seen { 218 for succ := range g[n] { 219 if seen[succ] { 220 edges = append(edges, n+" "+succ) 221 } 222 } 223 } 224 225 // Sort (so that this method is deterministic) and print edges. 226 sort.Strings(edges) 227 for _, e := range edges { 228 fmt.Fprintln(stdout, e) 229 } 230 231 return nil 232 } 233 234 func (g graph) somepath(from, to string) error { 235 // Search breadth-first so that we return a minimal path. 236 237 // A path is a linked list whose head is a candidate "to" node 238 // and whose tail is the path ending in the "from" node. 239 type path struct { 240 node string 241 tail *path 242 } 243 244 seen := nodeset{from: true} 245 246 var queue []*path 247 queue = append(queue, &path{node: from, tail: nil}) 248 for len(queue) > 0 { 249 p := queue[0] 250 queue = queue[1:] 251 252 if p.node == to { 253 // Found a path. Print, tail first. 254 var print func(p *path) 255 print = func(p *path) { 256 if p.tail != nil { 257 print(p.tail) 258 fmt.Fprintln(stdout, p.tail.node+" "+p.node) 259 } 260 } 261 print(p) 262 return nil 263 } 264 265 for succ := range g[p.node] { 266 if !seen[succ] { 267 seen[succ] = true 268 queue = append(queue, &path{node: succ, tail: p}) 269 } 270 } 271 } 272 return fmt.Errorf("no path from %q to %q", from, to) 273 } 274 275 func (g graph) toDot(w *bytes.Buffer) { 276 fmt.Fprintln(w, "digraph {") 277 for _, src := range g.nodelist() { 278 for _, dst := range g[src].sort() { 279 // Dot's quoting rules appear to align with Go's for escString, 280 // which is the syntax of node IDs. Labels require significantly 281 // more quoting, but that appears not to be necessary if the node ID 282 // is implicitly used as the label. 283 fmt.Fprintf(w, "\t%q -> %q;\n", src, dst) 284 } 285 } 286 fmt.Fprintln(w, "}") 287 } 288 289 func parse(rd io.Reader) (graph, error) { 290 g := make(graph) 291 292 var linenum int 293 // We avoid bufio.Scanner as it imposes a (configurable) limit 294 // on line length, whereas Reader.ReadString does not. 295 in := bufio.NewReader(rd) 296 for { 297 linenum++ 298 line, err := in.ReadString('\n') 299 eof := false 300 if err == io.EOF { 301 eof = true 302 } else if err != nil { 303 return nil, err 304 } 305 // Split into words, honoring double-quotes per Go spec. 306 words, err := split(line) 307 if err != nil { 308 return nil, fmt.Errorf("at line %d: %v", linenum, err) 309 } 310 if len(words) > 0 { 311 g.addEdges(words[0], words[1:]...) 312 } 313 if eof { 314 break 315 } 316 } 317 return g, nil 318 } 319 320 // Overridable for redirection. 321 var stdin io.Reader = os.Stdin 322 var stdout io.Writer = os.Stdout 323 324 func digraph(cmd string, args []string) error { 325 // Parse the input graph. 326 g, err := parse(stdin) 327 if err != nil { 328 return err 329 } 330 331 // Parse the command line. 332 switch cmd { 333 case "nodes": 334 if len(args) != 0 { 335 return fmt.Errorf("usage: digraph nodes") 336 } 337 g.nodelist().println("\n") 338 339 case "degree": 340 if len(args) != 0 { 341 return fmt.Errorf("usage: digraph degree") 342 } 343 nodes := make(nodeset) 344 for node := range g { 345 nodes[node] = true 346 } 347 rev := g.transpose() 348 for _, node := range nodes.sort() { 349 fmt.Fprintf(stdout, "%d\t%d\t%s\n", len(rev[node]), len(g[node]), node) 350 } 351 352 case "transpose": 353 if len(args) != 0 { 354 return fmt.Errorf("usage: digraph transpose") 355 } 356 var revEdges []string 357 for node, succs := range g.transpose() { 358 for succ := range succs { 359 revEdges = append(revEdges, fmt.Sprintf("%s %s", node, succ)) 360 } 361 } 362 sort.Strings(revEdges) // make output deterministic 363 for _, e := range revEdges { 364 fmt.Fprintln(stdout, e) 365 } 366 367 case "succs", "preds": 368 if len(args) == 0 { 369 return fmt.Errorf("usage: digraph %s <node> ... ", cmd) 370 } 371 g := g 372 if cmd == "preds" { 373 g = g.transpose() 374 } 375 result := make(nodeset) 376 for _, root := range args { 377 edges := g[root] 378 if edges == nil { 379 return fmt.Errorf("no such node %q", root) 380 } 381 result.addAll(edges) 382 } 383 result.sort().println("\n") 384 385 case "forward", "reverse": 386 if len(args) == 0 { 387 return fmt.Errorf("usage: digraph %s <node> ... ", cmd) 388 } 389 roots := make(nodeset) 390 for _, root := range args { 391 if g[root] == nil { 392 return fmt.Errorf("no such node %q", root) 393 } 394 roots[root] = true 395 } 396 g := g 397 if cmd == "reverse" { 398 g = g.transpose() 399 } 400 g.reachableFrom(roots).sort().println("\n") 401 402 case "somepath": 403 if len(args) != 2 { 404 return fmt.Errorf("usage: digraph somepath <from> <to>") 405 } 406 from, to := args[0], args[1] 407 if g[from] == nil { 408 return fmt.Errorf("no such 'from' node %q", from) 409 } 410 if g[to] == nil { 411 return fmt.Errorf("no such 'to' node %q", to) 412 } 413 if err := g.somepath(from, to); err != nil { 414 return err 415 } 416 417 case "allpaths": 418 if len(args) != 2 { 419 return fmt.Errorf("usage: digraph allpaths <from> <to>") 420 } 421 from, to := args[0], args[1] 422 if g[from] == nil { 423 return fmt.Errorf("no such 'from' node %q", from) 424 } 425 if g[to] == nil { 426 return fmt.Errorf("no such 'to' node %q", to) 427 } 428 if err := g.allpaths(from, to); err != nil { 429 return err 430 } 431 432 case "sccs": 433 if len(args) != 0 { 434 return fmt.Errorf("usage: digraph sccs") 435 } 436 buf := new(bytes.Buffer) 437 oldStdout := stdout 438 stdout = buf 439 for _, scc := range g.sccs() { 440 scc.sort().println(" ") 441 } 442 lines := strings.SplitAfter(buf.String(), "\n") 443 sort.Strings(lines) 444 stdout = oldStdout 445 io.WriteString(stdout, strings.Join(lines, "")) 446 447 case "scc": 448 if len(args) != 1 { 449 return fmt.Errorf("usage: digraph scc <node>") 450 } 451 node := args[0] 452 if g[node] == nil { 453 return fmt.Errorf("no such node %q", node) 454 } 455 for _, scc := range g.sccs() { 456 if scc[node] { 457 scc.sort().println("\n") 458 break 459 } 460 } 461 462 case "focus": 463 if len(args) != 1 { 464 return fmt.Errorf("usage: digraph focus <node>") 465 } 466 node := args[0] 467 if g[node] == nil { 468 return fmt.Errorf("no such node %q", node) 469 } 470 471 edges := make(map[string]struct{}) 472 for from := range g.reachableFrom(nodeset{node: true}) { 473 for to := range g[from] { 474 edges[fmt.Sprintf("%s %s", from, to)] = struct{}{} 475 } 476 } 477 478 gtrans := g.transpose() 479 for from := range gtrans.reachableFrom(nodeset{node: true}) { 480 for to := range gtrans[from] { 481 edges[fmt.Sprintf("%s %s", to, from)] = struct{}{} 482 } 483 } 484 485 edgesSorted := make([]string, 0, len(edges)) 486 for e := range edges { 487 edgesSorted = append(edgesSorted, e) 488 } 489 sort.Strings(edgesSorted) 490 fmt.Fprintln(stdout, strings.Join(edgesSorted, "\n")) 491 492 case "to": 493 if len(args) != 1 || args[0] != "dot" { 494 return fmt.Errorf("usage: digraph to dot") 495 } 496 var b bytes.Buffer 497 g.toDot(&b) 498 stdout.Write(b.Bytes()) 499 500 default: 501 return fmt.Errorf("no such command %q", cmd) 502 } 503 504 return nil 505 } 506 507 // -- Utilities -------------------------------------------------------- 508 509 // split splits a line into words, which are generally separated by 510 // spaces, but Go-style double-quoted string literals are also supported. 511 // (This approximates the behaviour of the Bourne shell.) 512 // 513 // `one "two three"` -> ["one" "two three"] 514 // `a"\n"b` -> ["a\nb"] 515 func split(line string) ([]string, error) { 516 var ( 517 words []string 518 inWord bool 519 current bytes.Buffer 520 ) 521 522 for len(line) > 0 { 523 r, size := utf8.DecodeRuneInString(line) 524 if unicode.IsSpace(r) { 525 if inWord { 526 words = append(words, current.String()) 527 current.Reset() 528 inWord = false 529 } 530 } else if r == '"' { 531 var ok bool 532 size, ok = quotedLength(line) 533 if !ok { 534 return nil, errors.New("invalid quotation") 535 } 536 s, err := strconv.Unquote(line[:size]) 537 if err != nil { 538 return nil, err 539 } 540 current.WriteString(s) 541 inWord = true 542 } else { 543 current.WriteRune(r) 544 inWord = true 545 } 546 line = line[size:] 547 } 548 if inWord { 549 words = append(words, current.String()) 550 } 551 return words, nil 552 } 553 554 // quotedLength returns the length in bytes of the prefix of input that 555 // contain a possibly-valid double-quoted Go string literal. 556 // 557 // On success, n is at least two (""); input[:n] may be passed to 558 // strconv.Unquote to interpret its value, and input[n:] contains the 559 // rest of the input. 560 // 561 // On failure, quotedLength returns false, and the entire input can be 562 // passed to strconv.Unquote if an informative error message is desired. 563 // 564 // quotedLength does not and need not detect all errors, such as 565 // invalid hex or octal escape sequences, since it assumes 566 // strconv.Unquote will be applied to the prefix. It guarantees only 567 // that if there is a prefix of input containing a valid string literal, 568 // its length is returned. 569 // 570 // TODO(adonovan): move this into a strconv-like utility package. 571 func quotedLength(input string) (n int, ok bool) { 572 var offset int 573 574 // next returns the rune at offset, or -1 on EOF. 575 // offset advances to just after that rune. 576 next := func() rune { 577 if offset < len(input) { 578 r, size := utf8.DecodeRuneInString(input[offset:]) 579 offset += size 580 return r 581 } 582 return -1 583 } 584 585 if next() != '"' { 586 return // error: not a quotation 587 } 588 589 for { 590 r := next() 591 if r == '\n' || r < 0 { 592 return // error: string literal not terminated 593 } 594 if r == '"' { 595 return offset, true // success 596 } 597 if r == '\\' { 598 var skip int 599 switch next() { 600 case 'a', 'b', 'f', 'n', 'r', 't', 'v', '\\', '"': 601 skip = 0 602 case '0', '1', '2', '3', '4', '5', '6', '7': 603 skip = 2 604 case 'x': 605 skip = 2 606 case 'u': 607 skip = 4 608 case 'U': 609 skip = 8 610 default: 611 return // error: invalid escape 612 } 613 614 for i := 0; i < skip; i++ { 615 next() 616 } 617 } 618 } 619 }