github.com/gorgonia/agogo@v0.1.1/dualnet/meta.go (about)

     1  package dual
     2  
     3  import (
     4  	"bytes"
     5  	"log"
     6  	"math/rand"
     7  	"time"
     8  
     9  	"github.com/pkg/errors"
    10  	G "gorgonia.org/gorgonia"
    11  	"gorgonia.org/tensor"
    12  	"gorgonia.org/tensor/native"
    13  )
    14  
    15  // Train is a basic trainer.
    16  func Train(d *Dual, Xs, policies, values *tensor.Dense, batches, iterations int) error {
    17  	m := G.NewTapeMachine(d.g, G.BindDualValues(d.Model()...))
    18  	model := G.NodesToValueGrads(d.Model())
    19  	solver := G.NewVanillaSolver(G.WithLearnRate(0.1))
    20  	var s slicer
    21  	for i := 0; i < iterations; i++ {
    22  		// var cost float32
    23  		for bat := 0; bat < batches; bat++ {
    24  			batchStart := bat * d.Config.BatchSize
    25  			batchEnd := batchStart + d.Config.BatchSize
    26  
    27  			Xs2 := s.Slice(Xs, sli(batchStart, batchEnd))
    28  			π := s.Slice(policies, sli(batchStart, batchEnd))
    29  			v := s.Slice(values, sli(batchStart, batchEnd))
    30  
    31  			G.Let(d.planes, Xs2)
    32  			G.Let(d.Π, π)
    33  			G.Let(d.V, v)
    34  			if err := m.RunAll(); err != nil {
    35  				return err
    36  			}
    37  			// cost = d.cost.Data().(float32)
    38  			if err := solver.Step(model); err != nil {
    39  				return err
    40  			}
    41  			m.Reset()
    42  			tensor.ReturnTensor(Xs2)
    43  			tensor.ReturnTensor(π)
    44  			tensor.ReturnTensor(v)
    45  		}
    46  		if err := shuffleBatch(Xs, policies, values); err != nil {
    47  			return err
    48  		}
    49  		// TODO: add a channel to send training  cost data down
    50  		// log.Printf("%d\t%v", i, cost/float32(batches))
    51  	}
    52  	return nil
    53  }
    54  
    55  // shuffleBatch shuffles the batches.
    56  func shuffleBatch(Xs, π, v *tensor.Dense) (err error) {
    57  	r := rand.New(rand.NewSource(time.Now().UnixNano()))
    58  	oriXs := Xs.Shape().Clone()
    59  	oriPis := π.Shape().Clone()
    60  
    61  	defer func() {
    62  		if r := recover(); r != nil {
    63  			log.Printf("%v %v", Xs.Shape(), π.Shape())
    64  			panic(r)
    65  		}
    66  	}()
    67  	Xs.Reshape(as2D(Xs.Shape())...)
    68  	π.Reshape(as2D(π.Shape())...)
    69  
    70  	var matXs, matPis [][]float32
    71  	if matXs, err = native.MatrixF32(Xs); err != nil {
    72  		return errors.Wrapf(err, "shuffle batch failed - matX")
    73  	}
    74  	if matPis, err = native.MatrixF32(π); err != nil {
    75  		return errors.Wrapf(err, "shuffle batch failed - pi")
    76  	}
    77  	vs := v.Data().([]float32)
    78  
    79  	tmp := make([]float32, Xs.Shape()[1])
    80  	for i := range matXs {
    81  		j := r.Intn(i + 1)
    82  
    83  		rowI := matXs[i]
    84  		rowJ := matXs[j]
    85  		copy(tmp, rowI)
    86  		copy(rowI, rowJ)
    87  		copy(rowJ, tmp)
    88  
    89  		piI := matPis[i]
    90  		piJ := matPis[j]
    91  		copy(tmp, piI)
    92  		copy(piI, piJ)
    93  		copy(piJ, tmp)
    94  
    95  		vs[i], vs[j] = vs[j], vs[i]
    96  	}
    97  	Xs.Reshape(oriXs...)
    98  	π.Reshape(oriPis...)
    99  
   100  	return nil
   101  }
   102  
   103  func as2D(s tensor.Shape) tensor.Shape {
   104  	retVal := tensor.BorrowInts(2)
   105  	retVal[0] = s[0]
   106  	retVal[1] = s[1]
   107  	for i := 2; i < len(s); i++ {
   108  		retVal[1] *= s[i]
   109  	}
   110  	return retVal
   111  }
   112  
   113  // Inferencer is a struct that holds the state for a *Dual and a VM. By using an Inferece struct,
   114  // there is no longer a need to create a VM every time an inference needs to be done.
   115  type Inferencer struct {
   116  	d *Dual
   117  	m G.VM
   118  
   119  	input *tensor.Dense
   120  	buf   *bytes.Buffer
   121  }
   122  
   123  // Infer takes a trained *Dual, and creates a interence data structure such that it'd be easy to infer
   124  func Infer(d *Dual, actionSpace int, toLog bool) (*Inferencer, error) {
   125  	conf := d.Config
   126  	conf.FwdOnly = true
   127  	conf.BatchSize = actionSpace
   128  	newShape := d.planes.Shape().Clone()
   129  	newShape[0] = actionSpace
   130  	retVal := &Inferencer{
   131  		d:     New(conf),
   132  		input: tensor.New(tensor.WithShape(newShape...), tensor.Of(Float)),
   133  	}
   134  	if err := retVal.d.Init(); err != nil {
   135  		return nil, err
   136  	}
   137  	retVal.d.SetTesting()
   138  	// G.WithInit(G.Zeroes())(retVal.d.planes)
   139  
   140  	infModel := retVal.d.Model()
   141  	for i, n := range d.Model() {
   142  		original := n.Value().Data().([]float32)
   143  		cloned := infModel[i].Value().Data().([]float32)
   144  		copy(cloned, original)
   145  	}
   146  
   147  	retVal.buf = new(bytes.Buffer)
   148  	if toLog {
   149  		logger := log.New(retVal.buf, "", 0)
   150  		retVal.m = G.NewTapeMachine(retVal.d.g,
   151  			G.WithLogger(logger),
   152  			G.WithWatchlist(),
   153  			G.TraceExec(),
   154  			G.WithValueFmt("%+1.1v"),
   155  			G.WithNaNWatch(),
   156  		)
   157  	} else {
   158  		retVal.m = G.NewTapeMachine(retVal.d.g)
   159  	}
   160  	return retVal, nil
   161  }
   162  
   163  // Dual implements Dualer
   164  func (m *Inferencer) Dual() *Dual { return m.d }
   165  
   166  // Infer takes the board, in form of a []float32, and runs inference, and returns the value
   167  func (m *Inferencer) Infer(board []float32) (policy []float32, value float32, err error) {
   168  	m.buf.Reset()
   169  	for _, op := range m.d.ops {
   170  		op.Reset()
   171  	}
   172  
   173  	// copy board to the provided preallocated input tensor
   174  	m.input.Zero()
   175  	data := m.input.Data().([]float32)
   176  	copy(data, board)
   177  
   178  	m.m.Reset()
   179  	// log.Printf("Let planes %p be input %v", m.d.planes, board)
   180  	m.buf.Reset()
   181  	G.Let(m.d.planes, m.input)
   182  	if err = m.m.RunAll(); err != nil {
   183  		return nil, 0, err
   184  	}
   185  	policy = m.d.policyValue.Data().([]float32)
   186  	value = m.d.value.Data().([]float32)[0]
   187  	// log.Printf("\t%v", policy)
   188  	return policy[:m.d.ActionSpace], value, nil
   189  }
   190  
   191  // ExecLog returns the execution log. If Infer was called with toLog = false, then it will return an empty string
   192  func (m *Inferencer) ExecLog() string { return m.buf.String() }
   193  
   194  // Close implements a closer, because well, a gorgonia VM is a resource.
   195  func (m *Inferencer) Close() error { return m.m.Close() }