gorgonia.org/gorgonia@v0.9.17/x/vm/machine.go (about) 1 package xvm 2 3 import ( 4 "context" 5 "strconv" 6 "strings" 7 "time" 8 9 "gorgonia.org/gorgonia" 10 ) 11 12 // Machine is a top-level struture that will coordinate the execution of a graph 13 type Machine struct { 14 nodes []*node 15 pubsub *pubsub 16 } 17 18 // NewMachine creates an exeuction machine from an exprgraph 19 func NewMachine(g *gorgonia.ExprGraph) *Machine { 20 if g == nil { 21 return nil 22 } 23 nodesIte := g.Nodes() 24 nodes := make([]*node, 0, nodesIte.Len()) 25 for nodesIte.Next() { 26 n := nodesIte.Node().(*gorgonia.Node) 27 var nn *node 28 topNode := g.To(n.ID()).Len() == 0 29 op := n.Op() 30 switch { 31 case op == nil: 32 nn = newInput(n) 33 case op.Arity() == 0: 34 nn = newInput(n) 35 default: 36 nn = newOp(n, !topNode) 37 } 38 nodes = append(nodes, nn) 39 } 40 m := &Machine{ 41 nodes: nodes, 42 } 43 m.pubsub = createNetwork(nodes, g) 44 return m 45 } 46 47 // createNetwork instantiate all the channels and create the pubsubs 48 func createNetwork(ns []*node, g *gorgonia.ExprGraph) *pubsub { 49 ids := make(map[int64]*node, len(ns)) 50 for i := range ns { 51 ids[ns[i].id] = ns[i] 52 } 53 ps := &pubsub{ 54 publishers: make([]*publisher, 0), 55 subscribers: make([]*subscriber, 0), 56 } 57 // Deal with publishers 58 publishers := make(map[int64]*publisher, len(ns)) 59 for i := range ns { 60 currNode := ns[i] 61 if currNode.outputC == nil { 62 continue 63 } 64 publisher := &publisher{ 65 id: currNode.id, 66 publisher: currNode.outputC, 67 subscribers: make([]chan<- gorgonia.Value, 0), 68 } 69 publishers[currNode.id] = publisher 70 ps.publishers = append(ps.publishers, publisher) 71 } 72 // Deal with subscribers 73 for i := range ns { 74 currNode := ns[i] 75 if currNode.inputC == nil { 76 continue 77 } 78 from := g.From(currNode.id) 79 subscriber := &subscriber{ 80 id: currNode.id, 81 subscriber: currNode.inputC, 82 publishers: make([]<-chan gorgonia.Value, from.Len()), 83 } 84 for i := 0; from.Next(); i++ { 85 pub := publishers[from.Node().ID()] 86 c := make(chan gorgonia.Value, 0) 87 pub.subscribers = append(pub.subscribers, c) 88 89 subscriber.publishers[i] = c 90 } 91 ps.subscribers = append(ps.subscribers, subscriber) 92 } 93 return ps 94 } 95 96 // Run the computation 97 func (m *Machine) Run(ctx context.Context) error { 98 cancel, wg := m.pubsub.run(ctx) 99 err := m.runAllNodes(ctx) 100 cancel() 101 // wait for the infrastructure to settle 102 wg.Wait() 103 return err 104 } 105 106 // Close all the plumbing to avoid leaking 107 func (m *Machine) Close() { 108 allChans := make(map[chan<- gorgonia.Value]struct{}, 0) 109 for _, pub := range m.pubsub.publishers { 110 for i := range pub.subscribers { 111 allChans[pub.subscribers[i]] = struct{}{} 112 } 113 } 114 for ch := range allChans { 115 close(ch) 116 } 117 for _, n := range m.nodes { 118 if n.inputC != nil { 119 close(n.inputC) 120 } 121 if n.outputC != nil { 122 close(n.outputC) 123 } 124 } 125 } 126 127 type nodeError struct { 128 id int64 129 t time.Time 130 err error 131 } 132 133 type nodeErrors []nodeError 134 135 func (e nodeErrors) Error() string { 136 var sb strings.Builder 137 for _, e := range e { 138 sb.WriteString(strconv.Itoa(int(e.id))) 139 sb.WriteString(":") 140 sb.WriteString(e.err.Error()) 141 sb.WriteString("\n") 142 } 143 return sb.String() 144 } 145 146 // Run performs the computation 147 func (m *Machine) runAllNodes(ctx context.Context) error { 148 ctx, cancel := context.WithCancel(ctx) 149 errC := make(chan nodeError, 0) 150 total := len(m.nodes) 151 for i := range m.nodes { 152 go func(n *node) { 153 err := n.Compute(ctx) 154 errC <- nodeError{ 155 id: n.id, 156 t: time.Now(), 157 err: err, 158 } 159 }(m.nodes[i]) 160 } 161 errs := make([]nodeError, 0) 162 for e := range errC { 163 total-- 164 if e.err != nil { 165 errs = append(errs, e) 166 // failfast, on error, cancel 167 cancel() 168 } 169 if total == 0 { 170 break 171 } 172 } 173 cancel() 174 close(errC) 175 if len(errs) != 0 { 176 return nodeErrors(errs) 177 } 178 return nil 179 } 180 181 // GetResult stored in a node 182 func (m *Machine) GetResult(id int64) gorgonia.Value { 183 for i := range m.nodes { 184 if m.nodes[i].id == id { 185 return m.nodes[i].output 186 } 187 } 188 return nil 189 }