github.com/gorgonia/agogo@v0.1.1/cmd/tictactoe/main.go (about) 1 package main 2 3 import ( 4 "flag" 5 "log" 6 "os" 7 "runtime/pprof" 8 "runtime/trace" 9 "time" 10 11 "github.com/gorgonia/agogo" 12 dual "github.com/gorgonia/agogo/dualnet" 13 "github.com/gorgonia/agogo/game" 14 "github.com/gorgonia/agogo/game/mnk" 15 "github.com/gorgonia/agogo/mcts" 16 17 "net/http" 18 _ "net/http/pprof" 19 ) 20 21 var ( 22 traceFlag = flag.String("trace", "", "do a trace") 23 cpuprofile = flag.String("cpuprofile", "", "cpuprofile") 24 ) 25 26 func encodeBoard(a game.State) []float32 { 27 board := agogo.EncodeTwoPlayerBoard(a.Board(), nil) 28 for i := range board { 29 if board[i] == 0 { 30 board[i] = 0.001 31 } 32 } 33 playerLayer := make([]float32, len(a.Board())) 34 next := a.ToMove() 35 if next == game.Player(game.Black) { 36 for i := range playerLayer { 37 playerLayer[i] = 1 38 } 39 } else if next == game.Player(game.White) { 40 // vecf32.Scale(board, -1) 41 for i := range playerLayer { 42 playerLayer[i] = -1 43 } 44 } 45 retVal := append(board, playerLayer...) 46 return retVal 47 } 48 49 func main() { 50 flag.Parse() 51 go func() { 52 log.Println(http.ListenAndServe("localhost:6060", nil)) 53 }() 54 55 f, err := os.OpenFile("game.gif", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) 56 if err != nil { 57 log.Fatalf("Unable to create gif file: %v", err) 58 } 59 defer f.Close() 60 61 conf := agogo.Config{ 62 Name: "Tic Tac Toe", 63 NNConf: dual.DefaultConf(3, 3, 10), 64 MCTSConf: mcts.DefaultConfig(3), 65 UpdateThreshold: 0.52, 66 } 67 conf.NNConf.BatchSize = 100 68 conf.NNConf.Features = 2 // write a better encoding of the board, and increase features (and that allows you to increase K as well) 69 conf.NNConf.K = 3 70 conf.NNConf.SharedLayers = 3 71 conf.MCTSConf = mcts.Config{ 72 PUCT: 1.0, 73 M: 3, 74 N: 3, 75 Timeout: 100 * time.Millisecond, 76 PassPreference: mcts.DontPreferPass, 77 Budget: 1000, 78 DumbPass: true, 79 RandomCount: 0, 80 } 81 82 outEnc := game.NewGifEncoder(300, 300) 83 outEnc.Writer = f 84 85 conf.Encoder = encodeBoard 86 conf.OutputEncoder = outEnc 87 88 if *traceFlag != "" { 89 f, err := os.Create("trace.out") 90 if err != nil { 91 log.Fatalf("failed to create trace output file: %v", err) 92 } 93 defer func() { 94 if err := f.Close(); err != nil { 95 log.Fatalf("failed to close trace file: %v", err) 96 } 97 }() 98 99 if err := trace.Start(f); err != nil { 100 log.Fatalf("failed to start trace: %v", err) 101 } 102 103 defer func() { 104 <-time.After(10 * time.Second) 105 trace.Stop() 106 }() 107 } 108 109 if *cpuprofile != "" { 110 f, err := os.Create(*cpuprofile) 111 if err != nil { 112 log.Fatal(err) 113 } 114 pprof.StartCPUProfile(f) 115 defer pprof.StopCPUProfile() 116 } 117 118 g := mnk.TicTacToe() 119 a := agogo.New(g, conf) 120 a.Learn(5, 30, 200, 30) // 5 epochs, 50 episode, 100 NN iters, 100 games. 121 outEnc.Flush() 122 a.Save("example.model") 123 }