gorgonia.org/gorgonia@v0.9.17/examples/286/main.go (about)

     1  // 286 is a program to test issue 286
     2  
     3  package main
     4  
     5  import (
     6  	"flag"
     7  	"log"
     8  	"math/rand"
     9  
    10  	_ "net/http/pprof"
    11  
    12  	"github.com/pkg/errors"
    13  	"gorgonia.org/gorgonia"
    14  	"gorgonia.org/tensor"
    15  )
    16  
    17  var (
    18  	epochs     = flag.Int("epochs", 10, "Number of epochs to train for")
    19  	dataset    = flag.String("dataset", "train", "Which dataset to train on? Valid options are \"train\" or \"test\"")
    20  	dtype      = flag.String("dtype", "float64", "Which dtype to use")
    21  	batchsize  = flag.Int("batchsize", 10, "Batch size")
    22  	cpuprofile = flag.String("cpuprofile", "", "CPU profiling")
    23  )
    24  
    25  const loc = "./mnist/"
    26  
    27  var dt tensor.Dtype
    28  
    29  func parseDtype() {
    30  	switch *dtype {
    31  	case "float64":
    32  		dt = tensor.Float64
    33  	case "float32":
    34  		dt = tensor.Float32
    35  	default:
    36  		log.Fatalf("Unknown dtype: %v", *dtype)
    37  	}
    38  }
    39  
    40  type nn struct {
    41  	g      *gorgonia.ExprGraph
    42  	w0, w1 *gorgonia.Node
    43  
    44  	out     *gorgonia.Node
    45  	predVal gorgonia.Value
    46  }
    47  
    48  type sli struct {
    49  	start, end int
    50  }
    51  
    52  func (s sli) Start() int { return s.start }
    53  func (s sli) End() int   { return s.end }
    54  func (s sli) Step() int  { return 1 }
    55  
    56  func newNN(g *gorgonia.ExprGraph) *nn {
    57  	// Create node for w/weight
    58  	w0 := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(784, 300), gorgonia.WithName("w0"), gorgonia.WithInit(gorgonia.GlorotN(1.0)))
    59  	w1 := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(300, 10), gorgonia.WithName("w1"), gorgonia.WithInit(gorgonia.GlorotN(1.0)))
    60  	return &nn{
    61  		g:  g,
    62  		w0: w0,
    63  		w1: w1,
    64  	}
    65  }
    66  
    67  func (m *nn) learnables() gorgonia.Nodes {
    68  	return gorgonia.Nodes{m.w0, m.w1}
    69  }
    70  
    71  func (m *nn) fwd(x *gorgonia.Node) (err error) {
    72  	var l0, l1 *gorgonia.Node
    73  	var l0dot *gorgonia.Node
    74  
    75  	// Set first layer to be copy of input
    76  	l0 = x
    77  
    78  	// Dot product of l0 and w0, use as input for ReLU
    79  	if l0dot, err = gorgonia.Mul(l0, m.w0); err != nil {
    80  		return errors.Wrap(err, "Unable to multiply l0 and w0")
    81  	}
    82  
    83  	// l0dot := gorgonia.Must(gorgonia.Mul(l0, m.w0))
    84  
    85  	// Build hidden layer out of result
    86  	l1 = gorgonia.Must(gorgonia.Rectify(l0dot))
    87  
    88  	var out *gorgonia.Node
    89  	if out, err = gorgonia.Mul(l1, m.w1); err != nil {
    90  		return errors.Wrapf(err, "Unable to multiply l1 and w1")
    91  	}
    92  
    93  	m.out, err = gorgonia.SoftMax(out)
    94  	gorgonia.Read(m.out, &m.predVal)
    95  	return
    96  
    97  }
    98  
    99  func main() {
   100  	flag.Parse()
   101  	parseDtype()
   102  	rand.Seed(7945)
   103  
   104  	var err error
   105  
   106  	bs := *batchsize
   107  	g := gorgonia.NewGraph()
   108  	x := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(bs, 784), gorgonia.WithName("x"), gorgonia.WithInit(gorgonia.GlorotN(1.0)))
   109  	y := gorgonia.NewMatrix(g, dt, gorgonia.WithShape(bs, 10), gorgonia.WithName("y"), gorgonia.WithInit(gorgonia.GlorotN(1.0)))
   110  
   111  	m := newNN(g)
   112  	if err = m.fwd(x); err != nil {
   113  		log.Fatalf("%+v", err)
   114  	}
   115  
   116  	losses, err := gorgonia.HadamardProd(m.out, y)
   117  	if err != nil {
   118  		log.Fatal(err)
   119  	}
   120  	cost := gorgonia.Must(gorgonia.Mean(losses))
   121  	cost = gorgonia.Must(gorgonia.Neg(cost))
   122  
   123  	// we wanna track costs
   124  	var costVal gorgonia.Value
   125  	gorgonia.Read(cost, &costVal)
   126  
   127  	if _, err = gorgonia.Grad(cost, m.learnables()...); err != nil {
   128  		log.Fatal(err)
   129  	}
   130  
   131  	vm := gorgonia.NewTapeMachine(g, gorgonia.BindDualValues(m.learnables()...))
   132  	solver := gorgonia.NewRMSPropSolver(gorgonia.WithBatchSize(float64(bs)))
   133  	defer vm.Close()
   134  
   135  	if err = vm.RunAll(); err != nil {
   136  		log.Fatalf("Failed %v", err)
   137  	}
   138  
   139  	solver.Step(gorgonia.NodesToValueGrads(m.learnables()))
   140  	vm.Reset()
   141  }