github.com/gorgonia/agogo@v0.1.1/mcts/search.go (about)

     1  package mcts
     2  
     3  import (
     4  	"context"
     5  	"runtime"
     6  	"sort"
     7  	"sync"
     8  	"sync/atomic"
     9  	"time"
    10  
    11  	"github.com/chewxy/math32"
    12  	"github.com/gorgonia/agogo/game"
    13  )
    14  
    15  /*
    16  Here lies the majority of the MCTS search code, while node.go and tree.go handles the data structure stuff.
    17  
    18  Right now the code is very specific to the game of Go. Ideally we'd be able to export the correct things and make it
    19  so that a search can be written for any other games but uses the same data structures
    20  */
    21  
    22  const (
    23  	MAXTREESIZE = 25000000 // a tree is at max allowed this many nodes - at about 56 bytes per node that is 1.2GB of memory required
    24  )
    25  
    26  func opponent(p game.Player) game.Player {
    27  	switch p {
    28  	case Black:
    29  		return White
    30  	case White:
    31  		return Black
    32  	}
    33  	panic("Unreachable")
    34  }
    35  
    36  // Result is a NaN tagged floating point, used to represent the reuslts.
    37  type Result float32
    38  
    39  const (
    40  	noResultBits = 0x7FE00000
    41  )
    42  
    43  func noResult() Result {
    44  	return Result(math32.Float32frombits(noResultBits))
    45  }
    46  
    47  // isNullResult returns true if the Result (a NaN tagged number) is noResult
    48  func isNullResult(r Result) bool {
    49  	b := math32.Float32bits(float32(r))
    50  	return b == noResultBits
    51  }
    52  
    53  type searchState struct {
    54  	tree          uintptr
    55  	current, prev game.State
    56  	root          naughty
    57  	depth         int
    58  
    59  	wg *sync.WaitGroup
    60  
    61  	// config
    62  	maxPlayouts, maxVisits, maxDepth int
    63  }
    64  
    65  func (s *searchState) nodeCount() int32 {
    66  	t := treeFromUintptr(s.tree)
    67  	return atomic.LoadInt32(&t.nc)
    68  }
    69  
    70  func (s *searchState) incrementPlayout() {
    71  	t := treeFromUintptr(s.tree)
    72  	atomic.AddInt32(&t.playouts, 1)
    73  }
    74  
    75  func (s *searchState) isRunning() bool {
    76  	t := treeFromUintptr(s.tree)
    77  	running := t.running.Load().(bool)
    78  	return running && t.nodeCount() < MAXTREESIZE
    79  }
    80  
    81  func (s *searchState) minPsaRatio() float32 {
    82  	ratio := float32(s.nodeCount()) / float32(MAXTREESIZE)
    83  	switch {
    84  	case ratio > 0.95:
    85  		return 0.01
    86  	case ratio > 0.5:
    87  		return 0.001
    88  	}
    89  	return 0
    90  }
    91  
    92  func (t *MCTS) Search(player game.Player) (retVal game.Single) {
    93  	t.log("SEARCH. Player %v\n%v", player, t.current)
    94  	t.updateRoot()
    95  	t.current.SetToMove(player)
    96  	boardHash := t.current.Hash()
    97  
    98  	// freeables
    99  	// if t.current.MoveNumber() == 1 {
   100  
   101  	// t.log("Acquiring lock ")
   102  	t.Lock()
   103  	for _, f := range t.freeables {
   104  		t.free(f)
   105  	}
   106  	t.Unlock()
   107  	// }
   108  
   109  	t.prepareRoot(player, t.current)
   110  	root := t.nodeFromNaughty(t.root)
   111  
   112  	ch := make(chan *searchState, runtime.NumCPU())
   113  	var wg sync.WaitGroup
   114  	for i := 0; i < runtime.NumCPU(); i++ {
   115  		ss := &searchState{
   116  			tree:     ptrFromTree(t),
   117  			current:  t.current,
   118  			root:     t.root,
   119  			maxDepth: t.M * t.N,
   120  			wg:       &wg,
   121  		}
   122  		ch <- ss
   123  	}
   124  
   125  	var iter int32
   126  	t.running.Store(true)
   127  	ctx, cancel := context.WithCancel(context.Background())
   128  	for i := 0; i < runtime.NumCPU(); i++ {
   129  		wg.Add(1)
   130  		go doSearch(t.root, &iter, ch, ctx, &wg)
   131  	}
   132  	<-time.After(t.Timeout)
   133  	cancel()
   134  
   135  	// TODO
   136  	// reactivate all pruned children
   137  	wg.Wait()
   138  	close(ch)
   139  
   140  	root = t.nodeFromNaughty(t.root)
   141  	if !root.HasChildren() {
   142  		policy, _ := t.nn.Infer(t.current)
   143  		moveID := argmax(policy)
   144  		if moveID > t.current.ActionSpace() {
   145  			return Pass
   146  		}
   147  		t.log("Returning Early. Best %v", moveID)
   148  		return game.Single(moveID)
   149  	}
   150  
   151  	retVal = t.bestMove()
   152  	t.prev = t.current.Clone().(game.State)
   153  	t.log("Move Number %d, Iterations %d Playouts: %v Nodes: %v. Best: %v", t.current.MoveNumber(), iter, t.playouts, len(t.nodes), retVal)
   154  	t.log("DUMMY")
   155  	// log.Printf("\n%v", t.prev)
   156  	// log.Printf("\tIterations %d Playouts: %v Nodes: %v. Best move %v Player %v", iter, t.playouts, len(t.nodes), retVal, player)
   157  
   158  	// update the cached policies.
   159  	// Again, nothing like having side effects to what appears to be a straightforwards
   160  	// pure function eh?
   161  	t.cachedPolicies[sa{boardHash, retVal}]++
   162  
   163  	return retVal
   164  }
   165  
   166  func doSearch(start naughty, iterBudget *int32, ch chan *searchState, ctx context.Context, wg *sync.WaitGroup) {
   167  	defer wg.Done()
   168  
   169  loop:
   170  	for {
   171  		select {
   172  		case s := <-ch:
   173  			current := s.current.Clone().(game.State)
   174  			root := start
   175  			res := s.pipeline(current, root)
   176  			if !isNullResult(res) {
   177  				s.incrementPlayout()
   178  			}
   179  
   180  			t := treeFromUintptr(s.tree)
   181  			val := atomic.AddInt32(iterBudget, 1)
   182  
   183  			if val > t.Budget {
   184  				t.running.Store(false)
   185  			}
   186  			// running := t.running.Load().(bool)
   187  			// running = running && !s.stopThinking( /*TODO*/ )
   188  			// running = running && s.hasAlternateMoves( /*TODO*/ )
   189  			if s.depth == s.maxDepth {
   190  				// reset s for another bout of playouts
   191  				s.root = t.root
   192  				s.current = t.current
   193  				s.depth = 0
   194  			}
   195  			ch <- s
   196  		case <-ctx.Done():
   197  			break loop
   198  		}
   199  	}
   200  
   201  	return
   202  }
   203  
   204  // pipeline is a recursive MCTS pipeline:
   205  //	SELECT, EXPAND, SIMULATE, BACKPROPAGATE.
   206  //
   207  // Because of the recursive nature, the pipeline is altered a bit to be this:
   208  //	EXPAND and SIMULATE, SELECT and RECURSE, BACKPROPAGATE.
   209  func (s *searchState) pipeline(current game.State, start naughty) (retVal Result) {
   210  	retVal = noResult()
   211  	s.depth++
   212  	if s.depth > s.maxDepth {
   213  		s.depth--
   214  		return
   215  	}
   216  
   217  	player := current.ToMove()
   218  	nodeCount := s.nodeCount()
   219  
   220  	t := treeFromUintptr(s.tree)
   221  	n := t.nodeFromNaughty(start)
   222  	n.addVirtualLoss()
   223  	t.log("\t%p PIPELINE: %v", s, n)
   224  
   225  	// EXPAND and SIMULATE
   226  	isExpandable := n.IsExpandable(0)
   227  	if isExpandable && current.Passes() >= 2 {
   228  		retVal = Result(combinedScore(current))
   229  	} else if isExpandable && nodeCount < MAXTREESIZE {
   230  		hadChildren := n.HasChildren()
   231  		value, ok := s.expandAndSimulate(start, current, s.minPsaRatio())
   232  		if !hadChildren && ok {
   233  			retVal = Result(value)
   234  		}
   235  	}
   236  
   237  	// SELECT and RECURSE
   238  	if n.HasChildren() && isNullResult(retVal) {
   239  		next := t.nodeFromNaughty(n.Select(player))
   240  		move := next.Move()
   241  		pm := game.PlayerMove{player, move}
   242  
   243  		// Check should check Superko. If it's superko, the node should be invalidated
   244  		if current.Check(pm) {
   245  			current = current.Apply(pm).(game.State)
   246  			retVal = s.pipeline(current, next.id)
   247  		}
   248  	}
   249  
   250  	// BACKPROPAGATE
   251  	if !isNullResult(retVal) {
   252  		n.Update(float32(retVal)) // nothing says non functional programs like side effects. Insert more functional programming circle jerk here.
   253  	}
   254  	n.undoVirtualLoss()
   255  	s.depth--
   256  	return retVal
   257  }
   258  
   259  func (s *searchState) expandAndSimulate(parent naughty, state game.State, minPsaRatio float32) (value float32, ok bool) {
   260  	t := treeFromUintptr(s.tree)
   261  	n := t.nodeFromNaughty(parent)
   262  
   263  	t.log("\t\t%p Expand and Simulate. Parent Move: %v. Player: %v. Move number %d\n%v", s, n.Move(), state.ToMove(), state.MoveNumber(), state)
   264  	if !n.IsExpandable(minPsaRatio) {
   265  		t.log("\t\tNot expandable. MinPSA Ratio %v", minPsaRatio)
   266  		return 0, false
   267  	}
   268  
   269  	if state.Passes() >= 2 {
   270  		t.log("\t\t%p Passes >= 2", s)
   271  		return 0, false
   272  	}
   273  	// get scored moves
   274  	var policy []float32              // boardSize + 1
   275  	policy, value = t.nn.Infer(state) // get policy probability, value from neural network
   276  	passProb := policy[len(policy)-1] // probability of a pass is the last in the policy
   277  	player := state.ToMove()
   278  	if player == White {
   279  		value = 1 - value
   280  	}
   281  
   282  	var nodelist []pair
   283  	var legalSum float32
   284  
   285  	for i := 0; i < s.current.ActionSpace(); i++ {
   286  		if state.Check(game.PlayerMove{player, game.Single(i)}) {
   287  			nodelist = append(nodelist, pair{Score: policy[i], Coord: game.Single(i)})
   288  			legalSum += policy[i]
   289  		}
   290  	}
   291  	t.log("\t\t%p Available Moves %d: %v", s, len(nodelist), nodelist)
   292  
   293  	if state.Check(game.PlayerMove{player, Pass}) {
   294  		nodelist = append(nodelist, pair{Score: passProb, Coord: Pass})
   295  		legalSum += passProb
   296  	}
   297  
   298  	if legalSum > math32.SmallestNonzeroFloat32 {
   299  		// re normalize
   300  		for i := range nodelist {
   301  			nodelist[i].Score /= legalSum
   302  		}
   303  	} else {
   304  		prob := 1 / float32(len(nodelist))
   305  		for i := range nodelist {
   306  			nodelist[i].Score = prob
   307  		}
   308  	}
   309  
   310  	if len(nodelist) == 0 {
   311  		t.log("\t\tNodelist is empty")
   312  		return value, true
   313  	}
   314  	sort.Sort(byScore(nodelist))
   315  	maxPsa := nodelist[0].Score
   316  	oldMinPsa := maxPsa * n.MinPsaRatio()
   317  	newMinPsa := maxPsa * minPsaRatio
   318  
   319  	var skippedChildren bool
   320  	for _, p := range nodelist {
   321  		if p.Score < newMinPsa {
   322  			t.log("\t\tp.score %v <  %v", p.Score, newMinPsa)
   323  			skippedChildren = true
   324  		} else if p.Score < oldMinPsa {
   325  			if nn := n.findChild(p.Coord); nn == nilNode {
   326  				nn := t.New(p.Coord, p.Score, value)
   327  				n.AddChild(nn)
   328  			}
   329  		}
   330  	}
   331  	t.log("\t\t%p skipped children? %v", s, skippedChildren)
   332  	if skippedChildren {
   333  		atomic.StoreUint32(&n.minPSARatioChildren, math32.Float32bits(minPsaRatio))
   334  	} else {
   335  		// if no children were skipped, then all that can be expanded has been expanded
   336  		atomic.StoreUint32(&n.minPSARatioChildren, 0)
   337  	}
   338  	return value, true
   339  }
   340  
   341  func (t *MCTS) bestMove() game.Single {
   342  	player := t.current.ToMove()
   343  	moveNum := t.current.MoveNumber()
   344  
   345  	children := t.children[t.root]
   346  	t.log("%p Children: ", &t.searchState)
   347  	for _, child := range children {
   348  		nc := t.nodeFromNaughty(child)
   349  		t.log("\t\t\t%v", nc)
   350  	}
   351  	t.log("%v", t.current)
   352  	t.childLock[t.root].Lock()
   353  	sort.Sort(fancySort{underEval: player, l: children, t: t})
   354  	t.childLock[t.root].Unlock()
   355  
   356  	if moveNum < t.Config.RandomCount {
   357  		t.randomizeChildren(t.root)
   358  	}
   359  	if len(children) == 0 {
   360  		t.log("Board\n%v |%v", t.current, t.nodeFromNaughty(t.root))
   361  		return Pass
   362  	}
   363  
   364  	firstChild := t.nodeFromNaughty(children[0])
   365  	bestMove := firstChild.Move()
   366  	bestScore := firstChild.Evaluate(player)
   367  
   368  	root := t.nodeFromNaughty(t.root)
   369  	switch {
   370  	case t.Config.PassPreference == DontPreferPass && bestMove.IsPass():
   371  		bestMove, bestScore = t.noPassBestMove(bestMove, bestScore, t.root, t.current, player)
   372  	case !t.Config.DumbPass && bestMove.IsPass():
   373  		score := root.Score()
   374  		if (score > 0 && player == White) || (score < 0 && player == Black) {
   375  			// passing will cause a loss. Let's find an alternative
   376  			bestMove, bestScore = t.noPassBestMove(bestMove, bestScore, t.root, t.current, player)
   377  		}
   378  	case !t.Config.DumbPass && t.current.LastMove().IsPass():
   379  		score := root.Score()
   380  		if (score > 0 && player == White) || (score < 0 && player == Black) {
   381  			// passing loses. Play on.
   382  		} else {
   383  			bestMove = Pass
   384  		}
   385  	}
   386  	if bestMove.IsPass() && t.shouldResign(bestScore, player) {
   387  		bestMove = Resign
   388  	}
   389  	return bestMove
   390  }
   391  
   392  func (t *MCTS) prepareRoot(player game.Player, state game.State) {
   393  	root := t.nodeFromNaughty(t.root)
   394  	hadChildren := len(t.children[t.root]) > 0
   395  	expandable := root.IsExpandable(0)
   396  	var value float32
   397  	if expandable {
   398  		value, _ = t.expandAndSimulate(t.root, state, t.minPsaRatio())
   399  	}
   400  
   401  	if hadChildren {
   402  		value = root.Evaluate(player)
   403  	} else {
   404  		root.Update(value)
   405  		if player == White {
   406  			// DO SOMETHING
   407  		}
   408  	}
   409  
   410  	// disable any children that is not suitable to be used
   411  	// children := t.children[t.root]
   412  	// for _, child := range children {
   413  	// 	c := t.nodeFromNaughty(child)
   414  	// 	if !t.searchState.current.Check(game.PlayerMove{player, game.Single(c.move)}) {
   415  	// 		log.Printf("Invalidating %v", c.move)
   416  	// 		c.Invalidate()
   417  	// 	}
   418  	// }
   419  }
   420  
   421  // newRootState moves the search state to use a new root state. It returns true when a new root state was created.
   422  //
   423  // As a side effect, the freeables list is also updated.
   424  func (t *MCTS) newRootState() bool {
   425  	if t.root == nilNode || t.prev == nil {
   426  		t.log("No root")
   427  		return false // no current state. Cannot advance to new state
   428  	}
   429  	depth := t.current.MoveNumber() - t.prev.MoveNumber()
   430  	if depth < 0 {
   431  		t.log("depth < 0")
   432  		return false // oops too far
   433  	}
   434  
   435  	tmp := t.current.Clone().(game.State)
   436  	for i := 0; i < depth; i++ {
   437  		tmp.UndoLastMove()
   438  	}
   439  	if !tmp.Eq(t.prev) {
   440  		return false // they're not the same tree - a new root needs to be created
   441  	}
   442  	// try to replay tmp
   443  	t.log("depth %v", depth)
   444  	for i := 0; i < depth; i++ {
   445  		tmp.Fwd()
   446  		move := tmp.LastMove()
   447  
   448  		oldRoot := t.root
   449  		oldRootNode := t.nodeFromNaughty(oldRoot)
   450  		newRoot := oldRootNode.findChild(move.Single)
   451  		if newRoot == nilNode {
   452  			return false
   453  		}
   454  		t.Lock()
   455  		t.root = newRoot
   456  		t.Unlock()
   457  		t.cleanup(oldRoot, newRoot)
   458  
   459  		t.prev = t.prev.Apply(move).(game.State)
   460  	}
   461  
   462  	if t.current.MoveNumber() != t.prev.MoveNumber() {
   463  		return false
   464  	}
   465  	if !t.current.Eq(t.prev) {
   466  		return false
   467  	}
   468  	return true
   469  }
   470  
   471  // updateRoot updates the root after searching for a new root state.
   472  // If no new root state can be found, a new Node indicating a PASS move is made.
   473  func (t *MCTS) updateRoot() {
   474  	t.freeables = t.freeables[:0]
   475  	player := t.searchState.current.ToMove()
   476  	if !t.newRootState() || t.searchState.root == nilNode {
   477  		// search for the first useful
   478  		if ok := t.searchState.current.Check(game.PlayerMove{player, Pass}); ok {
   479  			t.searchState.root = t.New(Pass, 0, 0)
   480  		} else {
   481  			for i := 0; i < t.searchState.current.ActionSpace(); i++ {
   482  				if t.searchState.current.Check(game.PlayerMove{player, game.Single(i)}) {
   483  					t.searchState.root = t.New(game.Single(i), 0, 0)
   484  					break
   485  				}
   486  			}
   487  		}
   488  	}
   489  	t.log("freables %d", len(t.freeables))
   490  	t.searchState.prev = nil
   491  	root := t.nodeFromNaughty(t.searchState.root)
   492  	atomic.StoreInt32(&t.nc, int32(root.countChildren()))
   493  
   494  	// if root has no children
   495  	children := t.Children(t.searchState.root)
   496  	if len(children) == 0 {
   497  		atomic.StoreUint32(&root.minPSARatioChildren, defaultMinPsaRatio)
   498  	}
   499  
   500  }
   501  
   502  func (t *MCTS) shouldResign(bestScore float32, player game.Player) bool {
   503  
   504  	if t.Config.PassPreference == DontResign {
   505  		return false
   506  	}
   507  	if t.Config.ResignPercentage == 0 {
   508  		return false
   509  	}
   510  	squares := t.Config.M * t.Config.N
   511  	threshold := squares / 4
   512  	moveNumber := t.current.MoveNumber()
   513  	if moveNumber <= threshold {
   514  		// too early to resign
   515  		return false
   516  	}
   517  
   518  	var resignThreshold float32
   519  	if t.Config.ResignPercentage < 0 {
   520  		resignThreshold = 0.1
   521  	} else {
   522  		resignThreshold = t.Config.ResignPercentage
   523  	}
   524  
   525  	if bestScore > resignThreshold {
   526  		return false
   527  	}
   528  	// TODO handicap
   529  	// handicap := t.current.Handicap()
   530  	// if handicap > 0 && player == White && t.Config.ResignPercentage < 0 {
   531  
   532  	// }
   533  	return true
   534  }
   535  
   536  // noPass finds aa child that is NOT a pass move that is valid (i.e. not in eye states for example)
   537  func (t *MCTS) noPass(of naughty, state game.State, player game.Player) naughty {
   538  	children := t.children[of]
   539  	for _, kid := range children {
   540  		child := t.nodeFromNaughty(kid)
   541  		move := child.Move()
   542  
   543  		// in Go games, this also checks for eye-ish situations
   544  		ok := state.Check(game.PlayerMove{player, move})
   545  		if !move.IsPass() && ok {
   546  			return kid
   547  		}
   548  	}
   549  	return nilNode
   550  }
   551  
   552  func (t *MCTS) noPassBestMove(bestMove game.Single, bestScore float32, of naughty, state game.State, player game.Player) (game.Single, float32) {
   553  	nopass := t.noPass(of, state, player)
   554  	if nopass.isValid() {
   555  		np := t.nodeFromNaughty(nopass)
   556  		bestMove = np.Move()
   557  		bestScore = 1
   558  		if !np.IsNotVisited() {
   559  			bestScore = np.Evaluate(player)
   560  		}
   561  	}
   562  	return bestMove, bestScore
   563  }