github.com/gorgonia/agogo@v0.1.1/dualnet/ermahagerdmonards.go (about) 1 package dual 2 3 import ( 4 "fmt" 5 6 "github.com/pkg/errors" 7 G "gorgonia.org/gorgonia" 8 "gorgonia.org/gorgonia/ops/nn" 9 "gorgonia.org/tensor" 10 ) 11 12 type maebe struct { 13 err error 14 } 15 16 type batchNormOp interface { 17 SetTraining() 18 SetTesting() 19 Reset() error 20 } 21 22 // generic monad... may be useful 23 func (m *maebe) do(f func() (*G.Node, error)) (retVal *G.Node) { 24 if m.err != nil { 25 return nil 26 } 27 if retVal, m.err = f(); m.err != nil { 28 m.err = errors.WithStack(m.err) 29 } 30 return 31 } 32 33 func (m *maebe) conv(input *G.Node, filterCount, size int, name string) (retVal *G.Node) { 34 if m.err != nil { 35 return nil 36 } 37 featureCount := input.Shape()[1] 38 padding := findPadding(input.Shape()[2], input.Shape()[3], size, size) 39 filter := G.NewTensor(input.Graph(), Float, 4, G.WithShape(filterCount, featureCount, size, size), G.WithName("Filter"+name), G.WithInit(G.GlorotU(1.0))) 40 41 // assume well behaved images 42 if retVal, m.err = nnops.Conv2d(input, filter, []int{size, size}, padding, []int{1, 1}, []int{1, 1}); m.err != nil { 43 m.err = errors.WithStack(m.err) 44 } 45 return 46 } 47 48 func (m *maebe) batchnorm(input *G.Node) (retVal *G.Node, retOp batchNormOp) { 49 if m.err != nil { 50 return nil, nil 51 } 52 // note: the scale and biases will still be created 53 // and they will still be backpropagated 54 if retVal, _, _, retOp, m.err = nnops.BatchNorm(input, nil, nil, 0.997, 1e-5); m.err != nil { 55 m.err = errors.WithStack(m.err) 56 } 57 return 58 } 59 60 func (m *maebe) res(input *G.Node, filterCount int, name string) (*G.Node, batchNormOp) { 61 convolved := m.conv(input, filterCount, 3, name) 62 normalized, op := m.batchnorm(convolved) 63 retVal := m.rectify(normalized) 64 return retVal, op 65 } 66 67 func (m *maebe) share(input *G.Node, filterCount, layer int) (*G.Node, batchNormOp, batchNormOp) { 68 layer1, l1Op := m.res(input, filterCount, fmt.Sprintf("Layer1 of Shared Layer %d", layer)) 69 layer2, l2Op := m.res(input, filterCount, fmt.Sprintf("Layer2 of Shared Layer %d", layer)) 70 added := m.do(func() (*G.Node, error) { return G.Add(layer1, layer2) }) 71 retVal := m.rectify(added) 72 return retVal, l1Op, l2Op 73 } 74 75 func (m *maebe) linear(input *G.Node, units int, name string) *G.Node { 76 if m.err != nil { 77 return nil 78 } 79 // figure out size 80 w := G.NewTensor(input.Graph(), Float, 2, G.WithShape(input.Shape()[1], units), G.WithInit(G.GlorotN(1.0)), G.WithName(name+"_w")) 81 xw := m.do(func() (*G.Node, error) { return G.Mul(input, w) }) 82 b := G.NewTensor(xw.Graph(), Float, xw.Shape().Dims(), G.WithShape(xw.Shape().Clone()...), G.WithName(name+"_b"), G.WithInit(G.Zeroes())) 83 return m.do(func() (*G.Node, error) { return G.Add(xw, b) }) 84 } 85 86 func (m *maebe) rectify(input *G.Node) (retVal *G.Node) { 87 if m.err != nil { 88 return nil 89 } 90 if retVal, m.err = nnops.Rectify(input); m.err != nil { 91 m.err = errors.WithStack(m.err) 92 } 93 return 94 } 95 96 func (m *maebe) reshape(input *G.Node, to tensor.Shape) (retVal *G.Node) { 97 if m.err != nil { 98 return nil 99 } 100 if retVal, m.err = G.Reshape(input, to); m.err != nil { 101 m.err = errors.WithStack(m.err) 102 } 103 return 104 } 105 106 func (m *maebe) xent(output, target *G.Node) (retVal *G.Node) { 107 var one *G.Node 108 switch Float { 109 case G.Float32: 110 one = G.NewConstant(float32(1)) 111 case G.Float64: 112 one = G.NewConstant(float64(1)) 113 } 114 var omy, omout *G.Node 115 if omy, m.err = G.Sub(one, target); m.err != nil { 116 m.err = errors.WithStack(m.err) 117 return nil 118 } 119 120 if omout, m.err = G.Sub(one, output); m.err != nil { 121 m.err = errors.WithStack(m.err) 122 return nil 123 } 124 125 var fst, snd *G.Node 126 if fst, m.err = G.HadamardProd(target, output); m.err != nil { 127 m.err = errors.WithStack(m.err) 128 return nil 129 } 130 if snd, m.err = G.HadamardProd(omy, omout); m.err != nil { 131 m.err = errors.WithStack(m.err) 132 return nil 133 } 134 135 if retVal, m.err = G.Add(fst, snd); m.err != nil { 136 m.err = errors.WithStack(m.err) 137 return nil 138 } 139 if retVal, m.err = G.Neg(retVal); m.err != nil { 140 m.err = errors.WithStack(m.err) 141 return nil 142 } 143 if retVal, m.err = G.Mean(retVal); m.err != nil { 144 m.err = errors.WithStack(m.err) 145 } 146 return 147 } 148 149 func findPadding(inputX, inputY, kernelX, kernelY int) []int { 150 return []int{ 151 (inputX - 1 - inputX + kernelX) / 2, 152 (inputY - 1 - inputY + kernelY) / 2, 153 } 154 }