github.com/gorgonia/agogo@v0.1.1/mcts/tree.go (about)

     1  package mcts
     2  
     3  import (
     4  	"math/rand"
     5  	"runtime"
     6  	"sync"
     7  	"sync/atomic"
     8  	"time"
     9  
    10  	"github.com/chewxy/math32"
    11  	"github.com/gorgonia/agogo/game"
    12  )
    13  
    14  // Config is the structure to configure the MCTS multitree (poorly named Tree)
    15  type Config struct {
    16  	// PUCT is the proportion of polynomial upper confidence trees to keep. Between 1 and 0
    17  	PUCT    float32
    18  	Timeout time.Duration
    19  
    20  	// M, N represents the height and width.
    21  	M, N              int
    22  	RandomCount       int   // if the move number is less than this, we should randomize
    23  	Budget            int32 // iteration budget
    24  	RandomMinVisits   uint32
    25  	RandomTemperature float32
    26  	DumbPass          bool
    27  	ResignPercentage  float32
    28  	PassPreference    PassPreference
    29  }
    30  
    31  func DefaultConfig(boardSize int) Config {
    32  	return Config{
    33  		PUCT:           1.0,
    34  		Timeout:        100 * time.Millisecond,
    35  		M:              boardSize,
    36  		N:              boardSize,
    37  		DumbPass:       true,
    38  		PassPreference: DontPreferPass,
    39  		Budget:         10000,
    40  	}
    41  }
    42  
    43  func (c Config) IsValid() bool {
    44  	return c.PUCT > 0 && c.PUCT <= 1
    45  }
    46  
    47  // sa is a state-action tuple, used for storing results
    48  type sa struct {
    49  	s game.Zobrist // a zobrist hash of the board
    50  	a game.Single
    51  }
    52  
    53  // MCTS is essentially a "global" manager of sorts for the memories. The goal is to build MCTS without much pointer chasing.
    54  type MCTS struct {
    55  	sync.RWMutex
    56  	Config
    57  	nn   Inferencer
    58  	rand *rand.Rand
    59  
    60  	// memory related fields
    61  	nodes []Node
    62  	// children  map[naughty][]naughty
    63  	children  [][]naughty
    64  	childLock []sync.Mutex
    65  
    66  	freelist  []naughty
    67  	freeables []naughty // list of nodes that can be freed
    68  
    69  	// global searchState
    70  	searchState
    71  	playouts, nc int32 // atomic pls
    72  	running      atomic.Value
    73  
    74  	// global policy values - useful for building policy vectors
    75  	cachedPolicies map[sa]float32
    76  
    77  	lumberjack
    78  }
    79  
    80  func New(game game.State, conf Config, nn Inferencer) *MCTS {
    81  	retVal := &MCTS{
    82  		Config: conf,
    83  		nn:     nn,
    84  		rand:   rand.New(rand.NewSource(time.Now().UnixNano())),
    85  
    86  		nodes: make([]Node, 0, 12288),
    87  		// children: make(map[naughty][]naughty),
    88  		children:  make([][]naughty, 0, 12288),
    89  		childLock: make([]sync.Mutex, 0, 12288),
    90  
    91  		searchState: searchState{
    92  			root:    nilNode,
    93  			current: game,
    94  		},
    95  
    96  		cachedPolicies: make(map[sa]float32),
    97  		lumberjack:     makeLumberJack(),
    98  	}
    99  	go retVal.start()
   100  	retVal.searchState.tree = ptrFromTree(retVal)
   101  	retVal.searchState.maxDepth = retVal.M * retVal.N
   102  	return retVal
   103  }
   104  
   105  // New creates a new node
   106  func (t *MCTS) New(move game.Single, score, value float32) (retVal naughty) {
   107  	n := t.alloc()
   108  	N := t.nodeFromNaughty(n)
   109  	atomic.StoreInt32(&N.move, int32(move))
   110  	atomic.StoreUint32(&N.visits, 1)
   111  	atomic.StoreUint32(&N.status, uint32(Active))
   112  	atomic.StoreUint32(&N.score, math32.Float32bits(score))
   113  	atomic.StoreUint32(&N.value, math32.Float32bits(value))
   114  
   115  	// log.Printf("New node %p - %v", N, N)
   116  	return n
   117  }
   118  
   119  // SetGame sets the game
   120  func (t *MCTS) SetGame(g game.State) {
   121  	t.Lock()
   122  	t.current = g
   123  	t.Unlock()
   124  }
   125  
   126  func (t *MCTS) Nodes() int { return len(t.nodes) }
   127  
   128  func (t *MCTS) Policies(g game.State) []float32 {
   129  	hash := g.Hash()
   130  	var sum float32
   131  	actionSpacePlusPass := g.ActionSpace() + 1
   132  	retVal := make([]float32, actionSpacePlusPass)
   133  	for i := 0; i < actionSpacePlusPass; i++ {
   134  		prob := t.cachedPolicies[sa{s: hash, a: game.Single(i)}]
   135  		retVal[i] = prob
   136  		sum += prob
   137  	}
   138  	for i := 0; i < len(retVal); i++ {
   139  		retVal[i] /= sum
   140  	}
   141  	return retVal
   142  }
   143  
   144  // alloc tries to get a node from the free list. If none is found a new node is allocated into the master arena
   145  func (t *MCTS) alloc() naughty {
   146  	t.Lock()
   147  	l := len(t.freelist)
   148  	if l == 0 {
   149  		N := Node{
   150  			tree: ptrFromTree(t),
   151  			id:   naughty(len(t.nodes)),
   152  
   153  			minPSARatioChildren: defaultMinPsaRatio,
   154  		}
   155  		t.nodes = append(t.nodes, N)
   156  		t.children = append(t.children, make([]naughty, 0, t.M*t.N+1))
   157  		t.childLock = append(t.childLock, sync.Mutex{})
   158  		n := naughty(len(t.nodes) - 1)
   159  		t.Unlock()
   160  		return n
   161  	}
   162  
   163  	i := t.freelist[l-1]
   164  	t.freelist = t.freelist[:l-1]
   165  	t.Unlock()
   166  	return naughty(i)
   167  }
   168  
   169  // free puts the node back into the freelist.
   170  //
   171  // Because the there isn't really strong reference tracking, there may be
   172  // use-after-free issues. Therefore it's absolutely vital that any calls to free()
   173  // has to be done with careful consideration.
   174  func (t *MCTS) free(n naughty) {
   175  	// delete(t.children, n)
   176  	t.children[int(n)] = t.children[int(n)][:0]
   177  	t.freelist = append(t.freelist, n)
   178  	N := &t.nodes[int(n)]
   179  	N.reset()
   180  }
   181  
   182  // cleanup cleans up the graph (WORK IN PROGRESS)
   183  func (t *MCTS) cleanup(oldRoot, newRoot naughty) {
   184  	children := t.Children(oldRoot)
   185  	// we aint going down other paths, those nodes can be freed
   186  	for _, kid := range children {
   187  		if kid != newRoot {
   188  			t.nodeFromNaughty(kid).Invalidate()
   189  			t.freeables = append(t.freeables, kid)
   190  			t.cleanChildren(kid)
   191  		}
   192  	}
   193  	t.Lock()
   194  	t.children[oldRoot] = t.children[oldRoot][:1]
   195  	t.children[oldRoot][0] = newRoot
   196  	t.Unlock()
   197  }
   198  
   199  func (t *MCTS) cleanChildren(root naughty) {
   200  	children := t.Children(root)
   201  	for _, kid := range children {
   202  		t.nodeFromNaughty(kid).Invalidate()
   203  		t.freeables = append(t.freeables, kid)
   204  		t.cleanChildren(kid) // recursively clean children
   205  	}
   206  	t.Lock()
   207  	t.children[root] = t.children[root][:0] // empty it
   208  	t.Unlock()
   209  }
   210  
   211  // randomizeChildren proportionally randomizes the children nodes by proportion of the visit
   212  func (t *MCTS) randomizeChildren(of naughty) {
   213  	var accum, norm float32
   214  	var accumVector []float32
   215  	children := t.Children(of)
   216  	for _, kid := range children {
   217  		child := t.nodeFromNaughty(kid)
   218  		visits := child.Visits()
   219  		if norm == 0 {
   220  			norm = float32(visits)
   221  
   222  			// nonsensical options
   223  			if visits <= t.Config.RandomMinVisits {
   224  				return
   225  			}
   226  		}
   227  		if visits > t.Config.RandomMinVisits {
   228  			accum += math32.Pow(float32(visits)/norm, 1/t.Config.RandomTemperature)
   229  			accumVector = append(accumVector, accum)
   230  		}
   231  	}
   232  	rnd := t.rand.Float32() * accum // uniform distro: rnd() * (max-min) + min
   233  	var index int
   234  	for i, a := range accumVector {
   235  		if rnd < a {
   236  			index = i
   237  			break
   238  		}
   239  	}
   240  	if index == 0 {
   241  		return
   242  	}
   243  
   244  	for i := 0; i < len(children)-index; i++ {
   245  		children[i], children[i+index] = children[i+index], children[i]
   246  	}
   247  }
   248  
   249  func (t *MCTS) Reset() {
   250  	t.Lock()
   251  	defer t.Unlock()
   252  
   253  	t.freelist = t.freelist[:0]
   254  	t.freeables = t.freeables[:0]
   255  	for i := range t.nodes {
   256  		t.nodes[i].move = -1
   257  		t.nodes[i].visits = 0
   258  		t.nodes[i].status = 0
   259  		t.nodes[i].blackScores = 0
   260  		t.nodes[i].virtualLoss = 0
   261  		t.nodes[i].minPSARatioChildren = 0
   262  		t.nodes[i].score = 0
   263  		t.nodes[i].value = 0
   264  
   265  		t.freelist = append(t.freelist, t.nodes[i].id)
   266  	}
   267  
   268  	for i := range t.children {
   269  		t.children[i] = t.children[i][:0]
   270  	}
   271  
   272  	t.playouts = 0
   273  	t.nodes = t.nodes[:0]
   274  	t.cachedPolicies = make(map[sa]float32)
   275  	runtime.GC()
   276  }