github.com/gorgonia/agogo@v0.1.1/mcts/example_test.go (about) 1 package mcts_test 2 3 import ( 4 "bytes" 5 "fmt" 6 "io/ioutil" 7 "log" 8 "math/rand" 9 "os" 10 "os/signal" 11 "time" 12 13 "github.com/gorgonia/agogo/game" 14 "github.com/gorgonia/agogo/game/komi" 15 "github.com/gorgonia/agogo/game/mnk" 16 "github.com/gorgonia/agogo/mcts" 17 ) 18 19 var ( 20 Pass = game.Single(-1) 21 22 Cross = game.Player(game.Black) 23 Nought = game.Player(game.White) 24 25 r = rand.New(rand.NewSource(1337)) 26 ) 27 28 func opponent(p game.Player) game.Player { 29 switch p { 30 case Cross: 31 return Nought 32 case Nought: 33 return Cross 34 } 35 panic("Unreachable") 36 } 37 38 type dummyNN struct{} 39 40 func (dummyNN) Infer(state game.State) (policy []float32, value float32) { 41 policy = make([]float32, 10) // 10 because last one is a pass 42 switch state.MoveNumber() { 43 case 0: 44 policy[4] = 0.9 45 value = 0.5 46 case 1: 47 policy[0] = 0.1 // switch colours remember? 48 value = 0.5 49 case 2: 50 policy[2] = 0.9 51 value = 8 / 9 52 case 3: 53 policy[6] = 0.1 54 value = 8 / 9 55 case 4: 56 policy[3] = 0.9 57 value = 8 / 9 58 case 5: 59 policy[5] = 0.1 60 value = 0.5 61 case 6: 62 policy[1] = 0.9 63 value = 8 / 9 64 case 7: 65 policy[7] = 0.1 66 value = 0 67 case 8: 68 policy[8] = 0.9 69 value = 0 70 } 71 return 72 } 73 74 func Example() { 75 g := mnk.TicTacToe() 76 conf := mcts.Config{ 77 PUCT: 1.0, 78 M: 3, 79 N: 3, 80 Timeout: 500 * time.Millisecond, 81 PassPreference: mcts.DontPreferPass, 82 Budget: 10000, 83 DumbPass: true, 84 RandomCount: 0, // this is a deterministic example 85 } 86 nn := dummyNN{} 87 t := mcts.New(g, conf, nn) 88 player := Cross 89 90 var buf bytes.Buffer 91 var ended bool 92 var winner game.Player 93 for ended, winner = g.Ended(); !ended; ended, winner = g.Ended() { 94 moveNum := g.MoveNumber() 95 best := t.Search(player) 96 g = g.Apply(game.PlayerMove{player, best}).(*mnk.MNK) 97 fmt.Fprintf(&buf, "Turn %d\n%v---\n", moveNum, g) 98 if moveNum == 2 { 99 ioutil.WriteFile("fullGraph_tictactoe.dot", []byte(t.ToDot()), 0644) 100 } 101 player = opponent(player) 102 } 103 104 log.Printf("Playout:\n%v", buf.String()) 105 fmt.Printf("WINNER %v\n", winner) 106 107 // the outputs should look something like this (may dfiffer due to random numbers) 108 // Turn 0 109 // ⎢ · · · ⎥ 110 // ⎢ · X · ⎥ 111 // ⎢ · · · ⎥ 112 // --- 113 // Turn 1 114 // ⎢ O · · ⎥ 115 // ⎢ · X · ⎥ 116 // ⎢ · · · ⎥ 117 // --- 118 // Turn 2 119 // ⎢ O · X ⎥ 120 // ⎢ · X · ⎥ 121 // ⎢ · · · ⎥ 122 // --- 123 // Turn 3 124 // ⎢ O · X ⎥ 125 // ⎢ · X · ⎥ 126 // ⎢ O · · ⎥ 127 // --- 128 // Turn 4 129 // ⎢ O · X ⎥ 130 // ⎢ X X · ⎥ 131 // ⎢ O · · ⎥ 132 // --- 133 // Turn 5 134 // ⎢ O · X ⎥ 135 // ⎢ X X O ⎥ 136 // ⎢ O · · ⎥ 137 // --- 138 // Turn 6 139 // ⎢ O X X ⎥ 140 // ⎢ X X O ⎥ 141 // ⎢ O · · ⎥ 142 // --- 143 // Turn 7 144 // ⎢ O X X ⎥ 145 // ⎢ X X O ⎥ 146 // ⎢ O O · ⎥ 147 // --- 148 // Turn 8 149 // ⎢ O X X ⎥ 150 // ⎢ X X O ⎥ 151 // ⎢ O O X ⎥ 152 // --- 153 154 // Output: 155 // WINNER None 156 } 157 158 type dummyNN2 struct{} 159 160 func (dummyNN2) Infer(state game.State) (policy []float32, value float32) { 161 policy = make([]float32, 25) 162 for i := range policy { 163 policy[i] = 1 / 25.0 164 } 165 return policy, 1 / 25.0 166 } 167 168 func Example_Komi() { 169 g := komi.New(5, 5, 3) 170 conf := mcts.Config{ 171 PUCT: 1.0, 172 M: 5, 173 N: 5, 174 Timeout: 500 * time.Millisecond, 175 PassPreference: mcts.DontPreferPass, 176 Budget: 100, 177 DumbPass: true, 178 RandomCount: 0, // this is a deterministic example 179 } 180 181 nn := dummyNN2{} 182 t := mcts.New(g, conf, nn) 183 player := game.Player(game.White) 184 185 c := make(chan os.Signal, 1) 186 signal.Notify(c, os.Interrupt) 187 go func() { 188 for range c { 189 // sig is a ^C, handle it 190 log.Println(t.Log()) 191 // ioutil.WriteFile("fullGraph.dot", []byte(t.ToDot()), 0644) 192 os.Exit(1) 193 } 194 }() 195 196 var buf bytes.Buffer 197 var ended bool 198 var winner game.Player 199 for ended, winner = g.Ended(); !ended; ended, winner = g.Ended() { 200 moveNum := g.MoveNumber() 201 player = opponent(player) 202 best := t.Search(player) 203 g = g.Apply(game.PlayerMove{player, best}).(*komi.Game) 204 fmt.Fprintf(&buf, "Turn %d\n%v---\n", moveNum, g) 205 } 206 207 log.Printf("Playout:\n%v", buf.String()) 208 fmt.Printf("WINNER %v\n", winner) 209 210 // Output: 211 // 0 212 }