gorgonia.org/gorgonia@v0.9.17/example_linearregression_test.go (about) 1 package gorgonia_test 2 3 import ( 4 "fmt" 5 "log" 6 "math/rand" 7 "runtime" 8 9 . "gorgonia.org/gorgonia" 10 "gorgonia.org/tensor" 11 ) 12 13 const ( 14 vecSize = 1000000 15 ) 16 17 // manually generate a fake dataset which is y=2x+random 18 func xy(dt tensor.Dtype) (x tensor.Tensor, y tensor.Tensor) { 19 var xBack, yBack interface{} 20 switch dt { 21 case Float32: 22 xBack = tensor.Range(tensor.Float32, 1, vecSize+1).([]float32) 23 yBackC := tensor.Range(tensor.Float32, 1, vecSize+1).([]float32) 24 25 for i, v := range yBackC { 26 yBackC[i] = v*2 + rand.Float32() 27 } 28 yBack = yBackC 29 case Float64: 30 xBack = tensor.Range(tensor.Float64, 1, vecSize+1).([]float64) 31 yBackC := tensor.Range(tensor.Float64, 1, vecSize+1).([]float64) 32 33 for i, v := range yBackC { 34 yBackC[i] = v*2 + rand.Float64() 35 } 36 yBack = yBackC 37 } 38 39 x = tensor.New(tensor.WithBacking(xBack), tensor.WithShape(vecSize)) 40 y = tensor.New(tensor.WithBacking(yBack), tensor.WithShape(vecSize)) 41 return 42 } 43 44 func random(dt tensor.Dtype) interface{} { 45 rand.Seed(13370) 46 switch dt { 47 case tensor.Float32: 48 return rand.Float32() 49 case tensor.Float64: 50 return rand.Float64() 51 default: 52 panic("Unhandled dtype") 53 } 54 } 55 56 func linregSetup(Float tensor.Dtype) (m, c *Node, machine VM) { 57 var xT, yT Value 58 xT, yT = xy(Float) 59 60 g := NewGraph() 61 x := NewVector(g, Float, WithShape(vecSize), WithName("x"), WithValue(xT)) 62 y := NewVector(g, Float, WithShape(vecSize), WithName("y"), WithValue(yT)) 63 m = NewScalar(g, Float, WithName("m"), WithValue(random(Float))) 64 c = NewScalar(g, Float, WithName("c"), WithValue(random(Float))) 65 66 pred := Must(Add(Must(Mul(x, m)), c)) 67 se := Must(Square(Must(Sub(pred, y)))) 68 cost := Must(Mean(se)) 69 70 if _, err := Grad(cost, m, c); err != nil { 71 log.Fatalf("Failed to backpropagate: %v", err) 72 } 73 74 // machine := NewLispMachine(g) // you can use a LispMachine, but it'll be VERY slow. 75 machine = NewTapeMachine(g, BindDualValues(m, c)) 76 return m, c, machine 77 } 78 79 func linregRun(m, c *Node, machine VM, iter int, autoCleanup bool) (retM, retC Value) { 80 if autoCleanup { 81 defer machine.Close() 82 } 83 model := []ValueGrad{m, c} 84 solver := NewVanillaSolver(WithLearnRate(0.001), WithClip(5)) // good idea to clip 85 86 if CUDA { 87 runtime.LockOSThread() 88 defer runtime.UnlockOSThread() 89 } 90 var err error 91 for i := 0; i < iter; i++ { 92 if err = machine.RunAll(); err != nil { 93 fmt.Printf("Error during iteration: %v: %v\n", i, err) 94 break 95 } 96 97 if err = solver.Step(model); err != nil { 98 log.Fatal(err) 99 } 100 101 machine.Reset() // Reset is necessary in a loop like this 102 } 103 return m.Value(), c.Value() 104 105 } 106 107 func linearRegression(Float tensor.Dtype, iter int) (retM, retC Value) { 108 defer runtime.GC() 109 m, c, machine := linregSetup(Float) 110 return linregRun(m, c, machine, iter, true) 111 } 112 113 // Linear Regression Example 114 // 115 // The formula for a straight line is 116 // y = mx + c 117 // We want to find an `m` and a `c` that fits the equation well. We'll do it in both float32 and float64 to showcase the extensibility of Gorgonia 118 func Example_linearRegression() { 119 var m, c Value 120 // Float32 121 m, c = linearRegression(Float32, 500) 122 fmt.Printf("float32: y = %3.3fx + %3.3f\n", m, c) 123 124 // Float64 125 m, c = linearRegression(Float64, 500) 126 fmt.Printf("float64: y = %3.3fx + %3.3f\n", m, c) 127 128 // Output: 129 // float32: y = 2.001x + 2.001 130 // float64: y = 2.001x + 2.001 131 }