github.com/gorgonia/agogo@v0.1.1/mcts/tree.go (about) 1 package mcts 2 3 import ( 4 "math/rand" 5 "runtime" 6 "sync" 7 "sync/atomic" 8 "time" 9 10 "github.com/chewxy/math32" 11 "github.com/gorgonia/agogo/game" 12 ) 13 14 // Config is the structure to configure the MCTS multitree (poorly named Tree) 15 type Config struct { 16 // PUCT is the proportion of polynomial upper confidence trees to keep. Between 1 and 0 17 PUCT float32 18 Timeout time.Duration 19 20 // M, N represents the height and width. 21 M, N int 22 RandomCount int // if the move number is less than this, we should randomize 23 Budget int32 // iteration budget 24 RandomMinVisits uint32 25 RandomTemperature float32 26 DumbPass bool 27 ResignPercentage float32 28 PassPreference PassPreference 29 } 30 31 func DefaultConfig(boardSize int) Config { 32 return Config{ 33 PUCT: 1.0, 34 Timeout: 100 * time.Millisecond, 35 M: boardSize, 36 N: boardSize, 37 DumbPass: true, 38 PassPreference: DontPreferPass, 39 Budget: 10000, 40 } 41 } 42 43 func (c Config) IsValid() bool { 44 return c.PUCT > 0 && c.PUCT <= 1 45 } 46 47 // sa is a state-action tuple, used for storing results 48 type sa struct { 49 s game.Zobrist // a zobrist hash of the board 50 a game.Single 51 } 52 53 // MCTS is essentially a "global" manager of sorts for the memories. The goal is to build MCTS without much pointer chasing. 54 type MCTS struct { 55 sync.RWMutex 56 Config 57 nn Inferencer 58 rand *rand.Rand 59 60 // memory related fields 61 nodes []Node 62 // children map[naughty][]naughty 63 children [][]naughty 64 childLock []sync.Mutex 65 66 freelist []naughty 67 freeables []naughty // list of nodes that can be freed 68 69 // global searchState 70 searchState 71 playouts, nc int32 // atomic pls 72 running atomic.Value 73 74 // global policy values - useful for building policy vectors 75 cachedPolicies map[sa]float32 76 77 lumberjack 78 } 79 80 func New(game game.State, conf Config, nn Inferencer) *MCTS { 81 retVal := &MCTS{ 82 Config: conf, 83 nn: nn, 84 rand: rand.New(rand.NewSource(time.Now().UnixNano())), 85 86 nodes: make([]Node, 0, 12288), 87 // children: make(map[naughty][]naughty), 88 children: make([][]naughty, 0, 12288), 89 childLock: make([]sync.Mutex, 0, 12288), 90 91 searchState: searchState{ 92 root: nilNode, 93 current: game, 94 }, 95 96 cachedPolicies: make(map[sa]float32), 97 lumberjack: makeLumberJack(), 98 } 99 go retVal.start() 100 retVal.searchState.tree = ptrFromTree(retVal) 101 retVal.searchState.maxDepth = retVal.M * retVal.N 102 return retVal 103 } 104 105 // New creates a new node 106 func (t *MCTS) New(move game.Single, score, value float32) (retVal naughty) { 107 n := t.alloc() 108 N := t.nodeFromNaughty(n) 109 atomic.StoreInt32(&N.move, int32(move)) 110 atomic.StoreUint32(&N.visits, 1) 111 atomic.StoreUint32(&N.status, uint32(Active)) 112 atomic.StoreUint32(&N.score, math32.Float32bits(score)) 113 atomic.StoreUint32(&N.value, math32.Float32bits(value)) 114 115 // log.Printf("New node %p - %v", N, N) 116 return n 117 } 118 119 // SetGame sets the game 120 func (t *MCTS) SetGame(g game.State) { 121 t.Lock() 122 t.current = g 123 t.Unlock() 124 } 125 126 func (t *MCTS) Nodes() int { return len(t.nodes) } 127 128 func (t *MCTS) Policies(g game.State) []float32 { 129 hash := g.Hash() 130 var sum float32 131 actionSpacePlusPass := g.ActionSpace() + 1 132 retVal := make([]float32, actionSpacePlusPass) 133 for i := 0; i < actionSpacePlusPass; i++ { 134 prob := t.cachedPolicies[sa{s: hash, a: game.Single(i)}] 135 retVal[i] = prob 136 sum += prob 137 } 138 for i := 0; i < len(retVal); i++ { 139 retVal[i] /= sum 140 } 141 return retVal 142 } 143 144 // alloc tries to get a node from the free list. If none is found a new node is allocated into the master arena 145 func (t *MCTS) alloc() naughty { 146 t.Lock() 147 l := len(t.freelist) 148 if l == 0 { 149 N := Node{ 150 tree: ptrFromTree(t), 151 id: naughty(len(t.nodes)), 152 153 minPSARatioChildren: defaultMinPsaRatio, 154 } 155 t.nodes = append(t.nodes, N) 156 t.children = append(t.children, make([]naughty, 0, t.M*t.N+1)) 157 t.childLock = append(t.childLock, sync.Mutex{}) 158 n := naughty(len(t.nodes) - 1) 159 t.Unlock() 160 return n 161 } 162 163 i := t.freelist[l-1] 164 t.freelist = t.freelist[:l-1] 165 t.Unlock() 166 return naughty(i) 167 } 168 169 // free puts the node back into the freelist. 170 // 171 // Because the there isn't really strong reference tracking, there may be 172 // use-after-free issues. Therefore it's absolutely vital that any calls to free() 173 // has to be done with careful consideration. 174 func (t *MCTS) free(n naughty) { 175 // delete(t.children, n) 176 t.children[int(n)] = t.children[int(n)][:0] 177 t.freelist = append(t.freelist, n) 178 N := &t.nodes[int(n)] 179 N.reset() 180 } 181 182 // cleanup cleans up the graph (WORK IN PROGRESS) 183 func (t *MCTS) cleanup(oldRoot, newRoot naughty) { 184 children := t.Children(oldRoot) 185 // we aint going down other paths, those nodes can be freed 186 for _, kid := range children { 187 if kid != newRoot { 188 t.nodeFromNaughty(kid).Invalidate() 189 t.freeables = append(t.freeables, kid) 190 t.cleanChildren(kid) 191 } 192 } 193 t.Lock() 194 t.children[oldRoot] = t.children[oldRoot][:1] 195 t.children[oldRoot][0] = newRoot 196 t.Unlock() 197 } 198 199 func (t *MCTS) cleanChildren(root naughty) { 200 children := t.Children(root) 201 for _, kid := range children { 202 t.nodeFromNaughty(kid).Invalidate() 203 t.freeables = append(t.freeables, kid) 204 t.cleanChildren(kid) // recursively clean children 205 } 206 t.Lock() 207 t.children[root] = t.children[root][:0] // empty it 208 t.Unlock() 209 } 210 211 // randomizeChildren proportionally randomizes the children nodes by proportion of the visit 212 func (t *MCTS) randomizeChildren(of naughty) { 213 var accum, norm float32 214 var accumVector []float32 215 children := t.Children(of) 216 for _, kid := range children { 217 child := t.nodeFromNaughty(kid) 218 visits := child.Visits() 219 if norm == 0 { 220 norm = float32(visits) 221 222 // nonsensical options 223 if visits <= t.Config.RandomMinVisits { 224 return 225 } 226 } 227 if visits > t.Config.RandomMinVisits { 228 accum += math32.Pow(float32(visits)/norm, 1/t.Config.RandomTemperature) 229 accumVector = append(accumVector, accum) 230 } 231 } 232 rnd := t.rand.Float32() * accum // uniform distro: rnd() * (max-min) + min 233 var index int 234 for i, a := range accumVector { 235 if rnd < a { 236 index = i 237 break 238 } 239 } 240 if index == 0 { 241 return 242 } 243 244 for i := 0; i < len(children)-index; i++ { 245 children[i], children[i+index] = children[i+index], children[i] 246 } 247 } 248 249 func (t *MCTS) Reset() { 250 t.Lock() 251 defer t.Unlock() 252 253 t.freelist = t.freelist[:0] 254 t.freeables = t.freeables[:0] 255 for i := range t.nodes { 256 t.nodes[i].move = -1 257 t.nodes[i].visits = 0 258 t.nodes[i].status = 0 259 t.nodes[i].blackScores = 0 260 t.nodes[i].virtualLoss = 0 261 t.nodes[i].minPSARatioChildren = 0 262 t.nodes[i].score = 0 263 t.nodes[i].value = 0 264 265 t.freelist = append(t.freelist, t.nodes[i].id) 266 } 267 268 for i := range t.children { 269 t.children[i] = t.children[i][:0] 270 } 271 272 t.playouts = 0 273 t.nodes = t.nodes[:0] 274 t.cachedPolicies = make(map[sa]float32) 275 runtime.GC() 276 }