github.com/gorgonia/agogo@v0.1.1/agent.go (about)

     1  package agogo
     2  
     3  import (
     4  	"log"
     5  	"runtime"
     6  	"sync"
     7  
     8  	dual "github.com/gorgonia/agogo/dualnet"
     9  	"github.com/gorgonia/agogo/game"
    10  	"github.com/gorgonia/agogo/mcts"
    11  )
    12  
    13  // An Agent is a player, AI or Human
    14  type Agent struct {
    15  	NN     *dual.Dual
    16  	MCTS   *mcts.MCTS
    17  	Player game.Player
    18  	Enc    GameEncoder
    19  
    20  	// Statistics
    21  	Wins float32
    22  	Loss float32
    23  	Draw float32
    24  	sync.Mutex
    25  
    26  	name    string
    27  	actions int
    28  	inferer chan Inferer
    29  	err     error
    30  }
    31  
    32  func newAgent(a Dualer) *Agent {
    33  	retVal := &Agent{
    34  		NN: a.Dual(),
    35  	}
    36  	return retVal
    37  }
    38  
    39  // SwitchToInference uses the inference mode neural network.
    40  func (a *Agent) SwitchToInference(g game.State) (err error) {
    41  	a.Lock()
    42  	a.inferer = make(chan Inferer, numCPU)
    43  
    44  	for i := 0; i < numCPU; i++ {
    45  		var inf Inferer
    46  		if inf, err = dual.Infer(a.NN, g.ActionSpace(), false); err != nil {
    47  			return err
    48  		}
    49  		a.inferer <- inf
    50  	}
    51  	// a.NN = nil // remove old NN
    52  	a.Unlock()
    53  	return nil
    54  }
    55  
    56  // Infer infers a bunch of moves based on the game state. This is mainly used to implement a Inferer such that the MCTS search can use it.
    57  func (a *Agent) Infer(g game.State) (policy []float32, value float32) {
    58  	input := a.Enc(g)
    59  	inf := <-a.inferer
    60  
    61  	var err error
    62  	policy, value, err = inf.Infer(input)
    63  	if err != nil {
    64  		if el, ok := inf.(ExecLogger); ok {
    65  			log.Println(el.ExecLog())
    66  		}
    67  		panic(err)
    68  	}
    69  	a.inferer <- inf
    70  	return
    71  }
    72  
    73  // Search searches the game state and returns a suggested coordinate.
    74  func (a *Agent) Search(g game.State) game.Single {
    75  	a.MCTS.SetGame(g)
    76  	return a.MCTS.Search(a.Player)
    77  }
    78  
    79  // NNOutput returns the output of the neural network
    80  func (a *Agent) NNOutput(g game.State) (policy []float32, value float32, err error) {
    81  	input := a.Enc(g)
    82  	inf := <-a.inferer
    83  	policy, value, err = inf.Infer(input)
    84  	a.inferer <- inf
    85  	return
    86  }
    87  
    88  func (a *Agent) Close() error {
    89  	close(a.inferer)
    90  	var allErrs manyErr
    91  	for inferer := range a.inferer {
    92  		if err := inferer.Close(); err != nil {
    93  			allErrs = append(allErrs, err)
    94  		}
    95  	}
    96  	if len(allErrs) > 0 {
    97  		return allErrs
    98  	}
    99  	return nil
   100  }
   101  
   102  func (a *Agent) useDummy(g game.State) {
   103  	a.inferer = make(chan Inferer, runtime.NumCPU())
   104  	for i := 0; i < runtime.NumCPU(); i++ {
   105  		a.inferer <- dummyInferer{
   106  			outputSize:    g.ActionSpace(),
   107  			currentPlayer: a.Player,
   108  		}
   109  	}
   110  }
   111  
   112  func (a *Agent) resetStats() {
   113  	a.Lock()
   114  	a.Wins = 0
   115  	a.Loss = 0
   116  	a.Draw = 0
   117  	a.Unlock()
   118  }