github.com/gorgonia/agogo@v0.1.1/mcts/utils.go (about)

     1  package mcts
     2  
     3  import (
     4  	"github.com/chewxy/math32"
     5  	"github.com/gorgonia/agogo/game"
     6  )
     7  
     8  // fancySort sorts the list of nodes under a certain condition of evaluation (i.e. which colour are we considering)
     9  // it sorts in such a way that nils get put at the back
    10  type fancySort struct {
    11  	underEval game.Player
    12  	l         []naughty
    13  	t         *MCTS
    14  }
    15  
    16  func (l fancySort) Len() int      { return len(l.l) }
    17  func (l fancySort) Swap(i, j int) { l.l[i], l.l[j] = l.l[j], l.l[i] }
    18  func (l fancySort) Less(i, j int) bool {
    19  	li := l.t.nodeFromNaughty(l.l[i])
    20  	lj := l.t.nodeFromNaughty(l.l[j])
    21  
    22  	// // push nils to the back
    23  	// switch {
    24  	// case li == nil && lj != nil:
    25  	// 	return false
    26  	// case li != nil && lj == nil:
    27  	// 	return true
    28  	// case li == nil && lj == nil:
    29  	// 	return false
    30  	// }
    31  
    32  	// check if both have the same visits
    33  
    34  	liVisits := li.Visits()
    35  	ljVisits := lj.Visits()
    36  	if liVisits != ljVisits {
    37  		return liVisits > ljVisits
    38  	}
    39  
    40  	// no visits, we sort on score
    41  	if liVisits == 0 {
    42  		return li.Score() > lj.Score()
    43  	}
    44  
    45  	// same visit count. Evaluate
    46  	return li.Evaluate(l.underEval) > lj.Evaluate(l.underEval)
    47  }
    48  
    49  // pair is a tuple of score and coordinate
    50  type pair struct {
    51  	Coord game.Single
    52  	Score float32
    53  }
    54  
    55  // byScore is a sortable list of pairs It sorts the list with best score fist
    56  type byScore []pair
    57  
    58  func (l byScore) Len() int           { return len(l) }
    59  func (l byScore) Less(i, j int) bool { return l[i].Score > l[j].Score }
    60  func (l byScore) Swap(i, j int)      { l[i], l[j] = l[j], l[i] }
    61  
    62  func combinedScore(state game.State) float32 {
    63  	whiteScore := state.Score(White)
    64  	blackScore := state.Score(Black)
    65  	komi := state.AdditionalScore()
    66  	return blackScore - whiteScore - komi
    67  }
    68  
    69  type byMove struct {
    70  	t *MCTS
    71  	l []naughty
    72  }
    73  
    74  func (l byMove) Len() int { return len(l.l) }
    75  func (l byMove) Less(i, j int) bool {
    76  	li := l.t.nodeFromNaughty(l.l[i])
    77  	lj := l.t.nodeFromNaughty(l.l[j])
    78  	return li.move < lj.move
    79  }
    80  func (l byMove) Swap(i, j int) {
    81  	l.l[i], l.l[j] = l.l[j], l.l[i]
    82  }
    83  
    84  func argmax(a []float32) int {
    85  	var retVal int
    86  	var max float32 = math32.Inf(-1)
    87  	for i := range a {
    88  		if a[i] > max {
    89  			max = a[i]
    90  			retVal = i
    91  		}
    92  	}
    93  	return retVal
    94  }