
     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     5  package optimize
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"reflect"
    11  	"testing"
    13  	""
    14  )
    16  func TestMoreThuente(t *testing.T) {
    17  	t.Parallel()
    18  	d := 0.001
    19  	c := 0.001
    20  	ls := &MoreThuente{
    21  		DecreaseFactor:  d,
    22  		CurvatureFactor: c,
    23  	}
    24  	testLinesearcher(t, ls, d, c, true)
    25  }
    27  func TestBisection(t *testing.T) {
    28  	t.Parallel()
    29  	c := 0.1
    30  	ls := &Bisection{
    31  		CurvatureFactor: c,
    32  	}
    33  	testLinesearcher(t, ls, 0, c, true)
    34  }
    36  func TestBacktracking(t *testing.T) {
    37  	t.Parallel()
    38  	d := 0.001
    39  	ls := &Backtracking{
    40  		DecreaseFactor: d,
    41  	}
    42  	testLinesearcher(t, ls, d, 0, false)
    43  }
    45  type funcGrader interface {
    46  	Func([]float64) float64
    47  	Grad([]float64, []float64)
    48  }
    50  type linesearcherTest struct {
    51  	name string
    52  	f    func(float64) float64
    53  	g    func(float64) float64
    54  }
    56  func newLinesearcherTest(name string, fg funcGrader) linesearcherTest {
    57  	grad := make([]float64, 1)
    58  	return linesearcherTest{
    59  		name: name,
    60  		f: func(x float64) float64 {
    61  			return fg.Func([]float64{x})
    62  		},
    63  		g: func(x float64) float64 {
    64  			fg.Grad(grad, []float64{x})
    65  			return grad[0]
    66  		},
    67  	}
    68  }
    70  func testLinesearcher(t *testing.T, ls Linesearcher, decrease, curvature float64, strongWolfe bool) {
    71  	for i, prob := range []linesearcherTest{
    72  		newLinesearcherTest("Concave-to-the-right function", functions.ConcaveRight{}),
    73  		newLinesearcherTest("Concave-to-the-left function", functions.ConcaveLeft{}),
    74  		newLinesearcherTest("Plassmann wiggly function (l=39, beta=0.01)", functions.Plassmann{L: 39, Beta: 0.01}),
    75  		newLinesearcherTest("Yanai-Ozawa-Kaneko function (beta1=0.001, beta2=0.001)", functions.YanaiOzawaKaneko{Beta1: 0.001, Beta2: 0.001}),
    76  		newLinesearcherTest("Yanai-Ozawa-Kaneko function (beta1=0.01, beta2=0.001)", functions.YanaiOzawaKaneko{Beta1: 0.01, Beta2: 0.001}),
    77  		newLinesearcherTest("Yanai-Ozawa-Kaneko function (beta1=0.001, beta2=0.01)", functions.YanaiOzawaKaneko{Beta1: 0.001, Beta2: 0.01}),
    78  	} {
    79  		for _, initStep := range []float64{0.001, 0.1, 1, 10, 1000} {
    80  			prefix := fmt.Sprintf("test %d (%v started from %v)", i,, initStep)
    82  			f0 := prob.f(0)
    83  			g0 := prob.g(0)
    84  			if g0 >= 0 {
    85  				panic("bad test function")
    86  			}
    88  			op := ls.Init(f0, g0, initStep)
    89  			if !op.isEvaluation() {
    90  				t.Errorf("%v: Linesearcher.Init returned non-evaluating operation %v", prefix, op)
    91  				continue
    92  			}
    94  			var (
    95  				err  error
    96  				k    int
    97  				f, g float64
    98  				step float64
    99  			)
   100  		loop:
   101  			for {
   102  				switch op {
   103  				case MajorIteration:
   104  					if f > f0+step*decrease*g0 {
   105  						t.Errorf("%v: %v found step %v that does not satisfy the sufficient decrease condition",
   106  							prefix, reflect.TypeOf(ls), step)
   107  					}
   108  					if strongWolfe && math.Abs(g) > curvature*(-g0) {
   109  						t.Errorf("%v: %v found step %v that does not satisfy the curvature condition",
   110  							prefix, reflect.TypeOf(ls), step)
   111  					}
   112  					break loop
   113  				case FuncEvaluation:
   114  					f = prob.f(step)
   115  				case GradEvaluation:
   116  					g = prob.g(step)
   117  				case FuncEvaluation | GradEvaluation:
   118  					f = prob.f(step)
   119  					g = prob.g(step)
   120  				default:
   121  					t.Errorf("%v: Linesearcher returned an invalid operation %v", prefix, op)
   122  					break loop
   123  				}
   125  				k++
   126  				if k == 1000 {
   127  					t.Errorf("%v: %v did not finish", prefix, reflect.TypeOf(ls))
   128  					break
   129  				}
   131  				op, step, err = ls.Iterate(f, g)
   132  				if err != nil {
   133  					t.Errorf("%v: %v failed at step %v with %v", prefix, reflect.TypeOf(ls), step, err)
   134  					break
   135  				}
   136  			}
   137  		}
   138  	}
   139  }