github.com/gorgonia/agogo@v0.1.1/dualnet/meta.go (about) 1 package dual 2 3 import ( 4 "bytes" 5 "log" 6 "math/rand" 7 "time" 8 9 "github.com/pkg/errors" 10 G "gorgonia.org/gorgonia" 11 "gorgonia.org/tensor" 12 "gorgonia.org/tensor/native" 13 ) 14 15 // Train is a basic trainer. 16 func Train(d *Dual, Xs, policies, values *tensor.Dense, batches, iterations int) error { 17 m := G.NewTapeMachine(d.g, G.BindDualValues(d.Model()...)) 18 model := G.NodesToValueGrads(d.Model()) 19 solver := G.NewVanillaSolver(G.WithLearnRate(0.1)) 20 var s slicer 21 for i := 0; i < iterations; i++ { 22 // var cost float32 23 for bat := 0; bat < batches; bat++ { 24 batchStart := bat * d.Config.BatchSize 25 batchEnd := batchStart + d.Config.BatchSize 26 27 Xs2 := s.Slice(Xs, sli(batchStart, batchEnd)) 28 π := s.Slice(policies, sli(batchStart, batchEnd)) 29 v := s.Slice(values, sli(batchStart, batchEnd)) 30 31 G.Let(d.planes, Xs2) 32 G.Let(d.Π, π) 33 G.Let(d.V, v) 34 if err := m.RunAll(); err != nil { 35 return err 36 } 37 // cost = d.cost.Data().(float32) 38 if err := solver.Step(model); err != nil { 39 return err 40 } 41 m.Reset() 42 tensor.ReturnTensor(Xs2) 43 tensor.ReturnTensor(π) 44 tensor.ReturnTensor(v) 45 } 46 if err := shuffleBatch(Xs, policies, values); err != nil { 47 return err 48 } 49 // TODO: add a channel to send training cost data down 50 // log.Printf("%d\t%v", i, cost/float32(batches)) 51 } 52 return nil 53 } 54 55 // shuffleBatch shuffles the batches. 56 func shuffleBatch(Xs, π, v *tensor.Dense) (err error) { 57 r := rand.New(rand.NewSource(time.Now().UnixNano())) 58 oriXs := Xs.Shape().Clone() 59 oriPis := π.Shape().Clone() 60 61 defer func() { 62 if r := recover(); r != nil { 63 log.Printf("%v %v", Xs.Shape(), π.Shape()) 64 panic(r) 65 } 66 }() 67 Xs.Reshape(as2D(Xs.Shape())...) 68 π.Reshape(as2D(π.Shape())...) 69 70 var matXs, matPis [][]float32 71 if matXs, err = native.MatrixF32(Xs); err != nil { 72 return errors.Wrapf(err, "shuffle batch failed - matX") 73 } 74 if matPis, err = native.MatrixF32(π); err != nil { 75 return errors.Wrapf(err, "shuffle batch failed - pi") 76 } 77 vs := v.Data().([]float32) 78 79 tmp := make([]float32, Xs.Shape()[1]) 80 for i := range matXs { 81 j := r.Intn(i + 1) 82 83 rowI := matXs[i] 84 rowJ := matXs[j] 85 copy(tmp, rowI) 86 copy(rowI, rowJ) 87 copy(rowJ, tmp) 88 89 piI := matPis[i] 90 piJ := matPis[j] 91 copy(tmp, piI) 92 copy(piI, piJ) 93 copy(piJ, tmp) 94 95 vs[i], vs[j] = vs[j], vs[i] 96 } 97 Xs.Reshape(oriXs...) 98 π.Reshape(oriPis...) 99 100 return nil 101 } 102 103 func as2D(s tensor.Shape) tensor.Shape { 104 retVal := tensor.BorrowInts(2) 105 retVal[0] = s[0] 106 retVal[1] = s[1] 107 for i := 2; i < len(s); i++ { 108 retVal[1] *= s[i] 109 } 110 return retVal 111 } 112 113 // Inferencer is a struct that holds the state for a *Dual and a VM. By using an Inferece struct, 114 // there is no longer a need to create a VM every time an inference needs to be done. 115 type Inferencer struct { 116 d *Dual 117 m G.VM 118 119 input *tensor.Dense 120 buf *bytes.Buffer 121 } 122 123 // Infer takes a trained *Dual, and creates a interence data structure such that it'd be easy to infer 124 func Infer(d *Dual, actionSpace int, toLog bool) (*Inferencer, error) { 125 conf := d.Config 126 conf.FwdOnly = true 127 conf.BatchSize = actionSpace 128 newShape := d.planes.Shape().Clone() 129 newShape[0] = actionSpace 130 retVal := &Inferencer{ 131 d: New(conf), 132 input: tensor.New(tensor.WithShape(newShape...), tensor.Of(Float)), 133 } 134 if err := retVal.d.Init(); err != nil { 135 return nil, err 136 } 137 retVal.d.SetTesting() 138 // G.WithInit(G.Zeroes())(retVal.d.planes) 139 140 infModel := retVal.d.Model() 141 for i, n := range d.Model() { 142 original := n.Value().Data().([]float32) 143 cloned := infModel[i].Value().Data().([]float32) 144 copy(cloned, original) 145 } 146 147 retVal.buf = new(bytes.Buffer) 148 if toLog { 149 logger := log.New(retVal.buf, "", 0) 150 retVal.m = G.NewTapeMachine(retVal.d.g, 151 G.WithLogger(logger), 152 G.WithWatchlist(), 153 G.TraceExec(), 154 G.WithValueFmt("%+1.1v"), 155 G.WithNaNWatch(), 156 ) 157 } else { 158 retVal.m = G.NewTapeMachine(retVal.d.g) 159 } 160 return retVal, nil 161 } 162 163 // Dual implements Dualer 164 func (m *Inferencer) Dual() *Dual { return m.d } 165 166 // Infer takes the board, in form of a []float32, and runs inference, and returns the value 167 func (m *Inferencer) Infer(board []float32) (policy []float32, value float32, err error) { 168 m.buf.Reset() 169 for _, op := range m.d.ops { 170 op.Reset() 171 } 172 173 // copy board to the provided preallocated input tensor 174 m.input.Zero() 175 data := m.input.Data().([]float32) 176 copy(data, board) 177 178 m.m.Reset() 179 // log.Printf("Let planes %p be input %v", m.d.planes, board) 180 m.buf.Reset() 181 G.Let(m.d.planes, m.input) 182 if err = m.m.RunAll(); err != nil { 183 return nil, 0, err 184 } 185 policy = m.d.policyValue.Data().([]float32) 186 value = m.d.value.Data().([]float32)[0] 187 // log.Printf("\t%v", policy) 188 return policy[:m.d.ActionSpace], value, nil 189 } 190 191 // ExecLog returns the execution log. If Infer was called with toLog = false, then it will return an empty string 192 func (m *Inferencer) ExecLog() string { return m.buf.String() } 193 194 // Close implements a closer, because well, a gorgonia VM is a resource. 195 func (m *Inferencer) Close() error { return m.m.Close() }