gorgonia.org/gorgonia@v0.9.17/node.go (about) 1 package gorgonia 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "fmt" 7 "hash" 8 "hash/fnv" 9 10 "github.com/awalterschulze/gographviz" 11 "github.com/chewxy/hm" 12 "github.com/pkg/errors" 13 "gonum.org/v1/gonum/graph" 14 "gorgonia.org/gorgonia/internal/encoding" 15 "gorgonia.org/tensor" 16 ) 17 18 // A Node is a node in the computation graph 19 type Node struct { 20 // metadata of the node 21 t hm.Type // pruned types only plz 22 shape tensor.Shape 23 24 // this node is the result of applying the op to the children 25 op Op 26 children Nodes // shortcut, instead of having to go through the graph 27 28 // For nicely grouping stuff in graphviz. 29 // TODO: Should this be in *Node? 30 name string 31 // DEPRECATED: the group attribute will be removed in the next version in favor of groups 32 group string 33 // the grouping notion is only useful for exporting to another format 34 groups encoding.Groups 35 36 g *ExprGraph // this node belongs in this graph 37 38 // value bondage 39 // inputs are bound to values directly 40 boundTo Value 41 dataOn Device // where is the data on 42 43 // to track derivations 44 derivOf Nodes 45 deriv *Node 46 47 // for hashing nodes 48 id int64 // id is the ID at which the node is added to the graph 49 hash uint32 50 51 hashed bool 52 inferredShape bool // is shape inferred? 53 unchanged bool // has this node been modified 54 isStmt bool // is this a statement node 55 ofInterest bool // is this node of particular interest? (for debugging) 56 } 57 58 // NodeConsOpt is a function that provides construction options for any Node. 59 type NodeConsOpt func(*Node) 60 61 // WithType is a node construction option to set a node to the specified type. 62 // Types in *Node are immutable once set. If the type has already been specified in the node, 63 // a check will be made to see if the both types are the same. If it isn't, it will panic. 64 func WithType(t hm.Type) NodeConsOpt { 65 f := func(n *Node) { 66 if n.t == nil { 67 n.t = t 68 } else if !n.t.Eq(t) { 69 panic(fmt.Sprintf("Node's type is %v. Asking to construct a Node with %v", n.t, t)) 70 } 71 } 72 return f 73 } 74 75 // WithChildren sets the children of a node to the specified chidren. 76 // This construction option does NOT check if existing children exists, and will overwrite the existing children. 77 func WithChildren(children Nodes) NodeConsOpt { 78 f := func(n *Node) { 79 n.children = children 80 } 81 return f 82 } 83 84 // WithOp is a node construction option to set a node's Op to the specified Op. 85 // `Op`s in `*Node`s are immutable once set and cannot be changed. If the node already has an Op specified 86 // a check will be made to see if the provided Op and the one already specified in the `*Node` is the same - 87 // do note that comparison of Ops is done using the `Hashcode()` method of Ops, and hash collisions MAY occur - 88 // If both ops are different, this function will panic. 89 func WithOp(op Op) NodeConsOpt { 90 f := func(n *Node) { 91 if n.op != nil { 92 if op.Hashcode() != n.op.Hashcode() { 93 panic(fmt.Sprintf("Node Ops are immutable. Cannot set op %v", op)) 94 } 95 return 96 } 97 n.op = op 98 if _, ok := op.(stmtOp); ok { 99 n.isStmt = true 100 } 101 } 102 return f 103 } 104 105 // In is a node construction option to set a node's graph. 106 // A `*Node`'s graph is immutable. If the graph has already been set, a check will be made that the specifiec *Graph 107 // and the *Graph set in *Node are the same. If they are not, the function will panic/ 108 func In(g *ExprGraph) NodeConsOpt { 109 f := func(n *Node) { 110 if n.g != nil { 111 if g != n.g { 112 panic(fmt.Sprintf("Node Graphs are immutable. Cannot set g %v", g)) 113 } 114 } 115 n.g = g 116 } 117 return f 118 } 119 120 // WithName is a node construction option that gives the *Node the provided name. This is especially useful in debugging graphs. 121 func WithName(name string) NodeConsOpt { 122 f := func(n *Node) { 123 n.name = name 124 } 125 return f 126 } 127 128 // WithValue is a node construction option that binds the value to the *Node. This function may panic if: 129 // - Gorgonia was unable to convert interface{} into a Value. 130 // - The type of the Value does not match the type of the nodes. 131 func WithValue(any interface{}) NodeConsOpt { 132 v, t, _, err := anyToValue(any) 133 if err != nil { 134 panic(err) 135 } 136 137 f := func(n *Node) { 138 if n.t == nil { 139 n.t = t 140 } else if !n.t.Eq(t) { 141 // scalars are exempted 142 // if n.t is a scalar type, and the `anyToValue` returns a vector that is scalar-like, 143 // we'll mark it as a scalar 144 tt, ok1 := t.(TensorType) 145 ndt, ok2 := n.t.(tensor.Dtype) 146 147 if !(ok1 && ok2 && v.Shape().IsScalarEquiv() && tt.Of == ndt) { 148 panic(fmt.Sprintf("TypeError: Want %v, Got %v instead (%T %T)", n.t, t, n.t, t)) // yes this is a runtime error 149 } 150 v.(tensor.Tensor).Reshape() // rehsape to scalar 151 152 } 153 154 n.bind(v) 155 if n.shape == nil { 156 n.shape = v.Shape() 157 } 158 } 159 return f 160 } 161 162 // WithGrad is a node construction option that binds the value to the *Node. This function may panic if: 163 // - There isn't already a value associated with the node (.boundTo == nil) 164 // - The type of the Value does not match the value of the node. 165 func WithGrad(any interface{}) NodeConsOpt { 166 v, t, _, err := anyToValue(any) 167 if err != nil { 168 panic(err) 169 } 170 f := func(n *Node) { 171 if n.boundTo == nil { 172 panic("No value already bound to node") 173 } 174 if !TypeOf(n.boundTo).Eq(t) { 175 panic("Different types ") 176 } 177 178 if dv, ok := n.boundTo.(*dualValue); !ok { 179 if err := n.bind(&dualValue{Value: n.boundTo, d: v}); err != nil { 180 panic(err) 181 } 182 } else { 183 dv.d = v 184 } 185 } 186 return f 187 } 188 189 // WithInit is a node construction option to initialize a *Node with the InitWFn provided. 190 func WithInit(fn InitWFn) NodeConsOpt { 191 f := func(n *Node) { 192 dt, err := dtypeOf(n.t) 193 if err != nil { 194 panic(err) 195 } 196 197 var v Value 198 v = tensor.New(tensor.WithShape(n.shape...), tensor.WithBacking(fn(dt, n.shape...))) 199 WithValue(v)(n) 200 } 201 return f 202 } 203 204 // WithShape is a node construction option to initialize a *Node with a particular shape. 205 // This function panics if the shape's dimensions do not match the specified dimensions of the *Node. 206 func WithShape(shp ...int) NodeConsOpt { 207 s := tensor.Shape(tensor.BorrowInts(len(shp))) 208 copy(s, shp) 209 f := func(n *Node) { 210 if n.t == nil && n.shape == nil { 211 n.shape = s 212 return 213 } 214 nd := n.Dims() 215 isVec := s.IsColVec() || s.IsRowVec() 216 acceptVec := (isVec && (nd == 1)) 217 sameDims := nd == s.Dims() 218 acceptScalar := nd == 0 && scalarEquiv(s) 219 220 if !acceptVec && !sameDims && !acceptScalar { 221 panic(fmt.Sprintf("Node %v, has %d dimensions(Shape: %v). Input shape is %v, which has %d dimensions", n, n.Dims(), n.shape, s, s.Dims())) 222 } 223 n.shape = s 224 } 225 return f 226 } 227 228 // WithGroupName is a node construction option to group a *Node within a particular group. This option is useful for debugging with graphs. 229 // This function is deprecated and will proabably be remove in the next version. 230 func WithGroupName(name string) NodeConsOpt { 231 f := func(n *Node) { 232 if n.group == "" { 233 n.group = name 234 } 235 } 236 return f 237 } 238 239 // withGroup is a node construction option to group a *Node within a particular group. This option is useful for debugging with graphs. 240 func withGroup(group encoding.Group) NodeConsOpt { 241 f := func(n *Node) { 242 n.groups = n.groups.Upsert(group) 243 } 244 return f 245 } 246 247 // Groups to fulfil the encoding Grouper interface 248 func (n *Node) Groups() encoding.Groups { 249 var isConst bool 250 var isInput = n.isInput() 251 252 if n.op != nil { 253 _, isConst = n.op.(constant) 254 } 255 256 switch { 257 case isConst: 258 n.groups = n.groups.Upsert(encoding.ConstantCluster) 259 case isInput: 260 n.groups = n.groups.Upsert(encoding.InputCluster) 261 default: 262 n.groups = n.groups.Upsert(encoding.ExprGraphCluster) 263 } 264 return n.groups 265 } 266 267 func newNode(opts ...NodeConsOpt) *Node { 268 n := borrowNode() 269 n.dataOn = CPU 270 n.id = -1 271 n.t = nil 272 n.shape = nil 273 274 for _, opt := range opts { 275 opt(n) 276 } 277 n.fix() 278 279 incrNN() 280 return n 281 } 282 283 // NewUniqueNode creates a new unique node in a graph. If no graph was specified in the construction options then it will just return a graphless node. 284 func NewUniqueNode(opts ...NodeConsOpt) *Node { 285 n := newNode(opts...) 286 if n.g == nil { 287 return n 288 } 289 n.fixChildren() // ensure that all the kids are in the graph first 290 291 m := n.g.AddNode(n) 292 if n != m { 293 returnNode(n) 294 } 295 m.fixEdges() 296 return m 297 } 298 299 // ID returns the ID of the node. This satisfies the gonum/graph.Node interface 300 func (n *Node) ID() int64 { return n.id } 301 302 // Node returns itself. This sorts of monoidal patterns are useful for compositions via interfaces. 303 func (n *Node) Node() *Node { return n } 304 305 // Nodes returns n as a slice of *Node. Again, this is mostly useful for interfaces 306 func (n *Node) Nodes() Nodes { return Nodes{n} } 307 308 // Err always returns nil. However, this method is implemented to enable nicer composition of functions 309 func (n *Node) Err() error { return nil } 310 311 func (n *Node) DataSize() int { return n.Shape().TotalSize() } 312 313 func (n *Node) DerivOf() Nodes { return n.derivOf } 314 315 func (n *Node) Deriv() *Node { return n.deriv } 316 317 // helper functions to help compilation process 318 func (n *Node) isArg() bool { return n.op == nil } 319 func (n *Node) isInput() bool { return (n.isArg() || n.isRandom()) && !n.isStmt } 320 func (n *Node) isMutable() bool { return !n.isInput() && n.op.ReturnsPtr() } 321 func (n *Node) isConstant() bool { _, ok := n.op.(constant); return ok } 322 func (n *Node) isRandom() bool { _, ok := n.op.(randomOp); return ok } 323 324 func (n *Node) isRoot() bool { 325 if n.g == nil { 326 return true 327 } 328 return len(n.g.to[n]) == 0 329 } 330 331 // IsVar returns true if the node represents a differentiable variable (i.e. it's an argument to the function that is not a statement) 332 func (n *Node) IsVar() bool { return n.isArg() && !n.isStmt && !n.isConstant() } 333 334 // type related isX() helper methods 335 336 // IsScalar indicates if a node represents a a scalar value. This is based on the type of the node, not the actual value associated with the node 337 func (n *Node) IsScalar() bool { _, ok := n.t.(tensor.Dtype); return ok } 338 339 // IsVector indicates if a node represents a vector value. This is based on the type of the node, not the actual value associated with the node 340 func (n *Node) IsVector() bool { 341 if t, ok := n.t.(TensorType); ok { 342 return t.Dims == 1 343 } 344 345 return false 346 } 347 348 // IsColVec indicates if a node represents a Column Vector. This is based on the type of the node, not the actual value associated with the node 349 func (n *Node) IsColVec() bool { 350 if _, ok := n.t.(TensorType); ok { 351 if n.shape != nil { 352 return n.shape.IsColVec() 353 } 354 } 355 return false 356 } 357 358 // IsRowVec indicates if a node represents a Row Vector. This is based on the type of the node, not the actual value associated with the node 359 func (n *Node) IsRowVec() bool { 360 if _, ok := n.t.(TensorType); ok { 361 if n.shape != nil { 362 return n.shape.IsRowVec() 363 } 364 } 365 return false 366 } 367 368 // IsMatrix indicates if a node represents a matrix. This is based on the type of the node, not the actual value associated with the node 369 func (n *Node) IsMatrix() bool { 370 if _, ok := n.t.(TensorType); ok { 371 return n.shape.Dims() == 2 372 } 373 return false 374 } 375 376 // methods 377 378 // Graph returns the graph of the node 379 func (n *Node) Graph() *ExprGraph { return n.g } 380 381 // CloneTo clones the node into a new graph. If CloneTo() is called on the same graph as the n, it will return n. The reason this is done is because 382 // at any given time, every node should be unique in the *ExprGraph. 383 // 384 //TODO: clone children as well (this means that CloneTo() is only currently suitable fo input nodes) 385 func (n *Node) CloneTo(g *ExprGraph) *Node { 386 if n.g != nil && g == n.g { 387 return n 388 } 389 n2 := n.Clone().(*Node) 390 n2.g = g 391 n2 = g.AddNode(n2) 392 return n2 393 } 394 395 // Clone clones the node. There are some caveats: 396 // - the graph is not copied over - the node essentially does not belong to a collection 397 // - there is no ID 398 // - the children are not cloned 399 func (n *Node) Clone() (retVal interface{}) { 400 n2 := newNode(In(n.g), WithOp(n.op), WithName(n.name), WithType(n.t)) 401 if n.shape != nil { 402 n2.shape = n.shape.Clone() 403 n2.inferredShape = n.inferredShape 404 } 405 406 if n.boundTo != nil { 407 var err error 408 if n2.boundTo, err = CloneValue(n.boundTo); err != nil { 409 panic(err) 410 } 411 } 412 413 // reset 414 n2.g = nil 415 416 // other things 417 n2.name = n.name 418 n2.group = n.group 419 n2.dataOn = n.dataOn 420 n2.hash = n.hash 421 422 n2.hashed = n.hashed 423 n2.inferredShape = n.inferredShape 424 n2.unchanged = n.unchanged 425 n2.isStmt = n.isStmt 426 n2.ofInterest = n.ofInterest 427 return n2 428 } 429 430 // Value returns the valuse bound to the node. May return nil 431 func (n *Node) Value() Value { 432 if n.isConstant() { 433 return n.op.(constant).Value() 434 } 435 if dv, ok := n.boundTo.(*dualValue); ok { 436 return dv.Value 437 } 438 return n.boundTo 439 } 440 441 // Grad returns the gradient if there is one. 442 func (n *Node) Grad() (Value, error) { 443 if dv, ok := n.boundTo.(*dualValue); ok { 444 return dv.d, nil 445 } 446 if n.deriv != nil { 447 return n.deriv.Value(), nil 448 } 449 450 return nil, errors.Errorf("No Gradient node/value found for %T", n) 451 } 452 453 // Dims indicates how many dimensions the node's result has 454 func (n *Node) Dims() int { 455 if n.shape != nil { 456 return n.shape.Dims() 457 } 458 switch nt := n.t.(type) { 459 case TensorType: 460 return nt.Dims 461 case tensor.Dtype: 462 return 0 463 default: 464 panic(fmt.Sprintf("Dims undefined for %v(%T)", nt, nt)) 465 } 466 } 467 468 // Type returns the type of the node 469 func (n *Node) Type() hm.Type { return n.t } 470 471 // Dtype returns the dtype of the node 472 func (n *Node) Dtype() tensor.Dtype { 473 dt, err := dtypeOf(n.t) 474 if err != nil { 475 panic(err) 476 } 477 return dt 478 } 479 480 // Shape returns the shape of the node 481 func (n *Node) Shape() tensor.Shape { return n.shape.Clone() } 482 483 // Strides returns the strides of the value of the node 484 func (n *Node) Strides() []int { 485 if n.boundTo != nil { 486 switch v := n.boundTo.(type) { 487 case *dualValue: 488 return v.Value.(tensor.Tensor).Strides() 489 case tensor.Tensor: 490 return v.Strides() 491 default: 492 panic(fmt.Sprintf("Unhandled type for Strides(): %T. Using fallback method and assuming dense tensor types", n.boundTo)) 493 } 494 } 495 return n.shape.CalcStrides() 496 } 497 498 // Device returns the device the data will be on 499 func (n *Node) Device() Device { return n.dataOn } 500 501 // Op returns the Op of the node 502 func (n *Node) Op() Op { return n.op } 503 504 // IsVec returns whether this node is a vector 505 func (n *Node) IsVec() bool { return n.IsVector() } 506 507 // Name returns the name of the node. If a name was specified and it is too long, 508 // the short name will be used instead (except in inputs) 509 // 510 // The short name is typically of the form: OpName(%1, %2 ...), making it read more like a function call 511 func (n *Node) Name() string { 512 if n.name != "" { 513 return n.name 514 } 515 516 var buf bytes.Buffer 517 fmt.Fprintf(&buf, "%s(", n.op) 518 for i, child := range n.children { 519 fmt.Fprintf(&buf, "%%%x", child.id) 520 if i < len(n.children)-1 { 521 buf.WriteString(", ") 522 } 523 } 524 buf.WriteString(")") 525 return buf.String() 526 } 527 528 // WriteHash writes the hash to the provided Hash32. 529 func (n *Node) WriteHash(h hash.Hash32) { 530 fmt.Fprintf(h, "%v%v", n.t, n.shape) 531 532 if n.isInput() { 533 h.Write([]byte(n.name)) 534 } else { 535 536 n.op.WriteHash(h) 537 } 538 539 // if len(n.children) == 0 { 540 // binary.Write(h, binary.LittleEndian, byte(0)) 541 // } 542 543 binary.Write(h, binary.LittleEndian, byte(len(n.children))) 544 for _, child := range n.children { 545 binary.Write(h, binary.LittleEndian, child.Hashcode()) 546 } 547 548 } 549 550 // Hashcode provides the hash for the tree, assuming that the node is the root of the tree. 551 // Original implementation was here by Vatine (who's apparently 80 years old and using SO!?!): 552 // http://stackoverflow.com/questions/1988665/hashing-a-tree-structure 553 func (n *Node) Hashcode() uint32 { 554 if n.hashed { 555 return n.hash 556 } 557 h := fnv.New32a() 558 n.WriteHash(h) 559 n.hash = h.Sum32() 560 n.hashed = true 561 return n.hash 562 } 563 564 // ToDot returns the graph as a graphviz compatible string. 565 // DEPRECATED: This function will be removed in the next release, please use the encoding/dot package 566 func (n *Node) ToDot() string { 567 graphName := exprgraphClust 568 569 g := gographviz.NewEscape() 570 g.SetName(graphName) 571 g.SetDir(true) 572 573 g.AddAttr(exprgraphClust, "splines", "spline") 574 g.AddAttr(exprgraphClust, "nodesep", "0.5") 575 g.AddAttr(exprgraphClust, "ranksep", "1.2 equally") 576 577 seen := make(map[*Node]string) 578 n.dot(g, graphName, seen) 579 580 return g.String() 581 } 582 583 // RestrictedToDot prints the graphviz compatible string but does not print the entire tree 584 // up and down indicates how many levels to look up, and how many levels to look down 585 func (n *Node) RestrictedToDot(up, down int) string { 586 if n.g == nil { 587 return n.ToDot() 588 } 589 590 g := n.g 591 var ns, upQ, downQ Nodes 592 593 // up 594 ns = Nodes{n} 595 upQ = Nodes{n} 596 for l := 0; l < up; l++ { 597 origLen := len(upQ) 598 for i := 0; i < origLen; i++ { 599 qn := upQ[i] 600 toQN := sliceNodesToNodes(graph.NodesOf(g.To(qn.ID()))) 601 upQ = append(upQ, toQN...) 602 ns = append(ns, toQN...) 603 } 604 upQ = upQ[origLen:] 605 } 606 607 // down 608 downQ = Nodes{n} 609 for d := 0; d < down; d++ { 610 origLen := len(downQ) 611 for i := 0; i < origLen; i++ { 612 qn := downQ[i] 613 downQ = append(downQ, qn.children...) 614 ns = append(ns, qn.children...) 615 } 616 downQ = downQ[origLen:] 617 } 618 619 sg := g.subgraph(ns, false) 620 621 n.ofInterest = true 622 defer func() { 623 n.ofInterest = false 624 }() 625 return sg.ToDot() 626 } 627 628 // String() implements the fmt.Stringer interface 629 func (n *Node) String() string { 630 var buf bytes.Buffer 631 if n.Name() != "" { 632 fmt.Fprintf(&buf, "%s :: ", n.Name()) 633 } else { 634 fmt.Fprintf(&buf, "%s :: ", n.op) 635 } 636 if c, ok := n.op.(constant); ok { 637 fmt.Fprintf(&buf, "%v{%v}", n.t, c.Value()) 638 } else { 639 fmt.Fprintf(&buf, "%v", n.t) 640 } 641 return buf.String() 642 } 643 644 // private methods 645 646 // TODO: check type, check shape, check if needsGrad -> promote to dualValue 647 func (n *Node) bind(v Value) error { 648 if n.boundTo == nil { 649 n.boundTo = v 650 return nil 651 } 652 653 if dv, ok := n.boundTo.(*dualValue); ok { 654 if vdv, ok := v.(*dualValue); ok { 655 if vdv == dv { 656 return nil 657 } 658 if n.isRandom() { 659 // then simply replace the value in it 660 dv.Value = vdv.Value 661 return nil 662 } 663 // n.boundTo = vdv 664 // return nil 665 panic("Undefined behaviour") // no seriously there literally is no defined behaviour of what should the right thing be. I'll come back to this TODO. 666 } 667 dv.Value = v 668 return nil 669 } 670 671 n.boundTo = v 672 673 return nil 674 } 675 676 // bindCopy copies the value if to the bound value. 677 func (n *Node) bindCopy(v Value) (err error) { 678 if n.boundTo == nil { 679 var cloned Value 680 if cloned, err = CloneValue(v); err != nil { 681 return 682 } 683 n.boundTo = cloned 684 return nil 685 } 686 687 var copied Value 688 if dv, ok := n.boundTo.(*dualValue); ok { 689 690 if vdv, ok := v.(*dualValue); ok { 691 if vdv == dv { 692 return nil // no need to copy! 693 } 694 695 if n.isRandom() { 696 // returnValue(dv.Value) 697 dv.Value = vdv.Value 698 return nil 699 } 700 701 return errors.Errorf("Cannot yet handle bindCopy() of *dualValue into *dualValue") // TODO FIX 702 } 703 if copied, err = Copy(dv.Value, v); err != nil { 704 return errors.Wrapf(err, "Failed to copy while binding to node with *dualValue") 705 } 706 dv.Value = copied // in case they're scalars 707 return nil 708 } 709 if copied, err = Copy(n.boundTo, v); err != nil { 710 return errors.Wrapf(err, "Failed to copy while binding to node") 711 } 712 n.boundTo = copied // in case it's a scalar 713 return nil 714 } 715 716 // unbind releases the values back to the pool 717 func (n *Node) unbind() { 718 if n.boundTo == nil { 719 return 720 } 721 722 if dv, ok := n.boundTo.(*dualValue); ok { 723 returnDV(dv) 724 } 725 726 if t, ok := n.boundTo.(tensor.Tensor); ok { 727 returnTensor(t) 728 } 729 n.boundTo = nil 730 } 731 732 func (n *Node) dotCluster() string { 733 var group string 734 var isConst bool 735 var isInput = n.isInput() 736 737 if n.op != nil { 738 _, isConst = n.op.(constant) 739 } 740 741 switch { 742 case isConst: 743 group = constantsClust 744 case isInput: 745 group = inputsClust 746 case n.group == "": 747 group = exprgraphClust 748 default: 749 group = n.group 750 } 751 return group 752 } 753 754 func (n *Node) dot(g *gographviz.Escape, graphName string, seen map[*Node]string) string { 755 var id string 756 var ok bool 757 if id, ok = seen[n]; !ok { 758 id = n.dotString(g, graphName) 759 seen[n] = id 760 } else { 761 return id 762 } 763 764 for i, child := range n.children { 765 childID := child.dot(g, graphName, seen) 766 edgeAttrs := map[string]string{ 767 "taillabel": fmt.Sprintf(" %d ", i+1), 768 "labelfloat": "false", 769 } 770 771 g.AddPortEdge(id, id+":anchor:s", childID, childID+":anchor:n", true, edgeAttrs) 772 } 773 return id 774 } 775 776 func (n *Node) fix() { 777 if n.IsScalar() { 778 n.shape = scalarShape 779 } 780 781 if n.isConstant() { 782 return 783 } 784 785 if n.g == nil { 786 panic(fmt.Sprintf("no graph supplied %v", n)) 787 } 788 } 789 790 func (n *Node) fixChildren() { 791 if n.g == nil { 792 return 793 } 794 795 for i, child := range n.children { 796 newChild := n.g.AddNode(child) 797 if child != newChild { 798 n.children[i] = newChild 799 } 800 } 801 } 802 803 func (n *Node) fixEdges() { 804 if n.g == nil { 805 return 806 } 807 808 if len(n.children) > 0 { 809 for _, child := range n.children { 810 e := edge{from: n, to: child} 811 n.g.SetEdge(e) 812 } 813 } else { 814 n.g.leaves = append(n.g.leaves, n) 815 } 816 } 817 818 func (n *Node) setShape(s tensor.Shape, inferred bool) { 819 n.shape = s 820 n.inferredShape = inferred 821 } 822 823 func (n *Node) setGroup(grp string) { 824 n.group = grp 825 } 826 827 func (n *Node) clone(opts ...NodeConsOpt) *Node { 828 if n.isInput() { 829 return n 830 } 831 832 nn := newNode(WithChildren(n.children), 833 WithType(n.t), 834 WithOp(n.op), 835 WithName(n.name), 836 In(n.g), 837 ) 838 839 for _, opt := range opts { 840 opt(nn) 841 } 842 843 // if the shape is already known... 844 if n.shape != nil { 845 nn.shape = n.shape 846 nn.inferredShape = n.inferredShape 847 } 848 849 return nn 850 } 851 852 func (n *Node) diffWRT() []bool { 853 if sdop, ok := n.op.(SDOp); ok { 854 return sdop.DiffWRT(len(n.children)) 855 } 856 return nil 857 } 858 859 // dfs but does not use channels. useful for extracting paths. used particularly in test 860 func (n *Node) seqWalk() Nodes { 861 retVal := Nodes{n} 862 for _, child := range n.children { 863 retVal = append(retVal, child.seqWalk()...) 864 } 865 return retVal 866 } 867 868 // dotString returns the ID of the node. 869 func (n *Node) dotString(g *gographviz.Escape, graphName string) string { 870 var buf bytes.Buffer 871 if err := exprNodeTempl.ExecuteTemplate(&buf, "node", n); err != nil { 872 panic(err) 873 } 874 875 id := fmt.Sprintf("Node_%p", n) 876 label := buf.String() 877 attrs := map[string]string{ 878 "fontname": "monospace", 879 "shape": "none", 880 "label": label, 881 } 882 883 g.AddNode(graphName, id, attrs) 884 return id 885 }