gorgonia.org/gorgonia@v0.9.17/x/vm/node.go (about) 1 package xvm 2 3 import ( 4 "context" 5 "errors" 6 7 "gorgonia.org/gorgonia" 8 ) 9 10 // Doer is implementing the Do method of gorgonia's Op interface 11 type Doer interface { 12 Do(...gorgonia.Value) (gorgonia.Value, error) 13 } 14 15 type node struct { 16 id int64 17 op Doer 18 output gorgonia.Value 19 outputC chan gorgonia.Value 20 receivedValues int 21 err error 22 inputValues []gorgonia.Value 23 inputC chan ioValue 24 } 25 26 // ioValue is a value with a position. as the infrastructure cannot guaranty the 27 // order of the input values, we use this structure carrying the position of the operator. 28 // this is mandatory for non commutative operations 29 type ioValue struct { 30 pos int 31 v gorgonia.Value 32 } 33 34 type stateFn func(context.Context, *node) stateFn 35 36 func defaultState(_ context.Context, n *node) stateFn { 37 n.receivedValues = 0 38 n.err = nil 39 if n.op == nil { 40 return emitOutput 41 } 42 return receiveInput 43 } 44 45 func receiveInput(ctx context.Context, n *node) stateFn { 46 // if inputC is nil, it is a variable or a constant, don't 47 // wait for any input 48 if n.inputC == nil { 49 return computeFwd 50 } 51 select { 52 case <-ctx.Done(): 53 n.err = ctx.Err() 54 return nil 55 case input := <-n.inputC: 56 if input.pos >= len(n.inputValues) { 57 n.err = errors.New("bad arity") 58 return nil 59 } 60 n.receivedValues++ 61 n.inputValues[input.pos] = input.v 62 if n.receivedValues < len(n.inputValues) { 63 return receiveInput 64 } 65 } 66 return computeFwd 67 } 68 69 func computeFwd(_ context.Context, n *node) stateFn { 70 v, err := n.op.Do(n.inputValues...) 71 if err != nil { 72 n.err = err 73 return nil 74 } 75 n.output = v 76 return emitOutput 77 } 78 79 func emitOutput(ctx context.Context, n *node) stateFn { 80 if n == nil || n.outputC == nil { 81 return nil 82 } 83 select { 84 case <-ctx.Done(): 85 n.err = ctx.Err() 86 return nil 87 case n.outputC <- n.output: 88 } 89 return nil 90 } 91 92 func computeBackward(_ context.Context, _ *node) stateFn { 93 return nil 94 } 95 96 func (n *node) Compute(ctx context.Context) error { 97 for state := defaultState; state != nil; { 98 t := trace(ctx, nil, n, state) 99 state = state(ctx, n) 100 trace(ctx, t, nil, nil) 101 } 102 return n.err 103 } 104 105 func newOp(n *gorgonia.Node, hasOutputChan bool) *node { 106 if n == nil { 107 return nil 108 } 109 var outputC chan gorgonia.Value 110 if hasOutputChan { 111 outputC = make(chan gorgonia.Value, 0) 112 113 } 114 return &node{ 115 id: n.ID(), 116 op: n.Op(), 117 inputValues: make([]gorgonia.Value, n.Op().Arity()), 118 inputC: make(chan ioValue, 0), 119 outputC: outputC, 120 } 121 } 122 123 func newInput(n *gorgonia.Node) *node { 124 if n == nil { 125 return nil 126 } 127 return &node{ 128 id: n.ID(), 129 output: n.Value(), 130 outputC: make(chan gorgonia.Value, 0), 131 } 132 }