github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/optimize/listsearch_test.go (about)

     1  // Copyright ©2018 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package optimize
     6  
     7  import (
     8  	"testing"
     9  
    10  	"golang.org/x/exp/rand"
    11  
    12  	"github.com/jingcheng-WU/gonum/floats"
    13  	"github.com/jingcheng-WU/gonum/mat"
    14  	"github.com/jingcheng-WU/gonum/optimize/functions"
    15  )
    16  
    17  func TestListSearch(t *testing.T) {
    18  	t.Parallel()
    19  	rnd := rand.New(rand.NewSource(1))
    20  	for cas, test := range []struct {
    21  		r, c       int
    22  		shortEvals int
    23  		fun        func([]float64) float64
    24  	}{
    25  		{
    26  			r:   100,
    27  			c:   10,
    28  			fun: functions.ExtendedRosenbrock{}.Func,
    29  		},
    30  	} {
    31  		// Generate a random list of items.
    32  		r, c := test.r, test.c
    33  		locs := mat.NewDense(r, c, nil)
    34  		for i := 0; i < r; i++ {
    35  			for j := 0; j < c; j++ {
    36  				locs.Set(i, j, rnd.NormFloat64())
    37  			}
    38  		}
    39  
    40  		// Evaluate all of the items in the list and find the minimum value.
    41  		fs := make([]float64, r)
    42  		for i := 0; i < r; i++ {
    43  			fs[i] = test.fun(locs.RawRowView(i))
    44  		}
    45  		minIdx := floats.MinIdx(fs)
    46  
    47  		// Check that the global minimum is found under normal conditions.
    48  		p := Problem{Func: test.fun}
    49  		method := &ListSearch{
    50  			Locs: locs,
    51  		}
    52  		settings := &Settings{
    53  			Converger: NeverTerminate{},
    54  		}
    55  		initX := make([]float64, c)
    56  		result, err := Minimize(p, initX, settings, method)
    57  		if err != nil {
    58  			t.Errorf("cas %v: error optimizing: %s", cas, err)
    59  		}
    60  		if result.Status != MethodConverge {
    61  			t.Errorf("cas %v: status should be MethodConverge", cas)
    62  		}
    63  		if !floats.Equal(result.X, locs.RawRowView(minIdx)) {
    64  			t.Errorf("cas %v: did not find minimum of whole list", cas)
    65  		}
    66  
    67  		// Check that the optimization works concurrently.
    68  		concurrent := 6
    69  		settings.Concurrent = concurrent
    70  		result, err = Minimize(p, initX, settings, method)
    71  		if err != nil {
    72  			t.Errorf("cas %v: error optimizing: %s", cas, err)
    73  		}
    74  		if result.Status != MethodConverge {
    75  			t.Errorf("cas %v: status should be MethodConverge", cas)
    76  		}
    77  		if !floats.Equal(result.X, locs.RawRowView(minIdx)) {
    78  			t.Errorf("cas %v: did not find minimum of whole list concurrent", cas)
    79  		}
    80  
    81  		// Check that the optimization works concurrently with more than the number of samples.
    82  		settings.Concurrent = test.r + concurrent
    83  		result, err = Minimize(p, initX, settings, method)
    84  		if err != nil {
    85  			t.Errorf("cas %v: error optimizing: %s", cas, err)
    86  		}
    87  		if result.Status != MethodConverge {
    88  			t.Errorf("cas %v: status should be MethodConverge", cas)
    89  		}
    90  		if !floats.Equal(result.X, locs.RawRowView(minIdx)) {
    91  			t.Errorf("cas %v: did not find minimum of whole list concurrent", cas)
    92  		}
    93  
    94  		// Check that cleanup happens properly by setting the minimum location
    95  		// to the last sample.
    96  		swapSamples(locs, fs, minIdx, test.r-1)
    97  		minIdx = test.r - 1
    98  		settings.Concurrent = concurrent
    99  		result, err = Minimize(p, initX, settings, method)
   100  		if err != nil {
   101  			t.Errorf("cas %v: error optimizing: %s", cas, err)
   102  		}
   103  		if result.Status != MethodConverge {
   104  			t.Errorf("cas %v: status should be MethodConverge", cas)
   105  		}
   106  		if !floats.Equal(result.X, locs.RawRowView(minIdx)) {
   107  			t.Errorf("cas %v: did not find minimum of whole list last sample", cas)
   108  		}
   109  
   110  		// Test that the correct optimum is found when the optimization ends early.
   111  		// Note that the above test swapped the list minimum to the last sample,
   112  		// so it's guaranteed that the minimum of the shortened list is not the
   113  		// same as the minimum of the whole list.
   114  		evals := test.r / 3
   115  		minIdxFirst := floats.MinIdx(fs[:evals])
   116  		settings.Concurrent = 0
   117  		settings.FuncEvaluations = evals
   118  		result, err = Minimize(p, initX, settings, method)
   119  		if err != nil {
   120  			t.Errorf("cas %v: error optimizing: %s", cas, err)
   121  		}
   122  		if result.Status != FunctionEvaluationLimit {
   123  			t.Errorf("cas %v: status was not FunctionEvaluationLimit", cas)
   124  		}
   125  		if !floats.Equal(result.X, locs.RawRowView(minIdxFirst)) {
   126  			t.Errorf("cas %v: did not find minimum of shortened list serial", cas)
   127  		}
   128  
   129  		// Test the same but concurrently. We can't guarantee a specific number
   130  		// of function evaluations concurrently, so make sure that the list optimum
   131  		// is not between [evals:evals+concurrent]
   132  		for floats.MinIdx(fs[:evals]) != floats.MinIdx(fs[:evals+concurrent]) {
   133  			// Swap the minimum index with a random element.
   134  			minIdxFirst := floats.MinIdx(fs[:evals+concurrent])
   135  			new := rnd.Intn(evals)
   136  			swapSamples(locs, fs, minIdxFirst, new)
   137  		}
   138  
   139  		minIdxFirst = floats.MinIdx(fs[:evals])
   140  		settings.Concurrent = concurrent
   141  		result, err = Minimize(p, initX, settings, method)
   142  		if err != nil {
   143  			t.Errorf("cas %v: error optimizing: %s", cas, err)
   144  		}
   145  		if result.Status != FunctionEvaluationLimit {
   146  			t.Errorf("cas %v: status was not FunctionEvaluationLimit", cas)
   147  		}
   148  		if !floats.Equal(result.X, locs.RawRowView(minIdxFirst)) {
   149  			t.Errorf("cas %v: did not find minimum of shortened list concurrent", cas)
   150  		}
   151  	}
   152  }
   153  
   154  func swapSamples(m *mat.Dense, f []float64, i, j int) {
   155  	f[i], f[j] = f[j], f[i]
   156  	row := mat.Row(nil, i, m)
   157  	m.SetRow(i, m.RawRowView(j))
   158  	m.SetRow(j, row)
   159  }