gorgonia.org/gorgonia@v0.9.17/x/vm/node.go (about)

     1  package xvm
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  
     7  	"gorgonia.org/gorgonia"
     8  )
     9  
    10  // Doer is implementing the Do method of gorgonia's Op interface
    11  type Doer interface {
    12  	Do(...gorgonia.Value) (gorgonia.Value, error)
    13  }
    14  
    15  type node struct {
    16  	id             int64
    17  	op             Doer
    18  	output         gorgonia.Value
    19  	outputC        chan gorgonia.Value
    20  	receivedValues int
    21  	err            error
    22  	inputValues    []gorgonia.Value
    23  	inputC         chan ioValue
    24  }
    25  
    26  // ioValue is a value with a position. as the infrastructure cannot guaranty the
    27  // order of the input values, we use this structure carrying the position of the operator.
    28  // this is mandatory for non commutative operations
    29  type ioValue struct {
    30  	pos int
    31  	v   gorgonia.Value
    32  }
    33  
    34  type stateFn func(context.Context, *node) stateFn
    35  
    36  func defaultState(_ context.Context, n *node) stateFn {
    37  	n.receivedValues = 0
    38  	n.err = nil
    39  	if n.op == nil {
    40  		return emitOutput
    41  	}
    42  	return receiveInput
    43  }
    44  
    45  func receiveInput(ctx context.Context, n *node) stateFn {
    46  	// if inputC is nil, it is a variable or a constant, don't
    47  	// wait for any input
    48  	if n.inputC == nil {
    49  		return computeFwd
    50  	}
    51  	select {
    52  	case <-ctx.Done():
    53  		n.err = ctx.Err()
    54  		return nil
    55  	case input := <-n.inputC:
    56  		if input.pos >= len(n.inputValues) {
    57  			n.err = errors.New("bad arity")
    58  			return nil
    59  		}
    60  		n.receivedValues++
    61  		n.inputValues[input.pos] = input.v
    62  		if n.receivedValues < len(n.inputValues) {
    63  			return receiveInput
    64  		}
    65  	}
    66  	return computeFwd
    67  }
    68  
    69  func computeFwd(_ context.Context, n *node) stateFn {
    70  	v, err := n.op.Do(n.inputValues...)
    71  	if err != nil {
    72  		n.err = err
    73  		return nil
    74  	}
    75  	n.output = v
    76  	return emitOutput
    77  }
    78  
    79  func emitOutput(ctx context.Context, n *node) stateFn {
    80  	if n == nil || n.outputC == nil {
    81  		return nil
    82  	}
    83  	select {
    84  	case <-ctx.Done():
    85  		n.err = ctx.Err()
    86  		return nil
    87  	case n.outputC <- n.output:
    88  	}
    89  	return nil
    90  }
    91  
    92  func computeBackward(_ context.Context, _ *node) stateFn {
    93  	return nil
    94  }
    95  
    96  func (n *node) Compute(ctx context.Context) error {
    97  	for state := defaultState; state != nil; {
    98  		t := trace(ctx, nil, n, state)
    99  		state = state(ctx, n)
   100  		trace(ctx, t, nil, nil)
   101  	}
   102  	return n.err
   103  }
   104  
   105  func newOp(n *gorgonia.Node, hasOutputChan bool) *node {
   106  	if n == nil {
   107  		return nil
   108  	}
   109  	var outputC chan gorgonia.Value
   110  	if hasOutputChan {
   111  		outputC = make(chan gorgonia.Value, 0)
   112  
   113  	}
   114  	return &node{
   115  		id:          n.ID(),
   116  		op:          n.Op(),
   117  		inputValues: make([]gorgonia.Value, n.Op().Arity()),
   118  		inputC:      make(chan ioValue, 0),
   119  		outputC:     outputC,
   120  	}
   121  }
   122  
   123  func newInput(n *gorgonia.Node) *node {
   124  	if n == nil {
   125  		return nil
   126  	}
   127  	return &node{
   128  		id:      n.ID(),
   129  		output:  n.Value(),
   130  		outputC: make(chan gorgonia.Value, 0),
   131  	}
   132  }