github.com/gorgonia/agogo@v0.1.1/arena.go (about)

     1  package agogo
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"log"
     8  	"math/rand"
     9  	"runtime"
    10  	"time"
    11  
    12  	"github.com/chewxy/math32"
    13  	dual "github.com/gorgonia/agogo/dualnet"
    14  	"github.com/gorgonia/agogo/game"
    15  	"github.com/gorgonia/agogo/mcts"
    16  )
    17  
    18  // Arena represents a game arena
    19  // Arena fulfils the interface game.MetaState
    20  type Arena struct {
    21  	r    *rand.Rand
    22  	game game.State
    23  	A, B *Agent
    24  
    25  	// state
    26  	currentPlayer *Agent
    27  	conf          mcts.Config
    28  	buf           bytes.Buffer
    29  	logger        *log.Logger
    30  
    31  	// only relevant to training
    32  	name       string
    33  	epoch      int // training epoch
    34  	gameNumber int // which game is this in
    35  
    36  	// when to screw it all and just reinit a new NN
    37  	oldThresh int
    38  	oldCount  int
    39  }
    40  
    41  // MakeArena makes an arena given a game.
    42  func MakeArena(g game.State, a, b Dualer, conf mcts.Config, enc GameEncoder, aug Augmenter, name string) Arena {
    43  	A := &Agent{
    44  		NN:   a.Dual(),
    45  		Enc:  enc,
    46  		name: "A",
    47  	}
    48  	A.MCTS = mcts.New(g, conf, A)
    49  	B := &Agent{
    50  		NN:   b.Dual(),
    51  		Enc:  enc,
    52  		name: "B",
    53  	}
    54  	B.MCTS = mcts.New(g, conf, B)
    55  
    56  	if name == "" {
    57  		name = "UNKNOWN GAME"
    58  	}
    59  
    60  	return Arena{
    61  		r:    rand.New(rand.NewSource(time.Now().UnixNano())),
    62  		game: g,
    63  		A:    A,
    64  		B:    B,
    65  		conf: conf,
    66  		name: name,
    67  
    68  		oldThresh: 10,
    69  	}
    70  }
    71  
    72  // NewArena makes an arena an returns a pointer to the Arena
    73  func NewArena(g game.State, a, b Dualer, conf mcts.Config, enc GameEncoder, aug Augmenter, name string) *Arena {
    74  	ar := MakeArena(g, a, b, conf, enc, aug, name)
    75  	ar.logger = log.New(&ar.buf, "", log.Ltime)
    76  	return &ar
    77  }
    78  
    79  // Play plays a game, and retrns a winner. If it is a draw, the returned colour is None.
    80  func (a *Arena) Play(record bool, enc OutputEncoder, aug Augmenter) (winner game.Player, examples []Example) {
    81  	if a.r.Intn(2) == 0 {
    82  		a.A.Player = game.Player(game.Black)
    83  		a.B.Player = game.Player(game.White)
    84  		a.currentPlayer = a.A
    85  	} else {
    86  		a.A.Player = game.Player(game.White)
    87  		a.B.Player = game.Player(game.Black)
    88  		a.currentPlayer = a.B
    89  	}
    90  
    91  	a.game.SetToMove(a.currentPlayer.Player)
    92  	a.logger.Printf("Playing. Recording %t\n", record)
    93  	a.logger.SetPrefix("\t\t")
    94  	var ended bool
    95  	var passCount int
    96  	for ended, winner = a.game.Ended(); !ended; ended, winner = a.game.Ended() {
    97  
    98  		best := a.currentPlayer.Search(a.game)
    99  		if best.IsPass() {
   100  			passCount++
   101  		} else {
   102  			passCount = 0
   103  		}
   104  		a.logger.Printf("Current Player: %v. Best Move %v\n", a.currentPlayer.Player, best)
   105  		if record {
   106  			boards := a.currentPlayer.Enc(a.game)
   107  			policies := a.currentPlayer.MCTS.Policies(a.game)
   108  			ex := Example{
   109  				Board:  boards,
   110  				Policy: policies,
   111  				// THIS IS A HACK.
   112  				// The value is 1 or -1 depending on player colour, but for now we store the player colour
   113  				Value: float32(a.currentPlayer.Player),
   114  			}
   115  			if validPolicies(policies) {
   116  				if aug != nil {
   117  					examples = append(examples, aug(ex)...)
   118  				} else {
   119  					examples = append(examples, ex)
   120  				}
   121  			}
   122  
   123  		}
   124  
   125  		// policy, value := a.currentPlayer.Infer(a.game)
   126  		// log.Printf("\t\tPlayer %v made Move %v | %1.1v %1.1v", a.currentPlayer.Player, best, policy, value)
   127  		a.game = a.game.Apply(game.PlayerMove{
   128  			Player: a.currentPlayer.Player,
   129  			Single: best,
   130  		})
   131  		a.switchPlayer()
   132  		if enc != nil {
   133  			enc.Encode(a)
   134  		}
   135  		if passCount >= 2 {
   136  			break
   137  		}
   138  	}
   139  	a.logger.SetPrefix("\t")
   140  	a.A.MCTS.Reset()
   141  	a.B.MCTS.Reset()
   142  	if enc != nil {
   143  		log.Printf("\tDone playing")
   144  	}
   145  
   146  	for i := range examples {
   147  		switch {
   148  		case winner == game.Player(game.None):
   149  			examples[i].Value = 0
   150  		case examples[i].Value == float32(winner):
   151  			examples[i].Value = 1
   152  		default:
   153  			examples[i].Value = -1
   154  		}
   155  	}
   156  	var winningAgent *Agent
   157  	switch {
   158  	case winner == game.Player(game.None):
   159  		a.A.Draw++
   160  		a.B.Draw++
   161  	case winner == a.A.Player:
   162  		a.A.Wins++
   163  		a.B.Loss++
   164  		winningAgent = a.A
   165  	case winner == a.B.Player:
   166  		a.B.Wins++
   167  		a.A.Loss++
   168  		winningAgent = a.B
   169  	}
   170  	if !record {
   171  		log.Printf("Winner %v | %p", winner, winningAgent)
   172  	}
   173  	// a.A.MCTS.Reset()
   174  	// a.B.MCTS.Reset()
   175  	a.A.MCTS = mcts.New(a.game, a.conf, a.A)
   176  	a.B.MCTS = mcts.New(a.game, a.conf, a.B)
   177  	runtime.GC()
   178  	return game.Player(game.None), examples
   179  }
   180  
   181  // Epoch returns the current Epoch
   182  func (a *Arena) Epoch() int { return a.epoch }
   183  
   184  // GameNumber returns the
   185  func (a *Arena) GameNumber() int { return a.gameNumber }
   186  
   187  // Name of the game
   188  func (a *Arena) Name() string { return a.name }
   189  
   190  // Score of the player p
   191  func (a *Arena) Score(p game.Player) float64 { return float64(a.game.Score(p)) }
   192  
   193  // State of the game
   194  func (a *Arena) State() game.State { return a.game }
   195  
   196  // Log the MCTS of both players into w
   197  func (a *Arena) Log(w io.Writer) {
   198  	fmt.Fprintf(w, a.buf.String())
   199  	fmt.Fprintf(w, "\nA:\n\n")
   200  	fmt.Fprintln(w, a.A.MCTS.Log())
   201  	fmt.Fprintf(w, "\nB:\n\n")
   202  	fmt.Fprintln(w, a.B.MCTS.Log())
   203  }
   204  
   205  func (a *Arena) newB(conf dual.Config, killedA bool) (err error) {
   206  	if killedA {
   207  		a.oldCount = 0
   208  	}
   209  
   210  	// if a.oldCount >= a.oldThresh {
   211  	// 	a.B.NN = dual.New(conf)
   212  	// 	err = a.B.NN.Init()
   213  	// 	a.oldCount = 0
   214  	// } else {
   215  	// 	a.B.NN, err = a.B.NN.Clone()
   216  	// }
   217  
   218  	a.B.NN = dual.New(conf)
   219  	err = a.B.NN.Init()
   220  
   221  	a.oldCount++
   222  	log.Printf("NewB NN %p", a.B.NN)
   223  	return err
   224  }
   225  
   226  func (a *Arena) switchPlayer() {
   227  	switch a.currentPlayer {
   228  	case a.A:
   229  		a.currentPlayer = a.B
   230  	case a.B:
   231  		a.currentPlayer = a.A
   232  	}
   233  }
   234  
   235  func cloneBoard(a []game.Colour) []game.Colour {
   236  	retVal := make([]game.Colour, len(a))
   237  	copy(retVal, a)
   238  	return retVal
   239  }
   240  
   241  func validPolicies(policy []float32) bool {
   242  	for _, v := range policy {
   243  		if math32.IsInf(v, 0) {
   244  			return false
   245  		}
   246  		if math32.IsNaN(v) {
   247  			return false
   248  		}
   249  	}
   250  	return true
   251  }