github.com/gorgonia/agogo@v0.1.1/dualnet/dual.go (about)

     1  package dual
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/gob"
     6  
     7  	G "gorgonia.org/gorgonia"
     8  	"gorgonia.org/tensor"
     9  )
    10  
    11  var Float = G.Float32
    12  
    13  // Dual is the whole neural network architecture of the dual network.
    14  //
    15  // The policy and value outputs are shared
    16  type Dual struct {
    17  	Config
    18  	ops []batchNormOp
    19  
    20  	g    *G.ExprGraph
    21  	Π, V *G.Node // pi and value labels. Pi is a matrix of 1s and 0s
    22  
    23  	planes       *G.Node
    24  	policyOutput *G.Node
    25  	valueOutput  *G.Node
    26  
    27  	policyValue G.Value // policy predicted
    28  	value       G.Value // the actual value predicted
    29  	cost        G.Value // cost, for training recoring
    30  }
    31  
    32  // New returns a new, uninitialized *Dual.
    33  func New(conf Config) *Dual {
    34  	retVal := &Dual{
    35  		Config: conf,
    36  	}
    37  
    38  	return retVal
    39  }
    40  
    41  func (d *Dual) Init() error {
    42  	d.reset()
    43  	d.g = G.NewGraph()
    44  	actionSpace := d.ActionSpace
    45  	logits, valueOutput := d.fwd(actionSpace)
    46  	return d.bwd(actionSpace, logits, valueOutput)
    47  
    48  }
    49  
    50  func (d *Dual) fwd(actionSpace int) (logits, valueOutput *G.Node) {
    51  	boardSize := d.Width * d.Height
    52  
    53  	// note, the data should be arranged like so:
    54  	//	BatchSize, Features, Height, Width
    55  	// because Gorgonia only supports doing convolutions on BCHW format
    56  	d.planes = G.NewTensor(d.g, Float, 4, G.WithShape(d.BatchSize, d.Features, d.Height, d.Width), G.WithName("Planes"))
    57  
    58  	var m maebe
    59  	initialOut, initalOp := m.res(d.planes, d.K, "Init")
    60  	d.ops = append(d.ops, initalOp)
    61  
    62  	// shared stack
    63  	sharedOut := initialOut
    64  	for i := 0; i < d.SharedLayers; i++ {
    65  		var op1, op2 batchNormOp
    66  		sharedOut, op1, op2 = m.share(sharedOut, d.K, i)
    67  		d.ops = append(d.ops, op1, op2)
    68  	}
    69  
    70  	// policy head
    71  	var batches int
    72  	policy, pop := m.batchnorm(m.conv(sharedOut, 2, 1, "PolicyHead"))
    73  	policy = m.rectify(policy)
    74  	if batches = policy.Shape().TotalSize() / (boardSize * 2); batches == 0 {
    75  		batches = 1
    76  	}
    77  	policy = m.reshape(policy, tensor.Shape{batches, boardSize * 2})
    78  	logits = m.linear(policy, actionSpace, "Policy")
    79  
    80  	// Read to output which can be used for deciding the policy
    81  	d.policyOutput = m.do(func() (*G.Node, error) { return G.SoftMax(logits) })
    82  	G.Read(d.policyOutput, &d.policyValue)
    83  
    84  	// value head
    85  	value, vop := m.batchnorm(m.conv(sharedOut, 1, 1, "ValueHead"))
    86  	value = m.rectify(value)
    87  	batches = value.Shape().TotalSize() / boardSize
    88  	value = m.reshape(value, tensor.Shape{batches, boardSize})
    89  	value = m.linear(value, d.FC, "Value") // value hidden
    90  	value = m.rectify(value)
    91  
    92  	valueOutput = m.linear(value, 1, "ValueOutput")
    93  	valueOutput = m.reshape(valueOutput, tensor.Shape{valueOutput.Shape().TotalSize()})
    94  
    95  	// Read the output to a value
    96  	d.valueOutput = m.do(func() (*G.Node, error) { return G.Tanh(valueOutput) })
    97  	G.Read(d.valueOutput, &d.value)
    98  
    99  	// add ops
   100  	d.ops = append(d.ops, pop, vop)
   101  
   102  	return logits, valueOutput
   103  }
   104  
   105  func (d *Dual) bwd(actionSpace int, logits, valueOutput *G.Node) error {
   106  	if d.FwdOnly {
   107  		return nil
   108  	}
   109  	d.Π = G.NewMatrix(d.g, Float, G.WithShape(d.BatchSize, actionSpace))
   110  	d.V = G.NewVector(d.g, Float, G.WithShape(d.BatchSize))
   111  
   112  	var m maebe
   113  	// policy, value and combined costs
   114  	var pcost, vcost, ccost *G.Node
   115  	pcost = m.xent(logits, d.Π) // cross entropy, averaged.
   116  	vcost = m.do(func() (*G.Node, error) { return G.Sub(valueOutput, d.V) })
   117  	vcost = m.do(func() (*G.Node, error) { return G.Square(vcost) })
   118  	vcost = m.do(func() (*G.Node, error) { return G.Mean(vcost) })
   119  
   120  	// combined costs
   121  	ccost = m.do(func() (*G.Node, error) { return G.Add(pcost, vcost) })
   122  	if m.err != nil {
   123  		return m.err
   124  	}
   125  	G.Read(ccost, &d.cost)
   126  
   127  	if _, err := G.Grad(ccost, d.Model()...); err != nil {
   128  		return err
   129  
   130  	}
   131  	return nil
   132  }
   133  
   134  func (d *Dual) Model() G.Nodes {
   135  	retVal := make(G.Nodes, 0, d.g.Nodes().Len())
   136  	for _, n := range d.g.AllNodes() {
   137  		if n.IsVar() && n != d.planes && n != d.Π && n != d.V {
   138  			retVal = append(retVal, n)
   139  		}
   140  	}
   141  	return retVal
   142  }
   143  
   144  func (d *Dual) SetTesting() {
   145  	for _, op := range d.ops {
   146  		op.SetTesting()
   147  	}
   148  }
   149  
   150  func (d *Dual) Clone() (*Dual, error) {
   151  	d2 := New(d.Config)
   152  	if err := d2.Init(); err != nil {
   153  		return nil, err
   154  	}
   155  
   156  	model := d.Model()
   157  	model2 := d2.Model()
   158  	for i, n := range model {
   159  		if err := G.Let(model2[i], n.Value()); err != nil {
   160  			return nil, err
   161  		}
   162  	}
   163  
   164  	return d2, nil
   165  }
   166  
   167  // Dual implemented Dualer
   168  func (d *Dual) Dual() *Dual { return d }
   169  
   170  func (d *Dual) reset() {
   171  	d.ops = nil
   172  	d.g = nil
   173  	d.Π = nil
   174  	d.V = nil
   175  
   176  	d.planes = nil
   177  	d.policyOutput = nil
   178  }
   179  
   180  func (d *Dual) GobEncode() (retVal []byte, err error) {
   181  	var buf bytes.Buffer
   182  	enc := gob.NewEncoder(&buf)
   183  	for _, n := range d.Model() {
   184  		v := n.Value()
   185  		if err = enc.Encode(&v); err != nil {
   186  			return nil, err
   187  		}
   188  	}
   189  	return buf.Bytes(), nil
   190  }
   191  
   192  func (d *Dual) GobDecode(p []byte) error {
   193  	d.reset()
   194  	d.Init()
   195  
   196  	buf := bytes.NewBuffer(p)
   197  	dec := gob.NewDecoder(buf)
   198  	for _, n := range d.Model() {
   199  		var v G.Value
   200  		if err := dec.Decode(&v); err != nil {
   201  			return err
   202  		}
   203  		G.Let(n, v)
   204  	}
   205  	return nil
   206  }