github.com/gorgonia/agogo@v0.1.1/agogo.go (about) 1 package agogo 2 3 import ( 4 "encoding/gob" 5 "fmt" 6 "log" 7 "math/rand" 8 "os" 9 "time" 10 11 dual "github.com/gorgonia/agogo/dualnet" 12 "github.com/gorgonia/agogo/game" 13 "github.com/gorgonia/agogo/mcts" 14 "github.com/pkg/errors" 15 "gorgonia.org/tensor" 16 ) 17 18 // AZ is the top level structure and the entry point of the API. 19 // It it a wrapper around the MTCS and the NeeuralNework that composes the algorithm. 20 // AZ stands for AlphaZero 21 type AZ struct { 22 // state 23 Arena 24 Statistics 25 useDummy bool 26 27 // config 28 nnConf dual.Config 29 mctsConf mcts.Config 30 enc GameEncoder 31 aug Augmenter 32 updateThreshold float32 33 maxExamples int 34 35 // io 36 outEnc OutputEncoder 37 } 38 39 // New AlphaZero structure. It takes a game state (implementing the board, rules, etc.) 40 // and a configuration to apply to the MCTS and the neural network 41 func New(g game.State, conf Config) *AZ { 42 if !conf.NNConf.IsValid() { 43 panic("NNConf is not valid. Unable to proceed") 44 } 45 if !conf.MCTSConf.IsValid() { 46 panic("MCTSConf is not valid. Unable to proceed") 47 } 48 49 a := dual.New(conf.NNConf) 50 b := dual.New(conf.NNConf) 51 52 if err := a.Init(); err != nil { 53 panic(fmt.Sprintf("%+v", err)) 54 } 55 if err := b.Init(); err != nil { 56 panic(fmt.Sprintf("%+v", err)) 57 } 58 59 retVal := &AZ{ 60 Arena: MakeArena(g, a, b, conf.MCTSConf, conf.Encoder, conf.Augmenter, conf.Name), 61 nnConf: conf.NNConf, 62 mctsConf: conf.MCTSConf, 63 enc: conf.Encoder, 64 outEnc: conf.OutputEncoder, 65 aug: conf.Augmenter, 66 updateThreshold: float32(conf.UpdateThreshold), 67 maxExamples: conf.MaxExamples, 68 Statistics: makeStatistics(), 69 useDummy: true, 70 } 71 retVal.logger = log.New(&retVal.buf, "", log.Ltime) 72 return retVal 73 } 74 75 func (a *AZ) setupSelfPlay(iter int) { 76 var err error 77 if err = a.A.SwitchToInference(a.game); err != nil { 78 // DO SOMETHING WITH ERROR 79 } 80 if err = a.B.SwitchToInference(a.game); err != nil { 81 // DO SOMETHING WITH ERROR 82 } 83 if iter == 0 && a.useDummy { 84 log.Printf("Using Dummy") 85 a.A.useDummy(a.game) 86 a.B.useDummy(a.game) 87 } 88 log.Printf("Set up selfplay: Switch To inference for A. A.NN %p (%T)", a.A.NN, a.A.NN) 89 log.Printf("Set up selfplay: Switch To inference for B. B.NN %p (%T)", a.B.NN, a.B.NN) 90 } 91 92 // SelfPlay plays an episode 93 func (a *AZ) SelfPlay() []Example { 94 _, examples := a.Play(true, nil, a.aug) // don't encode images while selfplay... that'd be boring to watch 95 a.game.Reset() 96 return examples 97 } 98 99 // Learn learns for iters. It self-plays for episodes, and then trains a new NN from the self play example. 100 func (a *AZ) Learn(iters, episodes, nniters, arenaGames int) error { 101 var err error 102 for a.epoch = 0; a.epoch < iters; a.epoch++ { 103 var ex []Example 104 log.Printf("Self Play for epoch %d. Player A %p, Player B %p", a.epoch, a.A, a.B) 105 106 a.buf.Reset() 107 a.logger.Printf("Self Play for epoch %d. Player A %p, Player B %p", a.epoch, a.A, a.B) 108 a.logger.SetPrefix("\t") 109 a.setupSelfPlay(a.epoch) 110 for e := 0; e < episodes; e++ { 111 log.Printf("\tEpisode %v", e) 112 a.logger.Printf("Episode %v\n", e) 113 ex = append(ex, a.SelfPlay()...) 114 } 115 a.logger.SetPrefix("") 116 a.buf.Reset() 117 118 if a.maxExamples > 0 && len(ex) > a.maxExamples { 119 shuffleExamples(ex) 120 ex = ex[:a.maxExamples] 121 } 122 Xs, Policies, Values, batches := a.prepareExamples(ex) 123 124 // // create a new DualNet for B 125 // a.B.NN = dual.New(a.nnConf) 126 // if err = a.B.NN.Dual().Init(); err != nil { 127 // return errors.WithMessage(err, "Unable to create new DualNet for B") 128 // } 129 130 if err = dual.Train(a.B.NN, Xs, Policies, Values, batches, nniters); err != nil { 131 return errors.WithMessage(err, fmt.Sprintf("Train fail")) 132 } 133 134 a.B.SwitchToInference(a.game) 135 136 a.A.resetStats() 137 a.B.resetStats() 138 139 a.logger.Printf("Playing Arena") 140 a.logger.SetPrefix("\t") 141 for a.gameNumber = 0; a.gameNumber < arenaGames; a.gameNumber++ { 142 a.logger.Printf("Playing game number %d", a.gameNumber) 143 a.Play(false, a.outEnc, nil) 144 a.game.Reset() 145 } 146 a.logger.SetPrefix("") 147 148 var killedA bool 149 log.Printf("A wins %v, loss %v, draw %v\nB wins %v, loss %v, draw %v", a.A.Wins, a.A.Loss, a.A.Draw, a.B.Wins, a.B.Loss, a.B.Draw) 150 151 // if a.B.Wins/(a.B.Wins+a.B.Loss+a.B.Draw) > a.updateThreshold { 152 if a.B.Wins/(a.B.Wins+a.A.Wins) > a.updateThreshold { 153 // B wins. Kill A, clean up its resources. 154 log.Printf("Kill A %p. New A's NN is %p", a.A.NN, a.B.NN) 155 if err = a.A.Close(); err != nil { 156 return err 157 } 158 a.A.NN = a.B.NN 159 // clear examples 160 ex = ex[:0] 161 killedA = true 162 } 163 a.update(a.A) 164 if err = a.newB(a.nnConf, killedA); err != nil { 165 return err 166 } 167 } 168 return nil 169 } 170 171 // Save learning into filenamee 172 func (a *AZ) Save(filename string) error { 173 f, err := os.OpenFile(filename, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0544) 174 if err != nil { 175 return err 176 } 177 defer f.Close() 178 179 enc := gob.NewEncoder(f) 180 return enc.Encode(a.A.NN) 181 } 182 183 // Load the Alpha Zero structure from a filename 184 func (a *AZ) Load(filename string) error { 185 f, err := os.Open(filename) 186 if err != nil { 187 return errors.WithStack(err) 188 } 189 defer f.Close() 190 191 a.A.NN = dual.New(a.nnConf) 192 a.B.NN = dual.New(a.nnConf) 193 194 dec := gob.NewDecoder(f) 195 if err = dec.Decode(a.A.NN); err != nil { 196 return errors.WithStack(err) 197 } 198 199 f.Seek(0, 0) 200 dec = gob.NewDecoder(f) 201 if err = dec.Decode(a.B.NN); err != nil { 202 return errors.WithStack(err) 203 } 204 a.useDummy = false 205 return nil 206 } 207 208 func (a *AZ) prepareExamples(examples []Example) (Xs, Policies, Values *tensor.Dense, batches int) { 209 shuffleExamples(examples) 210 batches = len(examples) / a.nnConf.BatchSize 211 total := batches * a.nnConf.BatchSize 212 var XsBacking, PoliciesBacking, ValuesBacking []float32 213 for i, ex := range examples { 214 if i >= total { 215 break 216 } 217 XsBacking = append(XsBacking, ex.Board...) 218 219 start := len(PoliciesBacking) 220 PoliciesBacking = append(PoliciesBacking, make([]float32, len(ex.Policy))...) 221 copy(PoliciesBacking[start:], ex.Policy) 222 223 ValuesBacking = append(ValuesBacking, ex.Value) 224 } 225 // padd out anythihng that is not full 226 // board0 := examples[0].Board 227 // policy0 := examples[0].Policy 228 // rem := len(examples) % a.nnConf.BatchSize 229 // if rem != 0 { 230 // diff := a.nnConf.BatchSize - rem 231 232 // // add padded data 233 // XsBacking = append(XsBacking, make([]float32, diff*len(board0))...) 234 // PoliciesBacking = append(PoliciesBacking, make([]float32, diff*len(policy0))...) 235 // ValuesBacking = append(ValuesBacking, make([]float32, diff)...) 236 // } 237 // if rem > 0 { 238 // batches++ 239 // } 240 241 actionSpace := a.Arena.game.ActionSpace() + 1 // allow passes 242 Xs = tensor.New(tensor.WithBacking(XsBacking), tensor.WithShape(a.nnConf.BatchSize*batches, a.nnConf.Features, a.nnConf.Height, a.nnConf.Width)) 243 Policies = tensor.New(tensor.WithBacking(PoliciesBacking), tensor.WithShape(a.nnConf.BatchSize*batches, actionSpace)) 244 Values = tensor.New(tensor.WithBacking(ValuesBacking), tensor.WithShape(a.nnConf.BatchSize*batches)) 245 return 246 } 247 248 func shuffleExamples(examples []Example) { 249 r := rand.New(rand.NewSource(time.Now().UnixNano())) 250 for i := range examples { 251 j := r.Intn(i + 1) 252 examples[i], examples[j] = examples[j], examples[i] 253 } 254 }