gorgonia.org/gorgonia@v0.9.17/example_linearregression_test.go (about)

     1  package gorgonia_test
     2  
     3  import (
     4  	"fmt"
     5  	"log"
     6  	"math/rand"
     7  	"runtime"
     8  
     9  	. "gorgonia.org/gorgonia"
    10  	"gorgonia.org/tensor"
    11  )
    12  
    13  const (
    14  	vecSize = 1000000
    15  )
    16  
    17  // manually generate a fake dataset which is y=2x+random
    18  func xy(dt tensor.Dtype) (x tensor.Tensor, y tensor.Tensor) {
    19  	var xBack, yBack interface{}
    20  	switch dt {
    21  	case Float32:
    22  		xBack = tensor.Range(tensor.Float32, 1, vecSize+1).([]float32)
    23  		yBackC := tensor.Range(tensor.Float32, 1, vecSize+1).([]float32)
    24  
    25  		for i, v := range yBackC {
    26  			yBackC[i] = v*2 + rand.Float32()
    27  		}
    28  		yBack = yBackC
    29  	case Float64:
    30  		xBack = tensor.Range(tensor.Float64, 1, vecSize+1).([]float64)
    31  		yBackC := tensor.Range(tensor.Float64, 1, vecSize+1).([]float64)
    32  
    33  		for i, v := range yBackC {
    34  			yBackC[i] = v*2 + rand.Float64()
    35  		}
    36  		yBack = yBackC
    37  	}
    38  
    39  	x = tensor.New(tensor.WithBacking(xBack), tensor.WithShape(vecSize))
    40  	y = tensor.New(tensor.WithBacking(yBack), tensor.WithShape(vecSize))
    41  	return
    42  }
    43  
    44  func random(dt tensor.Dtype) interface{} {
    45  	rand.Seed(13370)
    46  	switch dt {
    47  	case tensor.Float32:
    48  		return rand.Float32()
    49  	case tensor.Float64:
    50  		return rand.Float64()
    51  	default:
    52  		panic("Unhandled dtype")
    53  	}
    54  }
    55  
    56  func linregSetup(Float tensor.Dtype) (m, c *Node, machine VM) {
    57  	var xT, yT Value
    58  	xT, yT = xy(Float)
    59  
    60  	g := NewGraph()
    61  	x := NewVector(g, Float, WithShape(vecSize), WithName("x"), WithValue(xT))
    62  	y := NewVector(g, Float, WithShape(vecSize), WithName("y"), WithValue(yT))
    63  	m = NewScalar(g, Float, WithName("m"), WithValue(random(Float)))
    64  	c = NewScalar(g, Float, WithName("c"), WithValue(random(Float)))
    65  
    66  	pred := Must(Add(Must(Mul(x, m)), c))
    67  	se := Must(Square(Must(Sub(pred, y))))
    68  	cost := Must(Mean(se))
    69  
    70  	if _, err := Grad(cost, m, c); err != nil {
    71  		log.Fatalf("Failed to backpropagate: %v", err)
    72  	}
    73  
    74  	// machine := NewLispMachine(g)  // you can use a LispMachine, but it'll be VERY slow.
    75  	machine = NewTapeMachine(g, BindDualValues(m, c))
    76  	return m, c, machine
    77  }
    78  
    79  func linregRun(m, c *Node, machine VM, iter int, autoCleanup bool) (retM, retC Value) {
    80  	if autoCleanup {
    81  		defer machine.Close()
    82  	}
    83  	model := []ValueGrad{m, c}
    84  	solver := NewVanillaSolver(WithLearnRate(0.001), WithClip(5)) // good idea to clip
    85  
    86  	if CUDA {
    87  		runtime.LockOSThread()
    88  		defer runtime.UnlockOSThread()
    89  	}
    90  	var err error
    91  	for i := 0; i < iter; i++ {
    92  		if err = machine.RunAll(); err != nil {
    93  			fmt.Printf("Error during iteration: %v: %v\n", i, err)
    94  			break
    95  		}
    96  
    97  		if err = solver.Step(model); err != nil {
    98  			log.Fatal(err)
    99  		}
   100  
   101  		machine.Reset() // Reset is necessary in a loop like this
   102  	}
   103  	return m.Value(), c.Value()
   104  
   105  }
   106  
   107  func linearRegression(Float tensor.Dtype, iter int) (retM, retC Value) {
   108  	defer runtime.GC()
   109  	m, c, machine := linregSetup(Float)
   110  	return linregRun(m, c, machine, iter, true)
   111  }
   112  
   113  // Linear Regression Example
   114  //
   115  // The formula for a straight line is
   116  //		y = mx + c
   117  // We want to find an `m` and a `c` that fits the equation well. We'll do it in both float32 and float64 to showcase the extensibility of Gorgonia
   118  func Example_linearRegression() {
   119  	var m, c Value
   120  	// Float32
   121  	m, c = linearRegression(Float32, 500)
   122  	fmt.Printf("float32: y = %3.3fx + %3.3f\n", m, c)
   123  
   124  	// Float64
   125  	m, c = linearRegression(Float64, 500)
   126  	fmt.Printf("float64: y = %3.3fx + %3.3f\n", m, c)
   127  
   128  	// Output:
   129  	// float32: y = 2.001x + 2.001
   130  	// float64: y = 2.001x + 2.001
   131  }