github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/optimize/listsearch_test.go (about) 1 // Copyright ©2018 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package optimize 6 7 import ( 8 "testing" 9 10 "golang.org/x/exp/rand" 11 12 "github.com/jingcheng-WU/gonum/floats" 13 "github.com/jingcheng-WU/gonum/mat" 14 "github.com/jingcheng-WU/gonum/optimize/functions" 15 ) 16 17 func TestListSearch(t *testing.T) { 18 t.Parallel() 19 rnd := rand.New(rand.NewSource(1)) 20 for cas, test := range []struct { 21 r, c int 22 shortEvals int 23 fun func([]float64) float64 24 }{ 25 { 26 r: 100, 27 c: 10, 28 fun: functions.ExtendedRosenbrock{}.Func, 29 }, 30 } { 31 // Generate a random list of items. 32 r, c := test.r, test.c 33 locs := mat.NewDense(r, c, nil) 34 for i := 0; i < r; i++ { 35 for j := 0; j < c; j++ { 36 locs.Set(i, j, rnd.NormFloat64()) 37 } 38 } 39 40 // Evaluate all of the items in the list and find the minimum value. 41 fs := make([]float64, r) 42 for i := 0; i < r; i++ { 43 fs[i] = test.fun(locs.RawRowView(i)) 44 } 45 minIdx := floats.MinIdx(fs) 46 47 // Check that the global minimum is found under normal conditions. 48 p := Problem{Func: test.fun} 49 method := &ListSearch{ 50 Locs: locs, 51 } 52 settings := &Settings{ 53 Converger: NeverTerminate{}, 54 } 55 initX := make([]float64, c) 56 result, err := Minimize(p, initX, settings, method) 57 if err != nil { 58 t.Errorf("cas %v: error optimizing: %s", cas, err) 59 } 60 if result.Status != MethodConverge { 61 t.Errorf("cas %v: status should be MethodConverge", cas) 62 } 63 if !floats.Equal(result.X, locs.RawRowView(minIdx)) { 64 t.Errorf("cas %v: did not find minimum of whole list", cas) 65 } 66 67 // Check that the optimization works concurrently. 68 concurrent := 6 69 settings.Concurrent = concurrent 70 result, err = Minimize(p, initX, settings, method) 71 if err != nil { 72 t.Errorf("cas %v: error optimizing: %s", cas, err) 73 } 74 if result.Status != MethodConverge { 75 t.Errorf("cas %v: status should be MethodConverge", cas) 76 } 77 if !floats.Equal(result.X, locs.RawRowView(minIdx)) { 78 t.Errorf("cas %v: did not find minimum of whole list concurrent", cas) 79 } 80 81 // Check that the optimization works concurrently with more than the number of samples. 82 settings.Concurrent = test.r + concurrent 83 result, err = Minimize(p, initX, settings, method) 84 if err != nil { 85 t.Errorf("cas %v: error optimizing: %s", cas, err) 86 } 87 if result.Status != MethodConverge { 88 t.Errorf("cas %v: status should be MethodConverge", cas) 89 } 90 if !floats.Equal(result.X, locs.RawRowView(minIdx)) { 91 t.Errorf("cas %v: did not find minimum of whole list concurrent", cas) 92 } 93 94 // Check that cleanup happens properly by setting the minimum location 95 // to the last sample. 96 swapSamples(locs, fs, minIdx, test.r-1) 97 minIdx = test.r - 1 98 settings.Concurrent = concurrent 99 result, err = Minimize(p, initX, settings, method) 100 if err != nil { 101 t.Errorf("cas %v: error optimizing: %s", cas, err) 102 } 103 if result.Status != MethodConverge { 104 t.Errorf("cas %v: status should be MethodConverge", cas) 105 } 106 if !floats.Equal(result.X, locs.RawRowView(minIdx)) { 107 t.Errorf("cas %v: did not find minimum of whole list last sample", cas) 108 } 109 110 // Test that the correct optimum is found when the optimization ends early. 111 // Note that the above test swapped the list minimum to the last sample, 112 // so it's guaranteed that the minimum of the shortened list is not the 113 // same as the minimum of the whole list. 114 evals := test.r / 3 115 minIdxFirst := floats.MinIdx(fs[:evals]) 116 settings.Concurrent = 0 117 settings.FuncEvaluations = evals 118 result, err = Minimize(p, initX, settings, method) 119 if err != nil { 120 t.Errorf("cas %v: error optimizing: %s", cas, err) 121 } 122 if result.Status != FunctionEvaluationLimit { 123 t.Errorf("cas %v: status was not FunctionEvaluationLimit", cas) 124 } 125 if !floats.Equal(result.X, locs.RawRowView(minIdxFirst)) { 126 t.Errorf("cas %v: did not find minimum of shortened list serial", cas) 127 } 128 129 // Test the same but concurrently. We can't guarantee a specific number 130 // of function evaluations concurrently, so make sure that the list optimum 131 // is not between [evals:evals+concurrent] 132 for floats.MinIdx(fs[:evals]) != floats.MinIdx(fs[:evals+concurrent]) { 133 // Swap the minimum index with a random element. 134 minIdxFirst := floats.MinIdx(fs[:evals+concurrent]) 135 new := rnd.Intn(evals) 136 swapSamples(locs, fs, minIdxFirst, new) 137 } 138 139 minIdxFirst = floats.MinIdx(fs[:evals]) 140 settings.Concurrent = concurrent 141 result, err = Minimize(p, initX, settings, method) 142 if err != nil { 143 t.Errorf("cas %v: error optimizing: %s", cas, err) 144 } 145 if result.Status != FunctionEvaluationLimit { 146 t.Errorf("cas %v: status was not FunctionEvaluationLimit", cas) 147 } 148 if !floats.Equal(result.X, locs.RawRowView(minIdxFirst)) { 149 t.Errorf("cas %v: did not find minimum of shortened list concurrent", cas) 150 } 151 } 152 } 153 154 func swapSamples(m *mat.Dense, f []float64, i, j int) { 155 f[i], f[j] = f[j], f[i] 156 row := mat.Row(nil, i, m) 157 m.SetRow(i, m.RawRowView(j)) 158 m.SetRow(j, row) 159 }