gorgonia.org/gorgonia@v0.9.17/node.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"hash"
     8  	"hash/fnv"
     9  
    10  	"github.com/awalterschulze/gographviz"
    11  	"github.com/chewxy/hm"
    12  	"github.com/pkg/errors"
    13  	"gonum.org/v1/gonum/graph"
    14  	"gorgonia.org/gorgonia/internal/encoding"
    15  	"gorgonia.org/tensor"
    16  )
    17  
    18  // A Node is a node in the computation graph
    19  type Node struct {
    20  	// metadata of the node
    21  	t     hm.Type // pruned types only plz
    22  	shape tensor.Shape
    23  
    24  	// this node is the result of applying the op to the children
    25  	op       Op
    26  	children Nodes // shortcut, instead of having to go through the graph
    27  
    28  	// For nicely grouping stuff in graphviz.
    29  	// TODO: Should this be in *Node?
    30  	name string
    31  	// DEPRECATED: the group attribute will be removed in the next version in favor of groups
    32  	group string
    33  	// the grouping notion is only useful for exporting to another format
    34  	groups encoding.Groups
    35  
    36  	g *ExprGraph // this node belongs in this graph
    37  
    38  	// value bondage
    39  	// inputs are bound to values directly
    40  	boundTo Value
    41  	dataOn  Device // where is the data on
    42  
    43  	// to track derivations
    44  	derivOf Nodes
    45  	deriv   *Node
    46  
    47  	// for hashing nodes
    48  	id   int64 // id is the ID at which the node is added to the graph
    49  	hash uint32
    50  
    51  	hashed        bool
    52  	inferredShape bool // is shape inferred?
    53  	unchanged     bool // has this node been modified
    54  	isStmt        bool // is this a statement node
    55  	ofInterest    bool // is this node of particular interest? (for debugging)
    56  }
    57  
    58  // NodeConsOpt is a function that provides construction options for any Node.
    59  type NodeConsOpt func(*Node)
    60  
    61  // WithType is a node construction option to set a node to the specified type.
    62  // Types in *Node are immutable once set. If the type has already been specified in the node,
    63  // a check will be made to see if the both types are the same. If it isn't, it will panic.
    64  func WithType(t hm.Type) NodeConsOpt {
    65  	f := func(n *Node) {
    66  		if n.t == nil {
    67  			n.t = t
    68  		} else if !n.t.Eq(t) {
    69  			panic(fmt.Sprintf("Node's type is %v. Asking to construct a Node with %v", n.t, t))
    70  		}
    71  	}
    72  	return f
    73  }
    74  
    75  // WithChildren sets the children of a node to the specified chidren.
    76  // This construction option does NOT check if existing children exists, and will overwrite the existing children.
    77  func WithChildren(children Nodes) NodeConsOpt {
    78  	f := func(n *Node) {
    79  		n.children = children
    80  	}
    81  	return f
    82  }
    83  
    84  // WithOp is a node construction option to set a node's Op to the specified Op.
    85  // `Op`s in `*Node`s are immutable once set and cannot be changed. If the node already has an Op specified
    86  // a check will be made to see if the provided Op and the one already specified in the `*Node` is the same -
    87  // do note that comparison of Ops is done using the `Hashcode()` method of Ops, and hash collisions MAY occur -
    88  // If both ops are different, this function will panic.
    89  func WithOp(op Op) NodeConsOpt {
    90  	f := func(n *Node) {
    91  		if n.op != nil {
    92  			if op.Hashcode() != n.op.Hashcode() {
    93  				panic(fmt.Sprintf("Node Ops are immutable. Cannot set op %v", op))
    94  			}
    95  			return
    96  		}
    97  		n.op = op
    98  		if _, ok := op.(stmtOp); ok {
    99  			n.isStmt = true
   100  		}
   101  	}
   102  	return f
   103  }
   104  
   105  // In is a node construction option to set a node's graph.
   106  // A `*Node`'s graph is immutable. If the graph has already been set, a check will be made that the specifiec *Graph
   107  // and the *Graph set in *Node are the same. If they are not, the function will panic/
   108  func In(g *ExprGraph) NodeConsOpt {
   109  	f := func(n *Node) {
   110  		if n.g != nil {
   111  			if g != n.g {
   112  				panic(fmt.Sprintf("Node Graphs are immutable. Cannot set g %v", g))
   113  			}
   114  		}
   115  		n.g = g
   116  	}
   117  	return f
   118  }
   119  
   120  // WithName is a node construction option that gives the *Node the provided name. This is especially useful in debugging graphs.
   121  func WithName(name string) NodeConsOpt {
   122  	f := func(n *Node) {
   123  		n.name = name
   124  	}
   125  	return f
   126  }
   127  
   128  // WithValue is a node construction option that binds the value to the *Node. This function may panic if:
   129  //	- Gorgonia was unable to convert interface{} into a Value.
   130  //	- The type of the Value does not match the type of the nodes.
   131  func WithValue(any interface{}) NodeConsOpt {
   132  	v, t, _, err := anyToValue(any)
   133  	if err != nil {
   134  		panic(err)
   135  	}
   136  
   137  	f := func(n *Node) {
   138  		if n.t == nil {
   139  			n.t = t
   140  		} else if !n.t.Eq(t) {
   141  			// scalars are exempted
   142  			// if n.t is a scalar type, and the `anyToValue` returns a vector that is scalar-like,
   143  			// we'll mark it as a scalar
   144  			tt, ok1 := t.(TensorType)
   145  			ndt, ok2 := n.t.(tensor.Dtype)
   146  
   147  			if !(ok1 && ok2 && v.Shape().IsScalarEquiv() && tt.Of == ndt) {
   148  				panic(fmt.Sprintf("TypeError: Want %v, Got %v instead (%T %T)", n.t, t, n.t, t)) // yes this is a runtime error
   149  			}
   150  			v.(tensor.Tensor).Reshape() // rehsape to scalar
   151  
   152  		}
   153  
   154  		n.bind(v)
   155  		if n.shape == nil {
   156  			n.shape = v.Shape()
   157  		}
   158  	}
   159  	return f
   160  }
   161  
   162  // WithGrad is a node construction option that binds the value to the *Node. This function may panic if:
   163  //	- There isn't already a value associated with the node (.boundTo == nil)
   164  //	- The type of the Value does not match the value of the node.
   165  func WithGrad(any interface{}) NodeConsOpt {
   166  	v, t, _, err := anyToValue(any)
   167  	if err != nil {
   168  		panic(err)
   169  	}
   170  	f := func(n *Node) {
   171  		if n.boundTo == nil {
   172  			panic("No value already bound to node")
   173  		}
   174  		if !TypeOf(n.boundTo).Eq(t) {
   175  			panic("Different types ")
   176  		}
   177  
   178  		if dv, ok := n.boundTo.(*dualValue); !ok {
   179  			if err := n.bind(&dualValue{Value: n.boundTo, d: v}); err != nil {
   180  				panic(err)
   181  			}
   182  		} else {
   183  			dv.d = v
   184  		}
   185  	}
   186  	return f
   187  }
   188  
   189  // WithInit is a node construction option to initialize a *Node with the InitWFn provided.
   190  func WithInit(fn InitWFn) NodeConsOpt {
   191  	f := func(n *Node) {
   192  		dt, err := dtypeOf(n.t)
   193  		if err != nil {
   194  			panic(err)
   195  		}
   196  
   197  		var v Value
   198  		v = tensor.New(tensor.WithShape(n.shape...), tensor.WithBacking(fn(dt, n.shape...)))
   199  		WithValue(v)(n)
   200  	}
   201  	return f
   202  }
   203  
   204  // WithShape is a node construction option to initialize a *Node with a particular shape.
   205  // This function panics if the shape's dimensions do not match the specified dimensions of the *Node.
   206  func WithShape(shp ...int) NodeConsOpt {
   207  	s := tensor.Shape(tensor.BorrowInts(len(shp)))
   208  	copy(s, shp)
   209  	f := func(n *Node) {
   210  		if n.t == nil && n.shape == nil {
   211  			n.shape = s
   212  			return
   213  		}
   214  		nd := n.Dims()
   215  		isVec := s.IsColVec() || s.IsRowVec()
   216  		acceptVec := (isVec && (nd == 1))
   217  		sameDims := nd == s.Dims()
   218  		acceptScalar := nd == 0 && scalarEquiv(s)
   219  
   220  		if !acceptVec && !sameDims && !acceptScalar {
   221  			panic(fmt.Sprintf("Node %v, has %d dimensions(Shape: %v). Input shape is %v, which has %d dimensions", n, n.Dims(), n.shape, s, s.Dims()))
   222  		}
   223  		n.shape = s
   224  	}
   225  	return f
   226  }
   227  
   228  // WithGroupName is a node construction option to group a *Node within a particular group. This option is useful for debugging with graphs.
   229  // This function is deprecated and will proabably be remove in the next version.
   230  func WithGroupName(name string) NodeConsOpt {
   231  	f := func(n *Node) {
   232  		if n.group == "" {
   233  			n.group = name
   234  		}
   235  	}
   236  	return f
   237  }
   238  
   239  // withGroup is a node construction option to group a *Node within a particular group. This option is useful for debugging with graphs.
   240  func withGroup(group encoding.Group) NodeConsOpt {
   241  	f := func(n *Node) {
   242  		n.groups = n.groups.Upsert(group)
   243  	}
   244  	return f
   245  }
   246  
   247  // Groups to fulfil the encoding Grouper interface
   248  func (n *Node) Groups() encoding.Groups {
   249  	var isConst bool
   250  	var isInput = n.isInput()
   251  
   252  	if n.op != nil {
   253  		_, isConst = n.op.(constant)
   254  	}
   255  
   256  	switch {
   257  	case isConst:
   258  		n.groups = n.groups.Upsert(encoding.ConstantCluster)
   259  	case isInput:
   260  		n.groups = n.groups.Upsert(encoding.InputCluster)
   261  	default:
   262  		n.groups = n.groups.Upsert(encoding.ExprGraphCluster)
   263  	}
   264  	return n.groups
   265  }
   266  
   267  func newNode(opts ...NodeConsOpt) *Node {
   268  	n := borrowNode()
   269  	n.dataOn = CPU
   270  	n.id = -1
   271  	n.t = nil
   272  	n.shape = nil
   273  
   274  	for _, opt := range opts {
   275  		opt(n)
   276  	}
   277  	n.fix()
   278  
   279  	incrNN()
   280  	return n
   281  }
   282  
   283  // NewUniqueNode creates a new unique node in a graph. If no graph was specified in the construction options then it will just return a graphless node.
   284  func NewUniqueNode(opts ...NodeConsOpt) *Node {
   285  	n := newNode(opts...)
   286  	if n.g == nil {
   287  		return n
   288  	}
   289  	n.fixChildren() // ensure that all the kids are in the graph first
   290  
   291  	m := n.g.AddNode(n)
   292  	if n != m {
   293  		returnNode(n)
   294  	}
   295  	m.fixEdges()
   296  	return m
   297  }
   298  
   299  // ID returns the ID of the node. This satisfies the gonum/graph.Node interface
   300  func (n *Node) ID() int64 { return n.id }
   301  
   302  // Node returns itself. This sorts of monoidal patterns are useful for compositions via interfaces.
   303  func (n *Node) Node() *Node { return n }
   304  
   305  // Nodes returns n as a slice of *Node. Again, this is mostly useful for interfaces
   306  func (n *Node) Nodes() Nodes { return Nodes{n} }
   307  
   308  // Err always returns nil. However, this method is implemented to enable nicer composition of functions
   309  func (n *Node) Err() error { return nil }
   310  
   311  func (n *Node) DataSize() int { return n.Shape().TotalSize() }
   312  
   313  func (n *Node) DerivOf() Nodes { return n.derivOf }
   314  
   315  func (n *Node) Deriv() *Node { return n.deriv }
   316  
   317  // helper functions to help compilation process
   318  func (n *Node) isArg() bool      { return n.op == nil }
   319  func (n *Node) isInput() bool    { return (n.isArg() || n.isRandom()) && !n.isStmt }
   320  func (n *Node) isMutable() bool  { return !n.isInput() && n.op.ReturnsPtr() }
   321  func (n *Node) isConstant() bool { _, ok := n.op.(constant); return ok }
   322  func (n *Node) isRandom() bool   { _, ok := n.op.(randomOp); return ok }
   323  
   324  func (n *Node) isRoot() bool {
   325  	if n.g == nil {
   326  		return true
   327  	}
   328  	return len(n.g.to[n]) == 0
   329  }
   330  
   331  // IsVar returns true if  the node represents a differentiable variable (i.e. it's an argument to the function that is not a statement)
   332  func (n *Node) IsVar() bool { return n.isArg() && !n.isStmt && !n.isConstant() }
   333  
   334  // type related isX() helper methods
   335  
   336  // IsScalar indicates if a node represents a a scalar value. This is based on the type of the node, not the actual value associated with the node
   337  func (n *Node) IsScalar() bool { _, ok := n.t.(tensor.Dtype); return ok }
   338  
   339  // IsVector indicates if a node represents a vector value. This is based on the type of the node, not the actual value associated with the node
   340  func (n *Node) IsVector() bool {
   341  	if t, ok := n.t.(TensorType); ok {
   342  		return t.Dims == 1
   343  	}
   344  
   345  	return false
   346  }
   347  
   348  // IsColVec indicates if a node represents a Column Vector. This is based on the type of the node, not the actual value associated with the node
   349  func (n *Node) IsColVec() bool {
   350  	if _, ok := n.t.(TensorType); ok {
   351  		if n.shape != nil {
   352  			return n.shape.IsColVec()
   353  		}
   354  	}
   355  	return false
   356  }
   357  
   358  // IsRowVec indicates if a node represents a Row Vector. This is based on the type of the node, not the actual value associated with the node
   359  func (n *Node) IsRowVec() bool {
   360  	if _, ok := n.t.(TensorType); ok {
   361  		if n.shape != nil {
   362  			return n.shape.IsRowVec()
   363  		}
   364  	}
   365  	return false
   366  }
   367  
   368  // IsMatrix indicates if a node represents a matrix. This is based on the type of the node, not the actual value associated with the node
   369  func (n *Node) IsMatrix() bool {
   370  	if _, ok := n.t.(TensorType); ok {
   371  		return n.shape.Dims() == 2
   372  	}
   373  	return false
   374  }
   375  
   376  // methods
   377  
   378  // Graph returns the graph of the node
   379  func (n *Node) Graph() *ExprGraph { return n.g }
   380  
   381  // CloneTo clones the node into a new graph. If CloneTo() is called on the same graph as the n, it will return n. The reason this is done is because
   382  // at any given time, every node  should be unique in the *ExprGraph.
   383  //
   384  //TODO: clone children as well (this means that CloneTo() is only currently suitable fo input nodes)
   385  func (n *Node) CloneTo(g *ExprGraph) *Node {
   386  	if n.g != nil && g == n.g {
   387  		return n
   388  	}
   389  	n2 := n.Clone().(*Node)
   390  	n2.g = g
   391  	n2 = g.AddNode(n2)
   392  	return n2
   393  }
   394  
   395  // Clone clones the node. There are some caveats:
   396  //		- the graph is not copied over - the node essentially does not belong to a collection
   397  //		- there is no ID
   398  // 		- the children are not cloned
   399  func (n *Node) Clone() (retVal interface{}) {
   400  	n2 := newNode(In(n.g), WithOp(n.op), WithName(n.name), WithType(n.t))
   401  	if n.shape != nil {
   402  		n2.shape = n.shape.Clone()
   403  		n2.inferredShape = n.inferredShape
   404  	}
   405  
   406  	if n.boundTo != nil {
   407  		var err error
   408  		if n2.boundTo, err = CloneValue(n.boundTo); err != nil {
   409  			panic(err)
   410  		}
   411  	}
   412  
   413  	// reset
   414  	n2.g = nil
   415  
   416  	// other things
   417  	n2.name = n.name
   418  	n2.group = n.group
   419  	n2.dataOn = n.dataOn
   420  	n2.hash = n.hash
   421  
   422  	n2.hashed = n.hashed
   423  	n2.inferredShape = n.inferredShape
   424  	n2.unchanged = n.unchanged
   425  	n2.isStmt = n.isStmt
   426  	n2.ofInterest = n.ofInterest
   427  	return n2
   428  }
   429  
   430  // Value returns the valuse bound to the node. May return nil
   431  func (n *Node) Value() Value {
   432  	if n.isConstant() {
   433  		return n.op.(constant).Value()
   434  	}
   435  	if dv, ok := n.boundTo.(*dualValue); ok {
   436  		return dv.Value
   437  	}
   438  	return n.boundTo
   439  }
   440  
   441  // Grad returns the gradient if there is one.
   442  func (n *Node) Grad() (Value, error) {
   443  	if dv, ok := n.boundTo.(*dualValue); ok {
   444  		return dv.d, nil
   445  	}
   446  	if n.deriv != nil {
   447  		return n.deriv.Value(), nil
   448  	}
   449  
   450  	return nil, errors.Errorf("No Gradient node/value found for %T", n)
   451  }
   452  
   453  // Dims indicates how many dimensions the node's result has
   454  func (n *Node) Dims() int {
   455  	if n.shape != nil {
   456  		return n.shape.Dims()
   457  	}
   458  	switch nt := n.t.(type) {
   459  	case TensorType:
   460  		return nt.Dims
   461  	case tensor.Dtype:
   462  		return 0
   463  	default:
   464  		panic(fmt.Sprintf("Dims undefined for %v(%T)", nt, nt))
   465  	}
   466  }
   467  
   468  // Type returns the type of the node
   469  func (n *Node) Type() hm.Type { return n.t }
   470  
   471  // Dtype returns the dtype of the node
   472  func (n *Node) Dtype() tensor.Dtype {
   473  	dt, err := dtypeOf(n.t)
   474  	if err != nil {
   475  		panic(err)
   476  	}
   477  	return dt
   478  }
   479  
   480  // Shape returns the shape of the node
   481  func (n *Node) Shape() tensor.Shape { return n.shape.Clone() }
   482  
   483  // Strides returns the strides of the value of the node
   484  func (n *Node) Strides() []int {
   485  	if n.boundTo != nil {
   486  		switch v := n.boundTo.(type) {
   487  		case *dualValue:
   488  			return v.Value.(tensor.Tensor).Strides()
   489  		case tensor.Tensor:
   490  			return v.Strides()
   491  		default:
   492  			panic(fmt.Sprintf("Unhandled type for Strides(): %T. Using fallback method and assuming dense tensor types", n.boundTo))
   493  		}
   494  	}
   495  	return n.shape.CalcStrides()
   496  }
   497  
   498  // Device returns the device the data will be on
   499  func (n *Node) Device() Device { return n.dataOn }
   500  
   501  // Op returns the Op of the node
   502  func (n *Node) Op() Op { return n.op }
   503  
   504  // IsVec returns whether this node is a vector
   505  func (n *Node) IsVec() bool { return n.IsVector() }
   506  
   507  // Name returns the name of the node. If a name was specified and it is too long,
   508  // the short name will be used instead (except in inputs)
   509  //
   510  // The short name is typically of the form: OpName(%1, %2 ...), making it read more like a function call
   511  func (n *Node) Name() string {
   512  	if n.name != "" {
   513  		return n.name
   514  	}
   515  
   516  	var buf bytes.Buffer
   517  	fmt.Fprintf(&buf, "%s(", n.op)
   518  	for i, child := range n.children {
   519  		fmt.Fprintf(&buf, "%%%x", child.id)
   520  		if i < len(n.children)-1 {
   521  			buf.WriteString(", ")
   522  		}
   523  	}
   524  	buf.WriteString(")")
   525  	return buf.String()
   526  }
   527  
   528  // WriteHash writes the hash to the provided Hash32.
   529  func (n *Node) WriteHash(h hash.Hash32) {
   530  	fmt.Fprintf(h, "%v%v", n.t, n.shape)
   531  
   532  	if n.isInput() {
   533  		h.Write([]byte(n.name))
   534  	} else {
   535  
   536  		n.op.WriteHash(h)
   537  	}
   538  
   539  	// if len(n.children) == 0 {
   540  	// 	binary.Write(h, binary.LittleEndian, byte(0))
   541  	// }
   542  
   543  	binary.Write(h, binary.LittleEndian, byte(len(n.children)))
   544  	for _, child := range n.children {
   545  		binary.Write(h, binary.LittleEndian, child.Hashcode())
   546  	}
   547  
   548  }
   549  
   550  // Hashcode provides the hash for the tree, assuming that the node is the root of the tree.
   551  // Original implementation was here by Vatine (who's apparently 80 years old and using SO!?!):
   552  //		http://stackoverflow.com/questions/1988665/hashing-a-tree-structure
   553  func (n *Node) Hashcode() uint32 {
   554  	if n.hashed {
   555  		return n.hash
   556  	}
   557  	h := fnv.New32a()
   558  	n.WriteHash(h)
   559  	n.hash = h.Sum32()
   560  	n.hashed = true
   561  	return n.hash
   562  }
   563  
   564  // ToDot returns the graph as a graphviz compatible string.
   565  // DEPRECATED: This function will be removed in the next release, please use the encoding/dot package
   566  func (n *Node) ToDot() string {
   567  	graphName := exprgraphClust
   568  
   569  	g := gographviz.NewEscape()
   570  	g.SetName(graphName)
   571  	g.SetDir(true)
   572  
   573  	g.AddAttr(exprgraphClust, "splines", "spline")
   574  	g.AddAttr(exprgraphClust, "nodesep", "0.5")
   575  	g.AddAttr(exprgraphClust, "ranksep", "1.2 equally")
   576  
   577  	seen := make(map[*Node]string)
   578  	n.dot(g, graphName, seen)
   579  
   580  	return g.String()
   581  }
   582  
   583  // RestrictedToDot prints the graphviz compatible string but does not print the entire tree
   584  // up and down indicates how many levels to look up, and how many levels to look down
   585  func (n *Node) RestrictedToDot(up, down int) string {
   586  	if n.g == nil {
   587  		return n.ToDot()
   588  	}
   589  
   590  	g := n.g
   591  	var ns, upQ, downQ Nodes
   592  
   593  	//	up
   594  	ns = Nodes{n}
   595  	upQ = Nodes{n}
   596  	for l := 0; l < up; l++ {
   597  		origLen := len(upQ)
   598  		for i := 0; i < origLen; i++ {
   599  			qn := upQ[i]
   600  			toQN := sliceNodesToNodes(graph.NodesOf(g.To(qn.ID())))
   601  			upQ = append(upQ, toQN...)
   602  			ns = append(ns, toQN...)
   603  		}
   604  		upQ = upQ[origLen:]
   605  	}
   606  
   607  	// down
   608  	downQ = Nodes{n}
   609  	for d := 0; d < down; d++ {
   610  		origLen := len(downQ)
   611  		for i := 0; i < origLen; i++ {
   612  			qn := downQ[i]
   613  			downQ = append(downQ, qn.children...)
   614  			ns = append(ns, qn.children...)
   615  		}
   616  		downQ = downQ[origLen:]
   617  	}
   618  
   619  	sg := g.subgraph(ns, false)
   620  
   621  	n.ofInterest = true
   622  	defer func() {
   623  		n.ofInterest = false
   624  	}()
   625  	return sg.ToDot()
   626  }
   627  
   628  // String() implements the fmt.Stringer interface
   629  func (n *Node) String() string {
   630  	var buf bytes.Buffer
   631  	if n.Name() != "" {
   632  		fmt.Fprintf(&buf, "%s :: ", n.Name())
   633  	} else {
   634  		fmt.Fprintf(&buf, "%s :: ", n.op)
   635  	}
   636  	if c, ok := n.op.(constant); ok {
   637  		fmt.Fprintf(&buf, "%v{%v}", n.t, c.Value())
   638  	} else {
   639  		fmt.Fprintf(&buf, "%v", n.t)
   640  	}
   641  	return buf.String()
   642  }
   643  
   644  // private methods
   645  
   646  // TODO: check type, check shape, check if needsGrad -> promote to dualValue
   647  func (n *Node) bind(v Value) error {
   648  	if n.boundTo == nil {
   649  		n.boundTo = v
   650  		return nil
   651  	}
   652  
   653  	if dv, ok := n.boundTo.(*dualValue); ok {
   654  		if vdv, ok := v.(*dualValue); ok {
   655  			if vdv == dv {
   656  				return nil
   657  			}
   658  			if n.isRandom() {
   659  				// then simply replace the value in it
   660  				dv.Value = vdv.Value
   661  				return nil
   662  			}
   663  			// n.boundTo = vdv
   664  			// return nil
   665  			panic("Undefined behaviour") // no seriously there literally is no defined behaviour of what should the right thing be. I'll come back to this TODO.
   666  		}
   667  		dv.Value = v
   668  		return nil
   669  	}
   670  
   671  	n.boundTo = v
   672  
   673  	return nil
   674  }
   675  
   676  // bindCopy copies the value if to the bound value.
   677  func (n *Node) bindCopy(v Value) (err error) {
   678  	if n.boundTo == nil {
   679  		var cloned Value
   680  		if cloned, err = CloneValue(v); err != nil {
   681  			return
   682  		}
   683  		n.boundTo = cloned
   684  		return nil
   685  	}
   686  
   687  	var copied Value
   688  	if dv, ok := n.boundTo.(*dualValue); ok {
   689  
   690  		if vdv, ok := v.(*dualValue); ok {
   691  			if vdv == dv {
   692  				return nil // no need to copy!
   693  			}
   694  
   695  			if n.isRandom() {
   696  				// returnValue(dv.Value)
   697  				dv.Value = vdv.Value
   698  				return nil
   699  			}
   700  
   701  			return errors.Errorf("Cannot yet handle bindCopy() of *dualValue into *dualValue") // TODO FIX
   702  		}
   703  		if copied, err = Copy(dv.Value, v); err != nil {
   704  			return errors.Wrapf(err, "Failed to copy while binding to node with *dualValue")
   705  		}
   706  		dv.Value = copied // in case they're scalars
   707  		return nil
   708  	}
   709  	if copied, err = Copy(n.boundTo, v); err != nil {
   710  		return errors.Wrapf(err, "Failed to copy while binding to node")
   711  	}
   712  	n.boundTo = copied // in case it's a scalar
   713  	return nil
   714  }
   715  
   716  // unbind releases the values back to the pool
   717  func (n *Node) unbind() {
   718  	if n.boundTo == nil {
   719  		return
   720  	}
   721  
   722  	if dv, ok := n.boundTo.(*dualValue); ok {
   723  		returnDV(dv)
   724  	}
   725  
   726  	if t, ok := n.boundTo.(tensor.Tensor); ok {
   727  		returnTensor(t)
   728  	}
   729  	n.boundTo = nil
   730  }
   731  
   732  func (n *Node) dotCluster() string {
   733  	var group string
   734  	var isConst bool
   735  	var isInput = n.isInput()
   736  
   737  	if n.op != nil {
   738  		_, isConst = n.op.(constant)
   739  	}
   740  
   741  	switch {
   742  	case isConst:
   743  		group = constantsClust
   744  	case isInput:
   745  		group = inputsClust
   746  	case n.group == "":
   747  		group = exprgraphClust
   748  	default:
   749  		group = n.group
   750  	}
   751  	return group
   752  }
   753  
   754  func (n *Node) dot(g *gographviz.Escape, graphName string, seen map[*Node]string) string {
   755  	var id string
   756  	var ok bool
   757  	if id, ok = seen[n]; !ok {
   758  		id = n.dotString(g, graphName)
   759  		seen[n] = id
   760  	} else {
   761  		return id
   762  	}
   763  
   764  	for i, child := range n.children {
   765  		childID := child.dot(g, graphName, seen)
   766  		edgeAttrs := map[string]string{
   767  			"taillabel":  fmt.Sprintf(" %d ", i+1),
   768  			"labelfloat": "false",
   769  		}
   770  
   771  		g.AddPortEdge(id, id+":anchor:s", childID, childID+":anchor:n", true, edgeAttrs)
   772  	}
   773  	return id
   774  }
   775  
   776  func (n *Node) fix() {
   777  	if n.IsScalar() {
   778  		n.shape = scalarShape
   779  	}
   780  
   781  	if n.isConstant() {
   782  		return
   783  	}
   784  
   785  	if n.g == nil {
   786  		panic(fmt.Sprintf("no graph supplied %v", n))
   787  	}
   788  }
   789  
   790  func (n *Node) fixChildren() {
   791  	if n.g == nil {
   792  		return
   793  	}
   794  
   795  	for i, child := range n.children {
   796  		newChild := n.g.AddNode(child)
   797  		if child != newChild {
   798  			n.children[i] = newChild
   799  		}
   800  	}
   801  }
   802  
   803  func (n *Node) fixEdges() {
   804  	if n.g == nil {
   805  		return
   806  	}
   807  
   808  	if len(n.children) > 0 {
   809  		for _, child := range n.children {
   810  			e := edge{from: n, to: child}
   811  			n.g.SetEdge(e)
   812  		}
   813  	} else {
   814  		n.g.leaves = append(n.g.leaves, n)
   815  	}
   816  }
   817  
   818  func (n *Node) setShape(s tensor.Shape, inferred bool) {
   819  	n.shape = s
   820  	n.inferredShape = inferred
   821  }
   822  
   823  func (n *Node) setGroup(grp string) {
   824  	n.group = grp
   825  }
   826  
   827  func (n *Node) clone(opts ...NodeConsOpt) *Node {
   828  	if n.isInput() {
   829  		return n
   830  	}
   831  
   832  	nn := newNode(WithChildren(n.children),
   833  		WithType(n.t),
   834  		WithOp(n.op),
   835  		WithName(n.name),
   836  		In(n.g),
   837  	)
   838  
   839  	for _, opt := range opts {
   840  		opt(nn)
   841  	}
   842  
   843  	// if the shape is already known...
   844  	if n.shape != nil {
   845  		nn.shape = n.shape
   846  		nn.inferredShape = n.inferredShape
   847  	}
   848  
   849  	return nn
   850  }
   851  
   852  func (n *Node) diffWRT() []bool {
   853  	if sdop, ok := n.op.(SDOp); ok {
   854  		return sdop.DiffWRT(len(n.children))
   855  	}
   856  	return nil
   857  }
   858  
   859  // dfs but does not use channels. useful for extracting paths. used particularly in test
   860  func (n *Node) seqWalk() Nodes {
   861  	retVal := Nodes{n}
   862  	for _, child := range n.children {
   863  		retVal = append(retVal, child.seqWalk()...)
   864  	}
   865  	return retVal
   866  }
   867  
   868  // dotString returns the ID of the node.
   869  func (n *Node) dotString(g *gographviz.Escape, graphName string) string {
   870  	var buf bytes.Buffer
   871  	if err := exprNodeTempl.ExecuteTemplate(&buf, "node", n); err != nil {
   872  		panic(err)
   873  	}
   874  
   875  	id := fmt.Sprintf("Node_%p", n)
   876  	label := buf.String()
   877  	attrs := map[string]string{
   878  		"fontname": "monospace",
   879  		"shape":    "none",
   880  		"label":    label,
   881  	}
   882  
   883  	g.AddNode(graphName, id, attrs)
   884  	return id
   885  }