gorgonia.org/gorgonia@v0.9.17/graph.go (about) 1 package gorgonia 2 3 import ( 4 "bytes" 5 "fmt" 6 7 "github.com/awalterschulze/gographviz" 8 "gonum.org/v1/gonum/graph" 9 "gonum.org/v1/gonum/graph/iterator" 10 ) 11 12 // ExprGraph is a data structure for a directed acyclic graph (of expressions). This structure is the main entry point 13 // for Gorgonia. 14 type ExprGraph struct { 15 name string 16 17 all Nodes 18 19 byID map[int64]int 20 byHash map[uint32]*Node 21 evac map[uint32]Nodes 22 to map[*Node]Nodes 23 24 leaves Nodes 25 constants Nodes 26 roots Nodes 27 counter uint 28 } 29 30 // graphconopt sets options 31 type graphconopt func(g *ExprGraph) 32 33 // WithGraphName is a ExprGraph construction option that provides a name. 34 func WithGraphName(name string) graphconopt { 35 f := func(g *ExprGraph) { 36 g.name = name 37 } 38 return f 39 } 40 41 // NewGraph creates a new graph. Duh 42 func NewGraph(opts ...graphconopt) *ExprGraph { 43 g := &ExprGraph{ 44 byID: make(map[int64]int), 45 byHash: make(map[uint32]*Node), 46 evac: make(map[uint32]Nodes), 47 to: make(map[*Node]Nodes), 48 49 leaves: make(Nodes, 0, 64), 50 constants: make(Nodes, 0, 8), 51 } 52 53 for _, opt := range opts { 54 opt(g) 55 } 56 57 return g 58 } 59 60 // Clone clones the graph. All nodes gets cloned, and their values are cloned as well. 61 func (g *ExprGraph) Clone() interface{} { 62 g2 := new(ExprGraph) 63 g2.name = g.name 64 65 mapping := make(map[*Node]*Node) // a map of old nodes to new nodes 66 g2.all = make(Nodes, len(g.all)) 67 for i, n := range g.all { 68 cloned := n.Clone().(*Node) 69 cloned.g = g2 70 cloned.id = n.id 71 72 g2.all[i] = cloned 73 mapping[n] = cloned 74 } 75 76 // handle each node's children, deriv ofs, etc 77 for i, n := range g.all { 78 cloned := g2.all[i] 79 cloned.children = make(Nodes, len(n.children)) 80 for j, c := range n.children { 81 cloned.children[j] = mapping[c] 82 } 83 84 cloned.derivOf = make(Nodes, len(n.derivOf)) 85 for j, c := range n.derivOf { 86 cloned.derivOf[j] = mapping[c] 87 } 88 89 if n.deriv != nil { 90 cloned.deriv = mapping[n.deriv] 91 } 92 } 93 94 g2.byID = make(map[int64]int) 95 g2.byHash = make(map[uint32]*Node) 96 for k, v := range g.byHash { 97 g2.byHash[k] = mapping[v] 98 } 99 100 g2.evac = make(map[uint32]Nodes) 101 for k, v := range g.evac { 102 g2.evac[k] = make(Nodes, len(v)) 103 for i, n := range v { 104 g2.evac[k][i] = mapping[n] 105 } 106 } 107 108 g2.to = make(map[*Node]Nodes) 109 for k, v := range g.to { 110 to := mapping[k] 111 g2.to[to] = make(Nodes, len(v)) 112 for i, n := range v { 113 g2.to[to][i] = mapping[n] 114 } 115 } 116 117 g2.leaves = make(Nodes, len(g.leaves)) 118 for i, n := range g.leaves { 119 g2.leaves[i] = mapping[n] 120 } 121 122 g2.constants = make(Nodes, len(g.constants)) 123 for i, n := range g.constants { 124 g2.constants[i] = mapping[n] 125 } 126 127 g2.roots = make(Nodes, len(g.roots)) 128 for i, n := range g.roots { 129 g2.roots[i] = mapping[n] 130 } 131 132 g2.counter = g.counter 133 return g2 134 } 135 136 // AddNode adds n to the graph. It panics if the added node ID matches an existing node ID. 137 func (g *ExprGraph) AddNode(n *Node) (retVal *Node) { 138 defer func() { 139 if _, ok := g.to[retVal]; !ok { 140 g.to[retVal] = nil 141 } 142 }() 143 // check for node with the same name in the graph 144 // we don't update the graph if this is the case 145 for _, node := range g.constants { 146 if node.name == n.name && n.isConstant() { 147 return node 148 } 149 } 150 hash := n.Hashcode() 151 if existing, ok := g.byHash[hash]; ok { 152 if existing == nil { 153 // this means that there has been previous collisions 154 // so look at evac map 155 for _, e := range g.evac[hash] { 156 if nodeEq(n, e) { 157 return e 158 } 159 } 160 g.evac[hash] = append(g.evac[hash], n) 161 g.addToAll(n) 162 incrCC() // collision counter 163 return n 164 } 165 166 if !nodeEq(n, existing) { 167 g.evac[hash] = Nodes{existing, n} 168 g.byHash[hash] = nil // to signal that it's collided 169 g.addToAll(n) 170 incrCC() 171 return n 172 } 173 incrEC() // expected collision (they're the same node!) 174 return existing 175 } 176 177 if n.isConstant() { 178 n = n.clone() 179 g.constants = g.constants.Add(n) 180 n.g = g 181 } 182 183 g.addToAll(n) 184 g.byHash[hash] = n 185 return n 186 } 187 188 func (g *ExprGraph) addToAll(n *Node) { 189 if n == nil { 190 panic("HELP! trying to add nil") 191 } 192 g.all = append(g.all, n) 193 n.id = int64(g.counter) 194 g.counter++ 195 } 196 197 // RemoveNode removes n from the graph, as well as any edges attached to it. If the node 198 // is not in the graph it is a no-op. 199 func (g *ExprGraph) RemoveNode(node graph.Node) { 200 n := node.(*Node) 201 if n.id == -1 { 202 return // if it's -1, it was never in the graph to begin with 203 } 204 205 hash := n.Hashcode() 206 207 delete(g.byHash, hash) 208 delete(g.to, n) 209 g.evac[hash] = g.evac[hash].remove(n) 210 g.all = g.all.remove(n) 211 } 212 213 // SetEdge adds e, an edge from one node to another. If the nodes do not exist, they are added. 214 // It will panic if the IDs of the e.From and e.To are equal. 215 func (g *ExprGraph) SetEdge(e graph.Edge) { 216 from := e.From().(*Node) 217 to := e.To().(*Node) 218 219 if from == to { 220 panic(fmt.Sprintf("cannot add self edge: from %v to %v", from, to)) 221 } 222 223 if !g.Has(from.ID()) { 224 from = g.AddNode(from) 225 } 226 227 if !g.Has(to.ID()) { 228 to = g.AddNode(to) 229 } 230 231 // g.to[to] = g.to[to].Add(from) 232 g.to[to] = append(g.to[to], from) 233 } 234 235 // Roots returns a list of nodes that are not children of any other nodes 236 func (g *ExprGraph) Roots() (retVal Nodes) { 237 // handle subgraph 238 if g.roots != nil { 239 return g.roots 240 } 241 242 for n, tos := range g.to { 243 if len(tos) == 0 { 244 retVal = append(retVal, n) 245 } 246 // if the root is a statement (typically a read), and it only has one child 247 if len(n.children) == 1 && n.isStmt { 248 child := n.children[0] 249 if len(g.to[child]) == 1 { 250 retVal = append(retVal, child) 251 } 252 } 253 } 254 g.roots = retVal 255 return retVal 256 } 257 258 // Inputs returns a list of nodes which are inputs (that is to say, the user is required to set a value in it) 259 func (g *ExprGraph) Inputs() (retVal Nodes) { 260 for _, n := range g.all { 261 if n.isInput() { 262 retVal = append(retVal, n) 263 } 264 } 265 return 266 } 267 268 // UnbindAll unbinds all the values from the nodes 269 func (g *ExprGraph) UnbindAll() { 270 for _, n := range g.all { 271 n.unbind() 272 } 273 } 274 275 // UnbindAllNonInputs unbinds all the values from nodes that aren't input nodes 276 func (g *ExprGraph) UnbindAllNonInputs() { 277 for _, n := range g.all { 278 if n.isInput() || n.isConstant() { 279 continue 280 } 281 n.unbind() 282 } 283 } 284 285 // ByName returns nodes that have the name provided. 286 // Bear in mind that the name that is compared to is the internal name, 287 // not the result of calling node.Name(). The reason for doing this is 288 // for ease of finding only names that are user-supplied, instead of autogenerated names 289 func (g *ExprGraph) ByName(name string) (retVal Nodes) { 290 for _, n := range g.all { 291 if n.name == name { 292 retVal = append(retVal, n) 293 } 294 } 295 return 296 } 297 298 // Constant returns a constant that may be found in the graph. If no constant were found, a new one is created instead 299 func (g *ExprGraph) Constant(v Value) *Node { 300 for _, n := range g.constants { 301 if ValueEq(n.Value(), v) { 302 return n 303 } 304 } 305 306 n := NewConstant(v) 307 return g.AddNode(n) 308 } 309 310 func (g *ExprGraph) String() string { 311 var buf bytes.Buffer 312 buf.WriteString("Graph: [\n") 313 for _, n := range g.byHash { 314 fmt.Fprintf(&buf, "\t%d: %s\n", n.Hashcode(), n) 315 } 316 buf.WriteString("]") 317 return buf.String() 318 } 319 320 // ToDot generates the graph in graphviz format. The use of this is to generate for the entire graph 321 // which may have multiple trees with different roots 322 // TODO: This is getting unwieldy. Perhaps refactor out into a ToDot(...Opt)? 323 func (g *ExprGraph) ToDot() string { 324 gv := gographviz.NewEscape() 325 gv.SetName(fullGraphName) 326 gv.SetDir(true) 327 328 gv.AddAttr(fullGraphName, "nodesep", "1") 329 gv.AddAttr(fullGraphName, "ranksep", "1.5 equally") 330 gv.AddAttr(fullGraphName, "rankdir", "TB") 331 if len(g.byHash) > 100 { 332 gv.AddAttr(fullGraphName, "nslimit", "3") // numiter=3*len(nodes) 333 // gv.AddAttr(fullGraphName, "splines", "line") // ugly as sin. 334 } 335 336 groups := make(map[string]struct{}) 337 for h, n := range g.byHash { 338 if n != nil { 339 group := n.dotCluster() 340 groups[group] = struct{}{} 341 continue 342 } 343 // other wise it'se a clash of hash 344 for _, n := range g.evac[h] { 345 group := n.dotCluster() 346 groups[group] = struct{}{} 347 348 } 349 } 350 351 for grp := range groups { 352 attrs := map[string]string{"label": grp} 353 354 parentGraph := fullGraphName 355 if grp == inputsClust || grp == constantsClust { 356 parentGraph = inputConsts 357 if !gv.IsSubGraph(inputConsts) { 358 groupAttrs := map[string]string{"rank": "max"} 359 gv.AddSubGraph(fullGraphName, inputConsts, groupAttrs) 360 } 361 } 362 gv.AddSubGraph(parentGraph, "cluster_"+grp, attrs) 363 } 364 365 // for _, n := range g.byHash { 366 for _, n := range g.all { 367 group := n.dotCluster() 368 n.dotString(gv, "cluster_"+group) 369 } 370 371 // for _, from := range g.byHash { 372 for _, from := range g.all { 373 for i, child := range from.children { 374 if ok := g.all.Contains(child); !ok { 375 // not in graph, so ignore it... 376 continue 377 } 378 fromID := fmt.Sprintf("Node_%p", from) 379 toID := fmt.Sprintf("Node_%p", child) 380 381 edgeAttrs := map[string]string{ 382 "taillabel": fmt.Sprintf(" %d ", i), 383 "labelfloat": "false", 384 } 385 386 // we invert the from and to nodes for gradients, As the expressionGraph builds upwards from bottom, the gradient builds downwards. 387 if from.group == gradClust && child.group == gradClust { 388 edgeAttrs["dir"] = "back" 389 gv.AddPortEdge(toID, toID+":anchor:s", fromID, fromID+":anchor:n", true, edgeAttrs) 390 } else { 391 gv.AddPortEdge(fromID, fromID+":anchor:s", toID, toID+":anchor:n", true, edgeAttrs) 392 } 393 } 394 } 395 396 // draw deriv lines 397 if debugDerives { 398 edgeAttrs := map[string]string{ 399 "style": "dashed", 400 "constraint": "false", 401 "weight": "999", 402 } 403 404 for _, n := range g.byHash { 405 if n == nil { 406 // collision found... what to do? 407 continue 408 } 409 if n.derivOf != nil { 410 id := fmt.Sprintf("Node_%p", n) 411 for _, derivOf := range n.derivOf { 412 if _, ok := g.to[derivOf]; !ok { 413 continue 414 } 415 ofID := fmt.Sprintf("Node_%p", derivOf) 416 // gv.AddPortEdge(id, ":anchor:w", ofID, ofID+":anchor:e", true, edgeAttrs) 417 gv.AddEdge(id, ofID, true, edgeAttrs) 418 } 419 } 420 } 421 } 422 423 // stupid invisible nodes to keep expressiongraph on the left 424 subGAttrs := make(map[string]string) 425 // subGAttrs.Add("rank", "max") 426 gv.AddSubGraph(fullGraphName, outsideSubG, subGAttrs) 427 428 attrs := map[string]string{ 429 "style": "invis", 430 } 431 gv.AddNode(outsideSubG, outsideRoot, attrs) 432 433 outsides := []string{outsideRoot} 434 var insides []string 435 436 // build the inside and outside list 437 if _, hasInputs := groups[inputsClust]; hasInputs { 438 insides = append(insides, insideInputs) 439 gv.AddNode("cluster_inputs", insideInputs, attrs) 440 } 441 442 if _, hasConst := groups[constantsClust]; hasConst { 443 if len(insides) > 0 { 444 outsides = append(outsides, outsideConsts) 445 gv.AddNode(outsideSubG, outsideConsts, attrs) 446 } 447 insides = append(insides, insideConsts) 448 gv.AddNode("cluster_constants", insideConsts, attrs) 449 } 450 451 if len(insides) > 0 { 452 outsides = append(outsides, outsideExprG) 453 gv.AddNode(outsideSubG, outsideExprG, attrs) 454 } 455 insides = append(insides, insideExprG) 456 gv.AddNode("cluster_expressionGraph", insideExprG, attrs) 457 458 for group := range groups { 459 if group == exprgraphClust || group == constantsClust || group == inputsClust { 460 continue 461 } 462 inside := "inside_" + group 463 outside := "outside_" + group 464 insides = append(insides, inside) 465 outsides = append(outsides, outside) 466 467 gv.AddNode(outsideSubG, outside, attrs) 468 gv.AddNode("cluster_"+group, inside, attrs) 469 } 470 471 edgeAttrs := map[string]string{ 472 "style": "invis", 473 "weight": "999", 474 "constraint": "false", 475 } 476 for i, o := range outsides { 477 // outside-inside 478 gv.AddEdge(o, insides[i], true, edgeAttrs) 479 480 if i > 0 { 481 // outside-outside 482 gv.AddEdge(outsides[i-1], o, true, edgeAttrs) 483 484 // inside-inside 485 gv.AddEdge(insides[i-1], insides[i], true, edgeAttrs) 486 } 487 } 488 return gv.String() 489 } 490 491 // Edges returns all the edges in the graph. 492 func (g *ExprGraph) Edges() graph.Edges { 493 var edges []graph.Edge 494 for _, n := range g.all { 495 for _, toN := range g.to[n] { 496 edges = append(edges, edge{ 497 from: n, 498 to: toN, 499 }) 500 } 501 } 502 if len(edges) == 0 { 503 return graph.Empty 504 } 505 return iterator.NewOrderedEdges(edges) 506 } 507 508 // other private methods 509 func (g *ExprGraph) removeAllEdgesFrom(n *Node) { 510 for k, ns := range g.to { 511 g.to[k] = ns.remove(n) 512 } 513 } 514 515 /* Graph interface */ 516 517 // Node returns the node in the graph with the given ID. 518 func (g *ExprGraph) Node(id int64) graph.Node { 519 // n := (*Node)(unsafe.Pointer(uintptr(id))) 520 // for _, n := range g.all { 521 // if n.id == id { 522 // return n 523 // } 524 // } 525 // return nil 526 return g.node(id) 527 } 528 529 func (g *ExprGraph) node(id int64) *Node { 530 if idx, ok := g.byID[id]; ok { 531 if idx >= len(g.all) { 532 return nil 533 } 534 return g.all[idx] 535 } 536 for i, n := range g.all { 537 if n.id == id { 538 g.byID[id] = i 539 return n 540 } 541 } 542 return nil 543 } 544 545 // Has returns whether the node exists within the graph. 546 func (g *ExprGraph) Has(nodeid int64) bool { 547 n := g.node(nodeid) 548 return n != nil 549 } 550 551 // Nodes returns all the nodes in the graph. 552 func (g *ExprGraph) Nodes() graph.Nodes { 553 // nodes := make([]graph.Node, len(g.from)) 554 ns := g.AllNodes() 555 556 return nodeToGraphNode(ns) 557 } 558 559 // AllNodes is like Nodes, but returns Nodes instead of []graph.Node. 560 // Nodes() has been reserved for the graph.Directed interface, so this one is named AllNodes instead 561 func (g *ExprGraph) AllNodes() Nodes { return g.all } 562 563 // From returns all nodes in g that can be reached directly from n. 564 func (g *ExprGraph) From(nodeid int64) graph.Nodes { 565 if n := g.node(nodeid); n != nil { 566 return nodeToGraphNode(n.children) 567 } 568 return nil 569 } 570 571 // HasEdgeBetween returns whether an edge exists between nodes x and y without 572 // considering direction. 573 func (g *ExprGraph) HasEdgeBetween(x, y int64) bool { 574 xid := g.node(x) 575 yid := g.node(y) 576 if xid == nil || yid == nil { 577 return false 578 } 579 580 return xid.children.Contains(yid) || yid.children.Contains(xid) 581 } 582 583 // Edge returns the edge from u to v if such an edge exists and nil otherwise. 584 // The node v must be directly reachable from u as defined by the From method. 585 func (g *ExprGraph) Edge(u, v int64) graph.Edge { 586 uid := g.node(u) 587 vid := g.node(v) 588 589 if uid == nil || vid == nil { 590 return nil 591 } 592 593 if !uid.children.Contains(vid) { 594 return nil 595 } 596 e := edge{from: uid, to: vid} 597 return e 598 } 599 600 /* Directed interface */ 601 602 // HasEdgeFromTo returns whether an edge exists in the graph from u to v. 603 func (g *ExprGraph) HasEdgeFromTo(u, v int64) bool { 604 uid := g.node(u) 605 vid := g.node(v) 606 if uid == nil || vid == nil { 607 return false 608 } 609 610 return uid.children.Contains(vid) 611 } 612 613 // To returns all nodes in g that can reach directly to n. 614 func (g *ExprGraph) To(nid int64) graph.Nodes { 615 n := g.node(nid) 616 if n == nil { 617 return nil 618 } 619 620 ns := g.to[n] 621 ns = ns.Set() 622 g.to[n] = ns 623 return nodeToGraphNode(ns) 624 } 625 626 // subgraph is basically a subset of nodes. This is useful for compiling sub sections of the graph 627 func (g *ExprGraph) subgraph(ns Nodes, findMissing bool, opts ...Nodes) *ExprGraph { 628 // ns = ns.Set() 629 630 var roots Nodes 631 // add missing stuff first 632 if findMissing { 633 for _, n := range ns { 634 for _, parent := range g.to[n] { 635 if parent.isStmt { 636 roots = append(roots, parent) 637 ns = append(ns, parent) 638 } 639 } 640 } 641 } 642 643 // uniquify the froms and at the same time build a new roots 644 allset := ns.mapSet() 645 if len(opts) == 0 { 646 for _, n := range ns { 647 if len(g.to[n]) == 0 { 648 if n.isStmt { 649 roots = append(roots, n.children[0]) 650 } else { 651 roots = append(roots, n) 652 } 653 continue 654 } 655 656 var hasParent bool 657 for _, parent := range g.to[n] { 658 if allset.Contains(parent) { 659 hasParent = true 660 break 661 } 662 } 663 if !hasParent { 664 roots = append(roots, n) 665 } 666 } 667 } else { 668 rs := opts[0] 669 roots = make(Nodes, len(rs)) 670 for i, n := range rs { 671 if n.isStmt { 672 roots[i] = n.children[0] 673 continue 674 } 675 roots[i] = n 676 677 } 678 } 679 var leaves Nodes 680 for _, n := range ns { 681 if len(n.children) == 0 { 682 leaves = append(leaves, n) 683 } 684 } 685 686 // uniquify all the things 687 roots = roots.Set() 688 leaves = leaves.Set() 689 ns = ns.Set() 690 691 retVal := &ExprGraph{ 692 all: ns, 693 byID: make(map[int64]int), 694 byHash: g.byHash, 695 evac: g.evac, 696 to: g.to, 697 698 leaves: leaves, 699 constants: g.constants, 700 roots: roots, 701 } 702 703 return retVal 704 } 705 706 // Subgraph subsets a graph. This function has overloaded meanings - If only one node is passed in, it assumes that the one node is the root, 707 // otherwise, it treats ns as the subset of nodes to be included in the subgraph 708 func (g *ExprGraph) Subgraph(ns ...*Node) *ExprGraph { 709 if len(ns) == 1 { 710 g.SubgraphRoots(ns[0]) 711 } 712 return g.subgraph(ns, true) 713 } 714 715 // SubgraphRoots creates a subgraph, assuming the provided nodes are roots to the new subgraph. 716 func (g *ExprGraph) SubgraphRoots(ns ...*Node) *ExprGraph { 717 sub := g.walkFromRoots(ns...) 718 return g.subgraph(sub, true, ns) 719 } 720 721 // ExactSubgraphRoots creates a subgraph from the roots provided. 722 // The difference between SubgraphRoots and ExactSubgraphRoots is that ExactSubGraphRoots 723 // will not attempt to discover if any nodes are missing. 724 // 725 // Given a function like the following: 726 // z = x + y 727 // set(x, -x.Grad) // setting the value of x to the negative of the gradient 728 // 729 // When SubgraphRoots is used on z, the `-x.Grad` will be included. 730 // When using ExactSubgraphRoots, only `x` and `y` are included in the subgraph 731 func (g *ExprGraph) ExactSubgraphRoots(ns ...*Node) *ExprGraph { 732 sub := g.walkFromRoots(ns...) 733 return g.subgraph(sub, false, ns) 734 } 735 736 func (g *ExprGraph) walkFromRoots(ns ...*Node) Nodes { 737 sub := make(Nodes, len(ns)) 738 copy(sub, ns) 739 740 walked := NewNodeSet() 741 for _, n := range ns { 742 ch := make(chan *Node) 743 go func(ch chan *Node) { 744 defer close(ch) 745 walkGraph(n, ch, walked) 746 }(ch) 747 748 for node := range ch { 749 sub = append(sub, node) 750 } 751 } 752 return sub 753 } 754 755 type edge struct { 756 from, to graph.Node 757 weight float64 758 } 759 760 func (e edge) From() graph.Node { return e.from } 761 func (e edge) To() graph.Node { return e.to } 762 func (e edge) ReversedEdge() graph.Edge { e.from, e.to = e.to, e.from; return e } 763 func (e edge) Weight() float64 { return e.weight }