gorgonia.org/gorgonia@v0.9.17/example_concurrent_training_test.go (about) 1 package gorgonia_test 2 3 import ( 4 "fmt" 5 "runtime" 6 "sync" 7 8 . "gorgonia.org/gorgonia" 9 "gorgonia.org/tensor" 10 ) 11 12 const ( 13 // rows = 373127 14 // cols = 53 15 16 // We'll use a nice even sized batch size, instead of weird prime numbers 17 rows = 30000 18 cols = 5 19 batchSize = 100 20 epochs = 10 21 ) 22 23 type concurrentTrainer struct { 24 g *ExprGraph 25 x, y *Node 26 vm VM 27 // cost Value 28 29 batchSize int 30 epoch int // number of epochs done 31 } 32 33 func newConcurrentTrainer() *concurrentTrainer { 34 g := NewGraph() 35 x := NewMatrix(g, Float64, WithShape(batchSize, cols), WithName("x")) 36 y := NewVector(g, Float64, WithShape(batchSize), WithName("y")) 37 xT := Must(Transpose(x)) 38 z := Must(Mul(xT, y)) 39 sz := Must(Sum(z)) 40 41 // Read(sz, &ct.cost) 42 Grad(sz, x, y) 43 vm := NewTapeMachine(g, BindDualValues(x, y)) 44 45 return &concurrentTrainer{ 46 g: g, 47 x: x, 48 y: y, 49 vm: vm, 50 51 batchSize: batchSize, 52 epoch: -1, 53 } 54 } 55 56 type cost struct { 57 Nodes []ValueGrad 58 VM // placed for debugging purposes. In real life use you can just use a channel of Nodes 59 60 // cost Value 61 } 62 63 func (t *concurrentTrainer) train(x, y Value, costChan chan cost, wg *sync.WaitGroup) { 64 Let(t.x, x) 65 Let(t.y, y) 66 if err := t.vm.RunAll(); err != nil { 67 panic("HELP") 68 } 69 70 costChan <- cost{ 71 []ValueGrad{t.x, t.y}, 72 t.vm, 73 // t.cost, 74 } 75 76 t.vm.Reset() 77 wg.Done() 78 } 79 80 func trainEpoch(bs []batch, ts []*concurrentTrainer, threads int) { 81 // costs := make([]float64, 0, len(bs)) 82 chunks := len(bs) / len(ts) 83 for chunk := 0; chunk <= chunks; chunk++ { 84 costChan := make(chan cost, len(bs)) 85 86 var wg sync.WaitGroup 87 for i, t := range ts { 88 idx := chunk*threads + i 89 if idx >= len(bs) { 90 break 91 } 92 b := bs[idx] 93 94 wg.Add(1) 95 go t.train(b.xs, b.ys, costChan, &wg) 96 } 97 wg.Wait() 98 close(costChan) 99 100 solver := NewVanillaSolver(WithLearnRate(0.01), WithBatchSize(batchSize)) 101 for cost := range costChan { 102 // y := cost.Nodes[1].Value() 103 // yG, _ := cost.Nodes[1].Grad() 104 // c := cost.cost.Data().(float64) 105 // costs = append(costs, c) 106 solver.Step(cost.Nodes) 107 } 108 } 109 110 // var avg float64 111 // for _, c := range costs { 112 // avg += c 113 // } 114 // avg /= float64(len(costs)) 115 } 116 117 type batch struct { 118 xs Value 119 ys Value 120 } 121 122 func prep() (x, y Value, bs []batch) { 123 xV := tensor.New(tensor.WithShape(rows, cols), tensor.WithBacking(tensor.Range(Float64, 0, cols*rows))) 124 yV := tensor.New(tensor.WithShape(rows), tensor.WithBacking(tensor.Range(Float64, 0, rows))) 125 126 // prep the data: y = ΣnX, where n = col ID, x ∈ X = colID / 100 127 xData := xV.Data().([]float64) 128 yData := yV.Data().([]float64) 129 for r := 0; r < rows; r++ { 130 var sum float64 131 for c := 0; c < cols; c++ { 132 idx := r*cols + c 133 fc := float64(c) 134 v := fc * fc / 100 135 xData[idx] = v 136 sum += v 137 } 138 yData[r] = sum 139 } 140 141 // batch the examples up into their respective batchSize 142 for i := 0; i < rows; i += batchSize { 143 xVS, _ := xV.Slice(S(i, i+batchSize)) 144 yVS, _ := yV.Slice(S(i, i+batchSize)) 145 b := batch{xVS, yVS} 146 bs = append(bs, b) 147 } 148 return xV, yV, bs 149 } 150 151 func concurrentTraining(xV, yV Value, bs []batch, es int) { 152 threads := runtime.NumCPU() 153 154 ts := make([]*concurrentTrainer, threads) 155 for chunk := 0; chunk < threads; chunk++ { 156 trainer := newConcurrentTrainer() 157 ts[chunk] = trainer 158 defer trainer.vm.Close() 159 } 160 161 for e := 0; e < es; e++ { 162 trainEpoch(bs, ts, threads) 163 } 164 } 165 166 func nonConcurrentTraining(xV, yV Value, es int) { 167 g := NewGraph() 168 x := NewMatrix(g, Float64, WithValue(xV)) 169 y := NewVector(g, Float64, WithValue(yV)) 170 xT := Must(Transpose(x)) 171 z := Must(Mul(xT, y)) 172 sz := Must(Sum(z)) 173 Grad(sz, x, y) 174 vm := NewTapeMachine(g, BindDualValues(x, y)) 175 176 Let(x, xV) 177 Let(y, yV) 178 solver := NewVanillaSolver(WithLearnRate(0.01), WithBatchSize(batchSize)) 179 for i := 0; i < es; i++ { 180 vm.RunAll() 181 solver.Step([]ValueGrad{x, y}) 182 vm.Reset() 183 runtime.GC() 184 } 185 } 186 187 func Example_concurrentTraining() { 188 xV, yV, bs := prep() 189 concurrentTraining(xV, yV, bs, epochs) 190 191 fmt.Printf("x:\n%1.1v", xV) 192 fmt.Printf("y:\n%1.1v", yV) 193 194 // Output: 195 // x: 196 // ⎡-0.0003 0.01 0.04 0.09 0.2⎤ 197 // ⎢-0.0003 0.01 0.04 0.09 0.2⎥ 198 // ⎢-0.0003 0.01 0.04 0.09 0.2⎥ 199 // ⎢-0.0003 0.01 0.04 0.09 0.2⎥ 200 // . 201 // . 202 // . 203 // ⎢-0.0003 0.01 0.04 0.09 0.2⎥ 204 // ⎢-0.0003 0.01 0.04 0.09 0.2⎥ 205 // ⎢-0.0003 0.01 0.04 0.09 0.2⎥ 206 // ⎣-0.0003 0.01 0.04 0.09 0.2⎦ 207 // y: 208 // [0.3 0.3 0.3 0.3 ... 0.3 0.3 0.3 0.3] 209 210 } 211 212 func Example_nonConcurrentTraining() { 213 xV, yV, _ := prep() 214 nonConcurrentTraining(xV, yV, epochs) 215 216 fmt.Printf("x:\n%1.1v", xV) 217 fmt.Printf("y:\n%1.1v", yV) 218 219 //Output: 220 // x: 221 // ⎡-0.0003 0.01 0.04 0.09 0.2⎤ 222 // ⎢-0.0003 0.01 0.04 0.09 0.2⎥ 223 // ⎢-0.0003 0.01 0.04 0.09 0.2⎥ 224 // ⎢-0.0003 0.01 0.04 0.09 0.2⎥ 225 // . 226 // . 227 // . 228 // ⎢-0.0003 0.01 0.04 0.09 0.2⎥ 229 // ⎢-0.0003 0.01 0.04 0.09 0.2⎥ 230 // ⎢-0.0003 0.01 0.04 0.09 0.2⎥ 231 // ⎣-0.0003 0.01 0.04 0.09 0.2⎦ 232 // y: 233 // [0.3 0.3 0.3 0.3 ... 0.3 0.3 0.3 0.3] 234 235 }