github.com/gorgonia/agogo@v0.1.1/mcts/search.go (about) 1 package mcts 2 3 import ( 4 "context" 5 "runtime" 6 "sort" 7 "sync" 8 "sync/atomic" 9 "time" 10 11 "github.com/chewxy/math32" 12 "github.com/gorgonia/agogo/game" 13 ) 14 15 /* 16 Here lies the majority of the MCTS search code, while node.go and tree.go handles the data structure stuff. 17 18 Right now the code is very specific to the game of Go. Ideally we'd be able to export the correct things and make it 19 so that a search can be written for any other games but uses the same data structures 20 */ 21 22 const ( 23 MAXTREESIZE = 25000000 // a tree is at max allowed this many nodes - at about 56 bytes per node that is 1.2GB of memory required 24 ) 25 26 func opponent(p game.Player) game.Player { 27 switch p { 28 case Black: 29 return White 30 case White: 31 return Black 32 } 33 panic("Unreachable") 34 } 35 36 // Result is a NaN tagged floating point, used to represent the reuslts. 37 type Result float32 38 39 const ( 40 noResultBits = 0x7FE00000 41 ) 42 43 func noResult() Result { 44 return Result(math32.Float32frombits(noResultBits)) 45 } 46 47 // isNullResult returns true if the Result (a NaN tagged number) is noResult 48 func isNullResult(r Result) bool { 49 b := math32.Float32bits(float32(r)) 50 return b == noResultBits 51 } 52 53 type searchState struct { 54 tree uintptr 55 current, prev game.State 56 root naughty 57 depth int 58 59 wg *sync.WaitGroup 60 61 // config 62 maxPlayouts, maxVisits, maxDepth int 63 } 64 65 func (s *searchState) nodeCount() int32 { 66 t := treeFromUintptr(s.tree) 67 return atomic.LoadInt32(&t.nc) 68 } 69 70 func (s *searchState) incrementPlayout() { 71 t := treeFromUintptr(s.tree) 72 atomic.AddInt32(&t.playouts, 1) 73 } 74 75 func (s *searchState) isRunning() bool { 76 t := treeFromUintptr(s.tree) 77 running := t.running.Load().(bool) 78 return running && t.nodeCount() < MAXTREESIZE 79 } 80 81 func (s *searchState) minPsaRatio() float32 { 82 ratio := float32(s.nodeCount()) / float32(MAXTREESIZE) 83 switch { 84 case ratio > 0.95: 85 return 0.01 86 case ratio > 0.5: 87 return 0.001 88 } 89 return 0 90 } 91 92 func (t *MCTS) Search(player game.Player) (retVal game.Single) { 93 t.log("SEARCH. Player %v\n%v", player, t.current) 94 t.updateRoot() 95 t.current.SetToMove(player) 96 boardHash := t.current.Hash() 97 98 // freeables 99 // if t.current.MoveNumber() == 1 { 100 101 // t.log("Acquiring lock ") 102 t.Lock() 103 for _, f := range t.freeables { 104 t.free(f) 105 } 106 t.Unlock() 107 // } 108 109 t.prepareRoot(player, t.current) 110 root := t.nodeFromNaughty(t.root) 111 112 ch := make(chan *searchState, runtime.NumCPU()) 113 var wg sync.WaitGroup 114 for i := 0; i < runtime.NumCPU(); i++ { 115 ss := &searchState{ 116 tree: ptrFromTree(t), 117 current: t.current, 118 root: t.root, 119 maxDepth: t.M * t.N, 120 wg: &wg, 121 } 122 ch <- ss 123 } 124 125 var iter int32 126 t.running.Store(true) 127 ctx, cancel := context.WithCancel(context.Background()) 128 for i := 0; i < runtime.NumCPU(); i++ { 129 wg.Add(1) 130 go doSearch(t.root, &iter, ch, ctx, &wg) 131 } 132 <-time.After(t.Timeout) 133 cancel() 134 135 // TODO 136 // reactivate all pruned children 137 wg.Wait() 138 close(ch) 139 140 root = t.nodeFromNaughty(t.root) 141 if !root.HasChildren() { 142 policy, _ := t.nn.Infer(t.current) 143 moveID := argmax(policy) 144 if moveID > t.current.ActionSpace() { 145 return Pass 146 } 147 t.log("Returning Early. Best %v", moveID) 148 return game.Single(moveID) 149 } 150 151 retVal = t.bestMove() 152 t.prev = t.current.Clone().(game.State) 153 t.log("Move Number %d, Iterations %d Playouts: %v Nodes: %v. Best: %v", t.current.MoveNumber(), iter, t.playouts, len(t.nodes), retVal) 154 t.log("DUMMY") 155 // log.Printf("\n%v", t.prev) 156 // log.Printf("\tIterations %d Playouts: %v Nodes: %v. Best move %v Player %v", iter, t.playouts, len(t.nodes), retVal, player) 157 158 // update the cached policies. 159 // Again, nothing like having side effects to what appears to be a straightforwards 160 // pure function eh? 161 t.cachedPolicies[sa{boardHash, retVal}]++ 162 163 return retVal 164 } 165 166 func doSearch(start naughty, iterBudget *int32, ch chan *searchState, ctx context.Context, wg *sync.WaitGroup) { 167 defer wg.Done() 168 169 loop: 170 for { 171 select { 172 case s := <-ch: 173 current := s.current.Clone().(game.State) 174 root := start 175 res := s.pipeline(current, root) 176 if !isNullResult(res) { 177 s.incrementPlayout() 178 } 179 180 t := treeFromUintptr(s.tree) 181 val := atomic.AddInt32(iterBudget, 1) 182 183 if val > t.Budget { 184 t.running.Store(false) 185 } 186 // running := t.running.Load().(bool) 187 // running = running && !s.stopThinking( /*TODO*/ ) 188 // running = running && s.hasAlternateMoves( /*TODO*/ ) 189 if s.depth == s.maxDepth { 190 // reset s for another bout of playouts 191 s.root = t.root 192 s.current = t.current 193 s.depth = 0 194 } 195 ch <- s 196 case <-ctx.Done(): 197 break loop 198 } 199 } 200 201 return 202 } 203 204 // pipeline is a recursive MCTS pipeline: 205 // SELECT, EXPAND, SIMULATE, BACKPROPAGATE. 206 // 207 // Because of the recursive nature, the pipeline is altered a bit to be this: 208 // EXPAND and SIMULATE, SELECT and RECURSE, BACKPROPAGATE. 209 func (s *searchState) pipeline(current game.State, start naughty) (retVal Result) { 210 retVal = noResult() 211 s.depth++ 212 if s.depth > s.maxDepth { 213 s.depth-- 214 return 215 } 216 217 player := current.ToMove() 218 nodeCount := s.nodeCount() 219 220 t := treeFromUintptr(s.tree) 221 n := t.nodeFromNaughty(start) 222 n.addVirtualLoss() 223 t.log("\t%p PIPELINE: %v", s, n) 224 225 // EXPAND and SIMULATE 226 isExpandable := n.IsExpandable(0) 227 if isExpandable && current.Passes() >= 2 { 228 retVal = Result(combinedScore(current)) 229 } else if isExpandable && nodeCount < MAXTREESIZE { 230 hadChildren := n.HasChildren() 231 value, ok := s.expandAndSimulate(start, current, s.minPsaRatio()) 232 if !hadChildren && ok { 233 retVal = Result(value) 234 } 235 } 236 237 // SELECT and RECURSE 238 if n.HasChildren() && isNullResult(retVal) { 239 next := t.nodeFromNaughty(n.Select(player)) 240 move := next.Move() 241 pm := game.PlayerMove{player, move} 242 243 // Check should check Superko. If it's superko, the node should be invalidated 244 if current.Check(pm) { 245 current = current.Apply(pm).(game.State) 246 retVal = s.pipeline(current, next.id) 247 } 248 } 249 250 // BACKPROPAGATE 251 if !isNullResult(retVal) { 252 n.Update(float32(retVal)) // nothing says non functional programs like side effects. Insert more functional programming circle jerk here. 253 } 254 n.undoVirtualLoss() 255 s.depth-- 256 return retVal 257 } 258 259 func (s *searchState) expandAndSimulate(parent naughty, state game.State, minPsaRatio float32) (value float32, ok bool) { 260 t := treeFromUintptr(s.tree) 261 n := t.nodeFromNaughty(parent) 262 263 t.log("\t\t%p Expand and Simulate. Parent Move: %v. Player: %v. Move number %d\n%v", s, n.Move(), state.ToMove(), state.MoveNumber(), state) 264 if !n.IsExpandable(minPsaRatio) { 265 t.log("\t\tNot expandable. MinPSA Ratio %v", minPsaRatio) 266 return 0, false 267 } 268 269 if state.Passes() >= 2 { 270 t.log("\t\t%p Passes >= 2", s) 271 return 0, false 272 } 273 // get scored moves 274 var policy []float32 // boardSize + 1 275 policy, value = t.nn.Infer(state) // get policy probability, value from neural network 276 passProb := policy[len(policy)-1] // probability of a pass is the last in the policy 277 player := state.ToMove() 278 if player == White { 279 value = 1 - value 280 } 281 282 var nodelist []pair 283 var legalSum float32 284 285 for i := 0; i < s.current.ActionSpace(); i++ { 286 if state.Check(game.PlayerMove{player, game.Single(i)}) { 287 nodelist = append(nodelist, pair{Score: policy[i], Coord: game.Single(i)}) 288 legalSum += policy[i] 289 } 290 } 291 t.log("\t\t%p Available Moves %d: %v", s, len(nodelist), nodelist) 292 293 if state.Check(game.PlayerMove{player, Pass}) { 294 nodelist = append(nodelist, pair{Score: passProb, Coord: Pass}) 295 legalSum += passProb 296 } 297 298 if legalSum > math32.SmallestNonzeroFloat32 { 299 // re normalize 300 for i := range nodelist { 301 nodelist[i].Score /= legalSum 302 } 303 } else { 304 prob := 1 / float32(len(nodelist)) 305 for i := range nodelist { 306 nodelist[i].Score = prob 307 } 308 } 309 310 if len(nodelist) == 0 { 311 t.log("\t\tNodelist is empty") 312 return value, true 313 } 314 sort.Sort(byScore(nodelist)) 315 maxPsa := nodelist[0].Score 316 oldMinPsa := maxPsa * n.MinPsaRatio() 317 newMinPsa := maxPsa * minPsaRatio 318 319 var skippedChildren bool 320 for _, p := range nodelist { 321 if p.Score < newMinPsa { 322 t.log("\t\tp.score %v < %v", p.Score, newMinPsa) 323 skippedChildren = true 324 } else if p.Score < oldMinPsa { 325 if nn := n.findChild(p.Coord); nn == nilNode { 326 nn := t.New(p.Coord, p.Score, value) 327 n.AddChild(nn) 328 } 329 } 330 } 331 t.log("\t\t%p skipped children? %v", s, skippedChildren) 332 if skippedChildren { 333 atomic.StoreUint32(&n.minPSARatioChildren, math32.Float32bits(minPsaRatio)) 334 } else { 335 // if no children were skipped, then all that can be expanded has been expanded 336 atomic.StoreUint32(&n.minPSARatioChildren, 0) 337 } 338 return value, true 339 } 340 341 func (t *MCTS) bestMove() game.Single { 342 player := t.current.ToMove() 343 moveNum := t.current.MoveNumber() 344 345 children := t.children[t.root] 346 t.log("%p Children: ", &t.searchState) 347 for _, child := range children { 348 nc := t.nodeFromNaughty(child) 349 t.log("\t\t\t%v", nc) 350 } 351 t.log("%v", t.current) 352 t.childLock[t.root].Lock() 353 sort.Sort(fancySort{underEval: player, l: children, t: t}) 354 t.childLock[t.root].Unlock() 355 356 if moveNum < t.Config.RandomCount { 357 t.randomizeChildren(t.root) 358 } 359 if len(children) == 0 { 360 t.log("Board\n%v |%v", t.current, t.nodeFromNaughty(t.root)) 361 return Pass 362 } 363 364 firstChild := t.nodeFromNaughty(children[0]) 365 bestMove := firstChild.Move() 366 bestScore := firstChild.Evaluate(player) 367 368 root := t.nodeFromNaughty(t.root) 369 switch { 370 case t.Config.PassPreference == DontPreferPass && bestMove.IsPass(): 371 bestMove, bestScore = t.noPassBestMove(bestMove, bestScore, t.root, t.current, player) 372 case !t.Config.DumbPass && bestMove.IsPass(): 373 score := root.Score() 374 if (score > 0 && player == White) || (score < 0 && player == Black) { 375 // passing will cause a loss. Let's find an alternative 376 bestMove, bestScore = t.noPassBestMove(bestMove, bestScore, t.root, t.current, player) 377 } 378 case !t.Config.DumbPass && t.current.LastMove().IsPass(): 379 score := root.Score() 380 if (score > 0 && player == White) || (score < 0 && player == Black) { 381 // passing loses. Play on. 382 } else { 383 bestMove = Pass 384 } 385 } 386 if bestMove.IsPass() && t.shouldResign(bestScore, player) { 387 bestMove = Resign 388 } 389 return bestMove 390 } 391 392 func (t *MCTS) prepareRoot(player game.Player, state game.State) { 393 root := t.nodeFromNaughty(t.root) 394 hadChildren := len(t.children[t.root]) > 0 395 expandable := root.IsExpandable(0) 396 var value float32 397 if expandable { 398 value, _ = t.expandAndSimulate(t.root, state, t.minPsaRatio()) 399 } 400 401 if hadChildren { 402 value = root.Evaluate(player) 403 } else { 404 root.Update(value) 405 if player == White { 406 // DO SOMETHING 407 } 408 } 409 410 // disable any children that is not suitable to be used 411 // children := t.children[t.root] 412 // for _, child := range children { 413 // c := t.nodeFromNaughty(child) 414 // if !t.searchState.current.Check(game.PlayerMove{player, game.Single(c.move)}) { 415 // log.Printf("Invalidating %v", c.move) 416 // c.Invalidate() 417 // } 418 // } 419 } 420 421 // newRootState moves the search state to use a new root state. It returns true when a new root state was created. 422 // 423 // As a side effect, the freeables list is also updated. 424 func (t *MCTS) newRootState() bool { 425 if t.root == nilNode || t.prev == nil { 426 t.log("No root") 427 return false // no current state. Cannot advance to new state 428 } 429 depth := t.current.MoveNumber() - t.prev.MoveNumber() 430 if depth < 0 { 431 t.log("depth < 0") 432 return false // oops too far 433 } 434 435 tmp := t.current.Clone().(game.State) 436 for i := 0; i < depth; i++ { 437 tmp.UndoLastMove() 438 } 439 if !tmp.Eq(t.prev) { 440 return false // they're not the same tree - a new root needs to be created 441 } 442 // try to replay tmp 443 t.log("depth %v", depth) 444 for i := 0; i < depth; i++ { 445 tmp.Fwd() 446 move := tmp.LastMove() 447 448 oldRoot := t.root 449 oldRootNode := t.nodeFromNaughty(oldRoot) 450 newRoot := oldRootNode.findChild(move.Single) 451 if newRoot == nilNode { 452 return false 453 } 454 t.Lock() 455 t.root = newRoot 456 t.Unlock() 457 t.cleanup(oldRoot, newRoot) 458 459 t.prev = t.prev.Apply(move).(game.State) 460 } 461 462 if t.current.MoveNumber() != t.prev.MoveNumber() { 463 return false 464 } 465 if !t.current.Eq(t.prev) { 466 return false 467 } 468 return true 469 } 470 471 // updateRoot updates the root after searching for a new root state. 472 // If no new root state can be found, a new Node indicating a PASS move is made. 473 func (t *MCTS) updateRoot() { 474 t.freeables = t.freeables[:0] 475 player := t.searchState.current.ToMove() 476 if !t.newRootState() || t.searchState.root == nilNode { 477 // search for the first useful 478 if ok := t.searchState.current.Check(game.PlayerMove{player, Pass}); ok { 479 t.searchState.root = t.New(Pass, 0, 0) 480 } else { 481 for i := 0; i < t.searchState.current.ActionSpace(); i++ { 482 if t.searchState.current.Check(game.PlayerMove{player, game.Single(i)}) { 483 t.searchState.root = t.New(game.Single(i), 0, 0) 484 break 485 } 486 } 487 } 488 } 489 t.log("freables %d", len(t.freeables)) 490 t.searchState.prev = nil 491 root := t.nodeFromNaughty(t.searchState.root) 492 atomic.StoreInt32(&t.nc, int32(root.countChildren())) 493 494 // if root has no children 495 children := t.Children(t.searchState.root) 496 if len(children) == 0 { 497 atomic.StoreUint32(&root.minPSARatioChildren, defaultMinPsaRatio) 498 } 499 500 } 501 502 func (t *MCTS) shouldResign(bestScore float32, player game.Player) bool { 503 504 if t.Config.PassPreference == DontResign { 505 return false 506 } 507 if t.Config.ResignPercentage == 0 { 508 return false 509 } 510 squares := t.Config.M * t.Config.N 511 threshold := squares / 4 512 moveNumber := t.current.MoveNumber() 513 if moveNumber <= threshold { 514 // too early to resign 515 return false 516 } 517 518 var resignThreshold float32 519 if t.Config.ResignPercentage < 0 { 520 resignThreshold = 0.1 521 } else { 522 resignThreshold = t.Config.ResignPercentage 523 } 524 525 if bestScore > resignThreshold { 526 return false 527 } 528 // TODO handicap 529 // handicap := t.current.Handicap() 530 // if handicap > 0 && player == White && t.Config.ResignPercentage < 0 { 531 532 // } 533 return true 534 } 535 536 // noPass finds aa child that is NOT a pass move that is valid (i.e. not in eye states for example) 537 func (t *MCTS) noPass(of naughty, state game.State, player game.Player) naughty { 538 children := t.children[of] 539 for _, kid := range children { 540 child := t.nodeFromNaughty(kid) 541 move := child.Move() 542 543 // in Go games, this also checks for eye-ish situations 544 ok := state.Check(game.PlayerMove{player, move}) 545 if !move.IsPass() && ok { 546 return kid 547 } 548 } 549 return nilNode 550 } 551 552 func (t *MCTS) noPassBestMove(bestMove game.Single, bestScore float32, of naughty, state game.State, player game.Player) (game.Single, float32) { 553 nopass := t.noPass(of, state, player) 554 if nopass.isValid() { 555 np := t.nodeFromNaughty(nopass) 556 bestMove = np.Move() 557 bestScore = 1 558 if !np.IsNotVisited() { 559 bestScore = np.Evaluate(player) 560 } 561 } 562 return bestMove, bestScore 563 }