gorgonia.org/gorgonia@v0.9.17/x/vm/machine.go (about)

     1  package xvm
     2  
     3  import (
     4  	"context"
     5  	"strconv"
     6  	"strings"
     7  	"time"
     8  
     9  	"gorgonia.org/gorgonia"
    10  )
    11  
    12  // Machine is a top-level struture that will coordinate the execution of a graph
    13  type Machine struct {
    14  	nodes  []*node
    15  	pubsub *pubsub
    16  }
    17  
    18  // NewMachine creates an exeuction machine from an exprgraph
    19  func NewMachine(g *gorgonia.ExprGraph) *Machine {
    20  	if g == nil {
    21  		return nil
    22  	}
    23  	nodesIte := g.Nodes()
    24  	nodes := make([]*node, 0, nodesIte.Len())
    25  	for nodesIte.Next() {
    26  		n := nodesIte.Node().(*gorgonia.Node)
    27  		var nn *node
    28  		topNode := g.To(n.ID()).Len() == 0
    29  		op := n.Op()
    30  		switch {
    31  		case op == nil:
    32  			nn = newInput(n)
    33  		case op.Arity() == 0:
    34  			nn = newInput(n)
    35  		default:
    36  			nn = newOp(n, !topNode)
    37  		}
    38  		nodes = append(nodes, nn)
    39  	}
    40  	m := &Machine{
    41  		nodes: nodes,
    42  	}
    43  	m.pubsub = createNetwork(nodes, g)
    44  	return m
    45  }
    46  
    47  // createNetwork instantiate all the channels and create the pubsubs
    48  func createNetwork(ns []*node, g *gorgonia.ExprGraph) *pubsub {
    49  	ids := make(map[int64]*node, len(ns))
    50  	for i := range ns {
    51  		ids[ns[i].id] = ns[i]
    52  	}
    53  	ps := &pubsub{
    54  		publishers:  make([]*publisher, 0),
    55  		subscribers: make([]*subscriber, 0),
    56  	}
    57  	// Deal with publishers
    58  	publishers := make(map[int64]*publisher, len(ns))
    59  	for i := range ns {
    60  		currNode := ns[i]
    61  		if currNode.outputC == nil {
    62  			continue
    63  		}
    64  		publisher := &publisher{
    65  			id:          currNode.id,
    66  			publisher:   currNode.outputC,
    67  			subscribers: make([]chan<- gorgonia.Value, 0),
    68  		}
    69  		publishers[currNode.id] = publisher
    70  		ps.publishers = append(ps.publishers, publisher)
    71  	}
    72  	// Deal with subscribers
    73  	for i := range ns {
    74  		currNode := ns[i]
    75  		if currNode.inputC == nil {
    76  			continue
    77  		}
    78  		from := g.From(currNode.id)
    79  		subscriber := &subscriber{
    80  			id:         currNode.id,
    81  			subscriber: currNode.inputC,
    82  			publishers: make([]<-chan gorgonia.Value, from.Len()),
    83  		}
    84  		for i := 0; from.Next(); i++ {
    85  			pub := publishers[from.Node().ID()]
    86  			c := make(chan gorgonia.Value, 0)
    87  			pub.subscribers = append(pub.subscribers, c)
    88  
    89  			subscriber.publishers[i] = c
    90  		}
    91  		ps.subscribers = append(ps.subscribers, subscriber)
    92  	}
    93  	return ps
    94  }
    95  
    96  // Run the computation
    97  func (m *Machine) Run(ctx context.Context) error {
    98  	cancel, wg := m.pubsub.run(ctx)
    99  	err := m.runAllNodes(ctx)
   100  	cancel()
   101  	// wait for the infrastructure to settle
   102  	wg.Wait()
   103  	return err
   104  }
   105  
   106  // Close all the plumbing to avoid leaking
   107  func (m *Machine) Close() {
   108  	allChans := make(map[chan<- gorgonia.Value]struct{}, 0)
   109  	for _, pub := range m.pubsub.publishers {
   110  		for i := range pub.subscribers {
   111  			allChans[pub.subscribers[i]] = struct{}{}
   112  		}
   113  	}
   114  	for ch := range allChans {
   115  		close(ch)
   116  	}
   117  	for _, n := range m.nodes {
   118  		if n.inputC != nil {
   119  			close(n.inputC)
   120  		}
   121  		if n.outputC != nil {
   122  			close(n.outputC)
   123  		}
   124  	}
   125  }
   126  
   127  type nodeError struct {
   128  	id  int64
   129  	t   time.Time
   130  	err error
   131  }
   132  
   133  type nodeErrors []nodeError
   134  
   135  func (e nodeErrors) Error() string {
   136  	var sb strings.Builder
   137  	for _, e := range e {
   138  		sb.WriteString(strconv.Itoa(int(e.id)))
   139  		sb.WriteString(":")
   140  		sb.WriteString(e.err.Error())
   141  		sb.WriteString("\n")
   142  	}
   143  	return sb.String()
   144  }
   145  
   146  // Run performs the computation
   147  func (m *Machine) runAllNodes(ctx context.Context) error {
   148  	ctx, cancel := context.WithCancel(ctx)
   149  	errC := make(chan nodeError, 0)
   150  	total := len(m.nodes)
   151  	for i := range m.nodes {
   152  		go func(n *node) {
   153  			err := n.Compute(ctx)
   154  			errC <- nodeError{
   155  				id:  n.id,
   156  				t:   time.Now(),
   157  				err: err,
   158  			}
   159  		}(m.nodes[i])
   160  	}
   161  	errs := make([]nodeError, 0)
   162  	for e := range errC {
   163  		total--
   164  		if e.err != nil {
   165  			errs = append(errs, e)
   166  			// failfast, on error, cancel
   167  			cancel()
   168  		}
   169  		if total == 0 {
   170  			break
   171  		}
   172  	}
   173  	cancel()
   174  	close(errC)
   175  	if len(errs) != 0 {
   176  		return nodeErrors(errs)
   177  	}
   178  	return nil
   179  }
   180  
   181  // GetResult stored in a node
   182  func (m *Machine) GetResult(id int64) gorgonia.Value {
   183  	for i := range m.nodes {
   184  		if m.nodes[i].id == id {
   185  			return m.nodes[i].output
   186  		}
   187  	}
   188  	return nil
   189  }