gorgonia.org/gorgonia@v0.9.17/graph.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  
     7  	"github.com/awalterschulze/gographviz"
     8  	"gonum.org/v1/gonum/graph"
     9  	"gonum.org/v1/gonum/graph/iterator"
    10  )
    11  
    12  // ExprGraph is a data structure for a directed acyclic graph (of expressions). This structure is the main entry point
    13  // for Gorgonia.
    14  type ExprGraph struct {
    15  	name string
    16  
    17  	all Nodes
    18  
    19  	byID   map[int64]int
    20  	byHash map[uint32]*Node
    21  	evac   map[uint32]Nodes
    22  	to     map[*Node]Nodes
    23  
    24  	leaves    Nodes
    25  	constants Nodes
    26  	roots     Nodes
    27  	counter   uint
    28  }
    29  
    30  // graphconopt sets options
    31  type graphconopt func(g *ExprGraph)
    32  
    33  // WithGraphName is a ExprGraph construction option that provides a name.
    34  func WithGraphName(name string) graphconopt {
    35  	f := func(g *ExprGraph) {
    36  		g.name = name
    37  	}
    38  	return f
    39  }
    40  
    41  // NewGraph creates a new graph. Duh
    42  func NewGraph(opts ...graphconopt) *ExprGraph {
    43  	g := &ExprGraph{
    44  		byID:   make(map[int64]int),
    45  		byHash: make(map[uint32]*Node),
    46  		evac:   make(map[uint32]Nodes),
    47  		to:     make(map[*Node]Nodes),
    48  
    49  		leaves:    make(Nodes, 0, 64),
    50  		constants: make(Nodes, 0, 8),
    51  	}
    52  
    53  	for _, opt := range opts {
    54  		opt(g)
    55  	}
    56  
    57  	return g
    58  }
    59  
    60  // Clone clones the graph. All nodes gets cloned, and their values are cloned as well.
    61  func (g *ExprGraph) Clone() interface{} {
    62  	g2 := new(ExprGraph)
    63  	g2.name = g.name
    64  
    65  	mapping := make(map[*Node]*Node) // a map of old nodes to new nodes
    66  	g2.all = make(Nodes, len(g.all))
    67  	for i, n := range g.all {
    68  		cloned := n.Clone().(*Node)
    69  		cloned.g = g2
    70  		cloned.id = n.id
    71  
    72  		g2.all[i] = cloned
    73  		mapping[n] = cloned
    74  	}
    75  
    76  	// handle each node's children, deriv ofs, etc
    77  	for i, n := range g.all {
    78  		cloned := g2.all[i]
    79  		cloned.children = make(Nodes, len(n.children))
    80  		for j, c := range n.children {
    81  			cloned.children[j] = mapping[c]
    82  		}
    83  
    84  		cloned.derivOf = make(Nodes, len(n.derivOf))
    85  		for j, c := range n.derivOf {
    86  			cloned.derivOf[j] = mapping[c]
    87  		}
    88  
    89  		if n.deriv != nil {
    90  			cloned.deriv = mapping[n.deriv]
    91  		}
    92  	}
    93  
    94  	g2.byID = make(map[int64]int)
    95  	g2.byHash = make(map[uint32]*Node)
    96  	for k, v := range g.byHash {
    97  		g2.byHash[k] = mapping[v]
    98  	}
    99  
   100  	g2.evac = make(map[uint32]Nodes)
   101  	for k, v := range g.evac {
   102  		g2.evac[k] = make(Nodes, len(v))
   103  		for i, n := range v {
   104  			g2.evac[k][i] = mapping[n]
   105  		}
   106  	}
   107  
   108  	g2.to = make(map[*Node]Nodes)
   109  	for k, v := range g.to {
   110  		to := mapping[k]
   111  		g2.to[to] = make(Nodes, len(v))
   112  		for i, n := range v {
   113  			g2.to[to][i] = mapping[n]
   114  		}
   115  	}
   116  
   117  	g2.leaves = make(Nodes, len(g.leaves))
   118  	for i, n := range g.leaves {
   119  		g2.leaves[i] = mapping[n]
   120  	}
   121  
   122  	g2.constants = make(Nodes, len(g.constants))
   123  	for i, n := range g.constants {
   124  		g2.constants[i] = mapping[n]
   125  	}
   126  
   127  	g2.roots = make(Nodes, len(g.roots))
   128  	for i, n := range g.roots {
   129  		g2.roots[i] = mapping[n]
   130  	}
   131  
   132  	g2.counter = g.counter
   133  	return g2
   134  }
   135  
   136  // AddNode adds n to the graph. It panics if the added node ID matches an existing node ID.
   137  func (g *ExprGraph) AddNode(n *Node) (retVal *Node) {
   138  	defer func() {
   139  		if _, ok := g.to[retVal]; !ok {
   140  			g.to[retVal] = nil
   141  		}
   142  	}()
   143  	// check for node with the same name in the graph
   144  	// we don't update the graph if this is the case
   145  	for _, node := range g.constants {
   146  		if node.name == n.name && n.isConstant() {
   147  			return node
   148  		}
   149  	}
   150  	hash := n.Hashcode()
   151  	if existing, ok := g.byHash[hash]; ok {
   152  		if existing == nil {
   153  			// this means that there has been previous collisions
   154  			// so look at evac map
   155  			for _, e := range g.evac[hash] {
   156  				if nodeEq(n, e) {
   157  					return e
   158  				}
   159  			}
   160  			g.evac[hash] = append(g.evac[hash], n)
   161  			g.addToAll(n)
   162  			incrCC() // collision counter
   163  			return n
   164  		}
   165  
   166  		if !nodeEq(n, existing) {
   167  			g.evac[hash] = Nodes{existing, n}
   168  			g.byHash[hash] = nil // to signal that it's collided
   169  			g.addToAll(n)
   170  			incrCC()
   171  			return n
   172  		}
   173  		incrEC() // expected collision (they're the same node!)
   174  		return existing
   175  	}
   176  
   177  	if n.isConstant() {
   178  		n = n.clone()
   179  		g.constants = g.constants.Add(n)
   180  		n.g = g
   181  	}
   182  
   183  	g.addToAll(n)
   184  	g.byHash[hash] = n
   185  	return n
   186  }
   187  
   188  func (g *ExprGraph) addToAll(n *Node) {
   189  	if n == nil {
   190  		panic("HELP! trying to add nil")
   191  	}
   192  	g.all = append(g.all, n)
   193  	n.id = int64(g.counter)
   194  	g.counter++
   195  }
   196  
   197  // RemoveNode removes n from the graph, as well as any edges attached to it. If the node
   198  // is not in the graph it is a no-op.
   199  func (g *ExprGraph) RemoveNode(node graph.Node) {
   200  	n := node.(*Node)
   201  	if n.id == -1 {
   202  		return // if it's -1, it was never in the graph to begin with
   203  	}
   204  
   205  	hash := n.Hashcode()
   206  
   207  	delete(g.byHash, hash)
   208  	delete(g.to, n)
   209  	g.evac[hash] = g.evac[hash].remove(n)
   210  	g.all = g.all.remove(n)
   211  }
   212  
   213  // SetEdge adds e, an edge from one node to another. If the nodes do not exist, they are added.
   214  // It will panic if the IDs of the e.From and e.To are equal.
   215  func (g *ExprGraph) SetEdge(e graph.Edge) {
   216  	from := e.From().(*Node)
   217  	to := e.To().(*Node)
   218  
   219  	if from == to {
   220  		panic(fmt.Sprintf("cannot add self edge: from %v to %v", from, to))
   221  	}
   222  
   223  	if !g.Has(from.ID()) {
   224  		from = g.AddNode(from)
   225  	}
   226  
   227  	if !g.Has(to.ID()) {
   228  		to = g.AddNode(to)
   229  	}
   230  
   231  	// g.to[to] = g.to[to].Add(from)
   232  	g.to[to] = append(g.to[to], from)
   233  }
   234  
   235  // Roots returns a list of nodes that are not children of any other nodes
   236  func (g *ExprGraph) Roots() (retVal Nodes) {
   237  	// handle subgraph
   238  	if g.roots != nil {
   239  		return g.roots
   240  	}
   241  
   242  	for n, tos := range g.to {
   243  		if len(tos) == 0 {
   244  			retVal = append(retVal, n)
   245  		}
   246  		// if the root is a statement (typically a read), and it only has one child
   247  		if len(n.children) == 1 && n.isStmt {
   248  			child := n.children[0]
   249  			if len(g.to[child]) == 1 {
   250  				retVal = append(retVal, child)
   251  			}
   252  		}
   253  	}
   254  	g.roots = retVal
   255  	return retVal
   256  }
   257  
   258  // Inputs returns a list of nodes which are inputs (that is to say, the user is required to set a value in it)
   259  func (g *ExprGraph) Inputs() (retVal Nodes) {
   260  	for _, n := range g.all {
   261  		if n.isInput() {
   262  			retVal = append(retVal, n)
   263  		}
   264  	}
   265  	return
   266  }
   267  
   268  // UnbindAll unbinds all the values from the nodes
   269  func (g *ExprGraph) UnbindAll() {
   270  	for _, n := range g.all {
   271  		n.unbind()
   272  	}
   273  }
   274  
   275  // UnbindAllNonInputs unbinds all the values from nodes that aren't input nodes
   276  func (g *ExprGraph) UnbindAllNonInputs() {
   277  	for _, n := range g.all {
   278  		if n.isInput() || n.isConstant() {
   279  			continue
   280  		}
   281  		n.unbind()
   282  	}
   283  }
   284  
   285  // ByName returns nodes that have the name provided.
   286  // Bear in mind that the name that is compared to is the internal name,
   287  // not the result of calling node.Name(). The reason for doing this is
   288  // for ease of finding only names that are user-supplied, instead of autogenerated names
   289  func (g *ExprGraph) ByName(name string) (retVal Nodes) {
   290  	for _, n := range g.all {
   291  		if n.name == name {
   292  			retVal = append(retVal, n)
   293  		}
   294  	}
   295  	return
   296  }
   297  
   298  // Constant returns a constant that may be found in the graph. If no constant were found, a new one is created instead
   299  func (g *ExprGraph) Constant(v Value) *Node {
   300  	for _, n := range g.constants {
   301  		if ValueEq(n.Value(), v) {
   302  			return n
   303  		}
   304  	}
   305  
   306  	n := NewConstant(v)
   307  	return g.AddNode(n)
   308  }
   309  
   310  func (g *ExprGraph) String() string {
   311  	var buf bytes.Buffer
   312  	buf.WriteString("Graph: [\n")
   313  	for _, n := range g.byHash {
   314  		fmt.Fprintf(&buf, "\t%d: %s\n", n.Hashcode(), n)
   315  	}
   316  	buf.WriteString("]")
   317  	return buf.String()
   318  }
   319  
   320  // ToDot generates the graph in graphviz format. The use of this is to generate for the entire graph
   321  // which may have multiple trees with different roots
   322  // TODO: This is getting unwieldy. Perhaps refactor out into a ToDot(...Opt)?
   323  func (g *ExprGraph) ToDot() string {
   324  	gv := gographviz.NewEscape()
   325  	gv.SetName(fullGraphName)
   326  	gv.SetDir(true)
   327  
   328  	gv.AddAttr(fullGraphName, "nodesep", "1")
   329  	gv.AddAttr(fullGraphName, "ranksep", "1.5 equally")
   330  	gv.AddAttr(fullGraphName, "rankdir", "TB")
   331  	if len(g.byHash) > 100 {
   332  		gv.AddAttr(fullGraphName, "nslimit", "3") // numiter=3*len(nodes)
   333  		// gv.AddAttr(fullGraphName, "splines", "line") // ugly as sin.
   334  	}
   335  
   336  	groups := make(map[string]struct{})
   337  	for h, n := range g.byHash {
   338  		if n != nil {
   339  			group := n.dotCluster()
   340  			groups[group] = struct{}{}
   341  			continue
   342  		}
   343  		// other wise it'se a clash of hash
   344  		for _, n := range g.evac[h] {
   345  			group := n.dotCluster()
   346  			groups[group] = struct{}{}
   347  
   348  		}
   349  	}
   350  
   351  	for grp := range groups {
   352  		attrs := map[string]string{"label": grp}
   353  
   354  		parentGraph := fullGraphName
   355  		if grp == inputsClust || grp == constantsClust {
   356  			parentGraph = inputConsts
   357  			if !gv.IsSubGraph(inputConsts) {
   358  				groupAttrs := map[string]string{"rank": "max"}
   359  				gv.AddSubGraph(fullGraphName, inputConsts, groupAttrs)
   360  			}
   361  		}
   362  		gv.AddSubGraph(parentGraph, "cluster_"+grp, attrs)
   363  	}
   364  
   365  	// for _, n := range g.byHash {
   366  	for _, n := range g.all {
   367  		group := n.dotCluster()
   368  		n.dotString(gv, "cluster_"+group)
   369  	}
   370  
   371  	// for _, from := range g.byHash {
   372  	for _, from := range g.all {
   373  		for i, child := range from.children {
   374  			if ok := g.all.Contains(child); !ok {
   375  				// not in graph, so ignore it...
   376  				continue
   377  			}
   378  			fromID := fmt.Sprintf("Node_%p", from)
   379  			toID := fmt.Sprintf("Node_%p", child)
   380  
   381  			edgeAttrs := map[string]string{
   382  				"taillabel":  fmt.Sprintf(" %d ", i),
   383  				"labelfloat": "false",
   384  			}
   385  
   386  			// we invert the from and to nodes for gradients, As the expressionGraph builds upwards from bottom, the gradient builds downwards.
   387  			if from.group == gradClust && child.group == gradClust {
   388  				edgeAttrs["dir"] = "back"
   389  				gv.AddPortEdge(toID, toID+":anchor:s", fromID, fromID+":anchor:n", true, edgeAttrs)
   390  			} else {
   391  				gv.AddPortEdge(fromID, fromID+":anchor:s", toID, toID+":anchor:n", true, edgeAttrs)
   392  			}
   393  		}
   394  	}
   395  
   396  	// draw deriv lines
   397  	if debugDerives {
   398  		edgeAttrs := map[string]string{
   399  			"style":      "dashed",
   400  			"constraint": "false",
   401  			"weight":     "999",
   402  		}
   403  
   404  		for _, n := range g.byHash {
   405  			if n == nil {
   406  				// collision found... what to do?
   407  				continue
   408  			}
   409  			if n.derivOf != nil {
   410  				id := fmt.Sprintf("Node_%p", n)
   411  				for _, derivOf := range n.derivOf {
   412  					if _, ok := g.to[derivOf]; !ok {
   413  						continue
   414  					}
   415  					ofID := fmt.Sprintf("Node_%p", derivOf)
   416  					// gv.AddPortEdge(id, ":anchor:w", ofID, ofID+":anchor:e", true, edgeAttrs)
   417  					gv.AddEdge(id, ofID, true, edgeAttrs)
   418  				}
   419  			}
   420  		}
   421  	}
   422  
   423  	// stupid invisible nodes to keep expressiongraph on the left
   424  	subGAttrs := make(map[string]string)
   425  	// subGAttrs.Add("rank", "max")
   426  	gv.AddSubGraph(fullGraphName, outsideSubG, subGAttrs)
   427  
   428  	attrs := map[string]string{
   429  		"style": "invis",
   430  	}
   431  	gv.AddNode(outsideSubG, outsideRoot, attrs)
   432  
   433  	outsides := []string{outsideRoot}
   434  	var insides []string
   435  
   436  	// build the inside and outside list
   437  	if _, hasInputs := groups[inputsClust]; hasInputs {
   438  		insides = append(insides, insideInputs)
   439  		gv.AddNode("cluster_inputs", insideInputs, attrs)
   440  	}
   441  
   442  	if _, hasConst := groups[constantsClust]; hasConst {
   443  		if len(insides) > 0 {
   444  			outsides = append(outsides, outsideConsts)
   445  			gv.AddNode(outsideSubG, outsideConsts, attrs)
   446  		}
   447  		insides = append(insides, insideConsts)
   448  		gv.AddNode("cluster_constants", insideConsts, attrs)
   449  	}
   450  
   451  	if len(insides) > 0 {
   452  		outsides = append(outsides, outsideExprG)
   453  		gv.AddNode(outsideSubG, outsideExprG, attrs)
   454  	}
   455  	insides = append(insides, insideExprG)
   456  	gv.AddNode("cluster_expressionGraph", insideExprG, attrs)
   457  
   458  	for group := range groups {
   459  		if group == exprgraphClust || group == constantsClust || group == inputsClust {
   460  			continue
   461  		}
   462  		inside := "inside_" + group
   463  		outside := "outside_" + group
   464  		insides = append(insides, inside)
   465  		outsides = append(outsides, outside)
   466  
   467  		gv.AddNode(outsideSubG, outside, attrs)
   468  		gv.AddNode("cluster_"+group, inside, attrs)
   469  	}
   470  
   471  	edgeAttrs := map[string]string{
   472  		"style":      "invis",
   473  		"weight":     "999",
   474  		"constraint": "false",
   475  	}
   476  	for i, o := range outsides {
   477  		// outside-inside
   478  		gv.AddEdge(o, insides[i], true, edgeAttrs)
   479  
   480  		if i > 0 {
   481  			// outside-outside
   482  			gv.AddEdge(outsides[i-1], o, true, edgeAttrs)
   483  
   484  			// inside-inside
   485  			gv.AddEdge(insides[i-1], insides[i], true, edgeAttrs)
   486  		}
   487  	}
   488  	return gv.String()
   489  }
   490  
   491  // Edges returns all the edges in the graph.
   492  func (g *ExprGraph) Edges() graph.Edges {
   493  	var edges []graph.Edge
   494  	for _, n := range g.all {
   495  		for _, toN := range g.to[n] {
   496  			edges = append(edges, edge{
   497  				from: n,
   498  				to:   toN,
   499  			})
   500  		}
   501  	}
   502  	if len(edges) == 0 {
   503  		return graph.Empty
   504  	}
   505  	return iterator.NewOrderedEdges(edges)
   506  }
   507  
   508  // other private methods
   509  func (g *ExprGraph) removeAllEdgesFrom(n *Node) {
   510  	for k, ns := range g.to {
   511  		g.to[k] = ns.remove(n)
   512  	}
   513  }
   514  
   515  /* Graph interface */
   516  
   517  // Node returns the node in the graph with the given ID.
   518  func (g *ExprGraph) Node(id int64) graph.Node {
   519  	// n := (*Node)(unsafe.Pointer(uintptr(id)))
   520  	// for _, n := range g.all {
   521  	// 	if n.id == id {
   522  	// 		return n
   523  	// 	}
   524  	// }
   525  	// return nil
   526  	return g.node(id)
   527  }
   528  
   529  func (g *ExprGraph) node(id int64) *Node {
   530  	if idx, ok := g.byID[id]; ok {
   531  		if idx >= len(g.all) {
   532  			return nil
   533  		}
   534  		return g.all[idx]
   535  	}
   536  	for i, n := range g.all {
   537  		if n.id == id {
   538  			g.byID[id] = i
   539  			return n
   540  		}
   541  	}
   542  	return nil
   543  }
   544  
   545  // Has returns whether the node exists within the graph.
   546  func (g *ExprGraph) Has(nodeid int64) bool {
   547  	n := g.node(nodeid)
   548  	return n != nil
   549  }
   550  
   551  // Nodes returns all the nodes in the graph.
   552  func (g *ExprGraph) Nodes() graph.Nodes {
   553  	// nodes := make([]graph.Node, len(g.from))
   554  	ns := g.AllNodes()
   555  
   556  	return nodeToGraphNode(ns)
   557  }
   558  
   559  // AllNodes is like Nodes, but returns Nodes instead of []graph.Node.
   560  // Nodes() has been reserved for the graph.Directed interface, so this one is named AllNodes instead
   561  func (g *ExprGraph) AllNodes() Nodes { return g.all }
   562  
   563  // From returns all nodes in g that can be reached directly from n.
   564  func (g *ExprGraph) From(nodeid int64) graph.Nodes {
   565  	if n := g.node(nodeid); n != nil {
   566  		return nodeToGraphNode(n.children)
   567  	}
   568  	return nil
   569  }
   570  
   571  // HasEdgeBetween returns whether an edge exists between nodes x and y without
   572  // considering direction.
   573  func (g *ExprGraph) HasEdgeBetween(x, y int64) bool {
   574  	xid := g.node(x)
   575  	yid := g.node(y)
   576  	if xid == nil || yid == nil {
   577  		return false
   578  	}
   579  
   580  	return xid.children.Contains(yid) || yid.children.Contains(xid)
   581  }
   582  
   583  // Edge returns the edge from u to v if such an edge exists and nil otherwise.
   584  // The node v must be directly reachable from u as defined by the From method.
   585  func (g *ExprGraph) Edge(u, v int64) graph.Edge {
   586  	uid := g.node(u)
   587  	vid := g.node(v)
   588  
   589  	if uid == nil || vid == nil {
   590  		return nil
   591  	}
   592  
   593  	if !uid.children.Contains(vid) {
   594  		return nil
   595  	}
   596  	e := edge{from: uid, to: vid}
   597  	return e
   598  }
   599  
   600  /* Directed interface */
   601  
   602  // HasEdgeFromTo returns whether an edge exists in the graph from u to v.
   603  func (g *ExprGraph) HasEdgeFromTo(u, v int64) bool {
   604  	uid := g.node(u)
   605  	vid := g.node(v)
   606  	if uid == nil || vid == nil {
   607  		return false
   608  	}
   609  
   610  	return uid.children.Contains(vid)
   611  }
   612  
   613  // To returns all nodes in g that can reach directly to n.
   614  func (g *ExprGraph) To(nid int64) graph.Nodes {
   615  	n := g.node(nid)
   616  	if n == nil {
   617  		return nil
   618  	}
   619  
   620  	ns := g.to[n]
   621  	ns = ns.Set()
   622  	g.to[n] = ns
   623  	return nodeToGraphNode(ns)
   624  }
   625  
   626  // subgraph is basically a subset of nodes. This is useful for compiling sub sections of the graph
   627  func (g *ExprGraph) subgraph(ns Nodes, findMissing bool, opts ...Nodes) *ExprGraph {
   628  	// ns = ns.Set()
   629  
   630  	var roots Nodes
   631  	// add missing stuff first
   632  	if findMissing {
   633  		for _, n := range ns {
   634  			for _, parent := range g.to[n] {
   635  				if parent.isStmt {
   636  					roots = append(roots, parent)
   637  					ns = append(ns, parent)
   638  				}
   639  			}
   640  		}
   641  	}
   642  
   643  	// uniquify the froms and at the same time build a new roots
   644  	allset := ns.mapSet()
   645  	if len(opts) == 0 {
   646  		for _, n := range ns {
   647  			if len(g.to[n]) == 0 {
   648  				if n.isStmt {
   649  					roots = append(roots, n.children[0])
   650  				} else {
   651  					roots = append(roots, n)
   652  				}
   653  				continue
   654  			}
   655  
   656  			var hasParent bool
   657  			for _, parent := range g.to[n] {
   658  				if allset.Contains(parent) {
   659  					hasParent = true
   660  					break
   661  				}
   662  			}
   663  			if !hasParent {
   664  				roots = append(roots, n)
   665  			}
   666  		}
   667  	} else {
   668  		rs := opts[0]
   669  		roots = make(Nodes, len(rs))
   670  		for i, n := range rs {
   671  			if n.isStmt {
   672  				roots[i] = n.children[0]
   673  				continue
   674  			}
   675  			roots[i] = n
   676  
   677  		}
   678  	}
   679  	var leaves Nodes
   680  	for _, n := range ns {
   681  		if len(n.children) == 0 {
   682  			leaves = append(leaves, n)
   683  		}
   684  	}
   685  
   686  	// uniquify all the things
   687  	roots = roots.Set()
   688  	leaves = leaves.Set()
   689  	ns = ns.Set()
   690  
   691  	retVal := &ExprGraph{
   692  		all:    ns,
   693  		byID:   make(map[int64]int),
   694  		byHash: g.byHash,
   695  		evac:   g.evac,
   696  		to:     g.to,
   697  
   698  		leaves:    leaves,
   699  		constants: g.constants,
   700  		roots:     roots,
   701  	}
   702  
   703  	return retVal
   704  }
   705  
   706  // Subgraph subsets a graph. This function has overloaded meanings - If only one node is passed in, it assumes that the one node is the root,
   707  // otherwise, it treats ns as the subset of nodes to be included in the subgraph
   708  func (g *ExprGraph) Subgraph(ns ...*Node) *ExprGraph {
   709  	if len(ns) == 1 {
   710  		g.SubgraphRoots(ns[0])
   711  	}
   712  	return g.subgraph(ns, true)
   713  }
   714  
   715  // SubgraphRoots creates a subgraph, assuming the provided nodes are roots to the new subgraph.
   716  func (g *ExprGraph) SubgraphRoots(ns ...*Node) *ExprGraph {
   717  	sub := g.walkFromRoots(ns...)
   718  	return g.subgraph(sub, true, ns)
   719  }
   720  
   721  // ExactSubgraphRoots creates a subgraph from the roots provided.
   722  // The difference between SubgraphRoots and ExactSubgraphRoots is that ExactSubGraphRoots
   723  // will not attempt to discover if any nodes are missing.
   724  //
   725  // Given a function like the following:
   726  //		z = x + y
   727  //		set(x, -x.Grad) // setting the value of x to the negative of the gradient
   728  //
   729  // When SubgraphRoots is used on z, the `-x.Grad` will be included.
   730  // When using ExactSubgraphRoots, only `x` and `y` are included in the subgraph
   731  func (g *ExprGraph) ExactSubgraphRoots(ns ...*Node) *ExprGraph {
   732  	sub := g.walkFromRoots(ns...)
   733  	return g.subgraph(sub, false, ns)
   734  }
   735  
   736  func (g *ExprGraph) walkFromRoots(ns ...*Node) Nodes {
   737  	sub := make(Nodes, len(ns))
   738  	copy(sub, ns)
   739  
   740  	walked := NewNodeSet()
   741  	for _, n := range ns {
   742  		ch := make(chan *Node)
   743  		go func(ch chan *Node) {
   744  			defer close(ch)
   745  			walkGraph(n, ch, walked)
   746  		}(ch)
   747  
   748  		for node := range ch {
   749  			sub = append(sub, node)
   750  		}
   751  	}
   752  	return sub
   753  }
   754  
   755  type edge struct {
   756  	from, to graph.Node
   757  	weight   float64
   758  }
   759  
   760  func (e edge) From() graph.Node         { return e.from }
   761  func (e edge) To() graph.Node           { return e.to }
   762  func (e edge) ReversedEdge() graph.Edge { e.from, e.to = e.to, e.from; return e }
   763  func (e edge) Weight() float64          { return e.weight }