github.com/gorgonia/agogo@v0.1.1/mcts/node.go (about)

     1  package mcts
     2  
     3  import (
     4  	"fmt"
     5  	"sort"
     6  	"sync/atomic"
     7  
     8  	"github.com/chewxy/math32"
     9  	"github.com/gorgonia/agogo/game"
    10  )
    11  
    12  type Status uint32
    13  
    14  const (
    15  	Invalid Status = iota
    16  	Active
    17  	Pruned
    18  )
    19  
    20  func (a Status) String() string {
    21  	switch a {
    22  	case Invalid:
    23  		return "Invalid"
    24  	case Active:
    25  		return "Active"
    26  	case Pruned:
    27  		return "Pruned"
    28  	}
    29  	return "UNKNOWN STATUS"
    30  }
    31  
    32  type Node struct {
    33  
    34  	// atomic access only pl0x
    35  	move   int32  // should be game.Single
    36  	visits uint32 // visits to this node - N(s, a) in the literature
    37  	status uint32 // status
    38  
    39  	// float32s
    40  	blackScores         uint32 // actually float32.
    41  	virtualLoss         uint32 // actually float32. a virtual loss number - 0 or 3
    42  	minPSARatioChildren uint32 // actually float32. minimum P(s,a) ratio for the children. Default to 2
    43  	score               uint32 // Policy estimate for taking the move above (from NN)
    44  	value               uint32 // value from the neural network
    45  
    46  	// naughty things
    47  	id   naughty // index to the children allocation
    48  	tree uintptr // pointer to the tree
    49  }
    50  
    51  func (n *Node) Format(s fmt.State, c rune) {
    52  	fmt.Fprintf(s, "{NodeID: %v Move: %v, Score: %v, Value %v Visits %v minPSARatioChildren %v Status: %v}", n.id, n.Move(), n.Score(), n.value, n.Visits(), n.minPSARatioChildren, Status(n.status))
    53  }
    54  
    55  // AddChild adds a child to the node
    56  func (n *Node) AddChild(child naughty) {
    57  	tree := treeFromUintptr(n.tree)
    58  	tree.Lock()
    59  	tree.children[n.id] = append(tree.children[n.id], child)
    60  	tree.Unlock()
    61  }
    62  
    63  // IsFirstVisit returns true if this node hasn't ever been visited
    64  func (n *Node) IsNotVisited() bool {
    65  	visits := atomic.LoadUint32(&n.visits)
    66  	return visits == 0
    67  }
    68  
    69  // Update updates the accumulated score
    70  func (n *Node) Update(score float32) {
    71  	t := treeFromUintptr(n.tree)
    72  	t.Lock()
    73  	atomic.AddUint32(&n.visits, 1)
    74  	n.accumulate(score)
    75  	t.Unlock()
    76  }
    77  
    78  // BlackScores returns the scores for black
    79  func (n *Node) BlackScores() float32 {
    80  	blackScores := atomic.LoadUint32(&n.blackScores)
    81  	return math32.Float32frombits(blackScores)
    82  }
    83  
    84  // Move gets the move associated with the node
    85  func (n *Node) Move() game.Single { return game.Single(atomic.LoadInt32(&n.move)) }
    86  
    87  // Score returns the score
    88  func (n *Node) Score() float32 {
    89  	v := atomic.LoadUint32(&n.score)
    90  	return math32.Float32frombits(v)
    91  }
    92  
    93  // Value returns the predicted value (probability of winning from the NN) of the given node
    94  func (n *Node) Value() float32 {
    95  	v := atomic.LoadUint32(&n.value)
    96  	return math32.Float32frombits(v)
    97  }
    98  
    99  func (n *Node) Visits() uint32 { return atomic.LoadUint32(&n.visits) }
   100  
   101  // Activate activates the node
   102  func (n *Node) Activate() { atomic.StoreUint32(&n.status, uint32(Active)) }
   103  
   104  // Prune prunes the node
   105  func (n *Node) Prune() { atomic.StoreUint32(&n.status, uint32(Pruned)) }
   106  
   107  // Invalidate invalidates the node
   108  func (n *Node) Invalidate() { atomic.StoreUint32(&n.status, uint32(Invalid)) }
   109  
   110  // IsValid returns true if it's valid
   111  func (n *Node) IsValid() bool {
   112  	status := atomic.LoadUint32(&n.status)
   113  	return Status(status) != Invalid
   114  }
   115  
   116  // IsActive returns true if the node is active
   117  func (n *Node) IsActive() bool {
   118  	status := atomic.LoadUint32(&n.status)
   119  	return Status(status) == Active
   120  }
   121  
   122  // IsPruned returns true if the node has been pruned.
   123  func (n *Node) IsPruned() bool {
   124  	status := atomic.LoadUint32(&n.status)
   125  	return Status(status) == Pruned
   126  }
   127  
   128  // HasChildren returns true if the node has children
   129  func (n *Node) HasChildren() bool { return n.MinPsaRatio() <= 1 }
   130  
   131  // IsExpandable returns true if the node is exandable. It may not be for memory reasons.
   132  func (n *Node) IsExpandable(minPsaRatio float32) bool { return minPsaRatio < n.MinPsaRatio() }
   133  
   134  func (n *Node) VirtualLoss() float32 {
   135  	v := atomic.LoadUint32(&n.virtualLoss)
   136  	return math32.Float32frombits(v)
   137  }
   138  
   139  func (n *Node) MinPsaRatio() float32 {
   140  	v := atomic.LoadUint32(&n.minPSARatioChildren)
   141  	return math32.Float32frombits(v)
   142  }
   143  
   144  func (n *Node) ID() int { return int(n.id) }
   145  
   146  // Evaluate evaluates a move made by a player
   147  func (n *Node) Evaluate(player game.Player) float32 {
   148  	visits := n.Visits()
   149  	blackScores := n.BlackScores()
   150  	if player == White {
   151  		blackScores += n.VirtualLoss()
   152  	}
   153  
   154  	score := blackScores / float32(visits)
   155  	if player == White {
   156  		score = 1 - score
   157  	}
   158  	return score
   159  }
   160  
   161  // NNEvaluate returns the result of the NN evaluation of the colour.
   162  func (n *Node) NNEvaluate(player game.Player) float32 {
   163  	if player == White {
   164  		return 1.0 - n.Value()
   165  	}
   166  	return n.Value()
   167  }
   168  
   169  // Select selects the child of the given Colour
   170  func (n *Node) Select(of game.Player) naughty {
   171  	// sumScore is the sum of scores of the node that has been visited by the policy
   172  	var sumScore float32
   173  	var parentVisits uint32
   174  
   175  	tree := treeFromUintptr(n.tree)
   176  
   177  	children := tree.Children(n.id)
   178  	for _, kid := range children {
   179  		child := tree.nodeFromNaughty(kid)
   180  		if child.IsValid() {
   181  			visits := child.Visits()
   182  			parentVisits += visits
   183  			if visits > 0 {
   184  				sumScore += child.Score()
   185  			}
   186  		}
   187  	}
   188  
   189  	// the upper bound formula is as such
   190  	// U(s, a) = Q(s, a) + tree.PUCT * P(s, a) * ((sqrt(parent visits))/ (1+visits to this node))
   191  	//
   192  	// where
   193  	// U(s, a) = upper confidence bound given state and action
   194  	// Q(s, a) = reward of taking the action given the state
   195  	// P(s, a) = iniital probability/estimate of taking an action from the state given according to the policy
   196  	//
   197  	// in the following code,
   198  	// psa = P(s, a)
   199  	// qsa = Q(s, a)
   200  	//
   201  	// Given the state and action is already known and encoded into Node itself,it doesn't have to be a function
   202  	// like in most MCTS tutorials. This allows it to be slightly more performant (i.e. a AoS-ish data structure)
   203  
   204  	var best naughty
   205  	var bestValue float32 = math32.Inf(-1)
   206  	fpu := n.NNEvaluate(of)                         // first play urgency is the value predicted by the NN
   207  	numerator := math32.Sqrt(float32(parentVisits)) // in order to find the stochastic policy, we need to normalize the count
   208  
   209  	for _, kid := range children {
   210  		child := tree.nodeFromNaughty(kid)
   211  		if !child.IsActive() {
   212  			continue
   213  		}
   214  
   215  		qsa := fpu // the initial Q is what the NN predicts
   216  		visits := child.Visits()
   217  		if visits > 0 {
   218  			qsa = child.Evaluate(of) // but if this node has been visited before, Q from the node is used.
   219  		}
   220  		psa := child.Score()
   221  		denominator := 1.0 + float32(visits)
   222  		lastTerm := (numerator / denominator)
   223  		puct := tree.PUCT * psa * lastTerm
   224  		usa := qsa + puct
   225  
   226  		if usa > bestValue {
   227  			bestValue = usa
   228  			best = kid
   229  		}
   230  	}
   231  
   232  	if best == nilNode {
   233  		panic("Cannot return nil")
   234  	}
   235  	// log.Printf("SELECT %v. Best %v - %v", of, best, tree.nodeFromNaughty(best))
   236  	return best
   237  }
   238  
   239  // BestChild returns the best scoring child. Note that fancySort has all sorts of heuristics
   240  func (n *Node) BestChild(player game.Player) naughty {
   241  	tree := treeFromUintptr(n.tree)
   242  	children := tree.Children(n.id)
   243  
   244  	sort.Sort(fancySort{player, children, tree})
   245  	return children[len(children)-1]
   246  }
   247  
   248  func (n *Node) addVirtualLoss() {
   249  	t := treeFromUintptr(n.tree)
   250  	t.Lock()
   251  	atomic.StoreUint32(&n.virtualLoss, virtualLoss1)
   252  	t.Unlock()
   253  }
   254  
   255  func (n *Node) undoVirtualLoss() {
   256  	t := treeFromUintptr(n.tree)
   257  	t.Lock()
   258  	atomic.StoreUint32(&n.virtualLoss, 0)
   259  	t.Unlock()
   260  }
   261  
   262  // accumulate adds to the score atomically
   263  func (n *Node) accumulate(score float32) {
   264  	blackScores := atomic.LoadUint32(&n.blackScores)
   265  	evals := math32.Float32frombits(blackScores)
   266  	evals += score
   267  	blackScores = math32.Float32bits(evals)
   268  	atomic.StoreUint32(&n.blackScores, blackScores)
   269  
   270  }
   271  
   272  // countChildren counts the number of children node a node has and number of grandkids recursively
   273  func (n *Node) countChildren() (retVal int) {
   274  	tree := treeFromUintptr(n.tree)
   275  	children := tree.Children(n.id)
   276  	for _, kid := range children {
   277  		child := tree.nodeFromNaughty(kid)
   278  		if child.IsActive() {
   279  			retVal += child.countChildren()
   280  		}
   281  		retVal++ // plus the child itself
   282  	}
   283  	return
   284  }
   285  
   286  // findChild finds the first child that has the wanted move
   287  func (n *Node) findChild(move game.Single) naughty {
   288  	tree := treeFromUintptr(n.tree)
   289  	children := tree.Children(n.id)
   290  	for _, kid := range children {
   291  		child := tree.nodeFromNaughty(kid)
   292  		if game.Single(child.Move()) == move {
   293  			return kid
   294  		}
   295  	}
   296  	return nilNode
   297  }
   298  
   299  func (n *Node) reset() {
   300  	atomic.StoreInt32(&n.move, -1)
   301  	atomic.StoreUint32(&n.visits, 0)
   302  	atomic.StoreUint32(&n.status, 0)
   303  	atomic.StoreUint32(&n.blackScores, 0)
   304  	atomic.StoreUint32(&n.minPSARatioChildren, defaultMinPsaRatio)
   305  	atomic.StoreUint32(&n.score, 0)
   306  	atomic.StoreUint32(&n.value, 0)
   307  	atomic.StoreUint32(&n.virtualLoss, 0)
   308  }