github.com/gorgonia/agogo@v0.1.1/agogo.go (about)

     1  package agogo
     2  
     3  import (
     4  	"encoding/gob"
     5  	"fmt"
     6  	"log"
     7  	"math/rand"
     8  	"os"
     9  	"time"
    10  
    11  	dual "github.com/gorgonia/agogo/dualnet"
    12  	"github.com/gorgonia/agogo/game"
    13  	"github.com/gorgonia/agogo/mcts"
    14  	"github.com/pkg/errors"
    15  	"gorgonia.org/tensor"
    16  )
    17  
    18  // AZ is the top level structure and the entry point of the API.
    19  // It it a wrapper around the MTCS and the NeeuralNework that composes the algorithm.
    20  // AZ stands for AlphaZero
    21  type AZ struct {
    22  	// state
    23  	Arena
    24  	Statistics
    25  	useDummy bool
    26  
    27  	// config
    28  	nnConf          dual.Config
    29  	mctsConf        mcts.Config
    30  	enc             GameEncoder
    31  	aug             Augmenter
    32  	updateThreshold float32
    33  	maxExamples     int
    34  
    35  	// io
    36  	outEnc OutputEncoder
    37  }
    38  
    39  // New AlphaZero structure. It takes a game state (implementing the board, rules, etc.)
    40  // and a configuration to apply to the MCTS and the neural network
    41  func New(g game.State, conf Config) *AZ {
    42  	if !conf.NNConf.IsValid() {
    43  		panic("NNConf is not valid. Unable to proceed")
    44  	}
    45  	if !conf.MCTSConf.IsValid() {
    46  		panic("MCTSConf is not valid. Unable to proceed")
    47  	}
    48  
    49  	a := dual.New(conf.NNConf)
    50  	b := dual.New(conf.NNConf)
    51  
    52  	if err := a.Init(); err != nil {
    53  		panic(fmt.Sprintf("%+v", err))
    54  	}
    55  	if err := b.Init(); err != nil {
    56  		panic(fmt.Sprintf("%+v", err))
    57  	}
    58  
    59  	retVal := &AZ{
    60  		Arena:           MakeArena(g, a, b, conf.MCTSConf, conf.Encoder, conf.Augmenter, conf.Name),
    61  		nnConf:          conf.NNConf,
    62  		mctsConf:        conf.MCTSConf,
    63  		enc:             conf.Encoder,
    64  		outEnc:          conf.OutputEncoder,
    65  		aug:             conf.Augmenter,
    66  		updateThreshold: float32(conf.UpdateThreshold),
    67  		maxExamples:     conf.MaxExamples,
    68  		Statistics:      makeStatistics(),
    69  		useDummy:        true,
    70  	}
    71  	retVal.logger = log.New(&retVal.buf, "", log.Ltime)
    72  	return retVal
    73  }
    74  
    75  func (a *AZ) setupSelfPlay(iter int) {
    76  	var err error
    77  	if err = a.A.SwitchToInference(a.game); err != nil {
    78  		// DO SOMETHING WITH ERROR
    79  	}
    80  	if err = a.B.SwitchToInference(a.game); err != nil {
    81  		// DO SOMETHING WITH ERROR
    82  	}
    83  	if iter == 0 && a.useDummy {
    84  		log.Printf("Using Dummy")
    85  		a.A.useDummy(a.game)
    86  		a.B.useDummy(a.game)
    87  	}
    88  	log.Printf("Set up selfplay: Switch To inference for A. A.NN %p (%T)", a.A.NN, a.A.NN)
    89  	log.Printf("Set up selfplay: Switch To inference for B. B.NN %p (%T)", a.B.NN, a.B.NN)
    90  }
    91  
    92  // SelfPlay plays an episode
    93  func (a *AZ) SelfPlay() []Example {
    94  	_, examples := a.Play(true, nil, a.aug) // don't encode images while selfplay... that'd be boring to watch
    95  	a.game.Reset()
    96  	return examples
    97  }
    98  
    99  // Learn learns for iters. It self-plays for episodes, and then trains a new NN from the self play example.
   100  func (a *AZ) Learn(iters, episodes, nniters, arenaGames int) error {
   101  	var err error
   102  	for a.epoch = 0; a.epoch < iters; a.epoch++ {
   103  		var ex []Example
   104  		log.Printf("Self Play for epoch %d. Player A %p, Player B %p", a.epoch, a.A, a.B)
   105  
   106  		a.buf.Reset()
   107  		a.logger.Printf("Self Play for epoch %d. Player A %p, Player B %p", a.epoch, a.A, a.B)
   108  		a.logger.SetPrefix("\t")
   109  		a.setupSelfPlay(a.epoch)
   110  		for e := 0; e < episodes; e++ {
   111  			log.Printf("\tEpisode %v", e)
   112  			a.logger.Printf("Episode %v\n", e)
   113  			ex = append(ex, a.SelfPlay()...)
   114  		}
   115  		a.logger.SetPrefix("")
   116  		a.buf.Reset()
   117  
   118  		if a.maxExamples > 0 && len(ex) > a.maxExamples {
   119  			shuffleExamples(ex)
   120  			ex = ex[:a.maxExamples]
   121  		}
   122  		Xs, Policies, Values, batches := a.prepareExamples(ex)
   123  
   124  		// // create a new DualNet for B
   125  		// a.B.NN = dual.New(a.nnConf)
   126  		// if err = a.B.NN.Dual().Init(); err != nil {
   127  		// 	return errors.WithMessage(err, "Unable to create new DualNet for B")
   128  		// }
   129  
   130  		if err = dual.Train(a.B.NN, Xs, Policies, Values, batches, nniters); err != nil {
   131  			return errors.WithMessage(err, fmt.Sprintf("Train fail"))
   132  		}
   133  
   134  		a.B.SwitchToInference(a.game)
   135  
   136  		a.A.resetStats()
   137  		a.B.resetStats()
   138  
   139  		a.logger.Printf("Playing Arena")
   140  		a.logger.SetPrefix("\t")
   141  		for a.gameNumber = 0; a.gameNumber < arenaGames; a.gameNumber++ {
   142  			a.logger.Printf("Playing game number %d", a.gameNumber)
   143  			a.Play(false, a.outEnc, nil)
   144  			a.game.Reset()
   145  		}
   146  		a.logger.SetPrefix("")
   147  
   148  		var killedA bool
   149  		log.Printf("A wins %v, loss %v, draw %v\nB wins %v, loss %v, draw %v", a.A.Wins, a.A.Loss, a.A.Draw, a.B.Wins, a.B.Loss, a.B.Draw)
   150  
   151  		// if a.B.Wins/(a.B.Wins+a.B.Loss+a.B.Draw) > a.updateThreshold {
   152  		if a.B.Wins/(a.B.Wins+a.A.Wins) > a.updateThreshold {
   153  			// B wins. Kill A, clean up its resources.
   154  			log.Printf("Kill A %p. New A's NN is %p", a.A.NN, a.B.NN)
   155  			if err = a.A.Close(); err != nil {
   156  				return err
   157  			}
   158  			a.A.NN = a.B.NN
   159  			// clear examples
   160  			ex = ex[:0]
   161  			killedA = true
   162  		}
   163  		a.update(a.A)
   164  		if err = a.newB(a.nnConf, killedA); err != nil {
   165  			return err
   166  		}
   167  	}
   168  	return nil
   169  }
   170  
   171  // Save learning into filenamee
   172  func (a *AZ) Save(filename string) error {
   173  	f, err := os.OpenFile(filename, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0544)
   174  	if err != nil {
   175  		return err
   176  	}
   177  	defer f.Close()
   178  
   179  	enc := gob.NewEncoder(f)
   180  	return enc.Encode(a.A.NN)
   181  }
   182  
   183  // Load the Alpha Zero structure from a filename
   184  func (a *AZ) Load(filename string) error {
   185  	f, err := os.Open(filename)
   186  	if err != nil {
   187  		return errors.WithStack(err)
   188  	}
   189  	defer f.Close()
   190  
   191  	a.A.NN = dual.New(a.nnConf)
   192  	a.B.NN = dual.New(a.nnConf)
   193  
   194  	dec := gob.NewDecoder(f)
   195  	if err = dec.Decode(a.A.NN); err != nil {
   196  		return errors.WithStack(err)
   197  	}
   198  
   199  	f.Seek(0, 0)
   200  	dec = gob.NewDecoder(f)
   201  	if err = dec.Decode(a.B.NN); err != nil {
   202  		return errors.WithStack(err)
   203  	}
   204  	a.useDummy = false
   205  	return nil
   206  }
   207  
   208  func (a *AZ) prepareExamples(examples []Example) (Xs, Policies, Values *tensor.Dense, batches int) {
   209  	shuffleExamples(examples)
   210  	batches = len(examples) / a.nnConf.BatchSize
   211  	total := batches * a.nnConf.BatchSize
   212  	var XsBacking, PoliciesBacking, ValuesBacking []float32
   213  	for i, ex := range examples {
   214  		if i >= total {
   215  			break
   216  		}
   217  		XsBacking = append(XsBacking, ex.Board...)
   218  
   219  		start := len(PoliciesBacking)
   220  		PoliciesBacking = append(PoliciesBacking, make([]float32, len(ex.Policy))...)
   221  		copy(PoliciesBacking[start:], ex.Policy)
   222  
   223  		ValuesBacking = append(ValuesBacking, ex.Value)
   224  	}
   225  	// padd out anythihng that is not full
   226  	// board0 := examples[0].Board
   227  	// policy0 := examples[0].Policy
   228  	// rem := len(examples) % a.nnConf.BatchSize
   229  	// if rem != 0 {
   230  	// 	diff := a.nnConf.BatchSize - rem
   231  
   232  	// 	// add padded data
   233  	// 	XsBacking = append(XsBacking, make([]float32, diff*len(board0))...)
   234  	// 	PoliciesBacking = append(PoliciesBacking, make([]float32, diff*len(policy0))...)
   235  	// 	ValuesBacking = append(ValuesBacking, make([]float32, diff)...)
   236  	// }
   237  	// if rem > 0 {
   238  	// 	batches++
   239  	// }
   240  
   241  	actionSpace := a.Arena.game.ActionSpace() + 1 // allow passes
   242  	Xs = tensor.New(tensor.WithBacking(XsBacking), tensor.WithShape(a.nnConf.BatchSize*batches, a.nnConf.Features, a.nnConf.Height, a.nnConf.Width))
   243  	Policies = tensor.New(tensor.WithBacking(PoliciesBacking), tensor.WithShape(a.nnConf.BatchSize*batches, actionSpace))
   244  	Values = tensor.New(tensor.WithBacking(ValuesBacking), tensor.WithShape(a.nnConf.BatchSize*batches))
   245  	return
   246  }
   247  
   248  func shuffleExamples(examples []Example) {
   249  	r := rand.New(rand.NewSource(time.Now().UnixNano()))
   250  	for i := range examples {
   251  		j := r.Intn(i + 1)
   252  		examples[i], examples[j] = examples[j], examples[i]
   253  	}
   254  }