github.com/gorgonia/agogo@v0.1.1/mcts/node.go (about) 1 package mcts 2 3 import ( 4 "fmt" 5 "sort" 6 "sync/atomic" 7 8 "github.com/chewxy/math32" 9 "github.com/gorgonia/agogo/game" 10 ) 11 12 type Status uint32 13 14 const ( 15 Invalid Status = iota 16 Active 17 Pruned 18 ) 19 20 func (a Status) String() string { 21 switch a { 22 case Invalid: 23 return "Invalid" 24 case Active: 25 return "Active" 26 case Pruned: 27 return "Pruned" 28 } 29 return "UNKNOWN STATUS" 30 } 31 32 type Node struct { 33 34 // atomic access only pl0x 35 move int32 // should be game.Single 36 visits uint32 // visits to this node - N(s, a) in the literature 37 status uint32 // status 38 39 // float32s 40 blackScores uint32 // actually float32. 41 virtualLoss uint32 // actually float32. a virtual loss number - 0 or 3 42 minPSARatioChildren uint32 // actually float32. minimum P(s,a) ratio for the children. Default to 2 43 score uint32 // Policy estimate for taking the move above (from NN) 44 value uint32 // value from the neural network 45 46 // naughty things 47 id naughty // index to the children allocation 48 tree uintptr // pointer to the tree 49 } 50 51 func (n *Node) Format(s fmt.State, c rune) { 52 fmt.Fprintf(s, "{NodeID: %v Move: %v, Score: %v, Value %v Visits %v minPSARatioChildren %v Status: %v}", n.id, n.Move(), n.Score(), n.value, n.Visits(), n.minPSARatioChildren, Status(n.status)) 53 } 54 55 // AddChild adds a child to the node 56 func (n *Node) AddChild(child naughty) { 57 tree := treeFromUintptr(n.tree) 58 tree.Lock() 59 tree.children[n.id] = append(tree.children[n.id], child) 60 tree.Unlock() 61 } 62 63 // IsFirstVisit returns true if this node hasn't ever been visited 64 func (n *Node) IsNotVisited() bool { 65 visits := atomic.LoadUint32(&n.visits) 66 return visits == 0 67 } 68 69 // Update updates the accumulated score 70 func (n *Node) Update(score float32) { 71 t := treeFromUintptr(n.tree) 72 t.Lock() 73 atomic.AddUint32(&n.visits, 1) 74 n.accumulate(score) 75 t.Unlock() 76 } 77 78 // BlackScores returns the scores for black 79 func (n *Node) BlackScores() float32 { 80 blackScores := atomic.LoadUint32(&n.blackScores) 81 return math32.Float32frombits(blackScores) 82 } 83 84 // Move gets the move associated with the node 85 func (n *Node) Move() game.Single { return game.Single(atomic.LoadInt32(&n.move)) } 86 87 // Score returns the score 88 func (n *Node) Score() float32 { 89 v := atomic.LoadUint32(&n.score) 90 return math32.Float32frombits(v) 91 } 92 93 // Value returns the predicted value (probability of winning from the NN) of the given node 94 func (n *Node) Value() float32 { 95 v := atomic.LoadUint32(&n.value) 96 return math32.Float32frombits(v) 97 } 98 99 func (n *Node) Visits() uint32 { return atomic.LoadUint32(&n.visits) } 100 101 // Activate activates the node 102 func (n *Node) Activate() { atomic.StoreUint32(&n.status, uint32(Active)) } 103 104 // Prune prunes the node 105 func (n *Node) Prune() { atomic.StoreUint32(&n.status, uint32(Pruned)) } 106 107 // Invalidate invalidates the node 108 func (n *Node) Invalidate() { atomic.StoreUint32(&n.status, uint32(Invalid)) } 109 110 // IsValid returns true if it's valid 111 func (n *Node) IsValid() bool { 112 status := atomic.LoadUint32(&n.status) 113 return Status(status) != Invalid 114 } 115 116 // IsActive returns true if the node is active 117 func (n *Node) IsActive() bool { 118 status := atomic.LoadUint32(&n.status) 119 return Status(status) == Active 120 } 121 122 // IsPruned returns true if the node has been pruned. 123 func (n *Node) IsPruned() bool { 124 status := atomic.LoadUint32(&n.status) 125 return Status(status) == Pruned 126 } 127 128 // HasChildren returns true if the node has children 129 func (n *Node) HasChildren() bool { return n.MinPsaRatio() <= 1 } 130 131 // IsExpandable returns true if the node is exandable. It may not be for memory reasons. 132 func (n *Node) IsExpandable(minPsaRatio float32) bool { return minPsaRatio < n.MinPsaRatio() } 133 134 func (n *Node) VirtualLoss() float32 { 135 v := atomic.LoadUint32(&n.virtualLoss) 136 return math32.Float32frombits(v) 137 } 138 139 func (n *Node) MinPsaRatio() float32 { 140 v := atomic.LoadUint32(&n.minPSARatioChildren) 141 return math32.Float32frombits(v) 142 } 143 144 func (n *Node) ID() int { return int(n.id) } 145 146 // Evaluate evaluates a move made by a player 147 func (n *Node) Evaluate(player game.Player) float32 { 148 visits := n.Visits() 149 blackScores := n.BlackScores() 150 if player == White { 151 blackScores += n.VirtualLoss() 152 } 153 154 score := blackScores / float32(visits) 155 if player == White { 156 score = 1 - score 157 } 158 return score 159 } 160 161 // NNEvaluate returns the result of the NN evaluation of the colour. 162 func (n *Node) NNEvaluate(player game.Player) float32 { 163 if player == White { 164 return 1.0 - n.Value() 165 } 166 return n.Value() 167 } 168 169 // Select selects the child of the given Colour 170 func (n *Node) Select(of game.Player) naughty { 171 // sumScore is the sum of scores of the node that has been visited by the policy 172 var sumScore float32 173 var parentVisits uint32 174 175 tree := treeFromUintptr(n.tree) 176 177 children := tree.Children(n.id) 178 for _, kid := range children { 179 child := tree.nodeFromNaughty(kid) 180 if child.IsValid() { 181 visits := child.Visits() 182 parentVisits += visits 183 if visits > 0 { 184 sumScore += child.Score() 185 } 186 } 187 } 188 189 // the upper bound formula is as such 190 // U(s, a) = Q(s, a) + tree.PUCT * P(s, a) * ((sqrt(parent visits))/ (1+visits to this node)) 191 // 192 // where 193 // U(s, a) = upper confidence bound given state and action 194 // Q(s, a) = reward of taking the action given the state 195 // P(s, a) = iniital probability/estimate of taking an action from the state given according to the policy 196 // 197 // in the following code, 198 // psa = P(s, a) 199 // qsa = Q(s, a) 200 // 201 // Given the state and action is already known and encoded into Node itself,it doesn't have to be a function 202 // like in most MCTS tutorials. This allows it to be slightly more performant (i.e. a AoS-ish data structure) 203 204 var best naughty 205 var bestValue float32 = math32.Inf(-1) 206 fpu := n.NNEvaluate(of) // first play urgency is the value predicted by the NN 207 numerator := math32.Sqrt(float32(parentVisits)) // in order to find the stochastic policy, we need to normalize the count 208 209 for _, kid := range children { 210 child := tree.nodeFromNaughty(kid) 211 if !child.IsActive() { 212 continue 213 } 214 215 qsa := fpu // the initial Q is what the NN predicts 216 visits := child.Visits() 217 if visits > 0 { 218 qsa = child.Evaluate(of) // but if this node has been visited before, Q from the node is used. 219 } 220 psa := child.Score() 221 denominator := 1.0 + float32(visits) 222 lastTerm := (numerator / denominator) 223 puct := tree.PUCT * psa * lastTerm 224 usa := qsa + puct 225 226 if usa > bestValue { 227 bestValue = usa 228 best = kid 229 } 230 } 231 232 if best == nilNode { 233 panic("Cannot return nil") 234 } 235 // log.Printf("SELECT %v. Best %v - %v", of, best, tree.nodeFromNaughty(best)) 236 return best 237 } 238 239 // BestChild returns the best scoring child. Note that fancySort has all sorts of heuristics 240 func (n *Node) BestChild(player game.Player) naughty { 241 tree := treeFromUintptr(n.tree) 242 children := tree.Children(n.id) 243 244 sort.Sort(fancySort{player, children, tree}) 245 return children[len(children)-1] 246 } 247 248 func (n *Node) addVirtualLoss() { 249 t := treeFromUintptr(n.tree) 250 t.Lock() 251 atomic.StoreUint32(&n.virtualLoss, virtualLoss1) 252 t.Unlock() 253 } 254 255 func (n *Node) undoVirtualLoss() { 256 t := treeFromUintptr(n.tree) 257 t.Lock() 258 atomic.StoreUint32(&n.virtualLoss, 0) 259 t.Unlock() 260 } 261 262 // accumulate adds to the score atomically 263 func (n *Node) accumulate(score float32) { 264 blackScores := atomic.LoadUint32(&n.blackScores) 265 evals := math32.Float32frombits(blackScores) 266 evals += score 267 blackScores = math32.Float32bits(evals) 268 atomic.StoreUint32(&n.blackScores, blackScores) 269 270 } 271 272 // countChildren counts the number of children node a node has and number of grandkids recursively 273 func (n *Node) countChildren() (retVal int) { 274 tree := treeFromUintptr(n.tree) 275 children := tree.Children(n.id) 276 for _, kid := range children { 277 child := tree.nodeFromNaughty(kid) 278 if child.IsActive() { 279 retVal += child.countChildren() 280 } 281 retVal++ // plus the child itself 282 } 283 return 284 } 285 286 // findChild finds the first child that has the wanted move 287 func (n *Node) findChild(move game.Single) naughty { 288 tree := treeFromUintptr(n.tree) 289 children := tree.Children(n.id) 290 for _, kid := range children { 291 child := tree.nodeFromNaughty(kid) 292 if game.Single(child.Move()) == move { 293 return kid 294 } 295 } 296 return nilNode 297 } 298 299 func (n *Node) reset() { 300 atomic.StoreInt32(&n.move, -1) 301 atomic.StoreUint32(&n.visits, 0) 302 atomic.StoreUint32(&n.status, 0) 303 atomic.StoreUint32(&n.blackScores, 0) 304 atomic.StoreUint32(&n.minPSARatioChildren, defaultMinPsaRatio) 305 atomic.StoreUint32(&n.score, 0) 306 atomic.StoreUint32(&n.value, 0) 307 atomic.StoreUint32(&n.virtualLoss, 0) 308 }