github.com/gorgonia/agogo@v0.1.1/dualnet/dual.go (about) 1 package dual 2 3 import ( 4 "bytes" 5 "encoding/gob" 6 7 G "gorgonia.org/gorgonia" 8 "gorgonia.org/tensor" 9 ) 10 11 var Float = G.Float32 12 13 // Dual is the whole neural network architecture of the dual network. 14 // 15 // The policy and value outputs are shared 16 type Dual struct { 17 Config 18 ops []batchNormOp 19 20 g *G.ExprGraph 21 Π, V *G.Node // pi and value labels. Pi is a matrix of 1s and 0s 22 23 planes *G.Node 24 policyOutput *G.Node 25 valueOutput *G.Node 26 27 policyValue G.Value // policy predicted 28 value G.Value // the actual value predicted 29 cost G.Value // cost, for training recoring 30 } 31 32 // New returns a new, uninitialized *Dual. 33 func New(conf Config) *Dual { 34 retVal := &Dual{ 35 Config: conf, 36 } 37 38 return retVal 39 } 40 41 func (d *Dual) Init() error { 42 d.reset() 43 d.g = G.NewGraph() 44 actionSpace := d.ActionSpace 45 logits, valueOutput := d.fwd(actionSpace) 46 return d.bwd(actionSpace, logits, valueOutput) 47 48 } 49 50 func (d *Dual) fwd(actionSpace int) (logits, valueOutput *G.Node) { 51 boardSize := d.Width * d.Height 52 53 // note, the data should be arranged like so: 54 // BatchSize, Features, Height, Width 55 // because Gorgonia only supports doing convolutions on BCHW format 56 d.planes = G.NewTensor(d.g, Float, 4, G.WithShape(d.BatchSize, d.Features, d.Height, d.Width), G.WithName("Planes")) 57 58 var m maebe 59 initialOut, initalOp := m.res(d.planes, d.K, "Init") 60 d.ops = append(d.ops, initalOp) 61 62 // shared stack 63 sharedOut := initialOut 64 for i := 0; i < d.SharedLayers; i++ { 65 var op1, op2 batchNormOp 66 sharedOut, op1, op2 = m.share(sharedOut, d.K, i) 67 d.ops = append(d.ops, op1, op2) 68 } 69 70 // policy head 71 var batches int 72 policy, pop := m.batchnorm(m.conv(sharedOut, 2, 1, "PolicyHead")) 73 policy = m.rectify(policy) 74 if batches = policy.Shape().TotalSize() / (boardSize * 2); batches == 0 { 75 batches = 1 76 } 77 policy = m.reshape(policy, tensor.Shape{batches, boardSize * 2}) 78 logits = m.linear(policy, actionSpace, "Policy") 79 80 // Read to output which can be used for deciding the policy 81 d.policyOutput = m.do(func() (*G.Node, error) { return G.SoftMax(logits) }) 82 G.Read(d.policyOutput, &d.policyValue) 83 84 // value head 85 value, vop := m.batchnorm(m.conv(sharedOut, 1, 1, "ValueHead")) 86 value = m.rectify(value) 87 batches = value.Shape().TotalSize() / boardSize 88 value = m.reshape(value, tensor.Shape{batches, boardSize}) 89 value = m.linear(value, d.FC, "Value") // value hidden 90 value = m.rectify(value) 91 92 valueOutput = m.linear(value, 1, "ValueOutput") 93 valueOutput = m.reshape(valueOutput, tensor.Shape{valueOutput.Shape().TotalSize()}) 94 95 // Read the output to a value 96 d.valueOutput = m.do(func() (*G.Node, error) { return G.Tanh(valueOutput) }) 97 G.Read(d.valueOutput, &d.value) 98 99 // add ops 100 d.ops = append(d.ops, pop, vop) 101 102 return logits, valueOutput 103 } 104 105 func (d *Dual) bwd(actionSpace int, logits, valueOutput *G.Node) error { 106 if d.FwdOnly { 107 return nil 108 } 109 d.Π = G.NewMatrix(d.g, Float, G.WithShape(d.BatchSize, actionSpace)) 110 d.V = G.NewVector(d.g, Float, G.WithShape(d.BatchSize)) 111 112 var m maebe 113 // policy, value and combined costs 114 var pcost, vcost, ccost *G.Node 115 pcost = m.xent(logits, d.Π) // cross entropy, averaged. 116 vcost = m.do(func() (*G.Node, error) { return G.Sub(valueOutput, d.V) }) 117 vcost = m.do(func() (*G.Node, error) { return G.Square(vcost) }) 118 vcost = m.do(func() (*G.Node, error) { return G.Mean(vcost) }) 119 120 // combined costs 121 ccost = m.do(func() (*G.Node, error) { return G.Add(pcost, vcost) }) 122 if m.err != nil { 123 return m.err 124 } 125 G.Read(ccost, &d.cost) 126 127 if _, err := G.Grad(ccost, d.Model()...); err != nil { 128 return err 129 130 } 131 return nil 132 } 133 134 func (d *Dual) Model() G.Nodes { 135 retVal := make(G.Nodes, 0, d.g.Nodes().Len()) 136 for _, n := range d.g.AllNodes() { 137 if n.IsVar() && n != d.planes && n != d.Π && n != d.V { 138 retVal = append(retVal, n) 139 } 140 } 141 return retVal 142 } 143 144 func (d *Dual) SetTesting() { 145 for _, op := range d.ops { 146 op.SetTesting() 147 } 148 } 149 150 func (d *Dual) Clone() (*Dual, error) { 151 d2 := New(d.Config) 152 if err := d2.Init(); err != nil { 153 return nil, err 154 } 155 156 model := d.Model() 157 model2 := d2.Model() 158 for i, n := range model { 159 if err := G.Let(model2[i], n.Value()); err != nil { 160 return nil, err 161 } 162 } 163 164 return d2, nil 165 } 166 167 // Dual implemented Dualer 168 func (d *Dual) Dual() *Dual { return d } 169 170 func (d *Dual) reset() { 171 d.ops = nil 172 d.g = nil 173 d.Π = nil 174 d.V = nil 175 176 d.planes = nil 177 d.policyOutput = nil 178 } 179 180 func (d *Dual) GobEncode() (retVal []byte, err error) { 181 var buf bytes.Buffer 182 enc := gob.NewEncoder(&buf) 183 for _, n := range d.Model() { 184 v := n.Value() 185 if err = enc.Encode(&v); err != nil { 186 return nil, err 187 } 188 } 189 return buf.Bytes(), nil 190 } 191 192 func (d *Dual) GobDecode(p []byte) error { 193 d.reset() 194 d.Init() 195 196 buf := bytes.NewBuffer(p) 197 dec := gob.NewDecoder(buf) 198 for _, n := range d.Model() { 199 var v G.Value 200 if err := dec.Decode(&v); err != nil { 201 return err 202 } 203 G.Let(n, v) 204 } 205 return nil 206 }