github.com/gorgonia/agogo@v0.1.1/mcts/utils.go (about) 1 package mcts 2 3 import ( 4 "github.com/chewxy/math32" 5 "github.com/gorgonia/agogo/game" 6 ) 7 8 // fancySort sorts the list of nodes under a certain condition of evaluation (i.e. which colour are we considering) 9 // it sorts in such a way that nils get put at the back 10 type fancySort struct { 11 underEval game.Player 12 l []naughty 13 t *MCTS 14 } 15 16 func (l fancySort) Len() int { return len(l.l) } 17 func (l fancySort) Swap(i, j int) { l.l[i], l.l[j] = l.l[j], l.l[i] } 18 func (l fancySort) Less(i, j int) bool { 19 li := l.t.nodeFromNaughty(l.l[i]) 20 lj := l.t.nodeFromNaughty(l.l[j]) 21 22 // // push nils to the back 23 // switch { 24 // case li == nil && lj != nil: 25 // return false 26 // case li != nil && lj == nil: 27 // return true 28 // case li == nil && lj == nil: 29 // return false 30 // } 31 32 // check if both have the same visits 33 34 liVisits := li.Visits() 35 ljVisits := lj.Visits() 36 if liVisits != ljVisits { 37 return liVisits > ljVisits 38 } 39 40 // no visits, we sort on score 41 if liVisits == 0 { 42 return li.Score() > lj.Score() 43 } 44 45 // same visit count. Evaluate 46 return li.Evaluate(l.underEval) > lj.Evaluate(l.underEval) 47 } 48 49 // pair is a tuple of score and coordinate 50 type pair struct { 51 Coord game.Single 52 Score float32 53 } 54 55 // byScore is a sortable list of pairs It sorts the list with best score fist 56 type byScore []pair 57 58 func (l byScore) Len() int { return len(l) } 59 func (l byScore) Less(i, j int) bool { return l[i].Score > l[j].Score } 60 func (l byScore) Swap(i, j int) { l[i], l[j] = l[j], l[i] } 61 62 func combinedScore(state game.State) float32 { 63 whiteScore := state.Score(White) 64 blackScore := state.Score(Black) 65 komi := state.AdditionalScore() 66 return blackScore - whiteScore - komi 67 } 68 69 type byMove struct { 70 t *MCTS 71 l []naughty 72 } 73 74 func (l byMove) Len() int { return len(l.l) } 75 func (l byMove) Less(i, j int) bool { 76 li := l.t.nodeFromNaughty(l.l[i]) 77 lj := l.t.nodeFromNaughty(l.l[j]) 78 return li.move < lj.move 79 } 80 func (l byMove) Swap(i, j int) { 81 l.l[i], l.l[j] = l.l[j], l.l[i] 82 } 83 84 func argmax(a []float32) int { 85 var retVal int 86 var max float32 = math32.Inf(-1) 87 for i := range a { 88 if a[i] > max { 89 max = a[i] 90 retVal = i 91 } 92 } 93 return retVal 94 }