github.com/gorgonia/agogo@v0.1.1/agent.go (about) 1 package agogo 2 3 import ( 4 "log" 5 "runtime" 6 "sync" 7 8 dual "github.com/gorgonia/agogo/dualnet" 9 "github.com/gorgonia/agogo/game" 10 "github.com/gorgonia/agogo/mcts" 11 ) 12 13 // An Agent is a player, AI or Human 14 type Agent struct { 15 NN *dual.Dual 16 MCTS *mcts.MCTS 17 Player game.Player 18 Enc GameEncoder 19 20 // Statistics 21 Wins float32 22 Loss float32 23 Draw float32 24 sync.Mutex 25 26 name string 27 actions int 28 inferer chan Inferer 29 err error 30 } 31 32 func newAgent(a Dualer) *Agent { 33 retVal := &Agent{ 34 NN: a.Dual(), 35 } 36 return retVal 37 } 38 39 // SwitchToInference uses the inference mode neural network. 40 func (a *Agent) SwitchToInference(g game.State) (err error) { 41 a.Lock() 42 a.inferer = make(chan Inferer, numCPU) 43 44 for i := 0; i < numCPU; i++ { 45 var inf Inferer 46 if inf, err = dual.Infer(a.NN, g.ActionSpace(), false); err != nil { 47 return err 48 } 49 a.inferer <- inf 50 } 51 // a.NN = nil // remove old NN 52 a.Unlock() 53 return nil 54 } 55 56 // Infer infers a bunch of moves based on the game state. This is mainly used to implement a Inferer such that the MCTS search can use it. 57 func (a *Agent) Infer(g game.State) (policy []float32, value float32) { 58 input := a.Enc(g) 59 inf := <-a.inferer 60 61 var err error 62 policy, value, err = inf.Infer(input) 63 if err != nil { 64 if el, ok := inf.(ExecLogger); ok { 65 log.Println(el.ExecLog()) 66 } 67 panic(err) 68 } 69 a.inferer <- inf 70 return 71 } 72 73 // Search searches the game state and returns a suggested coordinate. 74 func (a *Agent) Search(g game.State) game.Single { 75 a.MCTS.SetGame(g) 76 return a.MCTS.Search(a.Player) 77 } 78 79 // NNOutput returns the output of the neural network 80 func (a *Agent) NNOutput(g game.State) (policy []float32, value float32, err error) { 81 input := a.Enc(g) 82 inf := <-a.inferer 83 policy, value, err = inf.Infer(input) 84 a.inferer <- inf 85 return 86 } 87 88 func (a *Agent) Close() error { 89 close(a.inferer) 90 var allErrs manyErr 91 for inferer := range a.inferer { 92 if err := inferer.Close(); err != nil { 93 allErrs = append(allErrs, err) 94 } 95 } 96 if len(allErrs) > 0 { 97 return allErrs 98 } 99 return nil 100 } 101 102 func (a *Agent) useDummy(g game.State) { 103 a.inferer = make(chan Inferer, runtime.NumCPU()) 104 for i := 0; i < runtime.NumCPU(); i++ { 105 a.inferer <- dummyInferer{ 106 outputSize: g.ActionSpace(), 107 currentPlayer: a.Player, 108 } 109 } 110 } 111 112 func (a *Agent) resetStats() { 113 a.Lock() 114 a.Wins = 0 115 a.Loss = 0 116 a.Draw = 0 117 a.Unlock() 118 }