gorgonia.org/gorgonia@v0.9.17/solvers_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"log"
     5  	"math"
     6  	"runtime"
     7  	"testing"
     8  
     9  	"github.com/chewxy/math32"
    10  	"github.com/stretchr/testify/assert"
    11  	"gorgonia.org/dawson"
    12  	"gorgonia.org/tensor"
    13  )
    14  
    15  func clampFloat64(v, min, max float64) float64 {
    16  	if v < min {
    17  		return min
    18  	}
    19  	if v > max {
    20  		return max
    21  	}
    22  	return v
    23  }
    24  
    25  func clampFloat32(v, min, max float32) float32 {
    26  	if v < min {
    27  		return min
    28  	}
    29  	if v > max {
    30  		return max
    31  	}
    32  	return v
    33  }
    34  
    35  func tf64Node() []ValueGrad {
    36  	backingV := []float64{1, 2, 3, 4}
    37  	backingD := []float64{0.5, -10, 10, 0.5}
    38  	v := tensor.New(tensor.WithBacking(backingV), tensor.WithShape(2, 2))
    39  	d := tensor.New(tensor.WithBacking(backingD), tensor.WithShape(2, 2))
    40  
    41  	dv := dvUnit0(v)
    42  	dv.d = d
    43  
    44  	n := new(Node)
    45  	n.boundTo = dv
    46  
    47  	model := []ValueGrad{n}
    48  	return model
    49  }
    50  
    51  func tf32Node() []ValueGrad {
    52  	backingV := []float32{1, 2, 3, 4}
    53  	backingD := []float32{0.5, -10, 10, 0.5}
    54  
    55  	v := tensor.New(tensor.WithBacking(backingV), tensor.WithShape(2, 2))
    56  	d := tensor.New(tensor.WithBacking(backingD), tensor.WithShape(2, 2))
    57  
    58  	dv := dvUnit0(v)
    59  	dv.d = d
    60  
    61  	n := new(Node)
    62  	n.boundTo = dv
    63  
    64  	model := []ValueGrad{n}
    65  	return model
    66  }
    67  
    68  func manualRMSProp64(t *testing.T, s *RMSPropSolver, model []ValueGrad) {
    69  	assert := assert.New(t)
    70  	correct := make([]float64, 4)
    71  	cached := make([]float64, 4)
    72  
    73  	grad0, _ := model[0].Grad()
    74  	backingV := model[0].Value().Data().([]float64)
    75  	backingD := grad0.Data().([]float64)
    76  
    77  	for i := 0; i < 5; i++ {
    78  		for j, v := range backingV {
    79  			grad := backingD[j]
    80  			cw := cached[j]
    81  
    82  			decayed := cw*s.decay + (1.0-s.decay)*grad*grad
    83  			cached[j] = decayed
    84  
    85  			grad = clampFloat64(grad, -s.clip, s.clip)
    86  			upd := -s.eta*grad/math.Sqrt(decayed+s.eps) - s.l2reg*v
    87  			correct[j] = v + upd
    88  		}
    89  
    90  		err := s.Step(model)
    91  		if err != nil {
    92  			t.Error(err)
    93  		}
    94  
    95  		sCache := s.cache[0].Value.(tensor.Tensor)
    96  		assert.Equal(correct, backingV, "Iteration: %d", i)
    97  		assert.Equal(cached, sCache.Data(), "Iteration: %d", i)
    98  
    99  	}
   100  }
   101  
   102  func manualRMSProp32(t *testing.T, s *RMSPropSolver, model []ValueGrad) {
   103  	assert := assert.New(t)
   104  	correct := make([]float32, 4)
   105  	cached := make([]float32, 4)
   106  
   107  	grad0, _ := model[0].Grad()
   108  	backingV := model[0].Value().Data().([]float32)
   109  	backingD := grad0.Data().([]float32)
   110  
   111  	decay := float32(s.decay)
   112  	l2reg := float32(s.l2reg)
   113  	eta := float32(s.eta)
   114  	eps := float32(s.eps)
   115  	clip := float32(s.clip)
   116  
   117  	// NOTE: THIS IS NAUGHTY. A proper comparison using 1e-5  should be used but that causes errors.
   118  	closef32 := func(a, b float32) bool {
   119  		return dawson.ToleranceF32(a, b, 1e-4)
   120  	}
   121  
   122  	for i := 0; i < 5; i++ {
   123  		for j, v := range backingV {
   124  			grad := backingD[j]
   125  			cw := cached[j]
   126  
   127  			decayed := cw*decay + (1.0-decay)*grad*grad
   128  			cached[j] = decayed
   129  
   130  			grad = clampFloat32(grad, -clip, clip)
   131  			upd := -eta*grad/math32.Sqrt(decayed+eps) - l2reg*v
   132  			correct[j] = v + upd
   133  		}
   134  
   135  		err := s.Step(model)
   136  		if err != nil {
   137  			t.Error(err)
   138  		}
   139  
   140  		sCache := s.cache[0].Value.(tensor.Tensor)
   141  		assert.True(dawson.AllClose(correct, backingV, closef32))
   142  		assert.True(dawson.AllClose(cached, sCache.Data().([]float32), closef32))
   143  	}
   144  }
   145  
   146  func TestRMSPropSolverManual(t *testing.T) {
   147  
   148  	stepSize := 0.01
   149  	l2Reg := 0.000001
   150  	clip := 5.0
   151  
   152  	var s *RMSPropSolver
   153  	var model []ValueGrad
   154  
   155  	s = NewRMSPropSolver(WithLearnRate(stepSize), WithL2Reg(l2Reg), WithClip(clip))
   156  	model = tf64Node()
   157  	manualRMSProp64(t, s, model)
   158  
   159  	s = NewRMSPropSolver(WithLearnRate(stepSize), WithL2Reg(l2Reg), WithClip(clip))
   160  	model = tf32Node()
   161  	manualRMSProp32(t, s, model)
   162  
   163  }
   164  
   165  func TestRMSPropSolver(t *testing.T) {
   166  	assert := assert.New(t)
   167  
   168  	z, cost, m, err := model2dRosenbrock(1, 100, -0.5, 0.5)
   169  	defer m.Close()
   170  	const costThreshold = 0.68
   171  	if nil != err {
   172  		t.Fatal(err)
   173  	}
   174  
   175  	solver := NewRMSPropSolver()
   176  
   177  	maxIterations := 1000
   178  
   179  	costFloat := 42.0
   180  	for 0 != maxIterations {
   181  		m.Reset()
   182  		err = m.RunAll()
   183  		if nil != err {
   184  			t.Fatal(err)
   185  		}
   186  
   187  		costFloat = cost.Value().Data().(float64)
   188  		if costThreshold > math.Abs(costFloat) {
   189  			break
   190  		}
   191  
   192  		err = solver.Step([]ValueGrad{z})
   193  		if nil != err {
   194  			t.Fatal(err)
   195  		}
   196  
   197  		maxIterations--
   198  	}
   199  
   200  	assert.InDelta(0, costFloat, costThreshold)
   201  }
   202  
   203  func TestAdaGradSolver(t *testing.T) {
   204  	assert := assert.New(t)
   205  
   206  	z, cost, m, err := model2dSquare(-0.5, 0.5)
   207  	defer m.Close()
   208  	const costThreshold = 0.39
   209  	if nil != err {
   210  		t.Fatal(err)
   211  	}
   212  
   213  	solver := NewAdaGradSolver()
   214  
   215  	maxIterations := 1000
   216  
   217  	costFloat := 42.0
   218  	for 0 != maxIterations {
   219  		m.Reset()
   220  		err = m.RunAll()
   221  		if nil != err {
   222  			t.Fatal(err)
   223  		}
   224  
   225  		costFloat = cost.Value().Data().(float64)
   226  		if costThreshold > math.Abs(costFloat) {
   227  			break
   228  		}
   229  
   230  		err = solver.Step([]ValueGrad{z})
   231  		if nil != err {
   232  			t.Fatal(err)
   233  		}
   234  
   235  		maxIterations--
   236  	}
   237  
   238  	assert.InDelta(0, costFloat, costThreshold)
   239  }
   240  
   241  func TestVanillaSolver(t *testing.T) {
   242  	assert := assert.New(t)
   243  
   244  	z, cost, m, err := model2dRosenbrock(1, 100, -0.5, 0.5)
   245  	defer m.Close()
   246  	const costThreshold = 0.185
   247  	if nil != err {
   248  		t.Fatal(err)
   249  	}
   250  
   251  	solver := NewVanillaSolver()
   252  
   253  	maxIterations := 1000
   254  
   255  	costFloat := 42.0
   256  	for 0 != maxIterations {
   257  		m.Reset()
   258  		err = m.RunAll()
   259  		if nil != err {
   260  			t.Fatal(err)
   261  		}
   262  
   263  		costFloat = cost.Value().Data().(float64)
   264  		if costThreshold > math.Abs(costFloat) {
   265  			break
   266  		}
   267  
   268  		err = solver.Step([]ValueGrad{z})
   269  		if nil != err {
   270  			t.Fatal(err)
   271  		}
   272  
   273  		maxIterations--
   274  	}
   275  
   276  	assert.InDelta(0, costFloat, costThreshold)
   277  }
   278  
   279  func TestMomentum(t *testing.T) {
   280  	assert := assert.New(t)
   281  
   282  	z, cost, m, err := model2dRosenbrock(1, 100, -0.5, 0.5)
   283  	defer m.Close()
   284  	const costThreshold = 0.39
   285  	if nil != err {
   286  		t.Fatal(err)
   287  	}
   288  
   289  	solver := NewMomentum()
   290  
   291  	maxIterations := 1000
   292  
   293  	costFloat := 42.0
   294  	for 0 != maxIterations {
   295  		m.Reset()
   296  		err = m.RunAll()
   297  		if nil != err {
   298  			t.Fatal(err)
   299  		}
   300  
   301  		costFloat = cost.Value().Data().(float64)
   302  		if costThreshold > math.Abs(costFloat) {
   303  			break
   304  		}
   305  
   306  		err = solver.Step([]ValueGrad{z})
   307  		if nil != err {
   308  			t.Fatal(err)
   309  		}
   310  
   311  		maxIterations--
   312  	}
   313  
   314  	assert.InDelta(0, costFloat, costThreshold)
   315  }
   316  
   317  func TestAdamSolver(t *testing.T) {
   318  	assert := assert.New(t)
   319  
   320  	z, cost, m, err := model2dRosenbrock(1, 100, -0.5, 0.5)
   321  	defer m.Close()
   322  	const costThreshold = 0.113
   323  	if nil != err {
   324  		t.Fatal(err)
   325  	}
   326  
   327  	solver := NewAdamSolver()
   328  
   329  	maxIterations := 1000
   330  
   331  	costFloat := 42.0
   332  	for 0 != maxIterations {
   333  		m.Reset()
   334  		err = m.RunAll()
   335  		if nil != err {
   336  			t.Fatal(err)
   337  		}
   338  
   339  		costFloat = cost.Value().Data().(float64)
   340  		if costThreshold > math.Abs(costFloat) {
   341  			break
   342  		}
   343  
   344  		err = solver.Step([]ValueGrad{z})
   345  		if nil != err {
   346  			t.Fatal(err)
   347  		}
   348  
   349  		maxIterations--
   350  	}
   351  
   352  	assert.InDelta(0, costFloat, costThreshold)
   353  }
   354  
   355  func TestBarzilaiBorweinSolver(t *testing.T) {
   356  	assert := assert.New(t)
   357  
   358  	z, cost, m, err := model2dRosenbrock(1, 100, -0.5, 0.5)
   359  	defer m.Close()
   360  	const costThreshold = 0.00002
   361  	if nil != err {
   362  		t.Fatal(err)
   363  	}
   364  
   365  	solver := NewBarzilaiBorweinSolver(WithLearnRate(0.0001))
   366  	iterations := 0
   367  	costFloat := 42.0
   368  
   369  	// NOTE: due to precision issues with floating-point arithmetic,
   370  	// amd64 reaches the minimum expected cost at iteration #198
   371  	// arm64 reaches the minimum expected cost at iteration #210
   372  	// In some other cases arm converges faster than amd
   373  	// See https://github.com/golang/go/issues/18354#issuecomment-267705645
   374  
   375  	for iterations < 250 {
   376  		m.Reset()
   377  		err = m.RunAll()
   378  		if nil != err {
   379  			t.Fatal(err)
   380  		}
   381  
   382  		costFloat = cost.Value().Data().(float64)
   383  		if costThreshold > math.Abs(costFloat) {
   384  			break
   385  		}
   386  
   387  		err = solver.Step([]ValueGrad{z})
   388  		if nil != err {
   389  			t.Fatal(err)
   390  		}
   391  
   392  		iterations++
   393  	}
   394  
   395  	log.Printf("Found minimum cost at iteration %d. arch=%s", iterations, runtime.GOARCH)
   396  
   397  	assert.InDelta(0, costFloat, costThreshold)
   398  }
   399  
   400  // The Rosenbrock function is a non-convex function,
   401  // which is used as a performance test problem for optimization algorithms.
   402  // https://en.wikipedia.org/wiki/Rosenbrock_function
   403  //
   404  // f(x,y) = (a-x)² + b(y-x²)²
   405  // It has a global minimum at (x, y) = (a, a²), where f(x,y) = 0.
   406  // Usually a = 1, b = 100, then the minimum is at x = y = 1
   407  // TODO: There is also an n-dimensional version...see wiki
   408  func model2dRosenbrock(a, b, xInit, yInit float64) (z, cost *Node, machine *tapeMachine, err error) {
   409  	g := NewGraph()
   410  
   411  	z = NewTensor(g, Float64, 1, WithShape(2), WithName("z"))
   412  
   413  	aN := NewConstant(a, WithName("a"))
   414  	bN := NewConstant(b, WithName("b"))
   415  
   416  	xProjFloat := []float64{1, 0}
   417  	xProj := NewConstant(tensor.New(tensor.WithBacking(xProjFloat), tensor.WithShape(2)))
   418  
   419  	yProjFloat := []float64{0, 1}
   420  	yProj := NewConstant(tensor.New(tensor.WithBacking(yProjFloat), tensor.WithShape(2)))
   421  
   422  	x := Must(Mul(z, xProj))
   423  	y := Must(Mul(z, yProj))
   424  
   425  	// First term
   426  
   427  	sqrt1stTerm := Must(Sub(aN, x))
   428  
   429  	firstTerm := Must(Square(sqrt1stTerm))
   430  
   431  	// Second term
   432  
   433  	xSquared := Must(Square(x))
   434  
   435  	yMinusxSquared := Must(Sub(y, xSquared))
   436  
   437  	yMinusxSquaredSqu := Must(Square(yMinusxSquared))
   438  
   439  	secondTerm := Must(Mul(bN, yMinusxSquaredSqu))
   440  
   441  	// cost
   442  	cost = Must(Add(firstTerm, secondTerm))
   443  
   444  	dcost, err := Grad(cost, z)
   445  	if nil != err {
   446  		return nil, nil, nil, err
   447  	}
   448  
   449  	prog, locMap, err := CompileFunction(g, Nodes{z}, Nodes{cost, dcost[0]})
   450  	if nil != err {
   451  		return nil, nil, nil, err
   452  	}
   453  
   454  	machine = NewTapeMachine(g, WithPrecompiled(prog, locMap), BindDualValues(z))
   455  
   456  	err = machine.Let(z, tensor.New(tensor.WithBacking([]float64{xInit, yInit}), tensor.WithShape(2)))
   457  	if nil != err {
   458  		return nil, nil, nil, err
   459  	}
   460  
   461  	return
   462  }
   463  
   464  func model2dSquare(xInit, yInit float64) (z, cost *Node, machine *tapeMachine, err error) {
   465  	g := NewGraph()
   466  
   467  	z = NewTensor(g, Float64, 1, WithShape(2), WithName("z"))
   468  
   469  	// cost
   470  	cost = Must(Mul(z, z))
   471  
   472  	dcost, err := Grad(cost, z)
   473  	if nil != err {
   474  		return nil, nil, nil, err
   475  	}
   476  
   477  	prog, locMap, err := CompileFunction(g, Nodes{z}, Nodes{cost, dcost[0]})
   478  	if nil != err {
   479  		return nil, nil, nil, err
   480  	}
   481  
   482  	machine = NewTapeMachine(g, WithPrecompiled(prog, locMap), BindDualValues(z))
   483  
   484  	err = machine.Let(z, tensor.New(tensor.WithBacking([]float64{xInit, yInit}), tensor.WithShape(2)))
   485  	if nil != err {
   486  		return nil, nil, nil, err
   487  	}
   488  
   489  	return
   490  }