github.com/gorgonia/agogo@v0.1.1/arena.go (about) 1 package agogo 2 3 import ( 4 "bytes" 5 "fmt" 6 "io" 7 "log" 8 "math/rand" 9 "runtime" 10 "time" 11 12 "github.com/chewxy/math32" 13 dual "github.com/gorgonia/agogo/dualnet" 14 "github.com/gorgonia/agogo/game" 15 "github.com/gorgonia/agogo/mcts" 16 ) 17 18 // Arena represents a game arena 19 // Arena fulfils the interface game.MetaState 20 type Arena struct { 21 r *rand.Rand 22 game game.State 23 A, B *Agent 24 25 // state 26 currentPlayer *Agent 27 conf mcts.Config 28 buf bytes.Buffer 29 logger *log.Logger 30 31 // only relevant to training 32 name string 33 epoch int // training epoch 34 gameNumber int // which game is this in 35 36 // when to screw it all and just reinit a new NN 37 oldThresh int 38 oldCount int 39 } 40 41 // MakeArena makes an arena given a game. 42 func MakeArena(g game.State, a, b Dualer, conf mcts.Config, enc GameEncoder, aug Augmenter, name string) Arena { 43 A := &Agent{ 44 NN: a.Dual(), 45 Enc: enc, 46 name: "A", 47 } 48 A.MCTS = mcts.New(g, conf, A) 49 B := &Agent{ 50 NN: b.Dual(), 51 Enc: enc, 52 name: "B", 53 } 54 B.MCTS = mcts.New(g, conf, B) 55 56 if name == "" { 57 name = "UNKNOWN GAME" 58 } 59 60 return Arena{ 61 r: rand.New(rand.NewSource(time.Now().UnixNano())), 62 game: g, 63 A: A, 64 B: B, 65 conf: conf, 66 name: name, 67 68 oldThresh: 10, 69 } 70 } 71 72 // NewArena makes an arena an returns a pointer to the Arena 73 func NewArena(g game.State, a, b Dualer, conf mcts.Config, enc GameEncoder, aug Augmenter, name string) *Arena { 74 ar := MakeArena(g, a, b, conf, enc, aug, name) 75 ar.logger = log.New(&ar.buf, "", log.Ltime) 76 return &ar 77 } 78 79 // Play plays a game, and retrns a winner. If it is a draw, the returned colour is None. 80 func (a *Arena) Play(record bool, enc OutputEncoder, aug Augmenter) (winner game.Player, examples []Example) { 81 if a.r.Intn(2) == 0 { 82 a.A.Player = game.Player(game.Black) 83 a.B.Player = game.Player(game.White) 84 a.currentPlayer = a.A 85 } else { 86 a.A.Player = game.Player(game.White) 87 a.B.Player = game.Player(game.Black) 88 a.currentPlayer = a.B 89 } 90 91 a.game.SetToMove(a.currentPlayer.Player) 92 a.logger.Printf("Playing. Recording %t\n", record) 93 a.logger.SetPrefix("\t\t") 94 var ended bool 95 var passCount int 96 for ended, winner = a.game.Ended(); !ended; ended, winner = a.game.Ended() { 97 98 best := a.currentPlayer.Search(a.game) 99 if best.IsPass() { 100 passCount++ 101 } else { 102 passCount = 0 103 } 104 a.logger.Printf("Current Player: %v. Best Move %v\n", a.currentPlayer.Player, best) 105 if record { 106 boards := a.currentPlayer.Enc(a.game) 107 policies := a.currentPlayer.MCTS.Policies(a.game) 108 ex := Example{ 109 Board: boards, 110 Policy: policies, 111 // THIS IS A HACK. 112 // The value is 1 or -1 depending on player colour, but for now we store the player colour 113 Value: float32(a.currentPlayer.Player), 114 } 115 if validPolicies(policies) { 116 if aug != nil { 117 examples = append(examples, aug(ex)...) 118 } else { 119 examples = append(examples, ex) 120 } 121 } 122 123 } 124 125 // policy, value := a.currentPlayer.Infer(a.game) 126 // log.Printf("\t\tPlayer %v made Move %v | %1.1v %1.1v", a.currentPlayer.Player, best, policy, value) 127 a.game = a.game.Apply(game.PlayerMove{ 128 Player: a.currentPlayer.Player, 129 Single: best, 130 }) 131 a.switchPlayer() 132 if enc != nil { 133 enc.Encode(a) 134 } 135 if passCount >= 2 { 136 break 137 } 138 } 139 a.logger.SetPrefix("\t") 140 a.A.MCTS.Reset() 141 a.B.MCTS.Reset() 142 if enc != nil { 143 log.Printf("\tDone playing") 144 } 145 146 for i := range examples { 147 switch { 148 case winner == game.Player(game.None): 149 examples[i].Value = 0 150 case examples[i].Value == float32(winner): 151 examples[i].Value = 1 152 default: 153 examples[i].Value = -1 154 } 155 } 156 var winningAgent *Agent 157 switch { 158 case winner == game.Player(game.None): 159 a.A.Draw++ 160 a.B.Draw++ 161 case winner == a.A.Player: 162 a.A.Wins++ 163 a.B.Loss++ 164 winningAgent = a.A 165 case winner == a.B.Player: 166 a.B.Wins++ 167 a.A.Loss++ 168 winningAgent = a.B 169 } 170 if !record { 171 log.Printf("Winner %v | %p", winner, winningAgent) 172 } 173 // a.A.MCTS.Reset() 174 // a.B.MCTS.Reset() 175 a.A.MCTS = mcts.New(a.game, a.conf, a.A) 176 a.B.MCTS = mcts.New(a.game, a.conf, a.B) 177 runtime.GC() 178 return game.Player(game.None), examples 179 } 180 181 // Epoch returns the current Epoch 182 func (a *Arena) Epoch() int { return a.epoch } 183 184 // GameNumber returns the 185 func (a *Arena) GameNumber() int { return a.gameNumber } 186 187 // Name of the game 188 func (a *Arena) Name() string { return a.name } 189 190 // Score of the player p 191 func (a *Arena) Score(p game.Player) float64 { return float64(a.game.Score(p)) } 192 193 // State of the game 194 func (a *Arena) State() game.State { return a.game } 195 196 // Log the MCTS of both players into w 197 func (a *Arena) Log(w io.Writer) { 198 fmt.Fprintf(w, a.buf.String()) 199 fmt.Fprintf(w, "\nA:\n\n") 200 fmt.Fprintln(w, a.A.MCTS.Log()) 201 fmt.Fprintf(w, "\nB:\n\n") 202 fmt.Fprintln(w, a.B.MCTS.Log()) 203 } 204 205 func (a *Arena) newB(conf dual.Config, killedA bool) (err error) { 206 if killedA { 207 a.oldCount = 0 208 } 209 210 // if a.oldCount >= a.oldThresh { 211 // a.B.NN = dual.New(conf) 212 // err = a.B.NN.Init() 213 // a.oldCount = 0 214 // } else { 215 // a.B.NN, err = a.B.NN.Clone() 216 // } 217 218 a.B.NN = dual.New(conf) 219 err = a.B.NN.Init() 220 221 a.oldCount++ 222 log.Printf("NewB NN %p", a.B.NN) 223 return err 224 } 225 226 func (a *Arena) switchPlayer() { 227 switch a.currentPlayer { 228 case a.A: 229 a.currentPlayer = a.B 230 case a.B: 231 a.currentPlayer = a.A 232 } 233 } 234 235 func cloneBoard(a []game.Colour) []game.Colour { 236 retVal := make([]game.Colour, len(a)) 237 copy(retVal, a) 238 return retVal 239 } 240 241 func validPolicies(policy []float32) bool { 242 for _, v := range policy { 243 if math32.IsInf(v, 0) { 244 return false 245 } 246 if math32.IsNaN(v) { 247 return false 248 } 249 } 250 return true 251 }