github.com/gorgonia/agogo@v0.1.1/mcts/example_test.go (about)

     1  package mcts_test
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io/ioutil"
     7  	"log"
     8  	"math/rand"
     9  	"os"
    10  	"os/signal"
    11  	"time"
    12  
    13  	"github.com/gorgonia/agogo/game"
    14  	"github.com/gorgonia/agogo/game/komi"
    15  	"github.com/gorgonia/agogo/game/mnk"
    16  	"github.com/gorgonia/agogo/mcts"
    17  )
    18  
    19  var (
    20  	Pass = game.Single(-1)
    21  
    22  	Cross  = game.Player(game.Black)
    23  	Nought = game.Player(game.White)
    24  
    25  	r = rand.New(rand.NewSource(1337))
    26  )
    27  
    28  func opponent(p game.Player) game.Player {
    29  	switch p {
    30  	case Cross:
    31  		return Nought
    32  	case Nought:
    33  		return Cross
    34  	}
    35  	panic("Unreachable")
    36  }
    37  
    38  type dummyNN struct{}
    39  
    40  func (dummyNN) Infer(state game.State) (policy []float32, value float32) {
    41  	policy = make([]float32, 10) // 10 because last one is a pass
    42  	switch state.MoveNumber() {
    43  	case 0:
    44  		policy[4] = 0.9
    45  		value = 0.5
    46  	case 1:
    47  		policy[0] = 0.1 // switch colours remember?
    48  		value = 0.5
    49  	case 2:
    50  		policy[2] = 0.9
    51  		value = 8 / 9
    52  	case 3:
    53  		policy[6] = 0.1
    54  		value = 8 / 9
    55  	case 4:
    56  		policy[3] = 0.9
    57  		value = 8 / 9
    58  	case 5:
    59  		policy[5] = 0.1
    60  		value = 0.5
    61  	case 6:
    62  		policy[1] = 0.9
    63  		value = 8 / 9
    64  	case 7:
    65  		policy[7] = 0.1
    66  		value = 0
    67  	case 8:
    68  		policy[8] = 0.9
    69  		value = 0
    70  	}
    71  	return
    72  }
    73  
    74  func Example() {
    75  	g := mnk.TicTacToe()
    76  	conf := mcts.Config{
    77  		PUCT:           1.0,
    78  		M:              3,
    79  		N:              3,
    80  		Timeout:        500 * time.Millisecond,
    81  		PassPreference: mcts.DontPreferPass,
    82  		Budget:         10000,
    83  		DumbPass:       true,
    84  		RandomCount:    0, // this is a deterministic example
    85  	}
    86  	nn := dummyNN{}
    87  	t := mcts.New(g, conf, nn)
    88  	player := Cross
    89  
    90  	var buf bytes.Buffer
    91  	var ended bool
    92  	var winner game.Player
    93  	for ended, winner = g.Ended(); !ended; ended, winner = g.Ended() {
    94  		moveNum := g.MoveNumber()
    95  		best := t.Search(player)
    96  		g = g.Apply(game.PlayerMove{player, best}).(*mnk.MNK)
    97  		fmt.Fprintf(&buf, "Turn %d\n%v---\n", moveNum, g)
    98  		if moveNum == 2 {
    99  			ioutil.WriteFile("fullGraph_tictactoe.dot", []byte(t.ToDot()), 0644)
   100  		}
   101  		player = opponent(player)
   102  	}
   103  
   104  	log.Printf("Playout:\n%v", buf.String())
   105  	fmt.Printf("WINNER %v\n", winner)
   106  
   107  	// the outputs should look something like this (may dfiffer due to random numbers)
   108  	// Turn 0
   109  	// ⎢ · · · ⎥
   110  	// ⎢ · X · ⎥
   111  	// ⎢ · · · ⎥
   112  	// ---
   113  	// Turn 1
   114  	// ⎢ O · · ⎥
   115  	// ⎢ · X · ⎥
   116  	// ⎢ · · · ⎥
   117  	// ---
   118  	// Turn 2
   119  	// ⎢ O · X ⎥
   120  	// ⎢ · X · ⎥
   121  	// ⎢ · · · ⎥
   122  	// ---
   123  	// Turn 3
   124  	// ⎢ O · X ⎥
   125  	// ⎢ · X · ⎥
   126  	// ⎢ O · · ⎥
   127  	// ---
   128  	// Turn 4
   129  	// ⎢ O · X ⎥
   130  	// ⎢ X X · ⎥
   131  	// ⎢ O · · ⎥
   132  	// ---
   133  	// Turn 5
   134  	// ⎢ O · X ⎥
   135  	// ⎢ X X O ⎥
   136  	// ⎢ O · · ⎥
   137  	// ---
   138  	// Turn 6
   139  	// ⎢ O X X ⎥
   140  	// ⎢ X X O ⎥
   141  	// ⎢ O · · ⎥
   142  	// ---
   143  	// Turn 7
   144  	// ⎢ O X X ⎥
   145  	// ⎢ X X O ⎥
   146  	// ⎢ O O · ⎥
   147  	// ---
   148  	// Turn 8
   149  	// ⎢ O X X ⎥
   150  	// ⎢ X X O ⎥
   151  	// ⎢ O O X ⎥
   152  	// ---
   153  
   154  	// Output:
   155  	// WINNER None
   156  }
   157  
   158  type dummyNN2 struct{}
   159  
   160  func (dummyNN2) Infer(state game.State) (policy []float32, value float32) {
   161  	policy = make([]float32, 25)
   162  	for i := range policy {
   163  		policy[i] = 1 / 25.0
   164  	}
   165  	return policy, 1 / 25.0
   166  }
   167  
   168  func Example_Komi() {
   169  	g := komi.New(5, 5, 3)
   170  	conf := mcts.Config{
   171  		PUCT:           1.0,
   172  		M:              5,
   173  		N:              5,
   174  		Timeout:        500 * time.Millisecond,
   175  		PassPreference: mcts.DontPreferPass,
   176  		Budget:         100,
   177  		DumbPass:       true,
   178  		RandomCount:    0, // this is a deterministic example
   179  	}
   180  
   181  	nn := dummyNN2{}
   182  	t := mcts.New(g, conf, nn)
   183  	player := game.Player(game.White)
   184  
   185  	c := make(chan os.Signal, 1)
   186  	signal.Notify(c, os.Interrupt)
   187  	go func() {
   188  		for range c {
   189  			// sig is a ^C, handle it
   190  			log.Println(t.Log())
   191  			// ioutil.WriteFile("fullGraph.dot", []byte(t.ToDot()), 0644)
   192  			os.Exit(1)
   193  		}
   194  	}()
   195  
   196  	var buf bytes.Buffer
   197  	var ended bool
   198  	var winner game.Player
   199  	for ended, winner = g.Ended(); !ended; ended, winner = g.Ended() {
   200  		moveNum := g.MoveNumber()
   201  		player = opponent(player)
   202  		best := t.Search(player)
   203  		g = g.Apply(game.PlayerMove{player, best}).(*komi.Game)
   204  		fmt.Fprintf(&buf, "Turn %d\n%v---\n", moveNum, g)
   205  	}
   206  
   207  	log.Printf("Playout:\n%v", buf.String())
   208  	fmt.Printf("WINNER %v\n", winner)
   209  
   210  	// Output:
   211  	// 0
   212  }