github.com/gorgonia/agogo@v0.1.1/cmd/tictactoe/main.go (about)

     1  package main
     2  
     3  import (
     4  	"flag"
     5  	"log"
     6  	"os"
     7  	"runtime/pprof"
     8  	"runtime/trace"
     9  	"time"
    10  
    11  	"github.com/gorgonia/agogo"
    12  	dual "github.com/gorgonia/agogo/dualnet"
    13  	"github.com/gorgonia/agogo/game"
    14  	"github.com/gorgonia/agogo/game/mnk"
    15  	"github.com/gorgonia/agogo/mcts"
    16  
    17  	"net/http"
    18  	_ "net/http/pprof"
    19  )
    20  
    21  var (
    22  	traceFlag  = flag.String("trace", "", "do a trace")
    23  	cpuprofile = flag.String("cpuprofile", "", "cpuprofile")
    24  )
    25  
    26  func encodeBoard(a game.State) []float32 {
    27  	board := agogo.EncodeTwoPlayerBoard(a.Board(), nil)
    28  	for i := range board {
    29  		if board[i] == 0 {
    30  			board[i] = 0.001
    31  		}
    32  	}
    33  	playerLayer := make([]float32, len(a.Board()))
    34  	next := a.ToMove()
    35  	if next == game.Player(game.Black) {
    36  		for i := range playerLayer {
    37  			playerLayer[i] = 1
    38  		}
    39  	} else if next == game.Player(game.White) {
    40  		// vecf32.Scale(board, -1)
    41  		for i := range playerLayer {
    42  			playerLayer[i] = -1
    43  		}
    44  	}
    45  	retVal := append(board, playerLayer...)
    46  	return retVal
    47  }
    48  
    49  func main() {
    50  	flag.Parse()
    51  	go func() {
    52  		log.Println(http.ListenAndServe("localhost:6060", nil))
    53  	}()
    54  
    55  	f, err := os.OpenFile("game.gif", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
    56  	if err != nil {
    57  		log.Fatalf("Unable to create gif file: %v", err)
    58  	}
    59  	defer f.Close()
    60  
    61  	conf := agogo.Config{
    62  		Name:            "Tic Tac Toe",
    63  		NNConf:          dual.DefaultConf(3, 3, 10),
    64  		MCTSConf:        mcts.DefaultConfig(3),
    65  		UpdateThreshold: 0.52,
    66  	}
    67  	conf.NNConf.BatchSize = 100
    68  	conf.NNConf.Features = 2 // write a better encoding of the board, and increase features (and that allows you to increase K as well)
    69  	conf.NNConf.K = 3
    70  	conf.NNConf.SharedLayers = 3
    71  	conf.MCTSConf = mcts.Config{
    72  		PUCT:           1.0,
    73  		M:              3,
    74  		N:              3,
    75  		Timeout:        100 * time.Millisecond,
    76  		PassPreference: mcts.DontPreferPass,
    77  		Budget:         1000,
    78  		DumbPass:       true,
    79  		RandomCount:    0,
    80  	}
    81  
    82  	outEnc := game.NewGifEncoder(300, 300)
    83  	outEnc.Writer = f
    84  
    85  	conf.Encoder = encodeBoard
    86  	conf.OutputEncoder = outEnc
    87  
    88  	if *traceFlag != "" {
    89  		f, err := os.Create("trace.out")
    90  		if err != nil {
    91  			log.Fatalf("failed to create trace output file: %v", err)
    92  		}
    93  		defer func() {
    94  			if err := f.Close(); err != nil {
    95  				log.Fatalf("failed to close trace file: %v", err)
    96  			}
    97  		}()
    98  
    99  		if err := trace.Start(f); err != nil {
   100  			log.Fatalf("failed to start trace: %v", err)
   101  		}
   102  
   103  		defer func() {
   104  			<-time.After(10 * time.Second)
   105  			trace.Stop()
   106  		}()
   107  	}
   108  
   109  	if *cpuprofile != "" {
   110  		f, err := os.Create(*cpuprofile)
   111  		if err != nil {
   112  			log.Fatal(err)
   113  		}
   114  		pprof.StartCPUProfile(f)
   115  		defer pprof.StopCPUProfile()
   116  	}
   117  
   118  	g := mnk.TicTacToe()
   119  	a := agogo.New(g, conf)
   120  	a.Learn(5, 30, 200, 30) // 5 epochs, 50 episode, 100 NN iters, 100 games.
   121  	outEnc.Flush()
   122  	a.Save("example.model")
   123  }