gorgonia.org/gorgonia@v0.9.17/examples/286/main.go (about) 1 // 286 is a program to test issue 286 2 3 package main 4 5 import ( 6 "flag" 7 "log" 8 "math/rand" 9 10 _ "net/http/pprof" 11 12 "github.com/pkg/errors" 13 "gorgonia.org/gorgonia" 14 "gorgonia.org/tensor" 15 ) 16 17 var ( 18 epochs = flag.Int("epochs", 10, "Number of epochs to train for") 19 dataset = flag.String("dataset", "train", "Which dataset to train on? Valid options are \"train\" or \"test\"") 20 dtype = flag.String("dtype", "float64", "Which dtype to use") 21 batchsize = flag.Int("batchsize", 10, "Batch size") 22 cpuprofile = flag.String("cpuprofile", "", "CPU profiling") 23 ) 24 25 const loc = "./mnist/" 26 27 var dt tensor.Dtype 28 29 func parseDtype() { 30 switch *dtype { 31 case "float64": 32 dt = tensor.Float64 33 case "float32": 34 dt = tensor.Float32 35 default: 36 log.Fatalf("Unknown dtype: %v", *dtype) 37 } 38 } 39 40 type nn struct { 41 g *gorgonia.ExprGraph 42 w0, w1 *gorgonia.Node 43 44 out *gorgonia.Node 45 predVal gorgonia.Value 46 } 47 48 type sli struct { 49 start, end int 50 } 51 52 func (s sli) Start() int { return s.start } 53 func (s sli) End() int { return s.end } 54 func (s sli) Step() int { return 1 } 55 56 func newNN(g *gorgonia.ExprGraph) *nn { 57 // Create node for w/weight 58 w0 := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(784, 300), gorgonia.WithName("w0"), gorgonia.WithInit(gorgonia.GlorotN(1.0))) 59 w1 := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(300, 10), gorgonia.WithName("w1"), gorgonia.WithInit(gorgonia.GlorotN(1.0))) 60 return &nn{ 61 g: g, 62 w0: w0, 63 w1: w1, 64 } 65 } 66 67 func (m *nn) learnables() gorgonia.Nodes { 68 return gorgonia.Nodes{m.w0, m.w1} 69 } 70 71 func (m *nn) fwd(x *gorgonia.Node) (err error) { 72 var l0, l1 *gorgonia.Node 73 var l0dot *gorgonia.Node 74 75 // Set first layer to be copy of input 76 l0 = x 77 78 // Dot product of l0 and w0, use as input for ReLU 79 if l0dot, err = gorgonia.Mul(l0, m.w0); err != nil { 80 return errors.Wrap(err, "Unable to multiply l0 and w0") 81 } 82 83 // l0dot := gorgonia.Must(gorgonia.Mul(l0, m.w0)) 84 85 // Build hidden layer out of result 86 l1 = gorgonia.Must(gorgonia.Rectify(l0dot)) 87 88 var out *gorgonia.Node 89 if out, err = gorgonia.Mul(l1, m.w1); err != nil { 90 return errors.Wrapf(err, "Unable to multiply l1 and w1") 91 } 92 93 m.out, err = gorgonia.SoftMax(out) 94 gorgonia.Read(m.out, &m.predVal) 95 return 96 97 } 98 99 func main() { 100 flag.Parse() 101 parseDtype() 102 rand.Seed(7945) 103 104 var err error 105 106 bs := *batchsize 107 g := gorgonia.NewGraph() 108 x := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(bs, 784), gorgonia.WithName("x"), gorgonia.WithInit(gorgonia.GlorotN(1.0))) 109 y := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(bs, 10), gorgonia.WithName("y"), gorgonia.WithInit(gorgonia.GlorotN(1.0))) 110 111 m := newNN(g) 112 if err = m.fwd(x); err != nil { 113 log.Fatalf("%+v", err) 114 } 115 116 losses, err := gorgonia.HadamardProd(m.out, y) 117 if err != nil { 118 log.Fatal(err) 119 } 120 cost := gorgonia.Must(gorgonia.Mean(losses)) 121 cost = gorgonia.Must(gorgonia.Neg(cost)) 122 123 // we wanna track costs 124 var costVal gorgonia.Value 125 gorgonia.Read(cost, &costVal) 126 127 if _, err = gorgonia.Grad(cost, m.learnables()...); err != nil { 128 log.Fatal(err) 129 } 130 131 vm := gorgonia.NewTapeMachine(g, gorgonia.BindDualValues(m.learnables()...)) 132 solver := gorgonia.NewRMSPropSolver(gorgonia.WithBatchSize(float64(bs))) 133 defer vm.Close() 134 135 if err = vm.RunAll(); err != nil { 136 log.Fatalf("Failed %v", err) 137 } 138 139 solver.Step(gorgonia.NodesToValueGrads(m.learnables())) 140 vm.Reset() 141 }